In [None]:
import os
import glob
import random
import numpy as np
import cv2

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

from unet import UNet
from data_core import myDataset

In [None]:
dim = 256
num_classes = 3  # Trimap {1,2,3}
batch_size = 16
num_epochs = 100

train_images_dir = f'./trainval_{dim}/images'
train_masks_dir  = f'./trainval_{dim}/annotations'

image_paths = sorted(glob.glob(os.path.join(train_images_dir, '*.png')))
mask_paths  = sorted(glob.glob(os.path.join(train_masks_dir, '*.png')))

print(f'Found {len(image_paths)} images and {len(mask_paths)} masks')

Found 3680 images and 3680 masks


In [7]:
train_img_paths, val_img_paths, train_mask_paths, val_mask_paths = train_test_split(
    image_paths, mask_paths, test_size=0.1, random_state=42, shuffle=True
)
print(f"Training images: {len(train_img_paths)}  Validation images: {len(val_img_paths)}")

# Create dataset instances.
train_dataset = myDataset(train_img_paths, train_mask_paths, dim)
val_dataset   = myDataset(val_img_paths, val_mask_paths, dim)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

Training images: 3312  Validation images: 368


In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")

Using device cuda:0


In [10]:
model = UNet(num_classes=num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()  # expects logits [B, C, H, W] and target [B, H, W]

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    epoch_loss = 0.0
    for images, masks in loader:
        images = images.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)  # (B, num_classes, H, W)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item() * images.size(0)
    return epoch_loss / len(loader.dataset)

def evaluate(model, loader, criterion, device):
    model.eval()
    epoch_loss = 0.0
    correct = 0
    total_pixels = 0
    with torch.no_grad():
        for images, masks in loader:
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            epoch_loss += loss.item() * images.size(0)
            
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == masks).sum().item()
            total_pixels += torch.numel(masks)
    accuracy = correct / total_pixels
    return epoch_loss / len(loader.dataset), accuracy

In [None]:
best_val_loss = float('inf')
for epoch in tqdm(range(1, num_epochs+1), desc="Training Epochs"):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)
    tqdm.write(f"\rEpoch [{epoch}/{num_epochs}]  Train Loss: {train_loss:.4f}  "
               f"Val Loss: {val_loss:.4f}  Val Pixel Acc: {val_acc*100:.2f}%")
    
    # Save the best model.
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), f"unet_weights/unet_model_{dim}_epochs_{num_epochs}.pth")

# Save the trained model.
torch.save(model.state_dict(), f"unet_weights/unet_model_{dim}_epochs_{num_epochs}.pth")

Training Epochs:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch [1/500]  Train Loss: 0.7363  Val Loss: 0.6959  Val Pixel Acc: 71.82%
Epoch [2/500]  Train Loss: 0.6264  Val Loss: 0.5947  Val Pixel Acc: 76.13%
Epoch [3/500]  Train Loss: 0.5895  Val Loss: 0.5734  Val Pixel Acc: 77.41%
Epoch [4/500]  Train Loss: 0.5436  Val Loss: 0.5412  Val Pixel Acc: 78.22%
Epoch [5/500]  Train Loss: 0.5175  Val Loss: 0.5018  Val Pixel Acc: 80.08%
Epoch [6/500]  Train Loss: 0.4870  Val Loss: 0.4749  Val Pixel Acc: 81.17%
Epoch [7/500]  Train Loss: 0.4720  Val Loss: 0.5171  Val Pixel Acc: 79.39%
Epoch [8/500]  Train Loss: 0.4591  Val Loss: 0.4928  Val Pixel Acc: 80.53%
Epoch [9/500]  Train Loss: 0.4400  Val Loss: 0.4744  Val Pixel Acc: 81.30%
Epoch [10/500]  Train Loss: 0.4273  Val Loss: 0.4641  Val Pixel Acc: 81.65%
Epoch [11/500]  Train Loss: 0.4074  Val Loss: 0.4485  Val Pixel Acc: 82.04%
Epoch [12/500]  Train Loss: 0.3953  Val Loss: 0.4528  Val Pixel Acc: 82.31%
Epoch [13/500]  Train Loss: 0.3801  Val Loss: 0.4670  Val Pixel Acc: 81.87%
Epoch [14/500]  Train