In [1]:
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

In [2]:
import matplotlib.pyplot as plt
import numpy as np
from monai.transforms import LoadImage, EnsureChannelFirst, Orientation, Compose, SaveImage, Transform
from monai.bundle import ConfigParser, download
from monai.data.meta_tensor import MetaTensor
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD
import seaborn as sns
from ipywidgets import interact
from nibabel import Nifti1Image
import nibabel as nib
from totalsegmentator.python_api import totalsegmentator

from monai.transforms import Resize





In [3]:
print(torch.backends.mps.is_available())  # Should return True
print(torch.backends.mps.is_built())  # Should return True


True
True


In [4]:
%matplotlib inline

def show_scrollable_image(image, cmap="nipy_spectral"):
    """
    Displays a 3D array as a scrollable series of slices.
    
    Parameters:
    - image (numpy.ndarray): A 3D numpy array (e.g., (H, W, D)) representing the image.
    """
    
    # Define a function to display each slice
    def display_slice(slice_index):
        plt.figure(figsize=(8, 8))
        plt.imshow(image[:, :, slice_index], cmap=cmap)
        plt.colorbar(label='HU')
        plt.axis("off")
        plt.title(f"Slice {slice_index + 1} / {image.shape[2]}")
        plt.show()
    
    # Create an interactive slider to scroll through slices
    interact(display_slice, slice_index=(0, image.shape[2] - 1))

In [5]:
def nifti1image_to_metatensor(nifti_img: Nifti1Image) -> MetaTensor:
    """
    Converts a Nifti1Image object to a MetaTensor with metadata.
    Args:
        nifti_img (Nifti1Image): A loaded NIfTI image.
    Returns:
        MetaTensor: A MONAI MetaTensor containing the image data and metadata.
    """
    # Convert voxel data to PyTorch tensor
    image_tensor = torch.tensor(nifti_img.get_fdata(), dtype=torch.float32)

    # Extract metadata
    metadata = {
        "affine": torch.tensor(nifti_img.affine, dtype=torch.float32),  # Affine transformation matrix
        "spacing": nifti_img.header.get_zooms(),  # Voxel spacing (x, y, z)
        "original_shape": image_tensor.shape  # Store original shape for reference
    }

    # Create a MetaTensor with metadata
    meta_tensor = MetaTensor(image_tensor, meta=metadata)

    return meta_tensor


In [12]:
# input_file_path = "../data/Inputs/normal_cases/AG 519880_37F"
input_file_path = "../data/Inputs/takotsubo_cases/AG 11370442"
dicom_file_path = os.path.join(input_file_path, "DICOM")
subfolders = sorted([f for f in os.listdir(dicom_file_path) if os.path.isdir(os.path.join(dicom_file_path, f))])

if not subfolders:
    raise FileNotFoundError(f"No subfolders found in {dicom_file_path}")

dicom_file_path = os.path.join(dicom_file_path, subfolders[0])
output_folder = input_file_path.replace("Inputs", "Outputs")

output_file = f"{output_folder}/heart_resized.nii.gz"

image_loader = LoadImage(image_only=True)
original_input_image = image_loader(dicom_file_path)

original_input_image.meta

{'00200037': {'vr': 'DS', 'Value': [1.0, 0.0, 0.0, 0.0, 1.0, 0.0]},
 '00200032': {'vr': 'DS', 'Value': [-156.223236, -166.379486, 21.399994]},
 '00280030': {'vr': 'DS', 'Value': [0.646484, 0.646484, 1.0]},
 'spacing': array([0.646484, 0.646484, 1.      ]),
 'lastImagePositionPatient': array([-156.223236, -166.379486, -246.600006]),
 spatial_shape: (512, 512, 269),
 space: RAS,
 original_affine: array([[ -0.646484,   0.      ,   0.      , 156.223236],
        [  0.      ,  -0.646484,   0.      , 166.379486],
        [  0.      ,   0.      ,  -1.      ,  21.399994],
        [  0.      ,   0.      ,   0.      ,   1.      ]]),
 affine: tensor([[ -0.6465,   0.0000,   0.0000, 156.2232],
         [  0.0000,  -0.6465,   0.0000, 166.3795],
         [  0.0000,   0.0000,  -1.0000,  21.4000],
         [  0.0000,   0.0000,   0.0000,   1.0000]], dtype=torch.float64),
 original_channel_dim: nan,
 'filename_or_obj': '../data/Inputs/takotsubo_cases/AG 11370442/DICOM/1 A STD'}

In [None]:
show_scrollable_image(original_input_image, cmap="grey")

interactive(children=(IntSlider(value=134, description='slice_index', max=268), Output()), _dom_classes=('widg…

In [8]:
# option 1: provide input and output as file paths
output_image = totalsegmentator(dicom_file_path, output_folder, license_number="aca_BWYHC6UQQFDU8A", task="heartchambers_highres", body_seg=True, device="mps")
# output_image = totalsegmentator(dicom_file_path, output_folder, license_number="aca_BWYHC6UQQFDU8A", roi_subset=["heart"], device="mps")


# option 2: provide input and output as nifti image objects
# input_img = nib.load(dicom_file_path)
# output_img = totalsegmentator(input_img)
# print(output_img)
# nib.save(output_img, data_directory)


If you use this tool please cite: https://pubs.rsna.org/doi/10.1148/ryai.230024

Generating rough segmentation for cropping...
Converting dicom to nifti...
  found image with shape (512, 512, 515)
Resampling...
  Resampled in 4.70s
Predicting...


100%|██████████| 1/1 [00:00<00:00,  2.08it/s]


  Predicted in 6.98s
Resampling...
Converting dicom to nifti...
  found image with shape (512, 512, 515)
  cropping from (512, 512, 515) to (252, 211, 248)
Predicting...


100%|██████████| 18/18 [00:35<00:00,  1.95s/it]


  Predicted in 52.07s
Saving segmentations...




Creating heart_myocardium.nii.gz
Creating heart_atrium_left.nii.gz
Creating heart_ventricle_left.nii.gz
Creating heart_ventricle_right.nii.gz
Creating heart_atrium_right.nii.gz
Creating aorta.nii.gz
Creating pulmonary_artery.nii.gz
  Saved in 14.96s


In [None]:
%matplotlib inline

show_scrollable_image(nifti1image_to_metatensor(output_image))

interactive(children=(IntSlider(value=257, description='slice_index', max=514), Output()), _dom_classes=('widg…

In [32]:
# Load the image
left_ventricle_img = image_loader(f"{output_folder}/heart_ventricle_left.nii.gz")

# Check the image type and shape
print(type(left_ventricle_img))
print(left_ventricle_img.shape)

show_scrollable_image(left_ventricle_img)

<class 'monai.data.meta_tensor.MetaTensor'>
torch.Size([512, 512, 515])


interactive(children=(IntSlider(value=257, description='slice_index', max=514), Output()), _dom_classes=('widg…

In [33]:
def resize_to_target_slices(image: MetaTensor, heart_mask: MetaTensor, target_width=512, target_height=512, target_slices=128):
    """
    Extracts the heart region from the input scan and resizes it to a target number of slices, 
    preserving metadata while correctly computing voxel spacing.

    Args:
        image (MetaTensor): The input 3D CT/MRI scan.
        heart_mask (MetaTensor): Binary mask indicating heart region.
        target_slices (int, optional): Number of slices to resize to. Default is 128.

    Returns:
        MetaTensor: Resized image with updated metadata.
    """
    if isinstance(heart_mask, MetaTensor):
        heart_mask = heart_mask.as_tensor().cpu().numpy()
    
    z_indices = np.any(heart_mask > 0, axis=(0, 1))
    heart_slices = np.where(z_indices)[0]
    # print(heart_slices)
    if len(heart_slices) == 0:
        raise ValueError("No heart region found in the mask.")
    
    start_slice = heart_slices[0]
    end_slice = heart_slices[-1] + 1
    
    sliced_image = image[:, :, start_slice:end_slice]
    if not isinstance(sliced_image, MetaTensor):
        sliced_image = MetaTensor(sliced_image, meta=image.meta)
    
    height, width, original_depth = image.shape
    # sliced_depth = sliced_image.shape[-1]

    sliced_image = sliced_image.unsqueeze(0)  # Shape becomes (1, H, W, Z)
    
    resizer = Resize(spatial_size=(target_width,target_height,target_slices), mode="trilinear", align_corners=True)
    resized_image = resizer(sliced_image)
    
    resized_image = resized_image.squeeze(0)

    # Update metadata
    new_meta = image.meta.copy()  
    
    original_spacing = image.meta.get("spacing", (1.0, 1.0, 1.0))
    new_spacing = (
        original_spacing[0],
        original_spacing[1],
        original_spacing[2] * (original_depth / target_slices)
    )
    new_meta["spacing"] = new_spacing

    # Keep affine transformation matrix but adjust slice resolution
    if "affine" in new_meta:
        new_meta["affine"][-1, -1] *= (original_depth / target_slices)  # Adjust Z scaling

    # Return resized MetaTensor with updated metadata
    return MetaTensor(resized_image, meta=new_meta)

    



In [34]:
full_image = nifti1image_to_metatensor(output_image)
resized_image = resize_to_target_slices(full_image, full_image)

print("Resized Shape:", resized_image.shape)  # (H, W, 128)
print("Updated Voxel Spacing:", resized_image.meta["spacing"])
print("Updated Affine Matrix:\n", resized_image.meta["affine"])

Resized Shape: torch.Size([512, 512, 128])
Updated Voxel Spacing: (0.738281, 0.738281, 2.5146484375)
Updated Affine Matrix:
 tensor([[  -0.7383,    0.0000,    0.0000,  182.2000],
        [  -0.0000,    0.7383,    0.0000, -193.0616],
        [   0.0000,   -0.0000,    0.6250, -262.7000],
        [   0.0000,    0.0000,    0.0000,    4.0234]], dtype=torch.float64)


In [35]:
print(resized_image.shape)

torch.Size([512, 512, 128])


In [36]:
%matplotlib inline

show_scrollable_image(resized_image, cmap="grey")

interactive(children=(IntSlider(value=63, description='slice_index', max=127), Output()), _dom_classes=('widge…

In [37]:
def custom_name_formatter(metadict: dict, saver: Transform) -> dict:
    """Returns a kwargs dict for :py:meth:`FolderLayout.filename`,
    according to the input metadata and SaveImage transform."""
    subject = "heart_resized"
    patch_index = None
    return {"subject": f"{subject}", "idx": patch_index}

In [38]:
output_image_meta_tensor = nifti1image_to_metatensor(output_image)
resized_image = resize_to_target_slices(output_image_meta_tensor, output_image_meta_tensor,target_height=256,target_width=256,target_slices=64)

print("saving")
image_saver = SaveImage(output_dir=f"{output_folder}", separate_folder=False, output_postfix="", output_name_formatter=custom_name_formatter)
image_saver(resized_image)
print("saved")

saving
2025-02-26 15:55:42,899 INFO image_writer.py:197 - writing: ../data/Outputs/normal_cases/AG 519880_37F/heart_resized.nii.gz
saved


In [39]:
def align_segmentation_to_image(segmentation):
    """
    Aligns the TotalSegmentator output (segmentation) with the original image
    by flipping the z-axis of the segmentation mask.
    
    Parameters:
        image (torch.Tensor or np.ndarray): Original image (H x W x D).
        segmentation (torch.Tensor or np.ndarray): TotalSegmentator output (H x W x D).
        
    Returns:
        torch.Tensor: Segmentation aligned with the image.
    """
   
    # Flip along the z-axis (last dimension)
    aligned_segmentation = torch.flip(segmentation, dims=[-1])
    
    return aligned_segmentation

In [40]:
aligned_mask = align_segmentation_to_image(full_image)

In [41]:
%matplotlib inline

show_scrollable_image(aligned_mask)

interactive(children=(IntSlider(value=257, description='slice_index', max=514), Output()), _dom_classes=('widg…

In [42]:
original_image_resized = resize_to_target_slices(original_input_image, aligned_mask)

In [43]:
image_loader = LoadImage(image_only=True, ensure_channel_first=True)
loaded_resized_image = image_loader(f"{output_folder}/heart_resized.nii.gz")
loaded_resized_image = resized_image.squeeze(dim=-1)
print(loaded_resized_image.shape)
show_scrollable_image(loaded_resized_image)


torch.Size([256, 256, 64])


interactive(children=(IntSlider(value=31, description='slice_index', max=63), Output()), _dom_classes=('widget…