In [1]:
%load_ext autoreload
%autoreload 2

In [7]:
from config import *
from data import SegDataset, CLASS_MAP_VOID, crop_augment_preprocess_batch, NUM_CLASSES_VOID, get_image_UIDs
from models.seg_models import evaluate
from path import SPLITS_PATH

In [26]:
from torchvision.models import segmentation as segmodels
from torchvision.transforms._presets import SemanticSegmentation
from functools import partial
from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex
import torchvision.transforms.v2 as T
from torch.utils.data import DataLoader
from torch import nn

# Evaluation

In [27]:
train_ds = SegDataset(get_image_UIDs(SPLITS_PATH, split="train", shuffle=False), resize_size=CONFIG['segnet']['image_size'], class_map=CLASS_MAP_VOID)
val_ds = SegDataset(get_image_UIDs(SPLITS_PATH, split="train", shuffle=False), resize_size=CONFIG['segnet']['image_size'], class_map=CLASS_MAP_VOID)

In [28]:
model = segmodels.lraspp_mobilenet_v3_large(weights=None, weights_backbone=None).to(CONFIG["device"])
model.load_state_dict(torch.load(TORCH_WEIGHTS_CHECKPOINTS / ("lraspp_mobilenet_v3_large-full-pt" + ".pth")))
model.eval();

In [29]:
center_crop_module = T.CenterCrop(CONFIG['segnet']['image_size'])
random_crop_module = T.RandomCrop(CONFIG['segnet']['image_size'])

In [30]:
preprocess = partial(SemanticSegmentation, resize_size=CONFIG['segnet']['image_size'])()
train_collate_fn = partial(crop_augment_preprocess_batch, crop_fn=T.CenterCrop(CONFIG['segnet']['image_size']), augment_fn=None, preprocess_fn=preprocess)
val_collate_fn = partial(crop_augment_preprocess_batch, crop_fn=lambda x, y: (x, y), augment_fn=None, preprocess_fn=preprocess)

In [31]:
criterion = nn.CrossEntropyLoss(ignore_index=21)

In [32]:
train_dl = DataLoader(
    train_ds,
    batch_size=CONFIG["segnet"]['train']["batch_size"],
    shuffle=True,
    generator=TORCH_GEN.clone_state(),
    collate_fn=train_collate_fn,
)
val_dl = DataLoader(
    val_ds,
    batch_size=1,
    shuffle=False,
    generator=TORCH_GEN.clone_state(),
    collate_fn=val_collate_fn,
)

In [33]:
metrics_dict = {
    "acc": MulticlassAccuracy(num_classes=NUM_CLASSES_VOID, top_k=1, average="micro", multidim_average="global", ignore_index=21).to(CONFIG["device"]),
    "IoU_per_class": MulticlassJaccardIndex(NUM_CLASSES_VOID, average=None, ignore_index=21, zero_division=torch.nan).to(CONFIG["device"]),
}

In [34]:
train_loss, train_metrics_score = evaluate(model, train_dl, criterion, metrics_dict)
train_loss, train_metrics_score

(0.22833188439978927,
 {'IoU_per_class': tensor([0.9096, 0.8566, 0.4046, 0.8583, 0.6919, 0.5531, 0.8731, 0.5510, 0.8985,
          0.2652, 0.8288, 0.6749, 0.8436, 0.8541, 0.7938, 0.8442, 0.4622, 0.8573,
          0.4624, 0.8740, 0.6180,    nan], device='cuda:0'),
  'acc': tensor(0.9222, device='cuda:0')})

In [35]:
print(train_metrics_score["IoU_per_class"].nanmean())

tensor(0.7131, device='cuda:0')


In [36]:
val_loss, val_metrics_score = evaluate(model, val_dl, criterion, metrics_dict)
val_loss, val_metrics_score

(0.2163727689334675,
 {'IoU_per_class': tensor([0.9198, 0.8524, 0.3711, 0.8392, 0.6574, 0.5793, 0.8627, 0.5535, 0.8884,
          0.3370, 0.8183, 0.5975, 0.8327, 0.8495, 0.7918, 0.8278, 0.4517, 0.8418,
          0.4489, 0.8613, 0.5966,    nan], device='cuda:0'),
  'acc': tensor(0.9283, device='cuda:0')})

In [37]:
print(val_metrics_score["IoU_per_class"].nanmean())

tensor(0.7037, device='cuda:0')
