In [121]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim import Optimizer
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset
from tqdm import tqdm

In [127]:
class DualConv(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        mid_channels: int | None = None,
    ):
        super(DualConv, self).__init__()
        if not mid_channels:
            mid_channels = out_channels

        self.sequential = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.sequential(x)


class Down(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super(Down, self).__init__()
        self.sequential = nn.Sequential(
            nn.MaxPool2d(2),
            DualConv(in_channels, out_channels),
        )

    def forward(self, x):
        return self.sequential(x)


class Up(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, bilinear: bool = True):
        super(Up, self).__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
            self.conv = DualConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(
                in_channels,
                in_channels // 2,
                kernel_size=2,
                stride=2,
            )
            self.conv = DualConv(in_channels, out_channels)

    def forward(self, x1, x2):
        """
        x1: from the previous layer - decoder
        x2: from the skip connection - encoder
        """
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]  # height
        diffX = x2.size()[3] - x1.size()[3]  # width

        # pad function: (L, R, T, B)
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])

        # Concatenate along the channels axis
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class UNetBaseline(nn.Module):
    def __init__(self, in_channels: int, num_classes: int):
        super(UNetBaseline, self).__init__()

        # Encoder
        # self.inc = DualConv(in_channels, 64)
        # self.down1 = Down(64, 128)
        # self.down2 = Down(128, 256)
        # self.down3 = Down(256, 512)

        # # Bottleneck
        # self.down4 = Down(512, 1024)

        # # Decoder
        # self.up1 = Up(1024, 512, bilinear=False)
        # self.up2 = Up(512, 256, bilinear=False)
        # self.up3 = Up(256, 128, bilinear=False)
        # self.up4 = Up(128, 64, bilinear=False)

        # Output layer
        # self.outc = nn.Conv2d(64, num_classes, kernel_size=1)

        # Smaller UNet for faster training
        self.inc = DualConv(in_channels, 32)
        self.down1 = Down(32, 64)
        self.down2 = Down(64, 128)
        self.down3 = Down(128, 256)
        self.down4 = Down(256, 512)  # Bottleneck
        self.up1 = Up(512, 256, bilinear=False)
        self.up2 = Up(256, 128, bilinear=False)
        self.up3 = Up(128, 64, bilinear=False)
        self.up4 = Up(64, 32, bilinear=False)
        self.outc = nn.Conv2d(32, num_classes, kernel_size=1)

    def forward(self, x):
        # Encoder with skip connections
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)  # Bottleneck

        # Decoder with skip connections
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)

        return x

In [128]:
class SegmentationMetrics:
    def __init__(self, threshold=0.5, epsilon: float = 1e-7):
        self.threshold = threshold
        self.epsilon = epsilon

    def compute_dice_score(self, pred, target) -> float:
        """
        Compute Dice Score.
        Formula:
            Dice Score = (2 * |A ∩ B|) / (|A| + |B|)

        Args:
            pred: Predicted mask (B, H, W) - logits (after sigmoid)
            target: Ground truth mask (B, H, W) - ground truth binary values {0, 1}

        Returns:
            Dice score value
        """
        pred = (pred > self.threshold).float()  # Binarize predictions

        # Flatten tensors
        pred = pred.view(-1)
        target = target.view(-1)

        intersection = (pred * target).sum()
        dice = (2. * intersection + self.epsilon) / (pred.sum() + target.sum() + self.epsilon)

        return dice.item()

    def compute_iou(self, pred, target) -> float:
        """
        Compute Intersection over Union (IoU).
        Formula:
            IoU = |A ∩ B| / |A ∪ B|

        Args:
            pred: Predicted mask (B, H, W) - logits (after sigmoid)
            target: Ground truth mask (B, H, W) - ground truth binary values {0, 1}
        Returns:
            IoU score value
        """
        pred_binary = (pred > self.threshold).float()
        pred = pred_binary.view(-1)
        target = target.view(-1)

        intersection = (pred * target).sum()
        union = pred.sum() + target.sum() - intersection
        iou = (intersection + self.epsilon) / (union + self.epsilon)
        return iou.item()

    def compute_pixel_accuracy(self, pred, target) -> float:
        pred_binary = (pred > self.threshold).float()
        correct = (pred_binary == target).float().sum()  # type: ignore
        total = target.numel()
        return (correct / total).item()

In [130]:
class DiceLoss(nn.Module):
    def __init__(self, epsilon: float = 1e-7):
        super(DiceLoss, self).__init__()
        self.epsilon = epsilon

    def forward(self, pred, target) -> torch.Tensor:
        """
        Calculate Dice Loss.
        Formula:
            Dice Score = (2 * |A ∩ B|) / (|A| + |B|)
            Dice Loss = 1 - Dice Score

        Args:
            pred: Predicted mask (B, H, W) - logits (before sigmoid)
            target: Ground truth mask (B, H, W) - ground truth binary values {0, 1}
        Returns:
            Dice loss value
        """
        pred = torch.sigmoid(pred)  # Apply sigmoid to get probabilities
        pred = pred.view(-1)
        target = target.view(-1)

        intersection = (pred * target).sum()
        dice = (2. * intersection + self.epsilon) / (pred.sum() + target.sum() + self.epsilon)

        return 1 - dice


class CombinedLoss(nn.Module):
    def __init__(self, bce_weight=0.5, dice_weight=0.5):
        super(CombinedLoss, self).__init__()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight

        self.bce_loss = nn.BCEWithLogitsLoss()
        self.dice_loss = DiceLoss()

    def forward(self, pred, target) -> torch.Tensor:
        bce = self.bce_loss(pred, target.float())
        dice = self.dice_loss(pred, target.float())

        combined_loss = self.bce_weight * bce + self.dice_weight * dice
        return combined_loss

##### For Binary Segmentation

In [136]:
NUM_CLASSES = 2
EPSILON = 1e-7
THRESHOLD = 0.5
BATCH_SIZE = 8
HEIGHT, WIDTH = 128, 128
CHANNELS = 3
DEVICE = torch.device("mps" if torch.mps.is_available() else "cpu")

torch.manual_seed(0)

# Metrics
metrics = SegmentationMetrics(threshold=THRESHOLD, epsilon=EPSILON)

# Loss function
criterion = nn.BCEWithLogitsLoss()
# criterion = CombinedLoss(bce_weight=0.5, dice_weight=0.5)

# num_classes = 1 for binary segmentation with BCEWithLogitsLoss
model = UNetBaseline(in_channels=CHANNELS, num_classes=1).to(DEVICE)

# 1. Init images and sample masks

# Shape: [batch_size, channels, height, width] : [8, 3, 128, 128]
images = torch.randn(BATCH_SIZE, CHANNELS, HEIGHT, WIDTH).to(DEVICE)
# Shape: [batch_size, height, width] : [8, 128, 128]
masks = torch.randint(0, 2, (BATCH_SIZE, HEIGHT, WIDTH)).long().to(DEVICE)  # {0, 1}

assert images.shape == (BATCH_SIZE, CHANNELS, HEIGHT, WIDTH)
assert masks.shape == (BATCH_SIZE, HEIGHT, WIDTH)

# 2. Forward pass

# Shape: [batch_size, 1, height, width] : [8, 1, 128, 128]
pred = model(images)
assert pred.shape == (BATCH_SIZE, 1, HEIGHT, WIDTH)

# 3. Compute loss
# For binary segmentation, use BCEWithLogitsLoss
pred = pred.squeeze(1)  # Shape: [batch_size, height, width]
loss = criterion(pred, masks.float())  # Should be a float value

# BCEWithLogitsLoss: 0.7238839864730835
# CombinedLoss: 0.5920742750167847

# 4. Compute metrics
pred_probs = F.sigmoid(pred)  # Convert logits to probabilities
iou = metrics.compute_iou(pred_probs, masks)
dice = metrics.compute_dice_score(pred_probs, masks)
pixel_acc = metrics.compute_pixel_accuracy(pred_probs, masks)

print(f"Loss: {loss.item()}")
print(f"IoU: {iou}")
print(f"Dice Score: {dice}")
print(f"Pixel Accuracy: {pixel_acc}")

Loss: 0.7238839864730835
IoU: 0.4617643356323242
Dice Score: 0.6317903995513916
Pixel Accuracy: 0.49936676025390625


In [146]:
class MultiClassDiceLoss(nn.Module):
    def __init__(self, num_classes: int, epsilon: float = 1e-7):
        super(MultiClassDiceLoss, self).__init__()
        self.num_classes = num_classes
        self.epsilon = epsilon

    def forward(self, pred, target):
        """
        Compute Dice Loss for multi-class segmentation.
        Formula:
            Dice Score = (2 * |A ∩ B|) / (|A| + |B|)
            Dice Loss = 1 - Dice Score
        Args:
            pred: Predicted mask (B, C, H, W) - logits (before softmax)
            target: Ground truth mask (B, H, W) - ground truth class indices {0, 1, ..., C-1}
        """
        # Convert logits to probabilities
        pred = F.softmax(pred, dim=1)  # [B, C, H, W]

        # Convert target to one-hot encoding
        target_oh = F.one_hot(target, self.num_classes)  # [B, H, W, C]
        target_oh = target_oh.permute(0, 3, 1, 2).float()  # [B, C, H, W]

        # Compute Dice for each class
        dice_scores = []
        for cls in range(self.num_classes):
            pred_cls = pred[:, cls, :, :].contiguous().view(-1)
            target_cls = target_oh[:, cls, :, :].contiguous().view(-1)

            intersection = (pred_cls * target_cls).sum()
            dice = (2. * intersection + self.epsilon) / (pred_cls.sum() + target_cls.sum() + self.epsilon)
            dice_scores.append(dice)

        # Average Dice loss over all classes
        mean_dice = torch.stack(dice_scores).mean()
        return 1 - mean_dice


class MultiClassCombinedLoss(nn.Module):
    def __init__(self, num_classes: int, ce_weight=0.5, dice_weight=0.5):
        super(MultiClassCombinedLoss, self).__init__()
        self.ce_weight = ce_weight
        self.dice_weight = dice_weight

        self.ce_loss = nn.CrossEntropyLoss()
        self.dice_loss = MultiClassDiceLoss(num_classes)

    def forward(self, pred, target):
        ce = self.ce_loss(pred, target)
        dice = self.dice_loss(pred, target)

        combined_loss = self.ce_weight * ce + self.dice_weight * dice
        return combined_loss

In [153]:
class MultiClassSegmentationMetrics:
    def __init__(self, num_classes: int, epsilon: float = 1e-7):
        self.num_classes = num_classes
        self.epsilon = epsilon

    def compute_dice_score(self, pred, target) -> float:
        """
        Compute average Dice Score for multi-class segmentation.
        Args:
            pred: Predicted mask (B, C, H, W) - logits (after softmax)
            target: Ground truth mask (B, H, W) - ground truth class indices {0, 1, ..., C-1}
        """
        pred_classes = torch.argmax(pred, dim=1)  # [B, H, W]

        # Convert pred and target to one-hot encoding
        # [B, H, W, C] -> [B, C, H, W]
        pred_oh = nn.functional.one_hot(pred_classes, self.num_classes)
        pred_oh = pred_oh.permute(0, 3, 1, 2).float()

        target_oh = nn.functional.one_hot(target, self.num_classes)
        target_oh = target_oh.permute(0, 3, 1, 2).float()

        # Compute Dice for each class
        dice_scores = []
        for cls in range(self.num_classes):
            pred_cls = pred_oh[:, cls, :, :].contiguous().view(-1)
            target_cls = target_oh[:, cls, :, :].contiguous().view(-1)

            intersection = (pred_cls * target_cls).sum()
            dice = (2. * intersection + self.epsilon) / (pred_cls.sum() + target_cls.sum() + self.epsilon)
            dice_scores.append(dice)

        # Average Dice score over all classes
        mean_dice = torch.stack(dice_scores).mean()
        return mean_dice.item()

    def compute_iou(self, pred, target) -> float:
        """
        Compute average IoU for multi-class segmentation.
        Args:
            pred: Predicted mask (B, C, H, W) - logits (after softmax)
            target: Ground truth mask (B, H, W) - ground truth class indices {0, 1, ..., C-1}
        """
        pred_classes = torch.argmax(pred, dim=1)  # [B, H, W]

        # Convert pred and target to one-hot encoding
        pred_oh = nn.functional.one_hot(pred_classes, self.num_classes)
        pred_oh = pred_oh.permute(0, 3, 1, 2).float()

        target_oh = nn.functional.one_hot(target, self.num_classes)
        target_oh = target_oh.permute(0, 3, 1, 2).float()

        # Compute IoU for each class
        iou_scores = []
        for cls in range(self.num_classes):
            pred_cls = pred_oh[:, cls, :, :].contiguous().view(-1)
            target_cls = target_oh[:, cls, :, :].contiguous().view(-1)

            intersection = (pred_cls * target_cls).sum()
            union = pred_cls.sum() + target_cls.sum() - intersection
            iou = (intersection + self.epsilon) / (union + self.epsilon)
            iou_scores.append(iou)

        # Average IoU over all classes
        mean_iou = torch.stack(iou_scores).mean()
        return mean_iou.item()

##### For Multi-class Segmentation

In [155]:
NUM_CLASSES = 3
EPSILON = 1e-7
THRESHOLD = 0.5
BATCH_SIZE = 8
HEIGHT, WIDTH = 128, 128
CHANNELS = 3
DEVICE = torch.device("mps" if torch.mps.is_available() else "cpu")

torch.manual_seed(0)

# Metrics
metrics = MultiClassSegmentationMetrics(num_classes=NUM_CLASSES, epsilon=EPSILON)

# Loss function
# criterion = nn.CrossEntropyLoss()
# criterion = MultiClassDiceLoss(num_classes=NUM_CLASSES, epsilon=EPSILON)
criterion = MultiClassCombinedLoss(num_classes=NUM_CLASSES)

# num_classes = 3 for multi-class segmentation
model = UNetBaseline(in_channels=CHANNELS, num_classes=NUM_CLASSES).to(DEVICE)

# 1. Init images and sample masks
# Shape: [batch_size, channels, height, width] : [8, 3, 128, 128]
images = torch.randn(BATCH_SIZE, CHANNELS, HEIGHT, WIDTH).to(DEVICE)
# Shape: [batch_size, height, width] : [8, 128, 128]
masks = torch.randint(0, NUM_CLASSES, (BATCH_SIZE, HEIGHT, WIDTH)).long().to(DEVICE)

assert images.shape == (BATCH_SIZE, CHANNELS, HEIGHT, WIDTH)
assert masks.shape == (BATCH_SIZE, HEIGHT, WIDTH)

# 2. Forward pass
# Shape: [batch_size, num_classes, height, width] : [8, 3, 128, 128]
pred = model(images)
assert pred.shape == (BATCH_SIZE, NUM_CLASSES, HEIGHT, WIDTH)

# 3. Compute loss
loss = criterion(pred, masks)  # Should be a float value

# CrossEntropyLoss: 1.1529099941253662
# MultiClassDiceLoss: 0.6677639484405518
# MultiClassCombinedLoss: 0.910336971282959

loss.item()

# 4. Compute metrics
iou = metrics.compute_iou(pred, masks)
dice = metrics.compute_dice_score(pred, masks)

print(f"Loss: {loss.item()}")
print(f"IoU: {iou}")
print(f"Dice Score: {dice}")

Loss: 0.910336971282959
IoU: 0.19103892147541046
Dice Score: 0.31671738624572754
