In [None]:
import shutil
import sys
from pathlib import Path
from pprint import pprint
from timeit import default_timer

import torch
import tqdm
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import VOCSegmentation
from torchvision.models.segmentation import fcn_resnet50
from torchvision.transforms.v2 import functional as TF

sys.path.append(str(Path("..").resolve()))
from src.semantic_segmentation_toolkit.datasets import (
    DATASET_ZOO,
    CityscapesClass,
    resolve_metadata,
)
from src.semantic_segmentation_toolkit.models import FCN_ResNet34_Weights, fcn_resnet34
from src.semantic_segmentation_toolkit.pipeline import (
    TesttimeAugmentations,
    create_snapshots,
    eval_one_epoch,
    inference_with_augmentations,
    inference_with_sliding_window,
)
from src.semantic_segmentation_toolkit.utils.metrics import MetricStore
from src.semantic_segmentation_toolkit.utils.rng import seed
from src.semantic_segmentation_toolkit.utils.transform import (
    SegmentationAugment,
    SegmentationTransform,
)
from src.semantic_segmentation_toolkit.utils.visual import combine_images

In [None]:
metadata = resolve_metadata("Cityscapes")
transforms = SegmentationTransform(mask_fill=metadata.ignore_index)
augment = SegmentationAugment(mask_fill=metadata.ignore_index)
dataset = CityscapesClass(
    root=r"..\..\Datasets\Cityscapes",
    split="val",
    target_type="semantic",
    transforms=transforms,
)
data_loader = DataLoader(dataset)

In [None]:
model_state_file = Path(r"..\runs\20250227_171651\latest_model.pth")
model = fcn_resnet50(num_classes=metadata.num_classes, aux_loss=True)
model_weights = torch.load(model_state_file)
model.load_state_dict(model_weights)

criterion = nn.CrossEntropyLoss(ignore_index=metadata.ignore_index)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
model.to(device).eval()
criterion.to(device).eval()

In [None]:
snapshots = create_snapshots(model, dataset, augment, device, metadata.colors)
combined = combine_images([s for ss in snapshots for s in ss])
combined_pil: Image.Image = TF.to_pil_image(combined)
combined_pil.reduce(7)

### Without augmentations

In [None]:
seed(42)
ms = eval_one_epoch(
    model, data_loader, augment, criterion, device, metadata.num_classes
)
pprint(ms.summarize())

### With augmentations

In [None]:
ttas = TesttimeAugmentations(
    (1,), (False, True), (False,), (0,), iter_product=True
)

In [None]:
seed(42)
ms = MetricStore(metadata.num_classes)
loader = tqdm.tqdm(iter(data_loader), total=len(data_loader))
for images, masks in loader:
    start_time = default_timer()
    prelim_images, masks = augment(images.to(device), masks.to(device))
    augmented_logits = inference_with_augmentations(model, prelim_images, ttas)

    # may change how to aggregate results
    aggregated_logits = torch.mean(augmented_logits, dim=0)
    loss = criterion(aggregated_logits, masks)
    end_time = default_timer()

    preds = aggregated_logits.argmax(1)
    ms.store_results(masks, preds)
    batch_size = images.size(0)
    measures = {
        "loss": loss.item() * batch_size,
        "time": end_time - start_time,
    }
    ms.store_measures(batch_size, measures)
    loader.set_postfix(ms.summarize())
pprint(ms.summarize())

### With sliding window

In [None]:
window_size = (512, 1024)

In [None]:
seed(42)
ms = MetricStore(metadata.num_classes)
loader = tqdm.tqdm(iter(data_loader), total=len(data_loader))
for images, masks in loader:
    start_time = default_timer()
    prelim_images, masks = augment(images.to(device), masks.to(device))
    augmented_logits = inference_with_sliding_window(model, prelim_images, window_size)
    aggregated_logits = torch.mean(augmented_logits, dim=0)
    end_time = default_timer()

    preds = aggregated_logits.argmax(1)
    ms.store_results(masks, preds)
    batch_size = images.size(0)
    measures = {"time": end_time - start_time}
    ms.store_measures(batch_size, measures)
    loader.set_postfix(ms.summarize())
pprint(ms.summarize())

### For submitting to cityscapes benchmark suite

In [None]:
out_folder = Path(r"cityscapes_semantic")
out_folder.mkdir(parents=True)
full_metadata = resolve_metadata("CityscapesFull")

ttas = TesttimeAugmentations((1,), (False, True), (False,), (0,), iter_product=True)
loader = tqdm.tqdm(enumerate(data_loader), total=len(data_loader))
for i, (images, masks) in loader:
    prelim_images, masks = augment(images.to(device), masks.to(device))
    augmented_logits = inference_with_augmentations(model, prelim_images, ttas)
    aggregated_logits = torch.mean(augmented_logits, dim=0)

    # convert preds to labelIDs
    preds = aggregated_logits.argmax(1)
    label_id_pred = torch.zeros_like(preds, dtype=torch.uint8)
    for train_id, train_label in enumerate(metadata.labels):
        if train_label in full_metadata.labels:
            label_id = full_metadata.labels.index(train_label)
            label_id_pred[preds == train_id] = label_id

    image_path = Path(dataset.images[i])
    preds_pil: Image.Image = TF.to_pil_image(label_id_pred)
    palette = [c for rgb in full_metadata.colors for c in rgb]
    preds_pil.putpalette(palette)
    preds_pil.save(out_folder / image_path.with_suffix(".png").name)

shutil.make_archive("cityscapes_semantic", "zip", out_folder)