In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from glob import glob
import matplotlib.pyplot as plt
import numpy as np

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
learning_rate = 1e-3
num_epochs = 50 
batch_size = 32
weight_decay = 5e-4

print(device)

In [None]:
class MRISegmentationDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None, target_transform=None):
        self.img_paths = sorted(glob(os.path.join(img_dir, "*.png")))
        self.mask_paths = sorted(glob(os.path.join(mask_dir, "*.png")))
        
        if len(self.img_paths) == 0 or len(self.mask_paths) == 0:
            raise RuntimeError("No images or masks found")

        assert len(self.img_paths) == len(self.mask_paths), \
            f"Number of images ({len(self.img_paths)}) and masks ({len(self.mask_paths)}) do not match"

        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img = Image.open(self.img_paths[idx]).convert("L")     # grayscale MRI
        mask = Image.open(self.mask_paths[idx]).convert("L")   # segmentation mask (labels)

        img = self.transform(img)
        mask = self.target_transform(mask)
        return img, mask

base_dir = "/mnt/c/Users/Acer/Downloads/keras_png_slices_data/keras_png_slices_data"
train_dir = f"{base_dir}/keras_png_slices_train"
train_dir_seg = f"{base_dir}/keras_png_slices_seg_train"
test_dir  = f"{base_dir}/keras_png_slices_test"
test_dir_seg  = f"{base_dir}/keras_png_slices_seg_test"

# Transforms for MRI images
img_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.1317], [0.1864])
])

# Masks don’t need normalization — just convert to tensor of class IDs
mask_transform = torchvision.transforms.Compose([
    torchvision.transforms.Lambda(lambda x: torch.from_numpy(np.array(x, dtype=np.int64))),
    torchvision.transforms.Lambda(lambda x: torch.where(x==85, torch.tensor(1),
                           torch.where(x==170, torch.tensor(2),
                           torch.where(x==255, torch.tensor(3), x))))
])

train_dataset = MRISegmentationDataset(train_dir, train_dir_seg, transform=img_transform, target_transform=mask_transform)
test_dataset  = MRISegmentationDataset(test_dir,  test_dir_seg, transform=img_transform, target_transform=mask_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False)

print("Train samples:", len(train_dataset))
print("Test samples:", len(test_dataset))


print("Train loader:", len(train_loader))
print("Test loader:", len(test_loader))

imgs, masks = next(iter(train_loader))
print("Image batch:", imgs.shape)   # (B,1,H,W)
print("Mask batch:", masks.shape)   # (B,H,W)

In [None]:
# Get one (image, mask) pair
img, mask = test_dataset[0]

print("Image shape:", img.shape)   # e.g. (1, 256, 256)
print("Mask shape:", mask.shape)   # e.g. (256, 256)
print("Mask dtype:", mask.dtype)   # should be torch.int64

# Print a small patch of the mask values
print("Unique label values in this mask:", torch.unique(mask))
print("Top-left corner of mask:\n", mask[:5, :5])


In [6]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_op = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU()
        )

    def forward(self, x):
        return self.conv_op(x)
    
    
class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
    def forward(self, x):
        down = self.conv(x)
        p = self.pool(down)
        
        return down, p
    
class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x1, x2], 1)
        x = self.conv(x)
        return x
    
class Unet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.down_convolution_1 = DownSample(in_channels, 64)
        self.down_convolution_2 = DownSample(64, 128)
        self.down_convolution_3 = DownSample(128, 256)
        self.down_convolution_4 = DownSample(256, 512)

        self.bottle_neck = DoubleConv(512, 1024)

        self.up_convolution_1 = UpSample(1024, 512)
        self.up_convolution_2 = UpSample(512, 256)
        self.up_convolution_3 = UpSample(256, 128)
        self.up_convolution_4 = UpSample(128, 64)

        self.out = nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=1)

    def forward(self, x):
        down_1, p1 = self.down_convolution_1(x)
        down_2, p2 = self.down_convolution_2(p1)
        down_3, p3 = self.down_convolution_3(p2)
        down_4, p4 = self.down_convolution_4(p3)

        b = self.bottle_neck(p4)

        up_1 = self.up_convolution_1(b, down_4)
        up_2 = self.up_convolution_2(up_1, down_3)
        up_3 = self.up_convolution_3(up_2, down_2)
        up_4 = self.up_convolution_4(up_3, down_1)

        out = self.out(up_4)
        return out

In [12]:
# input_image = torch.rand((1,1,256,256))
# output = model(input_image)
# print(output.size())

torch.Size([1, 4, 256, 256])


In [8]:
model = Unet(1 , 4).to(device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [None]:
model.train()
total_step = 50

for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (imgs, masks) in enumerate(train_loader):
        imgs = imgs.to(device)
        masks = masks.to(device)

        y_pred = model(imgs)
        loss = criterion(y_pred, masks)

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if (i + 1) % total_step == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], "
                  f"Loss: {running_loss / total_step:.4f}")
            running_loss = 0.0

In [None]:
def dice_coefficient(preds, target, num_classes, epsilon = 1e-7):
    # convert logits -> predicted class ids
    preds = preds.softmax(dim=1).argmax(dim=1)  # (B, H, W)

    dice_per_class = []
    for c in range(num_classes):
        pred_c = (preds == c).float()
        target_c = (target == c).float()

        intersection = (pred_c * target_c).sum()
        union = pred_c.sum() + target_c.sum()
        dice = (2. * intersection + epsilon) / (union + epsilon)
        dice_per_class.append(dice)

    dice_per_class = torch.stack(dice_per_class)
    mean_dice = dice_per_class.mean()

    return dice_per_class, mean_dice


model.eval()
all_dice = []

with torch.no_grad():
    for imgs, masks in test_loader:
        imgs = imgs.to(device)
        masks = masks.to(device).long()   # (B,H,W)

        logits = model(imgs)              # (B,C,H,W)
        dice_per_class, mean_dice = dice_coefficient(logits, masks, num_classes=4)

        all_dice.append(dice_per_class.cpu())

# Stack results and average across test set
all_dice = torch.stack(all_dice)   # (N, C)
mean_per_class = all_dice.mean(dim=0)

print("\n=== Test Set Dice ===")
for i, d in enumerate(mean_per_class):
    print(f"Class {i}: {d:.4f}")
print(f"Mean Dice: {mean_per_class.mean():.4f}")

In [None]:
model.eval()
imgs, masks = next(iter(test_loader))         # imgs: (B,1,H,W), masks: (B,H,W)
imgs, masks = imgs.to(device), masks.to(device).long()

with torch.no_grad():
    logits = model(imgs)                      # (B,C,H,W)
    preds = logits.softmax(1).argmax(1)       # (B,H,W) with {0,1,2,3}

# convert preds {0,1,2,3} -> {0,85,170,255}
mapping = torch.tensor([0,85,170,255], device=preds.device)
preds_val = mapping[preds]                    # (B,H,W)

mean, std = 0.1317, 0.1864

# Show first 8 samples
n = min(8, imgs.size(0))
fig, axes = plt.subplots(n, 3, figsize=(9, 3*n))

for i in range(n):
    img_np   = imgs[i][0].cpu().numpy()
    mask_np  = masks[i].cpu().numpy()
    pred_np  = preds_val[i].cpu().numpy()

    # Undo normalization for MRI
    img_np = img_np * std + mean
    img_np = img_np.clip(0,1)

    axes[i,0].imshow(img_np, cmap="gray")
    axes[i,0].set_title("MRI")
    axes[i,0].axis("off")

    axes[i,1].imshow(mask_np, cmap="gray")
    axes[i,1].set_title("Ground Truth")
    axes[i,1].axis("off")

    axes[i,2].imshow(pred_np, cmap="gray")
    axes[i,2].set_title("Prediction")
    axes[i,2].axis("off")

plt.tight_layout()
plt.show()