In [1]:
from src.segnet_model import SegNet
from src.dataset import SegmentationDataset 

import torch
import torchvision.transforms as transforms
import torch.optim as optim

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm

In [2]:
dataset_root_path = 'data/idd20k_lite_prepared'

data_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

train_dataset = SegmentationDataset(dataset_root_path, transform=data_transforms, mode='train')
val_dataset = SegmentationDataset(dataset_root_path, transform=data_transforms, mode='val')

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False)

DEVICE = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
# DEVICE = torch.device('cpu')

In [3]:
model = SegNet(in_channels=3, num_classes=34).to(DEVICE)
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)  # Ignore index 255 for void class
optimizer = optim.Adam(model.parameters(), lr=1e-4)

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0

    for images, masks in tqdm(dataloader, desc="Training", leave=False):
        images = images.to(device)                # shape: [B, 3, H, W]
        masks = masks.to(device).long()           # shape: [B, H, W], values: [0, ..., 33]

        optimizer.zero_grad()
        outputs = model(images)                   # shape: [B, 34, H, W]
        loss = criterion(outputs, masks)          # CrossEntropyLoss expects this

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    return running_loss / len(dataloader)

def evaluate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0

    with torch.no_grad():
        for images, masks in tqdm(dataloader, desc="Validation", leave=False):
            images = images.to(device)
            masks = masks.to(device).long()

            outputs = model(images)
            loss = criterion(outputs, masks)

            running_loss += loss.item()

    return running_loss / len(dataloader)

# Training Loop
NUM_EPOCHS = 20

for epoch in range(NUM_EPOCHS):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE)
    val_loss = evaluate(model, val_loader, criterion, DEVICE)

    print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

                                                  

NotImplementedError: The operator 'aten::max_unpool2d' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on https://github.com/pytorch/pytorch/issues/141287 and mention use-case, that resulted in missing op as well as commit hash 134179474539648ba7dee1317959529fbd0e7f89. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.