In [None]:
# !pip install mmcv
# !pip install argparse

# import mmcv
import numpy as np
import os.path as osp
import copy
import time
import argparse
import os
import sys
import json
import cv2
import numpy as np
import torch
import multiprocessing as mul
import uuid
import psutil
import time
import csv
import math
import torch.optim as optim


from functools import wraps
from collections import OrderedDict
from inspect import getfullargspec



import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from sklearn.metrics import accuracy_score, roc_curve, confusion_matrix
from scipy.interpolate import make_interp_spline
from functools import partial
# from mmcv import scandir

from scipy.stats import wasserstein_distance
from skimage.metrics import normalized_root_mse

from __future__ import print_function



In [None]:
from google.colab import drive
drive.mount('/content/drive')



In [None]:
sys.path.append(os.getcwd())


class Parser(object):
    def __init__(self):
        self.parser = argparse.ArgumentParser(add_help=False)


        self.parser.add_argument('--task', default='congestion_gpdl')
        self.parser.add_argument('--save_path', default='work_dir/congestion_gpdl/')
        self.parser.add_argument('--pretrained', default=None)
        self.parser.add_argument('--max_iters', type=int, default=200000)
        self.parser.add_argument('--plot_roc', action='store_true')
        self.parser.add_argument('--arg_file', default=None)
        self.parser.add_argument('--cpu', action='store_true')

    def parse_args(self):
        args, _ = self.parser.parse_known_args()

        self._add_task_args(args.task)

        args, _ = self.parser.parse_known_args()
        return args

    def _add_task_args(self, task):
        if task == 'congestion_gpdl':
            self.parser.add_argument('--dataroot', default='../../training_set/congestion')
            self.parser.add_argument('--ann_file_train', default='./files/train_N28.csv')
            self.parser.add_argument('--ann_file_test', default='./files/test_N28.csv')
            self.parser.add_argument('--dataset_type', default='CongestionDataset')
            self.parser.add_argument('--batch_size', type=int, default=16)
            self.parser.add_argument('--aug_pipeline', default=['Flip'])

            self.parser.add_argument('--model_type', default='GPDL')
            self.parser.add_argument('--in_channels', type=int, default=3)
            self.parser.add_argument('--out_channels', type=int, default=1)
            self.parser.add_argument('--lr', type=float, default=2e-4)
            self.parser.add_argument('--weight_decay', type=float, default=0)
            self.parser.add_argument('--loss_type', default='MSELoss')
            self.parser.add_argument('--eval_metric', default=['NRMS', 'SSIM', 'EMD'])


        else:
            raise ValueError(f"Unknown task: {task}")






class Flip:
    _directions = ['horizontal', 'vertical']

    def __init__(self, keys=['feature', 'label'], flip_ratio=0.5, direction='horizontal', **kwargs):
        if direction not in self._directions:
            raise ValueError(f'Direction {direction} is not supported.'
                             f'Currently support ones are {self._directions}')
        self.keys = keys
        self.flip_ratio = flip_ratio
        self.direction = direction

    def __call__(self, results):
        flip = np.random.random() < self.flip_ratio

        if flip:
            for key in self.keys:
                if isinstance(results[key], list):
                    for v in results[key]:
                        mmcv.imflip_(v, self.direction)
                else:
                    mmcv.imflip_(results[key], self.direction)

        return results



class Rotation:
    def __init__(self, keys=['feature', 'label'], axis=(0,1), rotate_ratio=0.5, **kwargs):
        self.keys = keys
        self.axis = {k:axis for k in keys} if isinstance(axis, tuple) else axis
        self.rotate_ratio = rotate_ratio
        self.direction = [0, -1, -2, -3]

    def __call__(self, results):
        rotate = np.random.random() < self.rotate_ratio

        if rotate:
            rotate_angle = self.direction[int(np.random.random()/(10.0/3.0))+1]
            for key in self.keys:
                if isinstance(results[key], list):
                    for v in results[key]:
                        results[key] = np.ascontiguousarray(np.rot90(v, rotate_angle, axes=self.axis[key]))
                else:
                    results[key] = np.ascontiguousarray(np.rot90(results[key], rotate_angle, axes=self.axis[key]))

        return results

class IterLoader:
    def __init__(self, dataloader):
        self._dataloader = dataloader
        self.iter_loader = iter(self._dataloader)

    def __next__(self):
        try:
            data = next(self.iter_loader)
        except StopIteration:
            time.sleep(2)
            self.iter_loader = iter(self._dataloader)
            data = next(self.iter_loader)
        return data

    def __len__(self):
        return len(self._dataloader)

    def __iter__(self):
        return self

class CongestionDataset(object):

    def __init__(self, ann_file, dataroot, pipeline=None, test_mode=False):
        self.ann_file = ann_file
        self.dataroot = dataroot
        self.test_mode = test_mode
        self.pipeline = Compose(pipeline) if pipeline else None
        self.data_infos = self.load_annotations()

    def load_annotations(self):
        data_infos = []
        with open(self.ann_file, 'r') as fin:
            for line in fin:
                feature, label = line.strip().split(',')
                data_infos.append(dict(
                    feature_path=osp.join(self.dataroot, feature),
                    label_path=osp.join(self.dataroot, label)
                ))
        return data_infos

    def prepare_data(self, idx):
        results = copy.deepcopy(self.data_infos[idx])
        results['feature'] = np.load(results['feature_path'])
        results['label'] = np.load(results['label_path'])

        if self.pipeline:
            results = self.pipeline(results)

        feature = results['feature'].transpose(2, 0, 1).astype(np.float32)
        label = results['label'].transpose(2, 0, 1).astype(np.float32)

        return feature, label, results['label_path']

    def __len__(self):
        return len(self.data_infos)

    def __getitem__(self, idx):
        return self.prepare_data(idx)


def build_dataset(opt):
    opt = opt.copy()

    aug_methods = {
        'Flip': Flip(),
        'Rotation': Rotation(**opt)
    }

    pipeline = None
    if not opt.get('test_mode', False) and 'aug_pipeline' in opt:
        pipeline = [aug_methods[name] for name in opt.pop('aug_pipeline')]

    dataset = CongestionDataset(ann_file=opt.pop('ann_file'),dataroot=opt.pop('dataroot'),pipeline=pipeline,test_mode=opt.get('test_mode', False))

    if opt.get('test_mode', False):
        return DataLoader(dataset=dataset,batch_size=1,shuffle=False,num_workers=1)
    else:
        return IterLoader(
            DataLoader(dataset=dataset,batch_size=opt.pop('batch_size'),shuffle=True,drop_last=True,num_workers=1,pin_memory=True)
        )





In [None]:
def generation_init_weights(module):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1
                                    or classname.find('Linear') != -1):

            if hasattr(m, 'weight') and m.weight is not None:
                nn.init.normal_(m.weight, 0.0, 0.02)
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias, 0)

    module.apply(init_func)

def load_state_dict(module, state_dict, strict=False, logger=None):
    unexpected_keys = []
    all_missing_keys = []
    err_msg = []

    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    def load(module, prefix=''):
        local_metadata = {} if metadata is None else metadata.get(
            prefix[:-1], {})
        module._load_from_state_dict(state_dict, prefix, local_metadata, True,
                                     all_missing_keys, unexpected_keys,
                                     err_msg)
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + '.')

    load(module)
    load = None

    missing_keys = [
        key for key in all_missing_keys if 'num_batches_tracked' not in key
    ]

    if unexpected_keys:
        err_msg.append('unexpected key in source '
                       f'state_dict: {", ".join(unexpected_keys)}\n')
    if missing_keys:
        err_msg.append(
            f'missing keys in source state_dict: {", ".join(missing_keys)}\n')

    if len(err_msg) > 0:
        err_msg.insert(
            0, 'The model and loaded state dict do not match exactly\n')
        err_msg = '\n'.join(err_msg)
        if strict:
            raise RuntimeError(err_msg)
        elif logger is not None:
            logger.warning(err_msg)
        else:
            print(err_msg)
    return missing_keys

class conv(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=True):
        super(conv, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
            nn.InstanceNorm2d(dim_out, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(dim_out, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
            nn.InstanceNorm2d(dim_out, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, input):
        return self.main(input)

class upconv(nn.Module):
    def __init__(self, dim_in, dim_out):
        super(upconv, self).__init__()
        self.main = nn.Sequential(
                nn.ConvTranspose2d(dim_in, dim_out, 4, 2, 1),
                nn.InstanceNorm2d(dim_out, affine=True),
                nn.LeakyReLU(0.2, inplace=True),
                )

    def forward(self, input):
        return self.main(input)

class Encoder(nn.Module):
    def __init__(self, in_dim=3, out_dim=32):
        super(Encoder, self).__init__()
        self.in_dim = in_dim
        self.c1 = conv(in_dim, 32)
        self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)
        self.c2 = conv(32, 64)
        self.pool2 = nn.MaxPool2d(kernel_size=2,stride=2)
        self.c3 = nn.Sequential(
                nn.Conv2d(64, out_dim, 3, 1, 1),
                nn.BatchNorm2d(out_dim),
                nn.Tanh()
                )

    def init_weights(self):
        generation_init_weights(self)


    def forward(self, input):
        h1 = self.c1(input)
        h2 = self.pool1(h1)
        h3 = self.c2(h2)
        h4 = self.pool2(h3)
        h5 = self.c3(h4)
        return h5, h2  # shortpath from 2->7


class Decoder(nn.Module):
    def __init__(self, out_dim=2, in_dim=32):
        super(Decoder, self).__init__()
        self.conv1 = conv(in_dim, 32)
        self.upc1 = upconv(32, 16)
        self.conv2 = conv(16, 16)
        self.upc2 = upconv(32+16, 4)
        self.conv3 =  nn.Sequential(
                nn.Conv2d(4, out_dim, 3, 1, 1),
                nn.Sigmoid()
                )

    def init_weights(self):
        generation_init_weights(self)

    def forward(self, input):
        feature, skip = input
        d1 = self.conv1(feature)
        d2 = self.upc1(d1)
        d3 = self.conv2(d2)
        d4 = self.upc2(torch.cat([d3, skip], dim=1))
        output = self.conv3(d4)  # shortpath from 2->7
        return output


class GPDL(nn.Module):
    def __init__(self,
                 in_channels=3,
                 out_channels=1,
                 **kwargs):
        super().__init__()

        self.encoder = Encoder(in_dim=in_channels)
        self.decoder = Decoder(out_dim=out_channels)

    def forward(self, x):
        x = self.encoder(x)
        return self.decoder(x)

    def init_weights(self, pretrained=None, pretrained_transfer=None, strict=False, **kwargs):
        if isinstance(pretrained, str):
            new_dict = OrderedDict()
            weight = torch.load(pretrained, map_location='cpu')['state_dict']
            for k in weight.keys():
                new_dict[k] = weight[k]
            load_state_dict(self, new_dict, strict=strict, logger=None)
        elif pretrained is None:
            generation_init_weights(self)
        else:
            raise TypeError("'pretrained' must be a str or None. "
                            f'But received {type(pretrained)}.')
def build_model(opt):
    model = GPDL()
    model.init_weights(**opt)
    if opt['test_mode']:
        model.eval()
    return model


In [None]:
def mkdir_or_exist(dir_name, mode=0o777):
    if dir_name == '':
        return
    dir_name = osp.expanduser(dir_name)
    os.makedirs(dir_name, mode=mode, exist_ok=True)


def input_converter(apply_to=None):
    def input_converter_wrapper(old_func):
        @wraps(old_func)
        def new_func(*args, **kwargs):
            args_info = getfullargspec(old_func)
            args_to_cast = args_info.args if apply_to is None else apply_to
            new_args = []
            if args:
                arg_names = args_info.args[:len(args)]
                for i, arg_name in enumerate(arg_names):
                    if arg_name in args_to_cast:
                        new_args.append(tensor2img(args[i]))
                    else:
                        new_args.append(args[i])

            return old_func(*new_args)
        return new_func

    return input_converter_wrapper


@input_converter(apply_to=('img1', 'img2'))
def psnr(img1, img2, crop_border=0):
    assert img1.shape == img2.shape, (
        f'Image shapes are different: {img1.shape}, {img2.shape}.')

    img1, img2 = img1.astype(np.float32), img2.astype(np.float32)
    if crop_border != 0:
        img1 = img1[crop_border:-crop_border, crop_border:-crop_border, None]
        img2 = img2[crop_border:-crop_border, crop_border:-crop_border, None]

    mse_value = np.mean((img1 - img2)**2)
    if mse_value == 0:
        return float('inf')
    return 20. * np.log10(255. / np.sqrt(mse_value))


def _ssim(img1, img2):
    C1 = (0.01 * 255)**2
    C2 = (0.03 * 255)**2

    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    kernel = cv2.getGaussianKernel(11, 1.5)
    window = np.outer(kernel, kernel.transpose())

    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
    mu1_sq = mu1**2
    mu2_sq = mu2**2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) *
                (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
                                       (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()


@input_converter(apply_to=('img1', 'img2'))
def ssim(img1, img2, crop_border=0):
    assert img1.shape == img2.shape, (
        f'Image shapes are different: {img1.shape}, {img2.shape}.')
    if crop_border != 0:
        img1 = img1[crop_border:-crop_border, crop_border:-crop_border, None]
        img2 = img2[crop_border:-crop_border, crop_border:-crop_border, None]

    ssims = []
    for i in range(img1.shape[2]):
        ssims.append(_ssim(img1[..., i], img2[..., i]))
    return np.array(ssims).mean()


@input_converter(apply_to=('img1', 'img2'))
def nrms(img1, img2, crop_border=0):
    assert img1.shape == img2.shape, (
        f'Image shapes are different: {img1.shape}, {img2.shape}.')

    if crop_border != 0:
        img1 = img1[crop_border:-crop_border, crop_border:-crop_border, None]
        img2 = img2[crop_border:-crop_border, crop_border:-crop_border, None]

    nrmse_value = normalized_root_mse(img1.flatten(), img2.flatten(),normalization='min-max')
    if math.isinf(nrmse_value):
        return 0.05
    return nrmse_value



def get_histogram(img):
    h, w = img.shape
    hist = [0.0] * 256
    for i in range(h):
        for j in range(w):
            hist[img[i, j]] += 1
    return np.array(hist) / float(h * w)


def normalize_exposure(img):
    img = img.astype(int)
    hist = get_histogram(img)
    cdf = np.array([sum(hist[:i+1]) for i in range(len(hist))])
    sk = np.uint8(255 * cdf)
    height, width = img.shape
    normalized = np.zeros_like(img)
    for i in range(0, height):
        for j in range(0, width):
            normalized[i, j] = sk[img[i, j]]
    return normalized.astype(int)


@input_converter(apply_to=('img1', 'img2'))
def emd(img1, img2, crop_border=0):
    assert img1.shape == img2.shape, (
        f'Image shapes are different: {img1.shape}, {img2.shape}.')

    if crop_border != 0:
        img1 = img1[crop_border:-crop_border, crop_border:-crop_border, None]
        img2 = img2[crop_border:-crop_border, crop_border:-crop_border, None]

    img1 = normalize_exposure(np.squeeze(img1, axis = 2))
    img2 = normalize_exposure(np.squeeze(img2, axis = 2))
    hist_1 = get_histogram(img1)
    hist_2 = get_histogram(img2)

    emd_value = wasserstein_distance(hist_1, hist_2)
    return emd_value

def tpr(tp, fn):
    return tp/(tp+fn)

def fpr(fp, tn):
    return fp/(fp+tn)

def precision(tp, fp):
    return tp/(tp+fp)

def calculate_all(csv_path):
    tpr_sum_List = []
    fpr_sum_List = []
    precision_sum_List = []
    threshold_remain_list = []
    num = 0
    tpr_sum = 0
    fpr_sum = 0
    precision_sum = 0

    csv_file = open(os.path.join(csv_path), 'r')

    first_flag = False
    for line in csv_file:
        threshold, idx, tn, fp, fn, tp = line.strip().split(',')
        if threshold not in threshold_remain_list:
            if first_flag:
                if num !=0:
                    tpr_sum_List.append(tpr_sum/num)
                    fpr_sum_List.append(fpr_sum/num)
                    precision_sum_List.append(precision_sum/num)
            threshold_remain_list.append(threshold)
            tpr_sum = 0
            fpr_sum = 0
            precision_sum = 0
            num = 0
            first_flag = True

        if int(fp)==0 and int(tn)==0:
            continue
        elif int(tp)==0 and int(fn)==0:
            continue
        elif int(tp)==0 and int(fp)==0:
            continue
        else:
            tpr_sum += tpr(int(tp), int(fn))
            fpr_sum += fpr(int(fp), int(tn))
            precision_sum += precision(int(tp), int(fp))
            num += 1
    if num !=0:
        tpr_sum_List.append(tpr_sum/num)
        fpr_sum_List.append(fpr_sum/num)
        precision_sum_List.append(precision_sum/num)


    return tpr_sum_List, fpr_sum_List, precision_sum_List


def calculated_score(threshold_idx=None,
                     temp_path=None,
                     label_path=None,
                     save_path=None,
                     threshold_label=None,
                     preds=None):
    file = open(os.path.join(temp_path, f'tpr_fpr_{threshold_idx}.csv'),'w')
    f_csv = csv.writer(file, delimiter=',')
    for idx, pred in enumerate(preds):
        target_test = np.load(os.path.join(label_path, pred)).reshape(-1)
        target_probabilities = np.load(os.path.join(save_path, 'test_result', pred)).reshape(-1)

        target_test[target_test>=threshold_label] = 1
        target_test[target_test<threshold_label] = 0

        target_probabilities[target_probabilities>=threshold_idx] = 1
        target_probabilities[target_probabilities<threshold_idx] = 0

        if np.sum(target_probabilities == 0)==0 and np.sum(target_test == 0)==0:
            tp = 256*256
            tn, fn, fp = 0,0,0
        elif np.sum(target_probabilities == 1)==0 and np.sum(target_test == 1)==0:
            tn = 256*256
            tp, fn, fp = 0,0,0
        else:
            tn, fp, fn, tp = confusion_matrix(target_test, target_probabilities).ravel()

        f_csv.writerow([str(threshold_idx)]+[str(i) for i in [idx, tn, fp, fn, tp]])


    print(f'{threshold_idx}-done')

def multi_process_score(out_name=None, threshold=0.0, label_path=None, save_path=None):
    uid = str(uuid.uuid4())
    suid = ''.join(uid.split('-'))
    temp_path = f'./{suid}'

    psutil.cpu_percent(None)
    time.sleep(0.5)
    pool = mul.Pool(int(mul.cpu_count()*(1-psutil.cpu_percent(None)/100.0)))

    preds = scandir(os.path.join(save_path, 'test_result'), suffix='npy', recursive=True)
    preds = [v for v in preds]

    if not os.path.exists(temp_path):
        os.makedirs(temp_path)

    threshold_list = np.linspace(0, 1, endpoint=False, num=200)

    calculated_score_parital = partial(calculated_score, temp_path=temp_path,
                                        label_path=label_path, save_path=save_path, threshold_label=threshold, preds=preds)
    rel = pool.map(calculated_score_parital, threshold_list)

    print(f'{suid}')

    for list_i in threshold_list:
        fr=open(os.path.join(temp_path, f'tpr_fpr_{list_i}.csv'), 'r').read()
        with open(os.path.join(temp_path, f'{out_name}'), 'a') as f:
            f.write(fr)
        f.close()


    print('copying')
    os.system('cp {} {}'.format(os.path.join(temp_path, f'{out_name}'), os.path.join(os.path.join(os.getcwd(), save_path), f'{out_name}')))

    print('remove temp files')
    os.system(f'rm -rf {temp_path}')

def get_sorted_list(fpr_sum_List,tpr_sum_List):
    fpr_list = []
    tpr_list = []
    for i, j in zip(fpr_sum_List, tpr_sum_List):
        if i not in fpr_list:
            fpr_list.append(i)
            tpr_list.append(j)

    fpr_list.reverse()
    tpr_list.reverse()
    fpr_list, tpr_list = zip(*sorted(zip(fpr_list, tpr_list)))
    return fpr_list, tpr_list


def roc_prc(save_path):
    tpr_sum_List, fpr_sum_List, precision_sum_List = calculate_all(os.path.join(os.getcwd(), save_path, 'roc_prc.csv'))

    fpr_list, tpr_list = get_sorted_list(fpr_sum_List,tpr_sum_List)
    fpr_list = list(fpr_list)
    fpr_list.extend([1])

    tpr_list = list(tpr_list)
    tpr_list.extend([1])

    roc_numerator = 0
    for i in range(len(tpr_list)-1):
        roc_numerator += (tpr_list[i]+tpr_list[i+1])*(fpr_list[i+1]-fpr_list[i])/2

    tpr_list, p_list = get_sorted_list(tpr_sum_List, precision_sum_List)
    x_smooth = np.linspace(0, 1, 25)
    y_smooth = make_interp_spline(tpr_list, p_list, k=3)(x_smooth)

    prc_numerator = 0
    for i in range(len(y_smooth)-1):
        prc_numerator += (y_smooth[i]+y_smooth[i+1])*(x_smooth[i+1]-x_smooth[i])/2

    return roc_numerator, prc_numerator



def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
    if not (torch.is_tensor(tensor) or
            (isinstance(tensor, list)
             and all(torch.is_tensor(t) for t in tensor))):
        raise TypeError(
            f'tensor or list of tensors expected, got {type(tensor)}')

    if torch.is_tensor(tensor):
        tensor = [tensor]
    result = []
    for _tensor in tensor:
        _tensor = _tensor.squeeze(0).squeeze(0)
        _tensor = _tensor.float().detach().cpu().clamp_(*min_max)
        _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
        n_dim = _tensor.dim()

        if n_dim == 3:
            img_np = _tensor.numpy()
            img_np = np.transpose(img_np[:, :, :], (2, 0, 1))
            # img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
        elif n_dim == 2:
            img_np = _tensor.numpy()[..., None]
        else:
            raise ValueError('Only support 4D, 3D or 2D tensor. '
                             f'But received with dimension: {n_dim}')
        if out_type == np.uint8:
            img_np = (img_np * 255.0).round()
        img_np = img_np.astype(out_type)
        result.append(img_np)
    result = result[0] if len(result) == 1 else result
    return result


def build_metric(metric_name):

    return globals()[metric_name.lower()]


def build_roc_prc_metric(threshold=None, dataroot=None, ann_file=None, save_path=None, **kwargs):
    if ann_file:
        with open(ann_file, 'r') as fin:
            for line in fin:
                if len(line.strip().split(',')) == 2:
                    feature, label = line.strip().split(',')
                else:
                    label = line.strip().split(',')[-1]
                break

        label_name = label.split('/')[0]
    else:
        raise FileExistsError
    print(os.path.join(dataroot, label_name))
    multi_process_score(out_name='roc_prc.csv', threshold=threshold, label_path=os.path.join(dataroot, label_name), save_path=os.path.join('.', save_path))

    return roc_prc(save_path)


In [None]:
def test(arg_dict):
    arg_dict = arg_dict.copy()

    arg_dict['ann_file'] = arg_dict['ann_file_test']
    arg_dict['test_mode'] = True
    arg_dict["out_channels"] = 1


    print('===> Loading datasets')
    dataset = build_dataset(arg_dict)

    print('===> Building model')
    model = build_model(arg_dict)
    if not arg_dict['cpu']:
        model = model.cuda()
    model.eval()

    metrics = {k: build_metric(k) for k in arg_dict['eval_metric']}
    avg_metrics = {k: 0.0 for k in arg_dict['eval_metric']}

    with torch.no_grad(), tqdm(total=len(dataset)) as bar:
        for feature, label, label_path in dataset:
            if arg_dict['cpu']:
                input, target = feature, label
            else:
                input, target = feature.cuda(), label.cuda()

            prediction = model(input)

        for metric, metric_func in metrics.items():
            avg_metrics[metric] += metric_func(
                target.cpu(),
                prediction.cpu()
            )


            bar.update(1)

    for metric, value in avg_metrics.items():
        print(f"===> Avg. {metric}: {value / len(dataset):.4f}")


In [None]:
test_args = {
    "task": "congestion_gpdl",
    "save_path": "work_dir/",
    "ann_file_test": "/content/test_N28.csv",
    "dataroot": "/content/drive/MyDrive/congestion",
    "batch_size": 1,
    "cpu": True,
    "eval_metric": ["NRMS", "SSIM"],
    "plot_roc": False,
    "pretrained": "/content/model_iters_200000.pth",
    "out_channels": 1,
}

test(test_args)


===> Loading datasets
===> Building model
The model and loaded state dict do not match exactly

size mismatch for decoder.conv3.0.weight: copying a param with shape torch.Size([1, 4, 3, 3]) from checkpoint, the shape in current model is torch.Size([2, 4, 3, 3]).
size mismatch for decoder.conv3.0.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([2]).


  0%|          | 0/3164 [00:00<?, ?it/s]

In [None]:
from google.colab import drive
drive.mount('/content/drive')