# 02. Train U-Net
This notebook trains the U-Net model on the COCO dataset.

In [None]:
import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

sys.path.append('../src')
from dataset import COCOSegmentationDataset, get_training_augmentations, get_validation_augmentations
from model import UNet
from utils import visualize

In [8]:
def dice_loss(pred_logits, target, smooth = 1.):
    # Apply sigmoid to logits to get probabilities for Dice Loss
    pred = torch.sigmoid(pred_logits)
    
    pred = pred.contiguous()
    target = target.contiguous()

    intersection = (pred * target).sum(dim=2).sum(dim=2)
    
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
    
    return loss.mean()

In [9]:
# Hyperparameters
LEARNING_RATE = 1e-4
BATCH_SIZE = 4 # Adjust based on GPU memory
NUM_EPOCHS = 10
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = '../data/coco2017/train2017'
TRAIN_ANN = '../data/coco2017/annotations/instances_train2017.json'
VAL_DIR = '../data/coco2017/val2017'
VAL_ANN = '../data/coco2017/annotations/instances_val2017.json'

print(f"Using device: {DEVICE}")

Using device: cuda


In [10]:
# Dataset & DataLoader
# Using separate datasets for Train and Val as per folder structure
train_dataset = COCOSegmentationDataset(TRAIN_DIR, TRAIN_ANN, transforms=get_training_augmentations())
val_dataset = COCOSegmentationDataset(VAL_DIR, VAL_ANN, transforms=get_validation_augmentations())

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Train size: {len(train_dataset)}, Val size: {len(val_dataset)}")

  original_init(self, **validated_kwargs)


loading annotations into memory...
Done (t=14.06s)
creating index...
index created!
loading annotations into memory...
Done (t=0.62s)
creating index...
index created!
Train size: 117266, Val size: 4952


In [11]:
# Model Setup
model = UNet(n_channels=3, n_classes=1).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
# BCEWithLogitsLoss combines Sigmoid and BCELoss for numerical stability
bce_criterion = nn.BCEWithLogitsLoss()

# Updated for newer PyTorch versions
scaler = torch.amp.GradScaler('cuda') # For Mixed Precision

In [12]:
# Training Loop
for epoch in range(NUM_EPOCHS):
    model.train()
    loop = tqdm(train_loader, leave=True)
    train_loss = 0
    
    for idx, (data, targets) in enumerate(loop):
        data = data.to(DEVICE)
        targets = targets.to(DEVICE)

        # Forward
        with torch.amp.autocast('cuda'):
            logits = model(data)
            # Loss calculated on logits
            loss = bce_criterion(logits, targets) + dice_loss(logits, targets)

        # Backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Update tqdm
        train_loss += loss.item()
        loop.set_description(f"Epoch [{epoch+1}/{NUM_EPOCHS}]")
        loop.set_postfix(loss=loss.item())
    
    print(f"Epoch {epoch+1} Average Loss: {train_loss/len(train_loader):.4f}")
    
    # Save Model
    if (epoch + 1) % 5 == 0:
        torch.save(model.state_dict(), f"unet_coco_epoch_{epoch+1}.pth")
        print("Model saved!")

Epoch [1/10]: 100%|██████████| 29317/29317 [2:56:58<00:00,  2.76it/s, loss=1.09]   


Epoch 1 Average Loss: 1.3099


Epoch [2/10]: 100%|██████████| 29317/29317 [2:58:13<00:00,  2.74it/s, loss=1.16]   


Epoch 2 Average Loss: 1.2838


Epoch [3/10]: 100%|██████████| 29317/29317 [2:57:01<00:00,  2.76it/s, loss=1.03]   


Epoch 3 Average Loss: 1.2765


Epoch [4/10]: 100%|██████████| 29317/29317 [2:54:42<00:00,  2.80it/s, loss=1.28]   


Epoch 4 Average Loss: 1.2718


Epoch [5/10]: 100%|██████████| 29317/29317 [2:41:55<00:00,  3.02it/s, loss=1.04]   


Epoch 5 Average Loss: 1.2685
Model saved!


Epoch [6/10]:  29%|██▊       | 8417/29317 [45:59<1:54:12,  3.05it/s, loss=1.32] 


KeyboardInterrupt: 

In [13]:
checkpoint = {
    'epoch': epoch + 1,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(), # Saves momentum
    'loss': train_loss,
}
torch.save(checkpoint, "full_checkpoint.pth")