# Training a U-net Model from Scratch 

# Loading The Data

In [166]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision import models
from torch.nn.functional import relu
import pydicom
from PIL import Image
import numpy as np
from torchsummary import summary
import torch.nn.functional as F



First, I need to make a data structure that contains pairs of an image along with its binary masks. 


In [112]:
root_dir = "mask_and_mri"

data = []  # List to store image-masks pairs

# Iterate over patient directories in mask_and_mri
for patient_dir in os.listdir(root_dir):
    patient_path = os.path.join(root_dir, patient_dir)
    if os.path.isdir(patient_path):
        images_dir = os.path.join(patient_path, "images")
        masks_dir = os.path.join(patient_path, "masks")
        
        # Iterate over image files
        for image_file in os.listdir(images_dir):
            if image_file.endswith(".dcm"):
                image_path = os.path.join(images_dir, image_file)
                
                # Extract image ID
                image_id = image_file[:-4]  # Remove extension
                
                # Find corresponding masks
                masks = []
                for mask_file in os.listdir(masks_dir):
                    if image_id in mask_file:
                        mask_path = os.path.join(masks_dir, mask_file)
                        # makes a one to one mapping for masks and their image
                        data.append((mask_path,image_path))
                

# Print first few entries for verification
for i in range(30):
    print(data[i])
    
print(len(data))

('mask_and_mri\\SC-HF-I-01\\masks\\IM-0001-0048_icontour_1_mask.png', 'mask_and_mri\\SC-HF-I-01\\images\\IM-0001-0048.dcm')
('mask_and_mri\\SC-HF-I-01\\masks\\IM-0001-0059_icontour_1_mask.png', 'mask_and_mri\\SC-HF-I-01\\images\\IM-0001-0059.dcm')
('mask_and_mri\\SC-HF-I-01\\masks\\IM-0001-0059_ocontour_1_mask.png', 'mask_and_mri\\SC-HF-I-01\\images\\IM-0001-0059.dcm')
('mask_and_mri\\SC-HF-I-01\\masks\\IM-0001-0068_icontour_1_mask.png', 'mask_and_mri\\SC-HF-I-01\\images\\IM-0001-0068.dcm')
('mask_and_mri\\SC-HF-I-01\\masks\\IM-0001-0079_icontour_1_mask.png', 'mask_and_mri\\SC-HF-I-01\\images\\IM-0001-0079.dcm')
('mask_and_mri\\SC-HF-I-01\\masks\\IM-0001-0079_ocontour_1_mask.png', 'mask_and_mri\\SC-HF-I-01\\images\\IM-0001-0079.dcm')
('mask_and_mri\\SC-HF-I-01\\masks\\IM-0001-0088_icontour_1_mask.png', 'mask_and_mri\\SC-HF-I-01\\images\\IM-0001-0088.dcm')
('mask_and_mri\\SC-HF-I-01\\masks\\IM-0001-0099_icontour_1_mask.png', 'mask_and_mri\\SC-HF-I-01\\images\\IM-0001-0099.dcm')
('mask_a

In [113]:
# one mask to one image
mask, image = data[0]
print(mask)
print(image)
data[0]



mask_and_mri\SC-HF-I-01\masks\IM-0001-0048_icontour_1_mask.png
mask_and_mri\SC-HF-I-01\images\IM-0001-0048.dcm


('mask_and_mri\\SC-HF-I-01\\masks\\IM-0001-0048_icontour_1_mask.png',
 'mask_and_mri\\SC-HF-I-01\\images\\IM-0001-0048.dcm')

In [199]:

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        mask_path, image_path = self.data[idx]
        
        # Load image
        image = pydicom.dcmread(image_path).pixel_array.astype(np.float32)
        image /= 255.0
        
        # Load mask
        mask = np.array(Image.open(mask_path)).astype(np.float32)

        # Convert to tensor
        image_tensor = torch.from_numpy(image)
        mask_tensor = torch.from_numpy(mask)
        
        return image_tensor, mask_tensor

dataset = CustomDataset(data)

# Create a DataLoader using your custom dataset
data_loader = DataLoader(dataset, batch_size=4, shuffle=True)

num_batches_to_print = 3
# Iterate over the DataLoader to inspect its contents
for batch_idx, (mask, image) in enumerate(data_loader):
    print(f"Batch {batch_idx + 1}:")
    print(f"Number of images: {len(image)}")
    print(f"Number of masks: {len(mask)}")
    print(f"Image shape: {image[0].shape}")  # Assuming images are of the same size
    print(f"Mask shape: {mask[0].shape}\n\n")    # Assuming masks are of the same size
    if batch_idx + 1 >= num_batches_to_print:
        break
    
print(len(data))
print(len(data_loader))


Batch 1:
Number of images: 4
Number of masks: 4
Image shape: torch.Size([256, 256])
Mask shape: torch.Size([256, 256])


Batch 2:
Number of images: 4
Number of masks: 4
Image shape: torch.Size([256, 256])
Mask shape: torch.Size([256, 256])


Batch 3:
Number of images: 4
Number of masks: 4
Image shape: torch.Size([256, 256])
Mask shape: torch.Size([256, 256])


495
124


# The Model

In [201]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, n_class=1):
        super(UNet, self).__init__()
        
        # Encoder
        self.e11 = nn.Conv2d(1, 64, kernel_size=3, padding=1) # Adjusted for 1 input channel
        self.e12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.e21 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.e22 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.e31 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.e32 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.e41 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.e42 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Bottom
        self.e51 = nn.Conv2d(512, 1024, kernel_size=3, padding=1)
        self.e52 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1)

        # Decoder
        self.upconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.d11 = nn.Conv2d(1024, 512, kernel_size=3, padding=1)
        self.d12 = nn.Conv2d(512, 512, kernel_size=3, padding=1)

        self.upconv2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.d21 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.d22 = nn.Conv2d(256, 256, kernel_size=3, padding=1)

        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.d31 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.d32 = nn.Conv2d(128, 128, kernel_size=3, padding=1)

        self.upconv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.d41 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.d42 = nn.Conv2d(64, 64, kernel_size=3, padding=1)

        # Output layer
        self.outconv = nn.Conv2d(64, n_class, kernel_size=1)

    def forward(self, x):
        # Encoder
        xe11 = F.relu(self.e11(x))
        xe12 = F.relu(self.e12(xe11))
        xp1 = self.pool1(xe12)

        xe21 = F.relu(self.e21(xp1))
        xe22 = F.relu(self.e22(xe21))
        xp2 = self.pool2(xe22)

        xe31 = F.relu(self.e31(xp2))
        xe32 = F.relu(self.e32(xe31))
        xp3 = self.pool3(xe32)

        xe41 = F.relu(self.e41(xp3))
        xe42 = F.relu(self.e42(xe41))
        xp4 = self.pool4(xe42)

        xe51 = F.relu(self.e51(xp4))
        xe52 = F.relu(self.e52(xe51))
        
        # Decoder
        xu1 = self.upconv1(xe52)
        xu11 = torch.cat([xu1, xe42], dim=1)
        xd11 = F.relu(self.d11(xu11))
        xd12 = F.relu(self.d12(xd11))

        xu2 = self.upconv2(xd12)
        xu22 = torch.cat([xu2, xe32], dim=1)
        xd21 = F.relu(self.d21(xu22))
        xd22 = F.relu(self.d22(xd21))

        xu3 = self.upconv3(xd22)
        xu33 = torch.cat([xu3, xe22], dim=1)
        xd31 = F.relu(self.d31(xu33))
        xd32 = F.relu(self.d32(xd31))

        xu4 = self.upconv4(xd32)
        xu44 = torch.cat([xu4, xe12], dim=1)
        xd41 = F.relu(self.d41(xu44))
        xd42 = F.relu(self.d42(xd41))

        # Output layer
        out = self.outconv(xd42)

        return out


In [202]:
# pass an image through the model to see if it errors

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet() # 1 image, 1 class to predict
model.to(device)
criterion = nn.BCEWithLogitsLoss()  # Binary Cross-Entropy Loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # Adam optimizer
input_shape = (1, 256, 256)
summary(model,input_shape)

Layer (type:depth-idx)                   Output Shape              Param #
├─Conv2d: 1-1                            [-1, 64, 256, 256]        640
├─Conv2d: 1-2                            [-1, 64, 256, 256]        36,928
├─MaxPool2d: 1-3                         [-1, 64, 128, 128]        --
├─Conv2d: 1-4                            [-1, 128, 128, 128]       73,856
├─Conv2d: 1-5                            [-1, 128, 128, 128]       147,584
├─MaxPool2d: 1-6                         [-1, 128, 64, 64]         --
├─Conv2d: 1-7                            [-1, 256, 64, 64]         295,168
├─Conv2d: 1-8                            [-1, 256, 64, 64]         590,080
├─MaxPool2d: 1-9                         [-1, 256, 32, 32]         --
├─Conv2d: 1-10                           [-1, 512, 32, 32]         1,180,160
├─Conv2d: 1-11                           [-1, 512, 32, 32]         2,359,808
├─MaxPool2d: 1-12                        [-1, 512, 16, 16]         --
├─Conv2d: 1-13                           [-1, 1

Layer (type:depth-idx)                   Output Shape              Param #
├─Conv2d: 1-1                            [-1, 64, 256, 256]        640
├─Conv2d: 1-2                            [-1, 64, 256, 256]        36,928
├─MaxPool2d: 1-3                         [-1, 64, 128, 128]        --
├─Conv2d: 1-4                            [-1, 128, 128, 128]       73,856
├─Conv2d: 1-5                            [-1, 128, 128, 128]       147,584
├─MaxPool2d: 1-6                         [-1, 128, 64, 64]         --
├─Conv2d: 1-7                            [-1, 256, 64, 64]         295,168
├─Conv2d: 1-8                            [-1, 256, 64, 64]         590,080
├─MaxPool2d: 1-9                         [-1, 256, 32, 32]         --
├─Conv2d: 1-10                           [-1, 512, 32, 32]         1,180,160
├─Conv2d: 1-11                           [-1, 512, 32, 32]         2,359,808
├─MaxPool2d: 1-12                        [-1, 512, 16, 16]         --
├─Conv2d: 1-13                           [-1, 1

In [203]:
num_epochs = 100
loss_threshold = .15 

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.1, verbose=True)


# Training loop
for epoch in range(num_epochs):
    model.train()  # Set the model to train mode
    running_loss = 0.0
    
    for mask, image in data_loader:
        optimizer.zero_grad()  # Zero the gradients
        image = image.to(device)
        mask = mask.to(device)
        image = image.unsqueeze(1)  # Add channel dimension
        mask = mask.unsqueeze(1)
        # Forward pass
        outputs = model(image)
        
        # Compute loss
        loss = criterion(outputs, mask)
        
        # Backpropagation
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * image.size(0)
    
    epoch_loss = running_loss / len(data_loader.dataset)
    scheduler.step(epoch_loss)  # Adjust learning rate based on epoch_loss
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
    if epoch_loss < loss_threshold:
        print(f"Stopping early as loss is below the threshold of {loss_threshold:.4f}.")
        break

Epoch [1/100], Loss: 0.6343
Epoch [2/100], Loss: 0.6189
Epoch [3/100], Loss: 0.6216
Epoch [4/100], Loss: 0.6143
Epoch [5/100], Loss: 0.5995
Epoch [6/100], Loss: 0.5572
Epoch [7/100], Loss: 0.5259
Epoch [8/100], Loss: 0.5241
Epoch [9/100], Loss: 0.5157
Epoch [10/100], Loss: 0.5097
Epoch [11/100], Loss: 0.5086
Epoch [12/100], Loss: 0.5101
Epoch [13/100], Loss: 0.5079
Epoch [14/100], Loss: 0.5051
Epoch [15/100], Loss: 0.5041
Epoch [16/100], Loss: 0.5024
Epoch [17/100], Loss: 0.4933
Epoch [18/100], Loss: 0.4875
Epoch [19/100], Loss: 0.4837
Epoch [20/100], Loss: 0.4818
Epoch [21/100], Loss: 0.4816
Epoch [22/100], Loss: 0.4804
Epoch [23/100], Loss: 0.4808
Epoch [24/100], Loss: 0.4766
Epoch [25/100], Loss: 0.4757
Epoch [26/100], Loss: 0.4730
Epoch [27/100], Loss: 0.4739
Epoch [28/100], Loss: 0.4748
Epoch [29/100], Loss: 0.4719
Epoch [30/100], Loss: 0.4748
Epoch [31/100], Loss: 0.4750
Epoch [32/100], Loss: 0.4738
Epoch [33/100], Loss: 0.4759
Epoch [34/100], Loss: 0.4759
Epoch [35/100], Loss: 0