In [1]:
import sys
from pathlib import Path
from pprint import pprint
from timeit import default_timer

import torch
import tqdm
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import VOCSegmentation

sys.path.append(str(Path("..").resolve()))
from src.datasets import resolve_metadata
from src.models import FCN_ResNet34_Weights, fcn_resnet34
from src.pipeline import (
    TesttimeAugmentations,
    eval_one_epoch,
    inference_with_augmentations,
)
from src.utils.metrics import MetricStore
from src.utils.transform import SegmentationTransform

In [2]:
weights = FCN_ResNet34_Weights.VOC2012
model = fcn_resnet34(weights=weights)
transforms = SegmentationTransform()
augment = weights.value.transforms()
dataset = VOCSegmentation(
    r"..\dataset", image_set="val", transforms=transforms, year="2007"
)
data_loader = DataLoader(dataset)
metadata = resolve_metadata("VOC")
criterion = nn.CrossEntropyLoss(ignore_index=metadata.ignore_index)
device = "cuda" if torch.cuda.is_available() else "cpu"

### Without augmentations

In [3]:
ms = eval_one_epoch(
    model, data_loader, augment, criterion, device, metadata.num_classes
)
pprint(ms.summarize())

Eval: 100%|██████████| 213/213 [00:21<00:00,  9.74it/s, acc=0.916, macc=0.725, miou=0.626, fwiou=0.851, dice=0.752, loss=0.282, time=0.0795]

{'acc': 0.9160854087837493,
 'dice': 0.7518142404788088,
 'fwiou': 0.8508208798466645,
 'loss': 0.2819843197967245,
 'macc': 0.7247294714938348,
 'miou': 0.6264660152690998,
 'time': 0.07951267934819692}





### With augmentations

In [4]:
ttas = TesttimeAugmentations(
    (0.75, 1, 1.25), (False, True), (False,), (0,), iter_product=True
)

In [None]:
model.to(device).eval()
ms = MetricStore(metadata.num_classes)
loader = tqdm.tqdm(iter(data_loader), total=len(data_loader))
for images, masks in loader:
    start_time = default_timer()
    prelim_images, masks = augment(images.to(device), masks)
    augmented_logits = inference_with_augmentations(model, prelim_images, ttas)
    aggregated_logits = torch.mean(augmented_logits, dim=0)
    loss = criterion(aggregated_logits, masks)
    end_time = default_timer()

    preds = aggregated_logits.argmax(1)
    ms.store_results(masks, preds)
    batch_size = images.size(0)
    measures = {
        "loss": loss.item() * batch_size,
        "time": end_time - start_time,
    }
    ms.store_measures(batch_size, measures)
    loader.set_postfix(ms.summarize())
pprint(ms.summarize())

100%|██████████| 213/213 [01:42<00:00,  2.07it/s, acc=0.92, macc=0.713, miou=0.636, fwiou=0.856, dice=0.755, loss=0.269, time=0.459]  

{'acc': 0.9200380971067635,
 'dice': 0.7545044644978887,
 'fwiou': 0.8555622179441071,
 'loss': 0.269082717763636,
 'macc': 0.7132029377897163,
 'miou': 0.6355081390238344,
 'time': 0.4585300286425094}



