In [None]:
import os
import SimpleITK as sitk
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

import torchvision.transforms as T
import torch
from torch.utils.data import Dataset
from torchvision import datasets
import torchvision.transforms.functional as TF
import random

In [2]:
def load_correct_study(patient_path):
    for root, dirs, files in os.walk(patient_path):
        #print(root, dirs, files)
        dcm_files = [f for f in files if f.endswith(".dcm")]
        if len(dcm_files) == 60:
            return os.path.join(root)
    return None

In [3]:
def load_patient_volume(patient_folder):
    """
    Load the 60-slice DICOM volume for a patient.
    Returns:
        volume_np: (Z,H,W) float32 numpy array
    """
    # Step 1: find the folder with exactly 60 .dcm files
    subfolders = [os.path.join(patient_folder, f) for f in os.listdir(patient_folder) 
                  if os.path.isdir(os.path.join(patient_folder, f))]
    
    study_folder = load_correct_study(patient_folder)
    if study_folder is None:
        #print(f"No 60-slice folder found in {patient_folder}")
        return None
    
    # Step 2: read slices
    dcm_files = sorted([os.path.join(study_folder,f) for f in os.listdir(study_folder) if f.lower().endswith('.dcm')])
    slices = []
    for f in dcm_files:
        img = sitk.ReadImage(f)
        arr = sitk.GetArrayFromImage(img)[0]  # (1,H,W) -> (H,W)
        slices.append(arr.astype(np.float32))
    
    # Step 3: stack into volume
    volume_np = np.stack(slices, axis=0)  # (Z,H,W)
    
    return volume_np

In [4]:
def generate_consecutive_triplets(volume):
    """
    Generate overlapping triplets: (slice[i], slice[i+2]) -> slice[i+1]
    Returns lists of numpy arrays: pre_slices, post_slices, middle_slices
    """
    pre_slices = []
    post_slices = []
    middle_slices = []
    
    for i in range(volume.shape[0]-2):
        pre_slices.append(volume[i])
        post_slices.append(volume[i+2])
        middle_slices.append(volume[i+1])
        
    for i in range(volume.shape[0]-4):
        pre_slices.append(volume[i])
        post_slices.append(volume[i+4])
        middle_slices.append(volume[i+2])
        
    return pre_slices, post_slices, middle_slices

In [5]:
parent_directory = os.path.dirname(os.getcwd())
parent_directory.split('/')[-1]

'Multi-Image-Super-Resolution'

In [6]:
def generate_progressive_triplets(volume):
    """
    volume: (Z,H,W)
    Returns: lists of pre_slices, post_slices, middle_slices
    """
    triplets = []
    
    def recursive_triplets(indices):
        if len(indices) < 2:
            return
        start = indices[0]
        end = indices[1]
        if(start>end or end-start<=2):
            return
        mid = (start + end) // 2
        if (mid != start and mid != end):
            # add triplet: start & end -> mid
            triplets.append((volume[start], volume[end], volume[mid]))
            # recurse on left and right halves
            recursive_triplets([start, mid])
            recursive_triplets([mid, end])
    
    recursive_triplets([0, volume.shape[0]-1])
    
    # Unpack triplets
    pre_slices = [t[0] for t in triplets]
    post_slices = [t[1] for t in triplets]
    middle_slices = [t[2] for t in triplets]
    
    return pre_slices, post_slices, middle_slices

In [7]:
BASE_DIR = os.path.join(parent_directory, 'data','manifest-1694710246744','Prostate-MRI-US-Biopsy')
#print(sorted(os.listdir(BASE_DIR)))
patient_folders = sorted([f for f in os.listdir(BASE_DIR) if f.startswith("Prostate-MRI-US-Biopsy-")])
print(len(patient_folders))

842


In [8]:
train_path, test_val_path = train_test_split(patient_folders, test_size = 0.3, random_state = 42, shuffle = True)

In [9]:
len(train_path)

589

In [10]:
len(test_val_path)

253

In [11]:
val_path, test_path = train_test_split(test_val_path, test_size = 0.6, random_state = 42, shuffle = True)
len(val_path)

101

In [12]:
len(test_path)

152

In [13]:
count_no_slices_train = []
count_no_slices_test = []
count_no_slices_val = []

In [14]:
def get_data(folders, split):
    all_pre, all_post, all_middle = [], [], []
    
    for pid in folders:
        patient_path = os.path.join(BASE_DIR, pid)
        volume = load_patient_volume(patient_path)
        if volume is None:
            if(split == 'train'):
                count_no_slices_train.append(patient_path)
            elif(split == 'test'):
                count_no_slices_test.append(patient_path)
            else:
                count_no_slices_val.append(patient_path)
            continue
        
        pre, post, middle = generate_consecutive_triplets(volume)
        all_pre.extend(pre)
        all_post.extend(post)
        all_middle.extend(middle)
        
    return all_pre, all_post, all_middle

In [15]:
train_pre, train_post, train_middle = get_data(train_path, split = 'train')
val_pre, val_post, val_middle = get_data(val_path, split = 'val')
test_pre, test_post, test_middle = get_data(test_path, split = 'test')

In [16]:
print(len(count_no_slices_train), len(count_no_slices_test), len(count_no_slices_val))

64 22 9


In [17]:
class TripletSliceDataset(Dataset):
    def __init__(self, pre_slices, post_slices, middle_slices, transform=None):
        assert len(pre_slices) == len(post_slices) == len(middle_slices)

        self.pre = pre_slices
        self.post = post_slices
        self.mid = middle_slices
        self.transform = transform

    def __len__(self):
        return len(self.pre)

    def normalize(self, x):
        x = x.astype(np.float32)
        mean = x.mean()
        std = x.std()
        if std < 1e-6: std = 1e-6
        return (x - mean) / std

    def __getitem__(self, idx):
        pre  = self.normalize(self.pre[idx])
        post = self.normalize(self.post[idx])
        mid  = self.normalize(self.mid[idx])

        # Convert to tensor CHW
        pre  = torch.tensor(pre).unsqueeze(0)
        post = torch.tensor(post).unsqueeze(0)
        mid  = torch.tensor(mid).unsqueeze(0)

        sample = {"pre": pre, "post": post, "mid": mid}

        # Apply paired transforms
        if self.transform:
            sample = self.transform(sample)

        return (sample["pre"], sample["post"]), sample["mid"]

In [18]:
class PairedTransforms:
    def __init__(self):
        pass

    def __call__(self, sample):
        pre, post, mid = sample["pre"], sample["post"], sample["mid"]

        # Random horizontal flip
        if random.random() < 0.5:
            pre = TF.hflip(pre)
            post = TF.hflip(post)
            mid = TF.hflip(mid)

        # Random vertical flip
        if random.random() < 0.5:
            pre = TF.vflip(pre)
            post = TF.vflip(post)
            mid = TF.vflip(mid)

        # Small rotation
        angle = random.uniform(-5, 5)
        pre = TF.rotate(pre, angle)
        post = TF.rotate(post, angle)
        mid = TF.rotate(mid, angle)

        return {"pre": pre, "post": post, "mid": mid}


In [None]:
transform = PairedTransforms()

dataset = TripletSliceDataset(
    pre_slices, post_slices, middle_slices,
    transform=None
)

dataloader = DataLoader(
    dataset,
    batch_size=4,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

In [1]:
import ModelDataGenerator as MDL

In [2]:
train_loader = MDL.build_dataloader(
        split="train",
        batch_size=32,
        augment=True
)

In [4]:
len(train_loader)

2284

In [5]:
for i,batch in enumerate(train_loader):
    print(i)
    (pre, post), mid = batch
    if(pre.shape != post.shape or pre.shape != mid.shape):
        print('yes')

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32


KeyboardInterrupt: 

In [None]:
len(train_loader)