In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim import Optimizer
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset
from tqdm import tqdm
from metrics import SegmentationMetrics, MultiClassSegmentationMetrics
from models import UNetBaseline, ResUNet, AttentionUNet, TransUNet
from losses import CombinedLoss, MultiClassDiceLoss, MultiClassCombinedLoss, DiceLoss

##### For Binary Segmentation

In [9]:
NUM_CLASSES = 2
EPSILON = 1e-7
THRESHOLD = 0.5
BATCH_SIZE = 8
HEIGHT, WIDTH = 128, 128
CHANNELS = 3
DEVICE = torch.device("mps" if torch.mps.is_available() else "cpu")

torch.manual_seed(0)

# Metrics
metrics = SegmentationMetrics(threshold=THRESHOLD, epsilon=EPSILON)

# Loss function
# criterion = nn.BCEWithLogitsLoss()
criterion = CombinedLoss(bce_weight=0.5, dice_weight=0.5)

# num_classes = 1 for binary segmentation with BCEWithLogitsLoss
# model = ResUNet(in_channels=CHANNELS, num_classes=1).to(DEVICE)
# model = AttentionUNet(in_channels=CHANNELS, num_classes=1).to(DEVICE)
model = TransUNet(in_channels=CHANNELS, num_classes=1).to(DEVICE)

# 1. Init images and sample masks

# Shape: [batch_size, channels, height, width] : [8, 3, 128, 128]
images = torch.randn(BATCH_SIZE, CHANNELS, HEIGHT, WIDTH).to(DEVICE)
# Shape: [batch_size, height, width] : [8, 128, 128]
masks = torch.randint(0, 2, (BATCH_SIZE, HEIGHT, WIDTH)).long().to(DEVICE)  # {0, 1}

assert images.shape == (BATCH_SIZE, CHANNELS, HEIGHT, WIDTH)
assert masks.shape == (BATCH_SIZE, HEIGHT, WIDTH)

# 2. Forward pass

# Shape: [batch_size, 1, height, width] : [8, 1, 128, 128]
pred = model(images)
assert pred.shape == (BATCH_SIZE, 1, HEIGHT, WIDTH)

# 3. Compute loss
# For binary segmentation, use BCEWithLogitsLoss
pred = pred.squeeze(1)  # Shape: [batch_size, height, width]
loss = criterion(pred, masks.float())  # Should be a float value

# BCEWithLogitsLoss: 0.7238839864730835
# CombinedLoss: 0.5920742750167847

# 4. Compute metrics
pred_probs = F.sigmoid(pred)  # Convert logits to probabilities
iou = metrics.compute_iou(pred_probs, masks)
dice = metrics.compute_dice_score(pred_probs, masks)

print("Loss:", loss.item())
print("IoU:", iou)
print("Dice:", dice)

Loss: 0.5896185040473938
IoU: 0.4748704135417938
Dice: 0.6439486742019653


##### For Multi-class Segmentation

In [3]:
NUM_CLASSES = 3
EPSILON = 1e-7
THRESHOLD = 0.5
BATCH_SIZE = 8
HEIGHT, WIDTH = 128, 128
CHANNELS = 3
DEVICE = torch.device("mps" if torch.mps.is_available() else "cpu")

torch.manual_seed(0)

# Metrics
metrics = MultiClassSegmentationMetrics(num_classes=NUM_CLASSES, epsilon=EPSILON)

# Loss function
# criterion = nn.CrossEntropyLoss()
# criterion = MultiClassDiceLoss(num_classes=NUM_CLASSES, epsilon=EPSILON)
criterion = MultiClassCombinedLoss(num_classes=NUM_CLASSES)

# num_classes = 3 for multi-class segmentation
# model = UNetBaseline(in_channels=CHANNELS, num_classes=NUM_CLASSES).to(DEVICE)
model = ResUNet(in_channels=CHANNELS, num_classes=NUM_CLASSES).to(DEVICE)

# 1. Init images and sample masks
# Shape: [batch_size, channels, height, width] : [8, 3, 128, 128]
images = torch.randn(BATCH_SIZE, CHANNELS, HEIGHT, WIDTH).to(DEVICE)
# Shape: [batch_size, height, width] : [8, 128, 128]
masks = torch.randint(0, NUM_CLASSES, (BATCH_SIZE, HEIGHT, WIDTH)).long().to(DEVICE)

assert images.shape == (BATCH_SIZE, CHANNELS, HEIGHT, WIDTH)
assert masks.shape == (BATCH_SIZE, HEIGHT, WIDTH)

# 2. Forward pass
# Shape: [batch_size, num_classes, height, width] : [8, 3, 128, 128]
pred = model(images)
assert pred.shape == (BATCH_SIZE, NUM_CLASSES, HEIGHT, WIDTH)

# 3. Compute loss
loss = criterion(pred, masks)  # Should be a float value

# CrossEntropyLoss: 1.1529099941253662
# MultiClassDiceLoss: 0.6677639484405518
# MultiClassCombinedLoss: 0.910336971282959

loss.item()

# 4. Compute metrics
iou = metrics.compute_iou(pred, masks)
dice = metrics.compute_dice_score(pred, masks)

print(f"Loss: {loss.item()}")
print(f"IoU: {iou}")
print(f"Dice Score: {dice}")

Bridge shape: torch.Size([8, 512, 8, 8])
Loss: 0.9455935955047607
IoU: 0.1975846290588379
Dice Score: 0.3287372887134552


In [7]:
NUM_CLASSES = 3
EPSILON = 1e-7
THRESHOLD = 0.5
BATCH_SIZE = 8
HEIGHT, WIDTH = 128, 128
CHANNELS = 3
DEVICE = torch.device("mps" if torch.mps.is_available() else "cpu")

# Configuration for TransUNet
PATCH_SIZE = 16
EMBEDDING_DIM = 768
DEPTH = 12
NUM_HEADS = 12

torch.manual_seed(0)

model = TransUNet(
    img_size=HEIGHT,
    in_channels=CHANNELS,
    num_classes=NUM_CLASSES,
    patch_size=PATCH_SIZE,
    embedding_dim=EMBEDDING_DIM,
    depth=DEPTH,
    num_heads=NUM_HEADS,
).to(DEVICE)

# Metrics
metrics = MultiClassSegmentationMetrics(num_classes=NUM_CLASSES, epsilon=EPSILON)

# Loss function
# criterion = nn.CrossEntropyLoss()
# criterion = MultiClassDiceLoss(num_classes=NUM_CLASSES, epsilon=EPSILON)
criterion = MultiClassCombinedLoss(num_classes=NUM_CLASSES)

# 1. Init images and sample masks
# Shape: [batch_size, channels, height, width] : [8, 3, 128, 128]
images = torch.randn(BATCH_SIZE, CHANNELS, HEIGHT, WIDTH).to(DEVICE)
# Shape: [batch_size, height, width] : [8, 128, 128]
masks = torch.randint(0, NUM_CLASSES, (BATCH_SIZE, HEIGHT, WIDTH)).long().to(DEVICE)

assert images.shape == (BATCH_SIZE, CHANNELS, HEIGHT, WIDTH)
assert masks.shape == (BATCH_SIZE, HEIGHT, WIDTH)

# 2. Forward pass
# Shape: [batch_size, num_classes, height, width] : [8, 3, 128, 128]
pred = model(images)
assert pred.shape == (BATCH_SIZE, NUM_CLASSES, HEIGHT, WIDTH)

# 3. Compute loss
loss = criterion(pred, masks)  # Should be a float value

# CrossEntropyLoss: 1.1529099941253662
# MultiClassDiceLoss: 0.6677639484405518
# MultiClassCombinedLoss: 0.910336971282959

loss.item()

# 4. Compute metrics
iou = metrics.compute_iou(pred, masks)
dice = metrics.compute_dice_score(pred, masks)

print(f"Loss: {loss.item()}")
print(f"IoU: {iou}")
print(f"Dice Score: {dice}")

Loss: 0.9132993817329407
IoU: 0.1756775826215744
Dice Score: 0.2911100387573242
