Import:

In [108]:
import glob
import os
from monai.transforms import (
    Compose, EnsureChannelFirstd, EnsureTyped, LoadImaged,
    RandSpatialCropd, RandRotate90d, ResizeWithPadOrCropd)
import torch
import nibabel as nib
import numpy as np
import demo
from matplotlib import cm
from vedo import load, Plotter

In [109]:
total = {
        1: "spleen",
        2: "kidney_right",
        3: "kidney_left",
        4: "gallbladder",
        5: "liver",
        6: "stomach",
        7: "pancreas",
        8: "adrenal_gland_right",
        9: "adrenal_gland_left",
        10: "lung_upper_lobe_left",
        11: "lung_lower_lobe_left",
        12: "lung_upper_lobe_right",
        13: "lung_middle_lobe_right",
        14: "lung_lower_lobe_right",
        15: "esophagus",
        16: "trachea",
        17: "thyroid_gland",
        18: "small_bowel",
        19: "duodenum",
        20: "colon",
        21: "urinary_bladder",
        22: "prostate",
        23: "kidney_cyst_left",
        24: "kidney_cyst_right",
        25: "sacrum",
        26: "vertebrae_S1",
        27: "vertebrae_L5",
        28: "vertebrae_L4",
        29: "vertebrae_L3",
        30: "vertebrae_L2",
        31: "vertebrae_L1",
        32: "vertebrae_T12",
        33: "vertebrae_T11",
        34: "vertebrae_T10",
        35: "vertebrae_T9",
        36: "vertebrae_T8",
        37: "vertebrae_T7",
        38: "vertebrae_T6",
        39: "vertebrae_T5",
        40: "vertebrae_T4",
        41: "vertebrae_T3",
        42: "vertebrae_T2",
        43: "vertebrae_T1",
        44: "vertebrae_C7",
        45: "vertebrae_C6",
        46: "vertebrae_C5",
        47: "vertebrae_C4",
        48: "vertebrae_C3",
        49: "vertebrae_C2",
        50: "vertebrae_C1",
        51: "heart",
        52: "aorta",
        53: "pulmonary_vein",
        54: "brachiocephalic_trunk",
        55: "subclavian_artery_right",
        56: "subclavian_artery_left",
        57: "common_carotid_artery_right",
        58: "common_carotid_artery_left",
        59: "brachiocephalic_vein_left",
        60: "brachiocephalic_vein_right",
        61: "atrial_appendage_left",
        62: "superior_vena_cava",
        63: "inferior_vena_cava",
        64: "portal_vein_and_splenic_vein",
        65: "iliac_artery_left",
        66: "iliac_artery_right",
        67: "iliac_vena_left",
        68: "iliac_vena_right",
        69: "humerus_left",
        70: "humerus_right",
        71: "scapula_left",
        72: "scapula_right",
        73: "clavicula_left",
        74: "clavicula_right",
        75: "femur_left",
        76: "femur_right",
        77: "hip_left",
        78: "hip_right",
        79: "spinal_cord",
        80: "gluteus_maximus_left",
        81: "gluteus_maximus_right",
        82: "gluteus_medius_left",
        83: "gluteus_medius_right",
        84: "gluteus_minimus_left",
        85: "gluteus_minimus_right",
        86: "autochthon_left",
        87: "autochthon_right",
        88: "iliopsoas_left",
        89: "iliopsoas_right",
        90: "brain",
        91: "skull",
        92: "rib_left_1",
        93: "rib_left_2",
        94: "rib_left_3",
        95: "rib_left_4",
        96: "rib_left_5",
        97: "rib_left_6",
        98: "rib_left_7",
        99: "rib_left_8",
        100: "rib_left_9",
        101: "rib_left_10",
        102: "rib_left_11",
        103: "rib_left_12",
        104: "rib_right_1",
        105: "rib_right_2",
        106: "rib_right_3",
        107: "rib_right_4",
        108: "rib_right_5",
        109: "rib_right_6",
        110: "rib_right_7",
        111: "rib_right_8",
        112: "rib_right_9",
        113: "rib_right_10",
        114: "rib_right_11",
        115: "rib_right_12",
        116: "sternum",
        117: "costal_cartilages"
    }

In [110]:
class Convert_To_Binary: 
    def __call__(self, key): # key - dictionary with keys 'img' and 'label'
        key['label'] = torch.where(key['label']> 0, 1, 0)
        return key 

Setup:

In [111]:
lbl_path = "C:\\Users\\Dell\\Downloads\\Totalsegmentator_dataset_v201\\s0001\\segmentations"

transform = Compose([
            LoadImaged(keys=['label']),
            EnsureChannelFirstd(keys=['label']),
            EnsureTyped(keys=['label']),
            RandSpatialCropd(keys =['label'], roi_size = (128,128,128), random_size = False),
            RandRotate90d(keys=['label'], prob = 0.5, max_k = 3),
            ResizeWithPadOrCropd(keys=['label'], spatial_size = (128,128,128)),
            Convert_To_Binary(),
        ])

In [112]:
def combine_masks(lbl_dir): 
    out_path = os.path.join(lbl_dir, 'combined_mask.nii.gz')  # https://docs.python.org/3/library/os.path.html

    if not os.path.exists(out_path):  # check if the combined mask already exists to avoid recalculation
        mask_files = glob.glob(os.path.join(lbl_dir, '*.nii.gz')) # get all nii.gz files (mask files)
        combined_mask = None


        for class_id, organ_name in total.items():
            struct_path = os.path.join(lbl_dir, f"{organ_name}.nii.gz")
            if os.path.exists(struct_path):    
                mask = nib.load(struct_path).get_fdata() # https://nipy.org/nibabel/images_and_memory.html - load mask data            
                if combined_mask is None:
                    combined_mask = np.zeros_like(mask) # https://numpy.org/doc/stable/reference/generated/numpy.zeros_like.html - generating empty mask
                combined_mask[mask > 0] = class_id # any non-zero value in the current mask will be assigned to the class label

        if combined_mask is not None:
            affine = nib.load(mask_files[0]).affine # loading affine matrix: https://medium.com/@junfeng142857/affine-transformation-why-3d-matrix-for-a-2d-transformation-8922b08bce75
            combined_mask_img = nib.Nifti1Image(combined_mask.astype(np.uint8), affine)
            nib.save(combined_mask_img, out_path)

        else:
            combined_mask = np.zeros((128,128,128), dtype=np.uint8) # https://numpy.org/doc/2.1/reference/generated/numpy.zeros.html    
            nifti = nib.Nifti1Image(combined_mask, np.eye(4))
            nib.save(nifti, out_path)
            
    return out_path

Reading labels by the dataloader. The code is copied from the dataloader.

In [113]:
lbl_files = glob.glob(os.path.join(lbl_path, '*.nii.gz'))
print(lbl_files)

cmb_lbl_path = combine_masks(lbl_path)
lbl_files = cmb_lbl_path

data = {'label': lbl_files }

if transform is not None: # if transform is provided
    print(f"Before transform: label path: {lbl_files}")
    data = transform(data) # applying any transforms
    print(f"After transform: label shape: {data['label'].shape}")

['C:\\Users\\Dell\\Downloads\\Totalsegmentator_dataset_v201\\s0001\\segmentations\\adrenal_gland_left.nii.gz', 'C:\\Users\\Dell\\Downloads\\Totalsegmentator_dataset_v201\\s0001\\segmentations\\adrenal_gland_right.nii.gz', 'C:\\Users\\Dell\\Downloads\\Totalsegmentator_dataset_v201\\s0001\\segmentations\\aorta.nii.gz', 'C:\\Users\\Dell\\Downloads\\Totalsegmentator_dataset_v201\\s0001\\segmentations\\atrial_appendage_left.nii.gz', 'C:\\Users\\Dell\\Downloads\\Totalsegmentator_dataset_v201\\s0001\\segmentations\\autochthon_left.nii.gz', 'C:\\Users\\Dell\\Downloads\\Totalsegmentator_dataset_v201\\s0001\\segmentations\\autochthon_right.nii.gz', 'C:\\Users\\Dell\\Downloads\\Totalsegmentator_dataset_v201\\s0001\\segmentations\\brachiocephalic_trunk.nii.gz', 'C:\\Users\\Dell\\Downloads\\Totalsegmentator_dataset_v201\\s0001\\segmentations\\brachiocephalic_vein_left.nii.gz', 'C:\\Users\\Dell\\Downloads\\Totalsegmentator_dataset_v201\\s0001\\segmentations\\brachiocephalic_vein_right.nii.gz', 'C:\\

That was converting labels to tensor. The code below is meant to reverse that process

In [114]:
def get_distinct_colors():
    colormaps = ['Pastel1', 'Pastel2', 'Paired', 'Accent','Dark2', 
        'Set1', 'Set2', 'Set3','tab10', 'tab20', 'tab20b', 'tab20c']  # https://matplotlib.org/stable/users/explain/colors/colormaps.html
    colors = []
    
    for i in range(117):
        cmap_name = colormaps[i % len(colormaps)]  # cycling through available colormaps defined above.
        cmap = cm.get_cmap(cmap_name)
        normalized_idx = (i % cmap.N) / cmap.N # normalize index.
        rgb_color = cmap(normalized_idx)[:3]  # extract RGB values.
        colors.append(rgb_color)
    
    return colors

In [115]:
lbls = []
lbls.append(data)
labels =  [sample['label'] for sample in lbls]
batched_labels = torch.stack(labels, dim=0)
img_tensor = batched_labels.unsqueeze(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
img_tensor = img_tensor.to(device)

seg_out = img_tensor.squeeze().cpu().numpy()
print(f"Segmentation output shape: {seg_out.shape}")

volume = []

for class_nr, organ_name in total:
    nifti = np.zeros(seg_out.shape)
    nifti[seg_out == class_nr] = 1
    affine = nib.load(nifti).affine
    nifti = nib.Nifti1Image(nifti.astype(np.uint8), affine)
    nib.save(nifti, f"{organ_name}.nii.gz")
    colors_rgb = get_distinct_colors()

    seg_path = f'{organ_name }.stl'
    if os.path.exists(seg_path):
        os.remove(seg_path)
    demo.convert_to_stl(nifti, seg_path)
    vol = load(seg_path).color(colors_rgb[i % len(colors_rgb)])
    volume.append(vol)
    
plotter = Plotter()
plotter.show(volume)

Segmentation output shape: (128, 128, 128)


TypeError: cannot unpack non-iterable int object