# DLMI Fall 2025 
# Project 02: Segmentation

## Setup

In [None]:
import torch
import torchvision
import SimpleITK as sitk
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from tqdm import tqdm
import glob
import os
import json

In [None]:
engr_dir = "/opt/nfsopt/DLMI"
idas_dir = os.path.join(os.path.expanduser('~'), "classdata")

if os.path.isdir(engr_dir):
    data_dir = engr_dir
elif os.path.isdir(idas_dir):  
    data_dir = idas_dir
else:
    print("Data directory not found")


In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
print(device)

## Provided Example Code

In [None]:
classes = {1:'parenchyma',2:'tumor',3:'cyst'}
colors = ['black','cyan','magenta','yellow']
cmap = ListedColormap(colors)

### JSON file with Custom Dataset

Read json file describing the dataset. At the top level the data will be a dictionary with 3 items corresponding to training, validation and testing datasets. Each item will be a list of all samples in that dataset. Each sample is a dictionary of the filenames for the image and label.

In [None]:
with open('seg_dataset.json', 'r') as file:
    data = json.load(file)

#access training dataset (list of dictionaries)
data['training'] 

#access single sample from training dataset (dictionary)
data['training'][0] 

In the `KidneyDataset` defined below, `self.data` is expected to be a **list** of **dictionaries**. Indexing an element of the list (`self.data[idx]`) will return a dictionary corresponding to a single sample idx that contains keys of "image" and "label" and the values are the corresponding filename for sample idx.

In [None]:
class KidneyDataset(torch.utils.data.Dataset):
    def __init__(self, data, root_dir, transform=None):
        self.data = data
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample_rel = self.data[idx]
        sample = {}
        
        for key,val in sample_rel.items():
            sample[key] = os.path.join(self.root_dir, val)
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample

Below an instance of the `KidneyDataset` class is created for the training dataset by passing the "training" item of the json data structure. Accessing a single element of the dataset will return a dictionary of the filenames for a sample (since no transforms are provided)

In [None]:
training_dataset = KidneyDataset(data["training"], data_dir)

sample = training_dataset[0]

sample

### Example transform for multiple images (image and segmentation)

Below is an example of a custom tranform that handles both images (CT and segmentation). Below, sample is expected to be a dictionary with the filenames for image and segmentaiton

In [None]:
class Read():
    def __init__(self, keys):
        self.keys = keys
    def __call__(self, sample):
        output = {}
        for key in self.keys:
            output[key] = sitk.ReadImage(sample[key])
        return output

Below a training dataset is created with the custom `Read` transform. Now, accessing a sample will return a dictionary with the SimpleITK image objects.

In [None]:
transform = Read(keys=["image","label"])

training_dataset = KidneyDataset(data["training"], data_dir, transform)

sample = training_dataset[0]

sample

### Visualize sample CT with segmentation overlaid

In [None]:
fig, axs = plt.subplots(1,3)

im = sample["image"] # This is a SimpleITK image
label = sample["label"] # This is a SimpleITK image

imnp = sitk.GetArrayFromImage(im)
labelnp = sitk.GetArrayFromImage(label)

labelnp = np.ma.masked_where(labelnp==0, labelnp)
shape = imnp.shape

alpha = 0.5
axs[0].imshow(imnp[shape[0]//3,:,:], cmap='gray', vmin=-300, vmax=300)
axs[0].imshow(labelnp[shape[0]//3,:,:], cmap=cmap, vmin=0, vmax=3, alpha=alpha)
    
axs[1].imshow(imnp[:,2*shape[1]//3,:], cmap='gray', vmin=-300, vmax=300)
axs[1].imshow(labelnp[:,2*shape[1]//3,:], cmap=cmap, vmin=0, vmax=3, alpha=alpha)

axs[2].imshow(imnp[:,:,shape[2]//2], cmap='gray', vmin=-300, vmax=300)
axs[2].imshow(labelnp[:,:,shape[2]//2], cmap=cmap, vmin=0, vmax=3, alpha=alpha)


for ax in axs.flat:
    ax.axis('off')

plt.tight_layout()

## Data Exploration

## Custom Dataset and Transforms

## DataLoader

## Model

## Training

## Evalution

## Visualization