In [1]:
from __future__ import absolute_import, division, print_function

import json
import multiprocessing
import os

import click
import joblib
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import yaml
from addict import Dict
from PIL import Image
from tensorboardX import SummaryWriter
from torchnet.meter import MovingAverageValueMeter
from tqdm import tqdm

from libs.datasets import get_dataset
from libs.models import *
from libs.utils import DenseCRF, PolynomialLR, scores

In [2]:
def makedirs(dirs):
    if not os.path.exists(dirs):
        os.makedirs(dirs)


def get_device(cuda):
    cuda = cuda and torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")
    if cuda:
        print("Device:")
        for i in range(torch.cuda.device_count()):
            print("    {}:".format(i), torch.cuda.get_device_name(i))
    else:
        print("Device: CPU")
    return device


def get_params(model, key):
    # For Dilated FCN
    if key == "1x":
        for m in model.named_modules():
            if "layer" in m[0]:
                if isinstance(m[1], nn.Conv2d):
                    for p in m[1].parameters():
                        yield p
    # For conv weight in the ASPP module
    if key == "10x":
        for m in model.named_modules():
            if "aspp" in m[0]:
                if isinstance(m[1], nn.Conv2d):
                    yield m[1].weight
    # For conv bias in the ASPP module
    if key == "20x":
        for m in model.named_modules():
            if "aspp" in m[0]:
                if isinstance(m[1], nn.Conv2d):
                    yield m[1].bias


def resize_labels(labels, size):
    """
    Downsample labels for 0.5x and 0.75x logits by nearest interpolation.
    Other nearest methods result in misaligned labels.
    -> F.interpolate(labels, shape, mode='nearest')
    -> cv2.resize(labels, shape, interpolation=cv2.INTER_NEAREST)
    """
    new_labels = []
    for label in labels:
        label = label.float().numpy()
        label = Image.fromarray(label).resize(size, resample=Image.NEAREST)
        new_labels.append(np.asarray(label))
    new_labels = torch.LongTensor(new_labels)
    return new_labels

In [3]:
model_path = "data/models/S2DS/deeplabv2_resnet101_msc/train/checkpoint_final.pth"
config_path = "configs/S2DS.yaml"
cuda = True

In [4]:
with open(config_path, 'r') as file:
    CONFIG = Dict(yaml.load(file, Loader=yaml.FullLoader))
device = get_device(cuda)
torch.set_grad_enabled(False)

# Dataset
dataset = get_dataset(CONFIG.DATASET.NAME)(
    root=CONFIG.DATASET.ROOT,
    split=CONFIG.DATASET.SPLIT.VAL,
    ignore_label=CONFIG.DATASET.IGNORE_LABEL,
    mean_bgr=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R),
    augment=False,
)
print(dataset)

Device:
    0: Tesla V100-PCIE-16GB
    1: Tesla V100-PCIE-16GB
Dataset: S2DS
    # data: 157
    Split: val
    Root: S2DSdevkit/part2


In [5]:
loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=CONFIG.SOLVER.BATCH_SIZE.TEST,
    num_workers=CONFIG.DATALOADER.NUM_WORKERS,
    shuffle=False,
)
loader_iter = iter(loader)

In [6]:
"""# Model
model = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES)
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
model.load_state_dict(state_dict)
model = nn.DataParallel(model)
model.eval()
model.to(device)"""

'# Model\nmodel = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES)\nstate_dict = torch.load(model_path, map_location=lambda storage, loc: storage)\nmodel.load_state_dict(state_dict)\nmodel = nn.DataParallel(model)\nmodel.eval()\nmodel.to(device)'

In [7]:
image_id, image, label = next(loader_iter)
image_id

('IMG_6272',)

In [8]:
print(type(image))
print(type(label))

<class 'torch.Tensor'>
<class 'torch.Tensor'>


In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
plt.rcParams["figure.figsize"] = (20,10) 
from libs.visualization import show_pair
show_pair(image_id, image, label, (CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R))

<class 'numpy.ndarray'>
<class 'numpy.ndarray'>


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
