In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from dataset import PetSegmentationDataset
from models import UNet
from trainer import SegmentationTrainer
import numpy as np
from tqdm.notebook import tqdm

class Config:
    # Data configs
    DATA_ROOT = './Dataset'
    IMG_SIZE = (256, 256)
    BATCH_SIZE = 8
    NUM_WORKERS = 2
    
    # Model configs
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    NUM_CLASSES = 3  # background, cat, dog
    
    # Training configs
    LEARNING_RATE = 3e-4
    NUM_EPOCHS = 100
    EARLY_STOPPING_PATIENCE = 10
    WEIGHT_DECAY = 1e-5

print(f"Using device: {Config.DEVICE}")

  Referenced from: <EB3FF92A-5EB1-3EE8-AF8B-5923C1265422> /opt/anaconda3/envs/torch/lib/python3.11/site-packages/torchvision/image.so
  warn(


Using device: cpu


In [2]:
def create_dataloaders():
    # Create datasets
    train_dataset = PetSegmentationDataset(
        Config.DATA_ROOT,
        split='train',
        img_size=Config.IMG_SIZE,
        augment=True
    )
    
    val_dataset = PetSegmentationDataset(
        Config.DATA_ROOT,
        split='val',
        img_size=Config.IMG_SIZE,
        augment=False
    )
    
    test_dataset = PetSegmentationDataset(
        Config.DATA_ROOT,
        split='test',
        img_size=Config.IMG_SIZE,
        augment=False
    )
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=Config.BATCH_SIZE,
        shuffle=True,
        num_workers=Config.NUM_WORKERS,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=Config.BATCH_SIZE,
        shuffle=False,
        num_workers=Config.NUM_WORKERS,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=Config.BATCH_SIZE,
        shuffle=False,
        num_workers=Config.NUM_WORKERS,
        pin_memory=True
    )
    
    return train_loader, val_loader, test_loader

train_loader, val_loader, test_loader = create_dataloaders()
print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")
print(f"Number of test batches: {len(test_loader)}")

Number of training batches: 460
Number of validation batches: 460
Number of test batches: 464


In [3]:
model = UNet(n_channels=3, n_classes=Config.NUM_CLASSES).to(Config.DEVICE)
model._init_weights()

# Create optimizer and loss function
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=Config.LEARNING_RATE,
    weight_decay=Config.WEIGHT_DECAY
)
criterion = nn.CrossEntropyLoss()

# Create trainer
trainer = SegmentationTrainer(
    model=model,
    device=Config.DEVICE,
    criterion=criterion,
    optimizer=optimizer
)

# Print model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

Total parameters: 31,043,651
Trainable parameters: 31,043,651


In [None]:
history = trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=Config.NUM_EPOCHS,
    early_stopping_patience=Config.EARLY_STOPPING_PATIENCE
)


Epoch 1/100


  Referenced from: <EB3FF92A-5EB1-3EE8-AF8B-5923C1265422> /opt/anaconda3/envs/torch/lib/python3.11/site-packages/torchvision/image.so
  warn(
  Referenced from: <EB3FF92A-5EB1-3EE8-AF8B-5923C1265422> /opt/anaconda3/envs/torch/lib/python3.11/site-packages/torchvision/image.so
  warn(
Training:   1%|▍                              | 6/460 [01:29<1:51:30, 14.74s/it, loss=3.23, mean_iou=0.188]