In [35]:
# !pip install mmcv
# !pip install argparse

# import mmcv
import numpy as np
import os.path as osp
import copy
import time
import argparse
import os
import sys
import json
import cv2
import numpy as np
import torch
import multiprocessing as mul
import uuid
import psutil
import time
import csv
import math
import torch.optim as optim


import functools

import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from sklearn.metrics import accuracy_score, roc_curve, confusion_matrix
from scipy.interpolate import make_interp_spline
from functools import partial
# from mmcv import scandir

from scipy.stats import wasserstein_distance
from skimage.metrics import normalized_root_mse

from __future__ import print_function



In [22]:
from google.colab import drive
drive.mount('/content/drive')



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [36]:
sys.path.append(os.getcwd())


class Parser(object):
    def __init__(self):
        self.parser = argparse.ArgumentParser(add_help=False)


        self.parser.add_argument('--task', default='congestion_gpdl')
        self.parser.add_argument('--save_path', default='work_dir/congestion_gpdl/')
        self.parser.add_argument('--pretrained', default=None)
        self.parser.add_argument('--max_iters', type=int, default=200000)
        self.parser.add_argument('--plot_roc', action='store_true')
        self.parser.add_argument('--arg_file', default=None)
        self.parser.add_argument('--cpu', action='store_true')

    def parse_args(self):
        args, _ = self.parser.parse_known_args()

        self._add_task_args(args.task)

        args, _ = self.parser.parse_known_args()
        return args

    def _add_task_args(self, task):
        if task == 'congestion_gpdl':
            self.parser.add_argument('--dataroot', default='../../training_set/congestion')
            self.parser.add_argument('--ann_file_train', default='./files/train_N28.csv')
            self.parser.add_argument('--ann_file_test', default='./files/test_N28.csv')
            self.parser.add_argument('--dataset_type', default='CongestionDataset')
            self.parser.add_argument('--batch_size', type=int, default=16)
            self.parser.add_argument('--aug_pipeline', default=['Flip'])

            self.parser.add_argument('--model_type', default='GPDL')
            self.parser.add_argument('--in_channels', type=int, default=3)
            self.parser.add_argument('--out_channels', type=int, default=1)
            self.parser.add_argument('--lr', type=float, default=2e-4)
            self.parser.add_argument('--weight_decay', type=float, default=0)
            self.parser.add_argument('--loss_type', default='MSELoss')
            self.parser.add_argument('--eval_metric', default=['NRMS', 'SSIM', 'EMD'])

        elif task == 'drc_routenet':
            self.parser.add_argument('--dataroot', default='../../training_set/DRC')
            self.parser.add_argument('--ann_file_train', default='./files/train_N28.csv')
            self.parser.add_argument('--ann_file_test', default='./files/test_N28.csv')
            self.parser.add_argument('--dataset_type', default='DRCDataset')
            self.parser.add_argument('--batch_size', type=int, default=8)
            self.parser.add_argument('--aug_pipeline', default=['Flip'])

            self.parser.add_argument('--model_type', default='RouteNet')
            self.parser.add_argument('--in_channels', type=int, default=9)
            self.parser.add_argument('--out_channels', type=int, default=1)
            self.parser.add_argument('--lr', type=float, default=2e-4)
            self.parser.add_argument('--weight_decay', type=float, default=1e-4)
            self.parser.add_argument('--loss_type', default='MSELoss')
            self.parser.add_argument('--eval_metric', default=['NRMS', 'SSIM'])
            self.parser.add_argument('--threshold', type=float, default=0.1)

        elif task == 'irdrop_mavi':
            self.parser.add_argument('--dataroot', default='../../training_set/IR_drop')
            self.parser.add_argument('--ann_file_train', default='./files/train_N28.csv')
            self.parser.add_argument('--ann_file_test', default='./files/test_N28.csv')
            self.parser.add_argument('--dataset_type', default='IRDropDataset')
            self.parser.add_argument('--batch_size', type=int, default=2)

            self.parser.add_argument('--model_type', default='MAVI')
            self.parser.add_argument('--in_channels', type=int, default=1)
            self.parser.add_argument('--out_channels', type=int, default=4)
            self.parser.add_argument('--lr', type=float, default=2e-4)
            self.parser.add_argument('--weight_decay', type=float, default=1e-2)
            self.parser.add_argument('--loss_type', default='L1Loss')
            self.parser.add_argument('--eval_metric', default=['NRMS', 'SSIM'])
            self.parser.add_argument('--threshold', type=float, default=0.9885)

        else:
            raise ValueError(f"Unknown task: {task}")






class Flip:
    _directions = ['horizontal', 'vertical']

    def __init__(self, keys=['feature', 'label'], flip_ratio=0.5, direction='horizontal', **kwargs):
        if direction not in self._directions:
            raise ValueError(f'Direction {direction} is not supported.'
                             f'Currently support ones are {self._directions}')
        self.keys = keys
        self.flip_ratio = flip_ratio
        self.direction = direction

    def __call__(self, results):
        flip = np.random.random() < self.flip_ratio

        if flip:
            for key in self.keys:
                if isinstance(results[key], list):
                    for v in results[key]:
                        mmcv.imflip_(v, self.direction)
                else:
                    mmcv.imflip_(results[key], self.direction)

        return results



class Rotation:
    def __init__(self, keys=['feature', 'label'], axis=(0,1), rotate_ratio=0.5, **kwargs):
        self.keys = keys
        self.axis = {k:axis for k in keys} if isinstance(axis, tuple) else axis
        self.rotate_ratio = rotate_ratio
        self.direction = [0, -1, -2, -3]

    def __call__(self, results):
        rotate = np.random.random() < self.rotate_ratio

        if rotate:
            rotate_angle = self.direction[int(np.random.random()/(10.0/3.0))+1]
            for key in self.keys:
                if isinstance(results[key], list):
                    for v in results[key]:
                        results[key] = np.ascontiguousarray(np.rot90(v, rotate_angle, axes=self.axis[key]))
                else:
                    results[key] = np.ascontiguousarray(np.rot90(results[key], rotate_angle, axes=self.axis[key]))

        return results

class IterLoader:
    def __init__(self, dataloader):
        self._dataloader = dataloader
        self.iter_loader = iter(self._dataloader)

    def __next__(self):
        try:
            data = next(self.iter_loader)
        except StopIteration:
            time.sleep(2)
            self.iter_loader = iter(self._dataloader)
            data = next(self.iter_loader)
        return data

    def __len__(self):
        return len(self._dataloader)

    def __iter__(self):
        return self

class CongestionDataset(object):

    def __init__(self, ann_file, dataroot, pipeline=None, test_mode=False):
        self.ann_file = ann_file
        self.dataroot = dataroot
        self.test_mode = test_mode
        self.pipeline = Compose(pipeline) if pipeline else None
        self.data_infos = self.load_annotations()

    def load_annotations(self):
        data_infos = []
        with open(self.ann_file, 'r') as fin:
            for line in fin:
                feature, label = line.strip().split(',')
                data_infos.append(dict(
                    feature_path=osp.join(self.dataroot, feature),
                    label_path=osp.join(self.dataroot, label)
                ))
        return data_infos

    def prepare_data(self, idx):
        results = copy.deepcopy(self.data_infos[idx])
        results['feature'] = np.load(results['feature_path'])
        results['label'] = np.load(results['label_path'])

        if self.pipeline:
            results = self.pipeline(results)

        feature = results['feature'].transpose(2, 0, 1).astype(np.float32)
        label = results['label'].transpose(2, 0, 1).astype(np.float32)

        return feature, label, results['label_path']

    def __len__(self):
        return len(self.data_infos)

    def __getitem__(self, idx):
        return self.prepare_data(idx)


def build_dataset(opt):
    opt = opt.copy()

    aug_methods = {
        'Flip': Flip(),
        'Rotation': Rotation(**opt)
    }

    pipeline = None
    if not opt.get('test_mode', False) and 'aug_pipeline' in opt:
        pipeline = [aug_methods[name] for name in opt.pop('aug_pipeline')]

    dataset = CongestionDataset(ann_file=opt.pop('ann_file'),dataroot=opt.pop('dataroot'),pipeline=pipeline,test_mode=opt.get('test_mode', False))

    if opt.get('test_mode', False):
        return DataLoader(dataset=dataset,batch_size=1,shuffle=False,num_workers=1)
    else:
        return IterLoader(
            DataLoader(dataset=dataset,batch_size=opt.pop('batch_size'),shuffle=True,drop_last=True,num_workers=1,pin_memory=True)
        )





In [37]:
def generation_init_weights(module):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1
                                    or classname.find('Linear') != -1):

            if hasattr(m, 'weight') and m.weight is not None:
                nn.init.normal_(m.weight, 0.0, 0.02)
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias, 0)

    module.apply(init_func)

def load_state_dict(module, state_dict, strict=False, logger=None):
    unexpected_keys = []
    all_missing_keys = []
    err_msg = []

    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    def load(module, prefix=''):
        local_metadata = {} if metadata is None else metadata.get(
            prefix[:-1], {})
        module._load_from_state_dict(state_dict, prefix, local_metadata, True,
                                     all_missing_keys, unexpected_keys,
                                     err_msg)
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + '.')

    load(module)
    load = None

    missing_keys = [
        key for key in all_missing_keys if 'num_batches_tracked' not in key
    ]

    if unexpected_keys:
        err_msg.append('unexpected key in source '
                       f'state_dict: {", ".join(unexpected_keys)}\n')
    if missing_keys:
        err_msg.append(
            f'missing keys in source state_dict: {", ".join(missing_keys)}\n')

    if len(err_msg) > 0:
        err_msg.insert(
            0, 'The model and loaded state dict do not match exactly\n')
        err_msg = '\n'.join(err_msg)
        if strict:
            raise RuntimeError(err_msg)
        elif logger is not None:
            logger.warning(err_msg)
        else:
            print(err_msg)
    return missing_keys

class conv(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=True):
        super(conv, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
            nn.InstanceNorm2d(dim_out, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(dim_out, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
            nn.InstanceNorm2d(dim_out, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, input):
        return self.main(input)

class upconv(nn.Module):
    def __init__(self, dim_in, dim_out):
        super(upconv, self).__init__()
        self.main = nn.Sequential(
                nn.ConvTranspose2d(dim_in, dim_out, 4, 2, 1),
                nn.InstanceNorm2d(dim_out, affine=True),
                nn.LeakyReLU(0.2, inplace=True),
                )

    def forward(self, input):
        return self.main(input)

class Encoder(nn.Module):
    def __init__(self, in_dim=3, out_dim=32):
        super(Encoder, self).__init__()
        self.in_dim = in_dim
        self.c1 = conv(in_dim, 32)
        self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)
        self.c2 = conv(32, 64)
        self.pool2 = nn.MaxPool2d(kernel_size=2,stride=2)
        self.c3 = nn.Sequential(
                nn.Conv2d(64, out_dim, 3, 1, 1),
                nn.BatchNorm2d(out_dim),
                nn.Tanh()
                )

    def init_weights(self):
        generation_init_weights(self)


    def forward(self, input):
        h1 = self.c1(input)
        h2 = self.pool1(h1)
        h3 = self.c2(h2)
        h4 = self.pool2(h3)
        h5 = self.c3(h4)
        return h5, h2  # shortpath from 2->7


class Decoder(nn.Module):
    def __init__(self, out_dim=2, in_dim=32):
        super(Decoder, self).__init__()
        self.conv1 = conv(in_dim, 32)
        self.upc1 = upconv(32, 16)
        self.conv2 = conv(16, 16)
        self.upc2 = upconv(32+16, 4)
        self.conv3 =  nn.Sequential(
                nn.Conv2d(4, out_dim, 3, 1, 1),
                nn.Sigmoid()
                )

    def init_weights(self):
        generation_init_weights(self)

    def forward(self, input):
        feature, skip = input
        d1 = self.conv1(feature)
        d2 = self.upc1(d1)
        d3 = self.conv2(d2)
        d4 = self.upc2(torch.cat([d3, skip], dim=1))
        output = self.conv3(d4)  # shortpath from 2->7
        return output


class GPDL(nn.Module):
    def __init__(self,
                 in_channels=3,
                 out_channels=2,
                 **kwargs):
        super().__init__()

        self.encoder = Encoder(in_dim=in_channels)
        self.decoder = Decoder(out_dim=out_channels)

    def forward(self, x):
        x = self.encoder(x)
        return self.decoder(x)

    def init_weights(self, pretrained=None, pretrained_transfer=None, strict=False, **kwargs):
        if isinstance(pretrained, str):
            new_dict = OrderedDict()
            weight = torch.load(pretrained, map_location='cpu')['state_dict']
            for k in weight.keys():
                new_dict[k] = weight[k]
            load_state_dict(self, new_dict, strict=strict, logger=None)
        elif pretrained is None:
            generation_init_weights(self)
        else:
            raise TypeError("'pretrained' must be a str or None. "
                            f'But received {type(pretrained)}.')
def build_model(opt):
    model = GPDL()
    model.init_weights(**opt)
    if opt['test_mode']:
        model.eval()
    return model


In [38]:
def reduce_loss(loss, reduction):
    if reduction == 'none':
        return loss
    elif reduction == 'mean':
        return loss.mean()
    elif reduction == 'sum':
        return loss.sum()
    else:
        raise ValueError(f"Invalid reduction: {reduction}")


def mask_reduce_loss(loss, weight=None, reduction='mean', sample_wise=False):
    if weight is not None:
        loss = loss * weight

    if reduction == 'sum':
        return loss.sum()

    if reduction == 'mean':
        eps = 1e-12
        if weight is None:
            return loss.mean()

        if sample_wise:
            # Normalize per-sample
            weight_sum = weight.sum(dim=[1, 2, 3], keepdim=True)
            return (loss / (weight_sum + eps)).sum() / weight.size(0)
        else:
            return loss.sum() / (weight.sum() + eps)

    return loss


def masked_loss(loss_func):
    @functools.wraps(loss_func)
    def wrapper(pred, target, weight=None, reduction='mean', sample_wise=False):
        loss = loss_func(pred, target)
        return mask_reduce_loss(loss, weight, reduction, sample_wise)
    return wrapper


@masked_loss
def mse_loss(pred, target):
    return F.mse_loss(pred, target, reduction='none')


class MSELoss(nn.Module):
    def __init__(self, loss_weight=100.0, reduction='mean', sample_wise=False):
        super().__init__()
        self.loss_weight = loss_weight
        self.reduction = reduction
        self.sample_wise = sample_wise

    def forward(self, pred, target, weight=None):
        return self.loss_weight * mse_loss(
            pred,
            target,
            weight=weight,
            reduction=self.reduction,
            sample_wise=self.sample_wise
        )
def build_loss(opt):
    # opt = opt.copy()
    loss_type = "MSELoss"
    loss_cls = globals()[loss_type]

    return torch.nn.MSELoss()



In [47]:

def checkpoint(model, epoch, save_path):
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    model_out_path = f"./{save_path}/model_iters_{epoch}.pth"
    torch.save({'state_dict': model.state_dict()}, model_out_path)


class CosineRestartLr(object):

    def __init__(self,
                 base_lr,
                 periods,
                 restart_weights = [1],
                 min_lr = None,
                 min_lr_ratio = None):
        self.periods = periods
        self.min_lr = min_lr
        self.min_lr_ratio = min_lr_ratio
        self.restart_weights = restart_weights
        super().__init__()

        self.cumulative_periods = [
            sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
        ]

        self.base_lr = base_lr

    def annealing_cos(self, start: float,
                    end: float,
                    factor: float,
                    weight: float = 1.) -> float:
        cos_out = math.cos(math.pi * factor) + 1
        return end + 0.5 * weight * (start - end) * cos_out

    def get_position_from_periods(self, iteration: int, cumulative_periods):
        for i, period in enumerate(cumulative_periods):
            if iteration < period:
                return i
        raise ValueError(f'Current iteration {iteration} exceeds '
                        f'cumulative_periods {cumulative_periods}')


    def get_lr(self, iter_num, base_lr: float):
        target_lr = self.min_lr  # type:ignore

        idx = self.get_position_from_periods(iter_num, self.cumulative_periods)
        current_weight = self.restart_weights[idx]
        nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1]
        current_periods = self.periods[idx]

        alpha = min((iter_num - nearest_restart) / current_periods, 1)
        return self.annealing_cos(base_lr, target_lr, alpha, current_weight)


    def _set_lr(self, optimizer, lr_groups):
        if isinstance(optimizer, dict):
            for k, optim in optimizer.items():
                for param_group, lr in zip(optim.param_groups, lr_groups[k]):
                    param_group['lr'] = lr
        else:
            for param_group, lr in zip(optimizer.param_groups,
                                        lr_groups):
                param_group['lr'] = lr

    def get_regular_lr(self, iter_num):
        return [self.get_lr(iter_num, _base_lr) for _base_lr in self.base_lr]  # iters

    def set_init_lr(self, optimizer):
        for group in optimizer.param_groups:  # type: ignore
            group.setdefault('initial_lr', group['lr'])
            self.base_lr = [group['initial_lr'] for group in optimizer.param_groups  # type: ignore
        ]


def train(arg_dict):

    device = torch.device("cpu" if arg_dict.get("cpu", False) else "cuda")
    print("Using device:", device)

    os.makedirs(arg_dict["save_path"], exist_ok=True)

    with open(os.path.join(arg_dict["save_path"], "arg.json"), "w") as f:
        json.dump(arg_dict, f, indent=4)

    arg_dict["ann_file"] = arg_dict["ann_file_train"]
    arg_dict["test_mode"] = False


    print("===> Loading datasets")
    dataset = build_dataset(arg_dict)

    if not hasattr(dataset, "__iter__") or not hasattr(dataset, "__len__"):
        dataset = DataLoader(dataset, batch_size=arg_dict["batch_size"], shuffle=True)

    print("===> Building model")
    model = build_model(arg_dict).to(device)

    criterion = build_loss(arg_dict)

    optimizer = optim.AdamW(
        model.parameters(),
        lr=arg_dict["lr"],
        betas=(0.9, 0.999),
        weight_decay=arg_dict["weight_decay"]
    )

    cosine_lr = CosineRestartLr(
        arg_dict["lr"],
        [arg_dict["max_iters"]],
        [1],
        min_lr=1e-7
    )
    cosine_lr.set_init_lr(optimizer)


    iter_num = 0
    epoch_loss = 0.0
    print_freq = 100
    save_freq = 10000

    while iter_num < arg_dict["max_iters"]:
        with tqdm(total=print_freq) as bar:
          for feature, label, _ in dataset:
                feature = feature.to(device)
                label = label.to(device)

                lr = cosine_lr.get_regular_lr(iter_num)
                cosine_lr._set_lr(optimizer, lr)

                optimizer.zero_grad()
                prediction = model(feature)
                loss = criterion(prediction, label)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                iter_num += 1
                bar.update(1)

                if iter_num % print_freq == 0:
                    break

        print(
            f"===> Iters[{iter_num}/{arg_dict['max_iters']}]: "
            f"Loss: {epoch_loss / print_freq:.4f}"
        )

        if iter_num % save_freq == 0:
            checkpoint(model, iter_num, arg_dict["save_path"])

        epoch_loss = 0.0


In [None]:
args = {
    "task": "congestion_gpdl",
    "save_path": "work_dir/",
    "ann_file_train": "/content/train_N28.csv",
    "dataroot": "/content/drive/MyDrive/congestion",
    "lr": 1e-4,
    "weight_decay": 1e-4,
    "batch_size": 1,
    "max_iters": 200,
    "cpu": True,
}

train(args)


Using device: cpu
===> Loading datasets
===> Building model


100%|██████████| 100/100 [01:03<00:00,  1.59it/s]


===> Iters[100/200]: Loss: 0.1381


  2%|▏         | 2/100 [00:00<00:42,  2.33it/s]