In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("nikhilroxtomar/ct-heart-segmentation")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/nikhilroxtomar/ct-heart-segmentation?dataset_version_number=1...


100%|██████████| 541M/541M [00:25<00:00, 22.6MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/nikhilroxtomar/ct-heart-segmentation/versions/1


In [None]:
import os
import random
import time
from glob import glob
import cv2 as cv
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import pydicom

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as F
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torchsummary import summary


In [None]:
os.chdir(path)

In [None]:
!ls ./data/test/*/*

./data/test/100251/2-0OPAGELSQXD3602.5120nullnullnull-72944:
1-001.dcm  1-019.dcm  1-037.dcm  1-055.dcm  1-073.dcm  1-091.dcm  1-109.dcm
1-002.dcm  1-020.dcm  1-038.dcm  1-056.dcm  1-074.dcm  1-092.dcm  1-110.dcm
1-003.dcm  1-021.dcm  1-039.dcm  1-057.dcm  1-075.dcm  1-093.dcm  1-111.dcm
1-004.dcm  1-022.dcm  1-040.dcm  1-058.dcm  1-076.dcm  1-094.dcm  1-112.dcm
1-005.dcm  1-023.dcm  1-041.dcm  1-059.dcm  1-077.dcm  1-095.dcm  1-113.dcm
1-006.dcm  1-024.dcm  1-042.dcm  1-060.dcm  1-078.dcm  1-096.dcm  1-114.dcm
1-007.dcm  1-025.dcm  1-043.dcm  1-061.dcm  1-079.dcm  1-097.dcm  1-115.dcm
1-008.dcm  1-026.dcm  1-044.dcm  1-062.dcm  1-080.dcm  1-098.dcm  1-116.dcm
1-009.dcm  1-027.dcm  1-045.dcm  1-063.dcm  1-081.dcm  1-099.dcm  1-117.dcm
1-010.dcm  1-028.dcm  1-046.dcm  1-064.dcm  1-082.dcm  1-100.dcm  1-118.dcm
1-011.dcm  1-029.dcm  1-047.dcm  1-065.dcm  1-083.dcm  1-101.dcm  1-119.dcm
1-012.dcm  1-030.dcm  1-048.dcm  1-066.dcm  1-084.dcm  1-102.dcm  1-120.dcm
1-013.dcm  1-031.dcm  1-049

In [None]:
dataset_path = os.path.join("data", "train")
test_dataset_path = os.path.join("data", "test")

dataset_path, test_dataset_path

('data/train', 'data/test')

In [None]:
def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [None]:
class SegmentationTrainTransform:
    def __init__(self, image_size=(512, 512)):
        self.image_size = image_size

    def __call__(self, image, mask):
        image = F.resize(image, self.image_size)
        mask = F.resize(mask, self.image_size)
        if random.random() < 0.5:
            image, mask = F.hflip(image), F.hflip(mask)
        if random.random() < 0.5:
            image, mask = F.vflip(image), F.vflip(mask)
        angle = random.uniform(-20, 20)
        image, mask = F.rotate(image, angle), F.rotate(mask, angle)
        image = F.to_tensor(image)
        mask = F.pil_to_tensor(mask)
        image = F.normalize(image, mean=[0.5], std=[0.5])
        return image, mask

class SegmentationValTransform:
    def __init__(self, image_size=(512, 512)):
        self.image_size = image_size

    def __call__(self, image, mask=None):
        image = F.resize(image, self.image_size)
        image = F.to_tensor(image)
        image = F.normalize(image, mean=[0.5], std=[0.5])

        if mask is not None:
            mask = F.resize(mask, self.image_size)
            mask = F.pil_to_tensor(mask)
            return image, mask

        return image

In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("L")
        mask = Image.open(self.mask_paths[idx]).convert("L")

        if self.transform:
            image, mask = self.transform(image, mask)

        image = image.float()
        mask = mask.float() / 255.0
        return image, mask



class DICOMTestDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]

        dcm_image = pydicom.dcmread(img_path)
        image_array = dcm_image.pixel_array

        img_norm = (image_array - image_array.min()) / (image_array.max() - image_array.min() + 1e-6)
        image_array_uint8 = (img_norm * 255.0).astype(np.uint8)

        image = Image.fromarray(image_array_uint8).convert("L")

        if self.transform:
            image = self.transform(image)

        image = image.float()

        return image, img_path




def load_data(path, test_path, val_split=0.2, train_transform=None, val_transform=None):
    images = sorted(glob(os.path.join(path, "*", "image", "*.png")))
    masks = sorted(glob(os.path.join(path, "*", "mask", "*.png")))

    X_train, X_val, y_train, y_val = train_test_split(
        images, masks, test_size=val_split, random_state=42
    )

    train_dataset = SegmentationDataset(X_train, y_train, train_transform)
    val_dataset = SegmentationDataset(X_val, y_val, val_transform)

    X_test = sorted(glob(os.path.join(test_path, "*", "*","*.dcm")))


    test_dataset = DICOMTestDataset(X_test, transform=val_transform)

    print(f"Found {len(train_dataset)} training images.")
    print(f"Found {len(val_dataset)} validation images.")
    print(f"Found {len(test_dataset)} test images.")

    return train_dataset, val_dataset, test_dataset

## model building

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, init_features=64):
        super(UNet, self).__init__()
        features = init_features

        self.encoder1 = self.block(in_channels, features)
        self.pool1 = nn.MaxPool2d(2)
        self.encoder2 = self.block(features, features * 2)
        self.pool2 = nn.MaxPool2d(2)
        self.encoder3 = self.block(features * 2, features * 4)
        self.pool3 = nn.MaxPool2d(2)
        self.encoder4 = self.block(features * 4, features * 8)
        self.pool4 = nn.MaxPool2d(2)

        self.bottleneck = self.block(features * 8, features * 16)

        self.upconv4 = nn.ConvTranspose2d(features * 16, features * 8, 2, stride=2)
        self.decoder4 = self.block(features * 16, features * 8)

        self.upconv3 = nn.ConvTranspose2d(features * 8, features * 4, 2, stride=2)
        self.decoder3 = self.block(features * 8, features * 4)

        self.upconv2 = nn.ConvTranspose2d(features * 4, features * 2, 2, stride=2)
        self.decoder2 = self.block(features * 4, features * 2)

        self.upconv1 = nn.ConvTranspose2d(features * 2, features, 2, stride=2)
        self.decoder1 = self.block(features * 2, features)

        self.conv_final = nn.Conv2d(features, out_channels, kernel_size=1)

    def block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))
        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)

        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        return self.conv_final(dec1)

In [None]:
class BCEDiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(BCEDiceLoss, self).__init__()
        self.smooth = smooth
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, inputs, targets):
        if inputs.size() != targets.size():
            targets = F.interpolate(targets, size=inputs.shape[2:], mode="nearest")

        bce_loss = self.bce(inputs, targets)

        probs = torch.sigmoid(inputs)
        targets = targets.float()

        probs_flat = probs.contiguous().view(probs.size(0), -1)
        targets_flat = targets.contiguous().view(targets.size(0), -1)

        intersection = (probs_flat * targets_flat).sum(dim=1)
        dice_score = (2. * intersection + self.smooth) / (
            probs_flat.sum(dim=1) + targets_flat.sum(dim=1) + self.smooth
        )
        dice_loss = 1 - dice_score.mean()

        total_loss = bce_loss + dice_loss
        return total_loss


In [None]:
train_transform = SegmentationTrainTransform((512, 512))
val_transform = SegmentationValTransform((512, 512))
train_dataset, val_dataset, test_dataset = load_data(dataset_path, test_dataset_path, 0.2, train_transform, val_transform)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=4, pin_memory=True)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels=1, out_channels=1).to(device)
summary(model, (1, 512, 512))

criterion = BCEDiceLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)



Found 2025 training images.
Found 507 validation images.
Found 832 test images.




----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 512, 512]             576
       BatchNorm2d-2         [-1, 64, 512, 512]             128
              ReLU-3         [-1, 64, 512, 512]               0
           Dropout-4         [-1, 64, 512, 512]               0
            Conv2d-5         [-1, 64, 512, 512]          36,864
       BatchNorm2d-6         [-1, 64, 512, 512]             128
              ReLU-7         [-1, 64, 512, 512]               0
         MaxPool2d-8         [-1, 64, 256, 256]               0
            Conv2d-9        [-1, 128, 256, 256]          73,728
      BatchNorm2d-10        [-1, 128, 256, 256]             256
             ReLU-11        [-1, 128, 256, 256]               0
          Dropout-12        [-1, 128, 256, 256]               0
           Conv2d-13        [-1, 128, 256, 256]         147,456
      BatchNorm2d-14        [-1, 128, 2

In [None]:
def dice_coefficient(preds, targets, smooth=1e-6):
    preds = torch.sigmoid(preds)
    preds = (preds > 0.5).float()
    intersection = (preds * targets).sum()
    return (2.0 * intersection + smooth) / (preds.sum() + targets.sum() + smooth)

def iou_score(preds, targets, smooth=1e-6):
    preds = torch.sigmoid(preds)
    preds = (preds > 0.5).float()
    intersection = (preds * targets).sum()
    union = preds.sum() + targets.sum() - intersection
    return (intersection + smooth) / (union + smooth)

def pixel_accuracy(preds, targets):
    preds = torch.sigmoid(preds)
    preds = (preds > 0.5).float()
    correct = (preds == targets).float().sum()
    total = torch.numel(targets)
    return correct / total


In [None]:
num_epochs = 30
best_val_loss = float("inf")
save_path = "/content/models/best_unet_model.pth"
os.makedirs(os.path.dirname(save_path), exist_ok=True)

for epoch in range(1, num_epochs + 1):
    start_time = time.time()
    model.train()
    running_train_loss = 0.0

    train_loop = tqdm(train_loader, desc=f"Epoch [{epoch}/{num_epochs}] Training", leave=False)
    for images, masks in train_loop:
        images = images.to(device, dtype=torch.float32, non_blocking=True)
        masks = masks.to(device, dtype=torch.float32, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        running_train_loss += loss.item()
        train_loop.set_postfix(loss=f"{loss.item():.4f}")

    avg_train_loss = running_train_loss / len(train_loader)

    model.eval()
    running_val_loss = 0.0
    total_dice = 0.0
    total_iou = 0.0
    total_acc = 0.0

    with torch.no_grad():
        val_loop = tqdm(val_loader, desc=f"Epoch [{epoch}/{num_epochs}] Validation", leave=False)
        for images, masks in val_loop:
            images = images.to(device, dtype=torch.float32, non_blocking=True)
            masks = masks.to(device, dtype=torch.float32, non_blocking=True)

            outputs = model(images)
            loss = criterion(outputs, masks)

            running_val_loss += loss.item()
            total_dice += dice_coefficient(outputs, masks).item()
            total_iou += iou_score(outputs, masks).item()
            total_acc += pixel_accuracy(outputs, masks).item()

    avg_val_loss = running_val_loss / len(val_loader)
    avg_dice = total_dice / len(val_loader)
    avg_iou = total_iou / len(val_loader)
    avg_acc = total_acc / len(val_loader)
    epoch_time = time.time() - start_time

    print(f"\nEpoch [{epoch}/{num_epochs}]")
    print(f"Train Loss: {avg_train_loss:.4f}")
    print(f"Val Loss  : {avg_val_loss:.4f}")
    print(f"Dice Score: {avg_dice:.6f}")
    print(f"IoU Score : {avg_iou:.6f}")
    print(f"Accuracy  : {avg_acc:.6f}")
    print(f"Time Taken: {epoch_time:.2f} sec")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), save_path)
        print(f"New best model saved to '{save_path}'.")
    print("-" * 60)




Epoch [1/30]
Train Loss: 1.2323
Val Loss  : 1.3110
Dice Score: 0.240606
IoU Score : 0.144924
Accuracy  : 0.824064
Time Taken: 521.03 sec
New best model saved to '/content/models/best_unet_model.pth'.
------------------------------------------------------------





Epoch [2/30]
Train Loss: 0.9811
Val Loss  : 0.8715
Dice Score: 0.632841
IoU Score : 0.510425
Accuracy  : 0.979890
Time Taken: 521.87 sec
New best model saved to '/content/models/best_unet_model.pth'.
------------------------------------------------------------





Epoch [3/30]
Train Loss: 0.8593
Val Loss  : 0.8143
Dice Score: 0.710143
IoU Score : 0.601603
Accuracy  : 0.986859
Time Taken: 522.63 sec
New best model saved to '/content/models/best_unet_model.pth'.
------------------------------------------------------------





Epoch [4/30]
Train Loss: 0.7941
Val Loss  : 0.7447
Dice Score: 0.716553
IoU Score : 0.614276
Accuracy  : 0.986120
Time Taken: 522.65 sec
New best model saved to '/content/models/best_unet_model.pth'.
------------------------------------------------------------





Epoch [5/30]
Train Loss: 0.7570
Val Loss  : 0.7279
Dice Score: 0.688882
IoU Score : 0.579138
Accuracy  : 0.982175
Time Taken: 522.38 sec
New best model saved to '/content/models/best_unet_model.pth'.
------------------------------------------------------------





Epoch [6/30]
Train Loss: 0.7421
Val Loss  : 0.9161
Dice Score: 0.438622
IoU Score : 0.304788
Accuracy  : 0.929302
Time Taken: 522.82 sec
------------------------------------------------------------





Epoch [7/30]
Train Loss: 0.7276
Val Loss  : 0.6929
Dice Score: 0.751287
IoU Score : 0.662333
Accuracy  : 0.988850
Time Taken: 522.80 sec
New best model saved to '/content/models/best_unet_model.pth'.
------------------------------------------------------------





Epoch [8/30]
Train Loss: 0.7159
Val Loss  : 0.7030
Dice Score: 0.829034
IoU Score : 0.751378
Accuracy  : 0.992591
Time Taken: 522.95 sec
------------------------------------------------------------





Epoch [9/30]
Train Loss: 0.7118
Val Loss  : 0.6839
Dice Score: 0.811225
IoU Score : 0.732445
Accuracy  : 0.991615
Time Taken: 522.49 sec
New best model saved to '/content/models/best_unet_model.pth'.
------------------------------------------------------------





Epoch [10/30]
Train Loss: 0.7044
Val Loss  : 0.6812
Dice Score: 0.806846
IoU Score : 0.726347
Accuracy  : 0.991523
Time Taken: 523.10 sec
New best model saved to '/content/models/best_unet_model.pth'.
------------------------------------------------------------





Epoch [11/30]
Train Loss: 0.7028
Val Loss  : 0.6787
Dice Score: 0.790515
IoU Score : 0.706682
Accuracy  : 0.991033
Time Taken: 522.94 sec
New best model saved to '/content/models/best_unet_model.pth'.
------------------------------------------------------------





Epoch [12/30]
Train Loss: 0.6993
Val Loss  : 0.6859
Dice Score: 0.834993
IoU Score : 0.757899
Accuracy  : 0.992657
Time Taken: 522.99 sec
------------------------------------------------------------





Epoch [13/30]
Train Loss: 0.6972
Val Loss  : 0.6786
Dice Score: 0.757000
IoU Score : 0.667418
Accuracy  : 0.989698
Time Taken: 522.56 sec
New best model saved to '/content/models/best_unet_model.pth'.
------------------------------------------------------------





Epoch [14/30]
Train Loss: 0.6908
Val Loss  : 0.6505
Dice Score: 0.831126
IoU Score : 0.760614
Accuracy  : 0.992797
Time Taken: 523.71 sec
New best model saved to '/content/models/best_unet_model.pth'.
------------------------------------------------------------





Epoch [15/30]
Train Loss: 0.6022
Val Loss  : 0.4014
Dice Score: 0.758132
IoU Score : 0.651600
Accuracy  : 0.988147
Time Taken: 523.35 sec
New best model saved to '/content/models/best_unet_model.pth'.
------------------------------------------------------------





Epoch [16/30]
Train Loss: 0.5144
Val Loss  : 0.3016
Dice Score: 0.764040
IoU Score : 0.664196
Accuracy  : 0.987875
Time Taken: 522.71 sec
New best model saved to '/content/models/best_unet_model.pth'.
------------------------------------------------------------





Epoch [17/30]
Train Loss: 0.3241
Val Loss  : 0.2730
Dice Score: 0.771770
IoU Score : 0.676092
Accuracy  : 0.988419
Time Taken: 522.97 sec
New best model saved to '/content/models/best_unet_model.pth'.
------------------------------------------------------------





Epoch [18/30]
Train Loss: 0.3535
Val Loss  : 0.2565
Dice Score: 0.786547
IoU Score : 0.688380
Accuracy  : 0.989359
Time Taken: 522.38 sec
New best model saved to '/content/models/best_unet_model.pth'.
------------------------------------------------------------





Epoch [19/30]
Train Loss: 0.2815
Val Loss  : 0.2057
Dice Score: 0.819141
IoU Score : 0.734407
Accuracy  : 0.990960
Time Taken: 522.92 sec
New best model saved to '/content/models/best_unet_model.pth'.
------------------------------------------------------------





Epoch [20/30]
Train Loss: 0.2770
Val Loss  : 0.2074
Dice Score: 0.822365
IoU Score : 0.738846
Accuracy  : 0.990556
Time Taken: 522.46 sec
------------------------------------------------------------





Epoch [21/30]
Train Loss: 0.2780
Val Loss  : 0.3014
Dice Score: 0.749852
IoU Score : 0.641030
Accuracy  : 0.987636
Time Taken: 522.22 sec
------------------------------------------------------------





Epoch [22/30]
Train Loss: 0.2635
Val Loss  : 0.2039
Dice Score: 0.827758
IoU Score : 0.741973
Accuracy  : 0.990721
Time Taken: 522.70 sec
New best model saved to '/content/models/best_unet_model.pth'.
------------------------------------------------------------





Epoch [23/30]
Train Loss: 0.2263
Val Loss  : 0.2191
Dice Score: 0.835025
IoU Score : 0.746525
Accuracy  : 0.989888
Time Taken: 522.39 sec
------------------------------------------------------------





Epoch [24/30]
Train Loss: 0.2230
Val Loss  : 0.1428
Dice Score: 0.865092
IoU Score : 0.790534
Accuracy  : 0.992247
Time Taken: 522.36 sec
New best model saved to '/content/models/best_unet_model.pth'.
------------------------------------------------------------





Epoch [25/30]
Train Loss: 0.2421
Val Loss  : 0.1843
Dice Score: 0.848499
IoU Score : 0.769246
Accuracy  : 0.991550
Time Taken: 522.56 sec
------------------------------------------------------------





Epoch [26/30]
Train Loss: 0.2143
Val Loss  : 0.2003
Dice Score: 0.835786
IoU Score : 0.753813
Accuracy  : 0.990942
Time Taken: 522.62 sec
------------------------------------------------------------





Epoch [27/30]
Train Loss: 0.2169
Val Loss  : 0.2454
Dice Score: 0.813612
IoU Score : 0.719231
Accuracy  : 0.989719
Time Taken: 522.56 sec
------------------------------------------------------------





Epoch [28/30]
Train Loss: 0.2335
Val Loss  : 0.2562
Dice Score: 0.772593
IoU Score : 0.666387
Accuracy  : 0.987765
Time Taken: 522.41 sec
------------------------------------------------------------





Epoch [29/30]
Train Loss: 0.2364
Val Loss  : 0.1599
Dice Score: 0.848151
IoU Score : 0.766958
Accuracy  : 0.991412
Time Taken: 522.59 sec
------------------------------------------------------------


                                                                           


Epoch [30/30]
Train Loss: 0.2345
Val Loss  : 0.2109
Dice Score: 0.788780
IoU Score : 0.696930
Accuracy  : 0.990005
Time Taken: 522.49 sec
------------------------------------------------------------




In [None]:
model = UNet(in_channels=1, out_channels=1).to(device)

save_path = "/content/models/best_unet_model.pth"

model.load_state_dict(torch.load(save_path))

model.eval()

print(f"Model loaded successfully from {save_path}")

Model loaded successfully from /content/models/best_unet_model.pth


In [None]:
import random

def visualize_prediction(model, dataset, device, num_images):
    model.eval()
    with torch.no_grad():
        random_indices = random.sample(range(len(dataset)), num_images)

        for idx in random_indices:
            inputs, targets = dataset[idx]  # get one sample
            inputs = inputs.unsqueeze(0).to(device, dtype=torch.float32)  # add batch dim
            targets = targets.unsqueeze(0).to(device, dtype=torch.float32)

            outputs = model(inputs)
            preds = (torch.sigmoid(outputs) > 0.5).float()

            img = inputs[0, 0].cpu().numpy()
            img = (img * 0.5) + 0.5
            img_uint8 = (img * 255).astype(np.uint8)

            mask = targets[0, 0].cpu().numpy().astype(np.uint8)
            pred = preds[0, 0].cpu().numpy().astype(np.uint8)

            img_rgb = cv.cvtColor(img_uint8, cv.COLOR_GRAY2RGB)
            contours, _ = cv.findContours(pred, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE)
            overlay_img = cv.drawContours(img_rgb.copy(), contours, -1, (255, 0, 0), 2)

            fig, ax = plt.subplots(1, 3, figsize=(18, 6))
            ax[0].imshow(img_uint8, cmap='gray')
            ax[0].set_title("Original Image")
            ax[0].axis('off')

            ax[1].imshow(mask, cmap='gray')
            ax[1].set_title("Ground Truth Mask")
            ax[1].axis('off')

            ax[2].imshow(overlay_img)
            ax[2].set_title("Predicted Mask Overlay")
            ax[2].axis('off')

            plt.tight_layout()
            plt.show()



visualize_prediction(model, val_dataset, device, num_images=16)

Output hidden; open in https://colab.research.google.com to view.

In [None]:
def visualize_test_prediction(model, dataset, device, num_images=4):
    model.eval()
    with torch.no_grad():
        random_indices = random.sample(range(len(dataset)), num_images)

        for idx in random_indices:
            inputs, img_path = dataset[idx]

            inputs = inputs.unsqueeze(0).to(device, dtype=torch.float32)

            outputs = model(inputs)
            preds = (torch.sigmoid(outputs) > 0.5).float()

            img = inputs[0, 0].cpu().numpy()
            img = (img * 0.5) + 0.5
            img_uint8 = (img * 255).astype(np.uint8)

            pred = preds[0, 0].cpu().numpy().astype(np.uint8)

            img_rgb = cv.cvtColor(img_uint8, cv.COLOR_GRAY2RGB)
            contours, _ = cv.findContours(pred, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE)
            overlay_img = cv.drawContours(img_rgb.copy(), contours, -1, (255, 0, 0), 2)

            fig, ax = plt.subplots(1, 2, figsize=(12, 6))
            ax[0].imshow(img_uint8, cmap='gray')
            ax[0].set_title(f"Original: {os.path.basename(img_path)}")
            ax[0].axis('off')

            ax[1].imshow(overlay_img)
            ax[1].set_title("Predicted Mask Overlay")
            ax[1].axis('off')

            plt.tight_layout()
            plt.show()

visualize_test_prediction(model, test_dataset, device, num_images=16)

Output hidden; open in https://colab.research.google.com to view.