In [109]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim import Optimizer
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset
from tqdm import tqdm

In [31]:
class DualConv(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        mid_channels: int | None = None,
    ):
        super(DualConv, self).__init__()
        if not mid_channels:
            mid_channels = out_channels

        self.sequential = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.sequential(x)


class Down(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super(Down, self).__init__()
        self.sequential = nn.Sequential(
            nn.MaxPool2d(2),
            DualConv(in_channels, out_channels),
        )

    def forward(self, x):
        return self.sequential(x)


class Up(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, bilinear: bool = True):
        super(Up, self).__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
            self.conv = DualConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(
                in_channels,
                in_channels // 2,
                kernel_size=2,
                stride=2,
            )
            self.conv = DualConv(in_channels, out_channels)

    def forward(self, x1, x2):
        """
        x1: from the previous layer - decoder
        x2: from the skip connection - encoder
        """
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]  # height
        diffX = x2.size()[3] - x1.size()[3]  # width

        # pad function: (L, R, T, B)
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])

        # Concatenate along the channels axis
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class UNetBaseline(nn.Module):
    def __init__(self, in_channels: int, num_classes: int):
        super(UNetBaseline, self).__init__()

        # Encoder
        # self.inc = DualConv(in_channels, 64)
        # self.down1 = Down(64, 128)
        # self.down2 = Down(128, 256)
        # self.down3 = Down(256, 512)

        # # Bottleneck
        # self.down4 = Down(512, 1024)

        # # Decoder
        # self.up1 = Up(1024, 512, bilinear=False)
        # self.up2 = Up(512, 256, bilinear=False)
        # self.up3 = Up(256, 128, bilinear=False)
        # self.up4 = Up(128, 64, bilinear=False)

        # Output layer
        # self.outc = nn.Conv2d(64, num_classes, kernel_size=1)

        # Smaller UNet for faster training
        self.inc = DualConv(in_channels, 32)
        self.down1 = Down(32, 64)
        self.down2 = Down(64, 128)
        self.down3 = Down(128, 256)
        self.down4 = Down(256, 512)  # Bottleneck
        self.up1 = Up(512, 256, bilinear=False)
        self.up2 = Up(256, 128, bilinear=False)
        self.up3 = Up(128, 64, bilinear=False)
        self.up4 = Up(64, 32, bilinear=False)
        self.outc = nn.Conv2d(32, num_classes, kernel_size=1)

    def forward(self, x):
        # Encoder with skip connections
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)  # Bottleneck

        # Decoder with skip connections
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)

        return x

##### Loss Functions

In [101]:
class DiceLoss(nn.Module):
    def __init__(self, epsilon: float = 1e-7):
        super(DiceLoss, self).__init__()
        self.epsilon = epsilon

    def forward(self, pred, target) -> torch.Tensor:
        """
        Calculate Dice Loss.
        Formula:
            Dice Score = (2 * |A ∩ B|) / (|A| + |B|)
            Dice Loss = 1 - Dice Score

        Args:
            pred: Predicted mask (B, H, W) - logits (before sigmoid)
            target: Ground truth mask (B, H, W) - ground truth binary values {0, 1}
        Returns:
            Dice loss value
        """
        pred = torch.sigmoid(pred)  # Apply sigmoid to get probabilities
        pred = pred.view(-1)
        target = target.view(-1)

        intersection = (pred * target).sum()
        dice = (2. * intersection + self.epsilon) / (pred.sum() + target.sum() + self.epsilon)

        return 1 - dice

In [99]:
class CombinedLoss(nn.Module):
    def __init__(self, bce_weight=0.5, dice_weight=0.5):
        super(CombinedLoss, self).__init__()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight

        self.bce_loss = nn.BCEWithLogitsLoss()
        self.dice_loss = DiceLoss()

    def forward(self, pred, target) -> torch.Tensor:
        bce = self.bce_loss(pred, target.float())
        dice = self.dice_loss(pred, target.float())

        combined_loss = self.bce_weight * bce + self.dice_weight * dice
        return combined_loss

##### Segmentation Metrics

In [72]:
class SegmentationMetrics:
    def __init__(self, threshold=0.5, epsilon: float = 1e-7):
        self.threshold = threshold
        self.epsilon = epsilon

    def compute_dice_score(self, pred, target) -> float:
        """
        Compute Dice Score.
        Formula:
            Dice Score = (2 * |A ∩ B|) / (|A| + |B|)

        Args:
            pred: Predicted mask (B, H, W) - logits (after sigmoid)
            target: Ground truth mask (B, H, W) - ground truth binary values {0, 1}

        Returns:
            Dice score value
        """
        pred = (pred > self.threshold).float()  # Binarize predictions

        # Flatten tensors
        pred = pred.view(-1)
        target = target.view(-1)

        intersection = (pred * target).sum()
        dice = (2. * intersection + self.epsilon) / (pred.sum() + target.sum() + self.epsilon)

        return dice.item()

    def compute_iou(self, pred, target) -> float:
        """
        Compute Intersection over Union (IoU).
        Formula:
            IoU = |A ∩ B| / |A ∪ B|

        Args:
            pred: Predicted mask (B, H, W) - logits (after sigmoid)
            target: Ground truth mask (B, H, W) - ground truth binary values {0, 1}
        Returns:
            IoU score value
        """
        pred_binary = (pred > self.threshold).float()
        pred = pred_binary.view(-1)
        target = target.view(-1)

        intersection = (pred * target).sum()
        union = pred.sum() + target.sum() - intersection
        iou = (intersection + self.epsilon) / (union + self.epsilon)
        return iou.item()

    def compute_pixel_accuracy(self, pred, target) -> float:
        pred_binary = (pred > self.threshold).float()
        correct = (pred_binary == target).float().sum()  # type: ignore
        total = target.numel()
        return (correct / total).item()

In [75]:
def train_one_epoch(
    mode: nn.Module,
    dataloader: DataLoader,
    criterion: nn.Module,
    optimizer: Optimizer,
    device: torch.device,
    metrics: SegmentationMetrics,
):
    mode.train()
    epoch_loss = 0.0
    epoch_dice = 0.0
    epoch_iou = 0.0
    num_batches = 0
    epoch_acc = 0.0

    train_progress = tqdm(dataloader, colour="blue", desc="Training", ncols=100)
    for batch_idx, (images, masks) in enumerate(train_progress):
        # Convert to device
        images = images.to(device)
        masks = masks.to(device)

        # Forward pass
        optimizer.zero_grad()
        pred = mode(images)  # Shape: [B, 1, H, W]
        pred = pred.squeeze(1)  # Shape: [B, H, W]

        # Compute loss
        loss = criterion(pred, masks.float())

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Compute metrics
        with torch.no_grad():
            pred_probs = torch.sigmoid(pred)
            dice = metrics.compute_dice_score(pred_probs, masks)
            iou = metrics.compute_iou(pred_probs, masks)
            acc = metrics.compute_pixel_accuracy(pred_probs, masks)

        epoch_loss += loss.item()
        epoch_dice += dice
        epoch_iou += iou
        num_batches += 1

        msg = (
            f"Training | Loss: {epoch_loss / num_batches:.4f} | "
            f"Dice: {epoch_dice / num_batches:.4f} | "
            f"IoU: {epoch_iou / num_batches:.4f} | "
            f"Acc: {epoch_acc / num_batches:.4f}"
        )
        train_progress.set_description(msg)

    return {
        "loss": epoch_loss / num_batches,
        "dice": epoch_dice / num_batches,
        "iou": epoch_iou / num_batches,
        "accuracy": epoch_acc / num_batches
    }

In [114]:
def validate(
    model: nn.Module,
    dataloader: DataLoader,
    criterion: nn.Module,
    device: torch.device,
    metrics: SegmentationMetrics,
):
    model.eval()
    epoch_loss = 0.0
    epoch_dice = 0.0
    epoch_iou = 0.0
    epoch_acc = 0.0
    num_batches = 0

    with torch.no_grad():
        val_progress = tqdm(dataloader, colour="blue", desc="Validation", ncols=100)
        for batch_idx, (images, masks) in enumerate(val_progress):
            images = images.to(device)
            masks = masks.to(device)

            # Forward pass
            pred_logits = model(images)  # [B, 1, H, W]
            pred_logits = pred_logits.squeeze(1)  # [B, H, W]

            # Compute loss
            loss = criterion(pred_logits, masks.float())

            # Compute metrics
            pred_probs = torch.sigmoid(pred_logits)
            dice = metrics.compute_dice_score(pred_probs, masks)
            iou = metrics.compute_iou(pred_probs, masks)
            acc = metrics.compute_pixel_accuracy(pred_probs, masks)

            epoch_loss += loss.item()
            epoch_dice += dice
            epoch_iou += iou
            epoch_acc += acc
            num_batches += 1

            msg = (
                f"Validation | Loss: {epoch_loss / num_batches:.4f} | "
                f"Dice: {epoch_dice / num_batches:.4f} | "
                f"IoU: {epoch_iou / num_batches:.4f} | "
                f"Acc: {epoch_acc / num_batches:.4f}"
            )
            
            val_progress.set_description(msg)

    return {
        "loss": epoch_loss / num_batches,
        "dice": epoch_dice / num_batches,
        "iou": epoch_iou / num_batches,
        "accuracy": epoch_acc / num_batches
    }

In [116]:
def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    criterion: nn.Module,
    optimizer: Optimizer,
    device: torch.device,
    metrics: SegmentationMetrics,
    num_epochs: int = 50,
    scheduler: StepLR | None = None,
):
    best_val_dice = 0.0
    history = {
        "train_loss": [], "train_dice": [], "train_iou": [], "train_acc": [],
        "val_loss": [], "val_dice": [], "val_iou": [], "val_acc": []
    }

    for epoch in range(num_epochs):
        train_metrics = train_one_epoch(
            model,
            train_loader,
            criterion,
            optimizer,
            device,
            metrics,
        )

        val_metrics = validate(
            model,
            val_loader,
            criterion,
            device,
            metrics,
        )

        # Learning rate scheduler step
        if scheduler:
            scheduler.step()

        # Save history
        history["train_loss"].append(train_metrics["loss"])
        history["train_dice"].append(train_metrics["dice"])
        history["train_iou"].append(train_metrics["iou"])
        history["train_acc"].append(train_metrics["accuracy"])

        history["val_loss"].append(val_metrics["loss"])
        history["val_dice"].append(val_metrics["dice"])
        history["val_iou"].append(val_metrics["iou"])
        history["val_acc"].append(val_metrics["accuracy"])

        # Save last model
        torch.save(model.state_dict(), "last.pth")

        # Save best model
        if val_metrics["dice"] > best_val_dice:
            best_val_dice = val_metrics['dice']
            torch.save(model.state_dict(), 'best.pth')

    return history

In [120]:
NUM_CLASSES = 2
BATCH_SIZE = 8
HEIGHT, WIDTH = 128, 128
CHANNELS = 3
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4


class DummyDataset(Dataset):
    def __init__(self, size=1000):
        self.size = size

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        image = torch.randn(CHANNELS, HEIGHT, WIDTH)
        mask = torch.randint(0, 2, (HEIGHT, WIDTH))
        return image, mask


DEVICE = torch.device("mps" if torch.mps.is_available() else "cpu")
print(f"Using device: {DEVICE}")

train_dataset = DummyDataset(size=200)
val_dataset = DummyDataset(size=40)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Initialize model
model = UNetBaseline(in_channels=CHANNELS, num_classes=1).to(DEVICE)

# Initialize loss function
# criterion = nn.BCEWithLogitsLoss()
# criterion = DiceLoss()
criterion = CombinedLoss(bce_weight=0.5, dice_weight=0.5)

# Optimizer & Scheduler
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

# Metrics
metrics = SegmentationMetrics(threshold=0.5)

history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=DEVICE,
    metrics=metrics,
    num_epochs=NUM_EPOCHS,
    scheduler=scheduler
)

print(f"Best Validation Dice Score: {max(history["val_dice"]):.4f}")
print(f"Best Validation IoU: {max(history["val_iou"]):.4f}")

Using device: mps


Validation | Loss: 0.5989 | Dice: 0.4368 | IoU: 0.2794 | Acc: 0.5002: 100%|[34m█[0m| 5/5 [00:00<00:00, 21.4[0m
Validation | Loss: 0.6025 | Dice: 0.4279 | IoU: 0.2722 | Acc: 0.4999: 100%|[34m█[0m| 5/5 [00:00<00:00, 21.2[0m
Validation | Loss: 0.5982 | Dice: 0.4993 | IoU: 0.3327 | Acc: 0.5005: 100%|[34m█[0m| 5/5 [00:00<00:00, 20.7[0m
Validation | Loss: 0.5933 | Dice: 0.5872 | IoU: 0.4156 | Acc: 0.5001: 100%|[34m█[0m| 5/5 [00:00<00:00, 21.9[0m
Validation | Loss: 0.5898 | Dice: 0.6386 | IoU: 0.4690 | Acc: 0.5000: 100%|[34m█[0m| 5/5 [00:00<00:00, 19.6[0m
Validation | Loss: 0.5884 | Dice: 0.6569 | IoU: 0.4891 | Acc: 0.5004: 100%|[34m█[0m| 5/5 [00:00<00:00,  8.8[0m
Validation | Loss: 0.5871 | Dice: 0.6654 | IoU: 0.4986 | Acc: 0.4999: 100%|[34m█[0m| 5/5 [00:00<00:00,  6.6[0m
Validation | Loss: 0.5863 | Dice: 0.6668 | IoU: 0.5001 | Acc: 0.5002: 100%|[34m█[0m| 5/5 [00:00<00:00,  6.9[0m
Validation | Loss: 0.5860 | Dice: 0.6667 | IoU: 0.5000 | Acc: 0.5000: 100%|[34m█[0m| 5

Best Validation Dice Score: 0.6674
Best Validation IoU: 0.5008
