In [22]:
import copy
import torch.nn as nn

import torch
from utils import transforms as my_transforms
from torchvision.models.segmentation import deeplabv3_resnet50

H2GIGA_DIR='../Data/augmented/H2giga/'


args = dict(

    cuda=True,
    display=True,

    save=True,
    save_dir='./test_deepLabv3',
    checkpoint_path='./deepLabv3/checkpoint.pth',
    color_map={0:(0,0,0),1: (21, 176, 26), 2:(5, 73, 7),3: (170, 166, 98),4: (229, 0, 0), 5: (140, 0, 15)},
    num_class = 6,
    dataset= { 
        'name': 'H2giga',
        'kwargs': {
            'root_dir': H2GIGA_DIR,
            'type': 'test',
            'class_id': None,            
            'transform': my_transforms.get_transform([
                {
                    'name': 'ToTensor',
                    'opts': {
                        'keys': ('image','instance', 'label'),
                        'type': (torch.FloatTensor, torch.ByteTensor, torch.ByteTensor),
                    }
                },
            ]),
        }
    },
)


def get_args():
    return copy.deepcopy(args)
def prepare_model(num_classes=6):
    model = deeplabv3_resnet50(weights='DEFAULT')
    model.classifier[4] = nn.Conv2d(256, num_classes, 1)
    model.aux_classifier[4] = nn.Conv2d(256, num_classes, 1)
    return model


In [20]:
import os
import time

import torchvision
from matplotlib import pyplot as plt
from tqdm import tqdm

import test_config
import torch
from datasets import get_dataset
from models import get_model
from utils.utils import Cluster,Visualizer,Metrics
from skimage.color import label2rgb
from skimage import io
import numpy as np

   
args = get_args()


def begin_test(args,n_sigma=2):
    torch.backends.cudnn.benchmark = True

    if args['save']:
        if not os.path.exists(args['save_dir']):
            os.makedirs(args['save_dir'])

    # set device
    device = torch.device("cuda:0" if args['cuda'] else "cpu")

    # dataloader
    dataset = get_dataset(args['dataset']['name'], args['dataset']['kwargs'])
    dataset_it = torch.utils.data.DataLoader(
                                dataset, 
                                batch_size=1, 
                                shuffle=False, 
                                drop_last=False, 
                                num_workers=2, 
                                pin_memory=True if args['cuda'] else False)

    # load model
    model = prepare_model(args["num_class"])
    model = torch.nn.DataParallel(model).to(device)
    num_class = 5

    # load snapshot
    if os.path.exists(args['checkpoint_path']):
        state = torch.load(args['checkpoint_path'])
#         state = torch.load(args['checkpoint_path'], map_location='cpu')
        model.load_state_dict(state['model_state_dict'], strict=True)
    else:
        assert False, 'checkpoint_path {} does not exist!'.format(args['checkpoint_path'])
#     print(model.device)

    model.eval()

    # Visualizer
    visualizer = Visualizer(args)
    metrics = Metrics(num_class=num_class)

    with torch.no_grad():

        for sample in tqdm(dataset_it):

            im = sample['image'].to(device)
            label = sample['label'].squeeze()
        
            output = model(im)
            out = output["out"][0]
            class_pred = visualizer.prepare_pred(out)
            pred_score = out[1:]
            metrics.add(label.numpy(),class_pred.cpu().numpy(),pred_score.cpu().numpy())
            if args['save']:
                img = io.imread(sample["im_name"][0])
                label = visualizer.label2colormap(label.numpy())
                class_pred = visualizer.label2colormap(class_pred.cpu().numpy())
                
                ground_truth = visualizer.overlay_image(img[...,:3],label)
                grid = np.concatenate((ground_truth,class_pred),axis=1)
                base, _ = os.path.splitext(os.path.basename(sample['im_name'][0]))
                io.imsave(os.path.join(args['save_dir'], base+'.png'),grid)
        metrics.log("evaluation.txt")


In [21]:

args =get_args()
begin_test(args)

  0%|          | 0/3 [00:04<?, ?it/s]


ValueError: Found input variables with inconsistent numbers of samples: [1048576, 0]