In [None]:
import sys
from pathlib import Path
from pprint import pprint
from timeit import default_timer

import torch
import tqdm
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import VOCSegmentation

sys.path.append(str(Path("..").resolve()))
from src.semantic_segmentation_toolkit.datasets import resolve_metadata
from src.semantic_segmentation_toolkit.models import FCN_ResNet34_Weights, fcn_resnet34
from src.semantic_segmentation_toolkit.pipeline import (
    TesttimeAugmentations,
    eval_one_epoch,
    inference_with_augmentations,
)
from src.semantic_segmentation_toolkit.utils.metrics import MetricStore
from src.semantic_segmentation_toolkit.utils.transform import SegmentationTransform
from src.semantic_segmentation_toolkit.utils.rng import seed

In [None]:
weights = FCN_ResNet34_Weights.VOC2012
model = fcn_resnet34(weights=weights)
transforms = SegmentationTransform()
augment = weights.value.transforms()
dataset = VOCSegmentation(
    r"..\dataset", image_set="val", transforms=transforms, year="2007"
)
data_loader = DataLoader(dataset)
metadata = resolve_metadata("VOC")
criterion = nn.CrossEntropyLoss(ignore_index=metadata.ignore_index)
device = "cuda" if torch.cuda.is_available() else "cpu"

### Without augmentations

In [None]:
seed(42)
ms = eval_one_epoch(
    model, data_loader, augment, criterion, device, metadata.num_classes
)
pprint(ms.summarize())

Eval: 100%|██████████| 213/213 [00:23<00:00,  9.06it/s, acc=0.916, macc=0.725, miou=0.626, fwiou=0.851, dice=0.752, loss=0.282, time=0.0812]

{'acc': 0.916085298055659,
 'dice': 0.7518140231973609,
 'fwiou': 0.8508206689830995,
 'loss': 0.2819842651416438,
 'macc': 0.724729110277966,
 'miou': 0.626465678574391,
 'time': 0.0811527868592893}





### With augmentations

In [4]:
ttas = TesttimeAugmentations(
    (0.75, 1, 1.25), (False, True), (False,), (0,), iter_product=True
)

In [None]:
seed(42)
model.to(device).eval()
ms = MetricStore(metadata.num_classes)
loader = tqdm.tqdm(iter(data_loader), total=len(data_loader))
for images, masks in loader:
    start_time = default_timer()
    prelim_images, masks = augment(images.to(device), masks)
    augmented_logits = inference_with_augmentations(model, prelim_images, ttas)
    aggregated_logits = torch.mean(augmented_logits, dim=0)
    loss = criterion(aggregated_logits, masks)
    end_time = default_timer()

    preds = aggregated_logits.argmax(1)
    ms.store_results(masks, preds)
    batch_size = images.size(0)
    measures = {
        "loss": loss.item() * batch_size,
        "time": end_time - start_time,
    }
    ms.store_measures(batch_size, measures)
    loader.set_postfix(ms.summarize())
pprint(ms.summarize())

100%|██████████| 213/213 [02:22<00:00,  1.49it/s, acc=0.92, macc=0.713, miou=0.636, fwiou=0.856, dice=0.755, loss=0.269, time=0.646]  

{'acc': 0.9200380417427184,
 'dice': 0.7545043196274661,
 'fwiou': 0.8555621297554995,
 'loss': 0.26908271920653015,
 'macc': 0.7132025147834363,
 'miou': 0.6355080084597899,
 'time': 0.6462920990536435}



