# Training a U-net Model from Scratch 

# Loading The Data

In [32]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import pydicom
from PIL import Image
import numpy as np

First, I need to make a data structure that contains pairs of an image along with its binary masks. 


In [4]:
root_dir = "mask_and_mri"

data = []  # List to store image-masks pairs

# Iterate over patient directories in mask_and_mri
for patient_dir in os.listdir(root_dir):
    patient_path = os.path.join(root_dir, patient_dir)
    if os.path.isdir(patient_path):
        images_dir = os.path.join(patient_path, "images")
        masks_dir = os.path.join(patient_path, "masks")
        
        # Iterate over image files
        for image_file in os.listdir(images_dir):
            if image_file.endswith(".dcm"):
                image_path = os.path.join(images_dir, image_file)
                
                # Extract image ID
                image_id = image_file[:-4]  # Remove extension
                
                # Find corresponding masks
                masks = []
                for mask_file in os.listdir(masks_dir):
                    if image_id in mask_file:
                        mask_path = os.path.join(masks_dir, mask_file)
                        masks.append(mask_path)
                
                # Append image-masks pair to the data list
                data.append((image_path, masks))

# Print first few entries for verification
for i in range(30):
    print(data[i])
    
print(len(data))

('mask_and_mri\\SC-HF-I-01\\images\\IM-0001-0048.dcm', ['mask_and_mri\\SC-HF-I-01\\masks\\IM-0001-0048_icontour_1_mask.png'])
('mask_and_mri\\SC-HF-I-01\\images\\IM-0001-0059.dcm', ['mask_and_mri\\SC-HF-I-01\\masks\\IM-0001-0059_icontour_1_mask.png', 'mask_and_mri\\SC-HF-I-01\\masks\\IM-0001-0059_ocontour_1_mask.png'])
('mask_and_mri\\SC-HF-I-01\\images\\IM-0001-0068.dcm', ['mask_and_mri\\SC-HF-I-01\\masks\\IM-0001-0068_icontour_1_mask.png'])
('mask_and_mri\\SC-HF-I-01\\images\\IM-0001-0079.dcm', ['mask_and_mri\\SC-HF-I-01\\masks\\IM-0001-0079_icontour_1_mask.png', 'mask_and_mri\\SC-HF-I-01\\masks\\IM-0001-0079_ocontour_1_mask.png'])
('mask_and_mri\\SC-HF-I-01\\images\\IM-0001-0088.dcm', ['mask_and_mri\\SC-HF-I-01\\masks\\IM-0001-0088_icontour_1_mask.png'])
('mask_and_mri\\SC-HF-I-01\\images\\IM-0001-0099.dcm', ['mask_and_mri\\SC-HF-I-01\\masks\\IM-0001-0099_icontour_1_mask.png', 'mask_and_mri\\SC-HF-I-01\\masks\\IM-0001-0099_ocontour_1_mask.png', 'mask_and_mri\\SC-HF-I-01\\masks\\IM-0

In [9]:
image, masks = data[0]
print(image)
print(masks)
data[0]

mask_and_mri\SC-HF-I-01\images\IM-0001-0048.dcm
['mask_and_mri\\SC-HF-I-01\\masks\\IM-0001-0048_icontour_1_mask.png']


('mask_and_mri\\SC-HF-I-01\\images\\IM-0001-0048.dcm',
 ['mask_and_mri\\SC-HF-I-01\\masks\\IM-0001-0048_icontour_1_mask.png'])

In [45]:

class CustomDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path, mask_paths = self.data[idx]
        image = pydicom.dcmread(image_path).pixel_array.astype(np.float32)  # Convert DICOM pixel data to NumPy array

        images = [image] * len(mask_paths)  # Duplicate the image for each mask

        masks = [np.array(Image.open(mask_path)).astype(np.float32) for mask_path in mask_paths]  # Ensure masks are float32

        if self.transform:
            images = [self.transform(img) for img in images]

        image_tensors = [torch.from_numpy(img) for img in images]
        masks_tensors = [torch.from_numpy(mask) for mask in masks]

        return image_tensors, masks_tensors

dataset = CustomDataset(data)
data_loader = DataLoader(dataset, batch_size=4, shuffle=True)
len(data_loader)

# Assuming data_loader is your DataLoader object
for batch_idx, (images, masks) in enumerate(data_loader):
    print(f"Batch {batch_idx + 1}:")
    
    # Print sizes of images in the batch
    print("Image sizes:")
    for img in images:
        print(img.size())
    
    # Print sizes of masks in the batch
    print("Mask sizes:")
    for mask in masks:
        print(mask.size())

RuntimeError: each element in list of batch should be of equal size