In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from config import *
from data import SegDataset, image_train_UIDs, image_val_UIDs, CLASS_MAP_VOID, crop_augment_preprocess_batch, NUM_CLASSES_VOID
from models.vl_models import evaluate

In [None]:
from torchvision.models import segmentation as segmodels
from torchvision.transforms._presets import SemanticSegmentation
from functools import partial
from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torch import nn

# Evaluation

In [4]:
train_ds = SegDataset(image_train_UIDs, resize_size=CONFIG['seg']['image_size'], class_map=CLASS_MAP_VOID)
val_ds = SegDataset(image_val_UIDs, resize_size=CONFIG['seg']['image_size'], class_map=CLASS_MAP_VOID)

In [None]:
model = segmodels.lraspp_mobilenet_v3_large(weights=None, weights_backbone=None).to(CONFIG["device"])
model.load_state_dict(torch.load(TORCH_WEIGHTS_CHECKPOINTS / ("lraspp_mobilenet_v3_large-full-pt" + ".pth")))
model.eval();

In [6]:
center_crop_module = T.CenterCrop(CONFIG['seg']['image_size'])
random_crop_module = T.RandomCrop(CONFIG['seg']['image_size'])

In [7]:
preprocess = partial(SemanticSegmentation, resize_size=CONFIG['seg']['image_size'])()
train_collate_fn = partial(crop_augment_preprocess_batch, crop_module=T.CenterCrop(CONFIG['seg']['image_size']), augment_fn=None, preprocess_fn=preprocess)
val_collate_fn = partial(crop_augment_preprocess_batch, crop_module=lambda x, y: (x, y), augment_fn=None, preprocess_fn=preprocess)

In [8]:
criterion = nn.CrossEntropyLoss(ignore_index=21)

In [None]:
train_dl = DataLoader(
    train_ds,
    batch_size=CONFIG["seg"]["batch_size"],
    shuffle=True,
    generator=TORCH_GEN.clone_state(),
    collate_fn=train_collate_fn,
)
val_dl = DataLoader(
    val_ds,
    batch_size=1,
    shuffle=False,
    generator=TORCH_GEN.clone_state(),
    collate_fn=val_collate_fn,
)

In [10]:
metrics_dict = {
    "acc": MulticlassAccuracy(num_classes=NUM_CLASSES_VOID, top_k=1, average="micro", multidim_average="global", ignore_index=21).to(CONFIG["device"]),
    "IoU_per_class": MulticlassJaccardIndex(NUM_CLASSES_VOID, average=None, ignore_index=21, zero_division=torch.nan).to(CONFIG["device"]),
}

In [11]:
train_loss, train_metrics_score = evaluate(model, train_dl, criterion, metrics_dict)
train_loss, train_metrics_score

(0.22815202925700306,
 {'IoU_per_class': tensor([0.9096, 0.8566, 0.4038, 0.8585, 0.6917, 0.5530, 0.8724, 0.5522, 0.8974,
          0.2665, 0.8278, 0.6740, 0.8436, 0.8540, 0.7945, 0.8437, 0.4622, 0.8569,
          0.4651, 0.8744, 0.6183,    nan], device='cuda:0'),
  'acc': tensor(0.9222, device='cuda:0')})

In [12]:
print(train_metrics_score["IoU_per_class"].nanmean())

tensor(0.7131, device='cuda:0')


In [11]:
val_loss, val_metrics_score = evaluate(model, val_dl, criterion, metrics_dict)
val_loss, val_metrics_score

(0.20719398433272518,
 {'IoU_per_class': tensor([0.9220, 0.8434, 0.3509, 0.8503, 0.6659, 0.5618, 0.8918, 0.6914, 0.8643,
          0.3536, 0.8243, 0.4685, 0.7804, 0.8229, 0.7969, 0.8273, 0.5317, 0.7762,
          0.5054, 0.8581, 0.6442,    nan], device='cuda:0'),
  'acc': tensor(0.9301, device='cuda:0')})

In [12]:
print(val_metrics_score["IoU_per_class"].nanmean())

tensor(0.7063, device='cuda:0')
