In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from core.config import *

from core.datasets import VOC2012SegDataset
from core.data import crop_augment_preprocess_batch
from core.data_utils import flatten_list_of_lists
from core.color_map import apply_colormap
from core.torch_utils import get_activation
from models.seg import (SEGMODELS_REGISTRY,
    compute_seg_grad_cam, 
    compute_seg_grad_cam_pp,
    compute_seg_xgrad_cam,
    compute_seg_xres_cam,
    compute_seg_hires_grad_cam)
from core.viz import normalize_sim_maps, viz_seg_saliency_maps

import torch.nn.functional as F
import torchvision.transforms.v2 as T
from torchvision.transforms._presets import SemanticSegmentation
from torchvision.transforms.functional import to_pil_image
from torch.utils.data import DataLoader
from torch import nn
from torch.utils.hooks import RemovableHandle
from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex

import seaborn as sns
import matplotlib.pyplot as plt
from collections import OrderedDict
from functools import partial

In [3]:
config = setup_config(BASE_CONFIG, Path('/home/olivieri/exp/src/eval/seg/config.yml'))
seg_config = config['seg']

In [None]:
segmodel = SEGMODELS_REGISTRY.get(
    name=seg_config['model_name'],
    pretrained_weights_path=seg_config['pretrained_weights_path'],
    device=config['device'],
    adaptation=seg_config['adaptation']
)
segmodel.adapt()

In [5]:
if seg_config['checkpoint_path']:
    state_dict: OrderedDict = torch.load(seg_config['checkpoint_path'], map_location='cpu')
    model_state_dict = state_dict.get('model_state_dict', state_dict)
    segmodel.model.load_state_dict(model_state_dict)

In [6]:
train_ds = VOC2012SegDataset(
    root_path=config['datasets']['VOC2012_root_path'],
    split='train',
    device=config['device'],
    resize_size=seg_config['image_size'],
    center_crop=False,
    with_unlabelled=True,
)

In [7]:
val_ds = VOC2012SegDataset(
    root_path=config['datasets']['VOC2012_root_path'],
    split='val',
    device=config['device'],
    resize_size=seg_config['image_size'],
    center_crop=False,
    with_unlabelled=True,
)

In [8]:
collate_fn = partial(
    crop_augment_preprocess_batch,
    crop_fn=T.CenterCrop(seg_config['image_size']),
    augment_fn=None,
    preprocess_fn=segmodel.preprocess_images
)

In [9]:
train_dl = DataLoader(
    train_ds,
    batch_size=seg_config['batch_size'],
    shuffle=False,
    generator=get_torch_gen(),
    collate_fn=collate_fn,
)

In [10]:
val_dl = DataLoader(
    val_ds,
    batch_size=seg_config['batch_size'],
    shuffle=False,
    generator=get_torch_gen(),
    collate_fn=collate_fn,
)

In [11]:
criterion = nn.CrossEntropyLoss(ignore_index=21)

In [12]:
metrics_dict = {
    "acc": MulticlassAccuracy(num_classes=val_ds.get_num_classes(with_unlabelled=True), top_k=1, average='micro', multidim_average='global', ignore_index=21).to(config['device']),
    "mIoU": MulticlassJaccardIndex(num_classes=val_ds.get_num_classes(with_unlabelled=True), average='none', ignore_index=21).to(config['device']),
}

In [13]:
train_loss, train_metrics_score = segmodel.evaluate(train_dl, criterion, metrics_dict)

In [14]:
train_mIoU = train_metrics_score['mIoU']
train_mIoU[:21].mean(), train_mIoU

(tensor(0.8063, device='cuda:0'),
 tensor([0.9524, 0.8641, 0.3496, 0.8926, 0.8507, 0.7918, 0.9279, 0.8923, 0.9344,
         0.4818, 0.7928, 0.7926, 0.8969, 0.8264, 0.8148, 0.8644, 0.6714, 0.8973,
         0.6833, 0.9207, 0.8340, 0.0000], device='cuda:0'))

In [15]:
val_loss, val_metrics_score = segmodel.evaluate(val_dl, criterion, metrics_dict)

In [16]:
mIoU = val_metrics_score['mIoU']
mIoU[:21].mean(), mIoU

(tensor(0.5503, device='cuda:0'),
 tensor([0.8921, 0.7252, 0.2802, 0.7397, 0.5178, 0.5451, 0.7730, 0.7079, 0.7107,
         0.1638, 0.4484, 0.3785, 0.5951, 0.4532, 0.6296, 0.7469, 0.2850, 0.4735,
         0.3419, 0.6312, 0.5168, 0.0000], device='cuda:0'))