In [None]:
import sys
from pathlib import Path

import torch
from torch import Tensor, nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.models import segmentation
from torchvision.transforms import v2
from tqdm import tqdm

sys.path.append(str(Path("..").resolve()))
from src.semantic_segmentation_toolkit.datasets import DATASET_ZOO, resolve_metadata
from src.semantic_segmentation_toolkit.models import MODEL_ZOO
from src.semantic_segmentation_toolkit.utils.metrics import MetricStore
from src.semantic_segmentation_toolkit.utils.rng import seed
from src.semantic_segmentation_toolkit.utils.transform import (
    SegmentationAugment,
    SegmentationTransform,
)

In [None]:
root = Path(r"..\Cityscapes")
metadata = resolve_metadata("Cityscapes")
data_transform = SegmentationTransform((512, 1024), mask_fill=metadata.ignore_index)
train_augment = SegmentationAugment(0.5)
val_augment = SegmentationAugment()
train_dataset = DATASET_ZOO["Cityscapes"].construct_train(data_transform, root=root)
val_dataset = DATASET_ZOO["Cityscapes"].construct_val(data_transform, root=root)
train_loader = DataLoader(train_dataset, 4, drop_last=True, shuffle=True)
val_loader = DataLoader(val_dataset)

num_epochs = 30
device = "cuda" if torch.cuda.is_available() else "cpu"
model = segmentation.fcn_resnet50(num_classes=metadata.num_classes)
criterion = nn.CrossEntropyLoss(ignore_index=metadata.ignore_index)
optimizer = optim.SGD(model.parameters(), 3e-4, 0.9)

In [None]:
seed(42)

model.to(device)
for i in range(num_epochs):
    print(f"Epoch {i:>4}/{num_epochs}")

    model.train()
    train_ms = MetricStore(metadata.num_classes)
    train_tqdm = tqdm(enumerate(train_loader), total=len(train_loader), desc="Train")
    for j, (images, masks) in train_tqdm:
        images: Tensor = images.to(device)
        masks: Tensor = masks.to(device)
        images, masks = train_augment(images, masks)

        logits: dict[str, Tensor] = model(images)
        for k, v in logits.items():
            logits[k] = F.interpolate(v, masks.shape[-2:], mode="bilinear")
        losses = {k: criterion(v, masks) for k, v in logits.items()}
        loss_sum = sum(losses.values())

        assert isinstance(loss_sum, Tensor)
        optimizer.zero_grad()
        loss_sum.backward()
        optimizer.step()

        train_ms.store_results(masks, logits["out"].argmax(1))
        train_tqdm.set_postfix(train_ms.summarize())

    model.eval()
    val_ms = MetricStore(metadata.num_classes)
    val_tqdm = tqdm(enumerate(val_loader), total=len(val_loader), desc="Val")
    with torch.no_grad():
        for j, (images, masks) in val_tqdm:
            images: Tensor = images.to(device)
            masks: Tensor = masks.to(device)
            images, masks = train_augment(images, masks)
            logits: dict[str, Tensor] = model(images)

            out_logits = F.interpolate(logits["out"], masks.shape[-2:], mode="bilinear")
            val_ms.store_results(masks, logits["out"].argmax(1))
            val_tqdm.set_postfix(val_ms.summarize())