In [1]:
dir_path = "/home/hyunho/sfda/"
import sys
sys.path.append(dir_path)

In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader

from ignite.engine import Engine, Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss, mIoU, ConfusionMatrix
from ignite.handlers import ModelCheckpoint
from ignite.contrib.handlers import TensorboardLogger, global_step_from_engine
from dataset.gta_loader import SegmentationDataset
from dataset.cityscapes_loader import CityscapesDataset
from torchvision import transforms
import cv2
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


image_transforms = transforms.Compose([
    transforms.Resize((720,1280)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=(0.4422, 0.4379, 0.4246), std=(0.2572, 0.2516, 0.2467)),
])
mask_transforms = transforms.Compose([
    transforms.Resize((720,1280), interpolation=transforms.InterpolationMode.NEAREST)
])


train_dataset = CityscapesDataset(
    images_dir="/home/hyunho/sfda/data/cityscapes_dataset/leftImg8bit/train",
    masks_dir="/home/hyunho/sfda/data/cityscapes_dataset/gtFine/train",
    transform=image_transforms,
    target_transform=mask_transforms,
    debug=True
)

train_loader = DataLoader(
    train_dataset, 
    batch_size=2, 
    shuffle=True, 
    pin_memory=True,
    num_workers=16
)

  from torch.distributed.optim import ZeroRedundancyOptimizer


In [3]:
data = next(iter(train_loader))
img, label, name = data

In [4]:
from model.deeplabv2 import DeeplabMulti
model = DeeplabMulti(num_classes=19, pretrained=False)
model.load_state_dict(
  torch.load("/home/hyunho/sfda/exp/deeplabv2_1022/best_model_3_accuracy=0.8210.pt", map_location=device, weights_only=True)
  # torch.load("/home/hyunho/sfda/exp/deeplabv2_1022/best_model_4_accuracy=0.8218.pt", map_location=device, weights_only=True)
)

model.eval()
temp_output = model(img)

In [12]:
np.unique(label)

array([  0,   1,   2,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
        18, 255], dtype=uint8)

In [31]:
from itertools import filterfalse as  ifilterfalse

def isnan(x):
    return x != x

def mean(l, ignore_nan=False, empty=0):
    """
    nanmean compatible with generators.
    """
    l = iter(l)
    if ignore_nan:
        l = ifilterfalse(isnan, l)
    try:
        n = 1
        acc = next(l)
    except StopIteration:
        if empty == 'raise':
            raise ValueError('Empty mean')
        return empty
    for n, v in enumerate(l, 2):
        acc += v
    if n == 1:
        return acc
    return acc / n

def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
    """
    Array of IoU for each (non ignored) class
    """
    if not per_image:
        preds, labels = (preds,), (labels,)
    ious = []
    for pred, label in zip(preds, labels):
        iou = []    
        for i in range(C):
            if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
                intersection = ((label == i) & (pred == i)).sum()
                union = ((label == i) | ((pred == i) & (label != ignore))).sum()
                if not union:
                    iou.append(EMPTY)
                else:
                    iou.append(float(intersection) / float(union))
        ious.append(iou)
    ious = [mean(x) for x in zip(*ious)] # mean accross images if per_image
    return 100 * np.array(ious)

In [32]:
preds, classes = temp_output.softmax(1).max(1)

In [33]:
loss = iou(classes, label.long(), C=19, ignore=255)
loss

array([  5.02415371,   0.        ,  68.94382682,   0.        ,
         0.        ,   1.9728905 ,   0.        ,   0.        ,
         6.44570929,   0.        ,  69.20268769,   0.        ,
         0.        ,   7.99529264,   0.        ,   0.        ,
       100.        , 100.        ,   0.        ])

In [37]:
total_loss = np.zeros(19)
total_loss += loss

np.mean(total_loss)


18.925503192307822

In [36]:
total_loss

array([  5.02415371,   0.        ,  68.94382682,   0.        ,
         0.        ,   1.9728905 ,   0.        ,   0.        ,
         6.44570929,   0.        ,  69.20268769,   0.        ,
         0.        ,   7.99529264,   0.        ,   0.        ,
       100.        , 100.        ,   0.        ])