# Libraries

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset
import os
import albumentations as A
import numpy as np
from PIL import Image
from tqdm import tqdm
from albumentations.pytorch import ToTensorV2
import torch
import os

# DoubleConv Module

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

# Unet Architecture

In [8]:
class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Downsampling path
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Upsampling path
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2)
            )
            self.ups.append(DoubleConv(feature * 2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)  # ConvTranspose2d
            skip_connection = skip_connections[idx // 2]

            # Resize x to match skip_connection if needed
            if x.shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:], mode="bilinear", align_corners=True)

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx + 1](concat_skip)  # DoubleConv

        return self.final_conv(x)

# Testing of the input and output shape

In [None]:
def test():
    x = torch.randn((3, 1, 161, 161))  # Batch size=3, 1-channel, 161x161 image
    model = UNET(in_channels=1, out_channels=1)
    preds = model(x)

    print(f"Input shape: {x.shape}")
    print(f"Output shape: {preds.shape}")

    assert preds.shape == x.shape, "Output shape does not match input shape!"

if __name__ == "__main__":
    test()

Input shape: torch.Size([3, 1, 161, 161])
Output shape: torch.Size([3, 1, 161, 161])


# Dataset

In [2]:
class SegmentationDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(img_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx])

        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)

        # Normalize mask (convert 255 → 1)
        mask[mask == 255.0] = 1.0

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image, mask = augmented["image"], augmented["mask"]

        return image, mask

# Train

In [None]:
 from dataset import SegmentationDataset
 from utils import (
     load_checkpoint,
     save_checkpoint,
     get_loaders,
     check_accuracy,
     save_predictions_as_imgs,
 )
# Hyperparameters
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 15
NUM_WORKERS = 2
IMAGE_WIDTH = 1024  # Original: 3840
IMAGE_HEIGHT = 1024  # Original: 2160
PIN_MEMORY = True
LOAD_MODEL = False

# Paths (Update these before running)
TRAIN_IMG_DIR =
TRAIN_MASK_DIR =
VAL_IMG_DIR =
VAL_MASK_DIR =

# Training function
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # Forward pass
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # Backward pass
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Update tqdm loop
        loop.set_postfix(loss=loss.item())

# Main function
def main():
    train_transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ])

    val_transforms = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ])

    model = UNET(in_channels=3, out_channels=1).to(DEVICE)  # Single Class Segmentation
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scaler = torch.cuda.amp.GradScaler()

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transforms,
        NUM_WORKERS,
        PIN_MEMORY,
    )

    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

if __name__ == "__main__":
    main()

In [6]:
# Save model and check accuracy
checkpoint = {
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
save_checkpoint(checkpoint)

# check_accuracy(val_loader, model, device=DEVICE)
check_accuracy(val_loader, model, device=DEVICE)

# save_predictions_as_imgs(val_loader, model, folder="saved_predictions/", device=DEVICE)
save_predictions_as_imgs(val_loader, model, folder="saved_predictions/", device=DEVICE)

#Utils file

In [None]:
from dataset import SegmentationDataset

def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    """Saves model checkpoint."""
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    """Loads model checkpoint."""
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

def get_loaders(
    train_dir, train_maskdir, val_dir, val_maskdir, batch_size,
    train_transform, val_transform, num_workers=4, pin_memory=True
):
    """Creates and returns data loaders for training and validation."""

    train_ds = SegmentationDataset(
        image_dir=train_dir,
        mask_dir=train_maskdir,
        transform=train_transform
    )
    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True
    )

    val_ds = SegmentationDataset(
        image_dir=val_dir,
        mask_dir=val_maskdir,
        transform=val_transform
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False
    )

    return train_loader, val_loader

def check_accuracy(loader, model, device="cuda"):
    """Checks model accuracy and Dice score on the dataset."""

    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)  # Ensure proper shape

            preds = torch.sigmoid(model(x))  # Apply sigmoid activation
            preds = (preds > 0.5).float()  # Convert to binary mask

            num_correct += (preds == y).sum().item()
            num_pixels += preds.numel()
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)

    acc = (num_correct / num_pixels) * 100
    print(f"Got {num_correct}/{num_pixels} pixels correct with accuracy {acc:.2f}%")
    print(f"Dice score: {dice_score / len(loader):.4f}")

    model.train()  # Set back to training mode

    return acc, dice_score / len(loader)

def save_predictions_as_imgs(loader, model, folder="saved_images/", device="cuda"):
    """Saves model predictions as images."""

    if not os.path.exists(folder):
        os.makedirs(folder)

    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()

        torchvision.utils.save_image(preds, f"{folder}/pred_{idx}.png")
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}/{idx}.png")

    model.train()