In [None]:
from google.colab import drive
import torch
import torch.nn.functional as F
import numpy as np


drive.mount(‘/content/gdrive’)
class DiffAug():
    def __init__(self,
                 strategy='color_crop_cutout_flip_scale_rotate',
                 batch=False,
                 ratio_cutout=0.5,
                 single=False):
        self.prob_flip = 0.5
        self.ratio_scale = 1.2
        self.ratio_rotate = 15.0
        self.ratio_crop_pad = 0.125
        self.ratio_cutout = ratio_cutout
        self.ratio_noise = 0.05
        self.brightness = 1.0
        self.saturation = 2.0
        self.contrast = 0.5

        self.batch = batch

        self.aug = True
        if strategy == '' or strategy.lower() == 'none':
            self.aug = False
        else:
            self.strategy = []
            self.flip = False
            self.color = False
            self.cutout = False
            for aug in strategy.lower().split('_'):
                if aug == 'flip' and single == False:
                    self.flip = True
                elif aug == 'color' and single == False:
                    self.color = True
                elif aug == 'cutout' and single == False:
                    self.cutout = True
                else:
                    self.strategy.append(aug)

        self.aug_fn = {
            'color': [self.brightness_fn, self.saturation_fn, self.contrast_fn],
            'crop': [self.crop_fn],
            'cutout': [self.cutout_fn],
            'flip': [self.flip_fn],
            'scale': [self.scale_fn],
            'rotate': [self.rotate_fn],
            'translate': [self.translate_fn],
        }

    def __call__(self, x, single_aug=True, seed=-1):
        if not self.aug:
            return x
        else:
            if self.flip:
                self.set_seed(seed)
                x = self.flip_fn(x, self.batch)
            if self.color:
                for f in self.aug_fn['color']:
                    self.set_seed(seed)
                    x = f(x, self.batch)
            if len(self.strategy) > 0:
                if single_aug:
                    # single
                    idx = np.random.randint(len(self.strategy))
                    p = self.strategy[idx]
                    for f in self.aug_fn[p]:
                        self.set_seed(seed)
                        x = f(x, self.batch)
                else:
                    # multiple
                    for p in self.strategy:
                        for f in self.aug_fn[p]:
                            self.set_seed(seed)
                            x = f(x, self.batch)
            if self.cutout:
                self.set_seed(seed)
                x = self.cutout_fn(x, self.batch)

            x = x.contiguous()
            return x

    def set_seed(self, seed):
        if seed > 0:
            np.random.seed(seed)
            torch.random.manual_seed(seed)

    def scale_fn(self, x, batch=True):
        # x>1, max scale
        # sx, sy: (0, +oo), 1: orignial size, 0.5: enlarge 2 times
        ratio = self.ratio_scale

        if batch:
            sx = np.random.uniform() * (ratio - 1.0 / ratio) + 1.0 / ratio
            sy = np.random.uniform() * (ratio - 1.0 / ratio) + 1.0 / ratio
            theta = [[sx, 0, 0], [0, sy, 0]]
            theta = torch.tensor(theta, dtype=torch.float, device=x.device)
            theta = theta.expand(x.shape[0], 2, 3)
        else:
            sx = np.random.uniform(size=x.shape[0]) * (ratio - 1.0 / ratio) + 1.0 / ratio
            sy = np.random.uniform(size=x.shape[0]) * (ratio - 1.0 / ratio) + 1.0 / ratio
            theta = [[[sx[i], 0, 0], [0, sy[i], 0]] for i in range(x.shape[0])]
            theta = torch.tensor(theta, dtype=torch.float, device=x.device)

        grid = F.affine_grid(theta, x.shape)
        x = F.grid_sample(x, grid)
        return x

    def rotate_fn(self, x, batch=True):
        # [-180, 180], 90: anticlockwise 90 degree
        ratio = self.ratio_rotate

        if batch:
            theta = (np.random.uniform() - 0.5) * 2 * ratio / 180 * float(np.pi)
            theta = [[np.cos(theta), np.sin(-theta), 0], [np.sin(theta), np.cos(theta), 0]]
            theta = torch.tensor(theta, dtype=torch.float, device=x.device)
            theta = theta.expand(x.shape[0], 2, 3)
        else:
            theta = (np.random.uniform(size=x.shape[0]) - 0.5) * 2 * ratio / 180 * float(np.pi)
            theta = [[[np.cos(theta[i]), np.sin(-theta[i]), 0],
                      [np.sin(theta[i]), np.cos(theta[i]), 0]] for i in range(x.shape[0])]
            theta = torch.tensor(theta, dtype=torch.float, device=x.device)

        grid = F.affine_grid(theta, x.shape)
        x = F.grid_sample(x, grid)
        return x

    def flip_fn(self, x, batch=True):
        prob = self.prob_flip

        if batch:
            coin = np.random.uniform()
            if coin < prob:
                return x.flip(3)
            else:
                return x
        else:
            randf = torch.rand(x.size(0), 1, 1, 1, device=x.device)
            return torch.where(randf < prob, x.flip(3), x)

    def brightness_fn(self, x, batch=True):
        # mean
        ratio = self.brightness

        if batch:
            randb = np.random.uniform()
        else:
            randb = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
        x = x + (randb - 0.5) * ratio
        return x

    def saturation_fn(self, x, batch=True):
        # channel concentration
        ratio = self.saturation

        x_mean = x.mean(dim=1, keepdim=True)
        if batch:
            rands = np.random.uniform()
        else:
            rands = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
        x = (x - x_mean) * (rands * ratio) + x_mean
        return x

    def contrast_fn(self, x, batch=True):
        # spatially concentrating
        ratio = self.contrast

        x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
        if batch:
            randc = np.random.uniform()
        else:
            randc = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
        x = (x - x_mean) * (randc + ratio) + x_mean
        return x

    def translate_fn(self, x, batch=True):
        ratio = self.ratio_crop_pad

        shift_y = int(x.size(3) * ratio + 0.5)
        if batch:
            translation_y = np.random.randint(-shift_y, shift_y + 1)
        else:
            translation_y = torch.randint(-shift_y,
                                          shift_y + 1,
                                          size=[x.size(0), 1, 1],
                                          device=x.device)

        grid_batch, grid_x, grid_y = torch.meshgrid(
            torch.arange(x.size(0), dtype=torch.long, device=x.device),
            torch.arange(x.size(2), dtype=torch.long, device=x.device),
            torch.arange(x.size(3), dtype=torch.long, device=x.device),
        )
        grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
        x_pad = F.pad(x, (1, 1))
        x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
        return x

    def crop_fn(self, x, batch=True):
        # The image is padded on its surrounding and then cropped.
        ratio = self.ratio_crop_pad

        shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
        if batch:
            translation_x = np.random.randint(-shift_x, shift_x + 1)
            translation_y = np.random.randint(-shift_y, shift_y + 1)
        else:
            translation_x = torch.randint(-shift_x,
                                          shift_x + 1,
                                          size=[x.size(0), 1, 1],
                                          device=x.device)

            translation_y = torch.randint(-shift_y,
                                          shift_y + 1,
                                          size=[x.size(0), 1, 1],
                                          device=x.device)

        grid_batch, grid_x, grid_y = torch.meshgrid(
            torch.arange(x.size(0), dtype=torch.long, device=x.device),
            torch.arange(x.size(2), dtype=torch.long, device=x.device),
            torch.arange(x.size(3), dtype=torch.long, device=x.device),
        )
        grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
        grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
        x_pad = F.pad(x, (1, 1, 1, 1))
        x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
        return x

    def cutout_fn(self, x, batch=True):
        ratio = self.ratio_cutout
        cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)

        if batch:
            offset_x = np.random.randint(0, x.size(2) + (1 - cutout_size[0] % 2))
            offset_y = np.random.randint(0, x.size(3) + (1 - cutout_size[1] % 2))
        else:
            offset_x = torch.randint(0,
                                     x.size(2) + (1 - cutout_size[0] % 2),
                                     size=[x.size(0), 1, 1],
                                     device=x.device)

            offset_y = torch.randint(0,
                                     x.size(3) + (1 - cutout_size[1] % 2),
                                     size=[x.size(0), 1, 1],
                                     device=x.device)

        grid_batch, grid_x, grid_y = torch.meshgrid(
            torch.arange(x.size(0), dtype=torch.long, device=x.device),
            torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
            torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
        )
        grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
        grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
        mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
        mask[grid_batch, grid_x, grid_y] = 0
        x = x * mask.unsqueeze(1)
        return x

    def cutout_inv_fn(self, x, batch=True):
        ratio = self.ratio_cutout
        cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)

        if batch:
            offset_x = np.random.randint(0, x.size(2) - cutout_size[0])
            offset_y = np.random.randint(0, x.size(3) - cutout_size[1])
        else:
            offset_x = torch.randint(0,
                                     x.size(2) - cutout_size[0],
                                     size=[x.size(0), 1, 1],
                                     device=x.device)
            offset_y = torch.randint(0,
                                     x.size(3) - cutout_size[1],
                                     size=[x.size(0), 1, 1],
                                     device=x.device)

        grid_batch, grid_x, grid_y = torch.meshgrid(
            torch.arange(x.size(0), dtype=torch.long, device=x.device),
            torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
            torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
        )
        grid_x = torch.clamp(grid_x + offset_x, min=0, max=x.size(2) - 1)
        grid_y = torch.clamp(grid_y + offset_y, min=0, max=x.size(3) - 1)
        mask = torch.zeros(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
        mask[grid_batch, grid_x, grid_y] = 1.
        x = x * mask.unsqueeze(1)
        return x

In [None]:

import torch
import random
import numpy as np
import os
import sys
import time
import matplotlib
import matplotlib.pyplot as plt

matplotlib.use('Agg')

__all__ = ["Compose", "Lighting", "ColorJitter"]


def dist_l2(data, target):
    dist = (data**2).sum(-1).unsqueeze(1) + (
        target**2).sum(-1).unsqueeze(0) - 2 * torch.matmul(data, target.transpose(1, 0))
    return dist


def get_time():
    return str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()))


class Logger(object):
    def __init__(self, fpath=None):
        self.console = sys.stdout
        self.file = None
        if fpath is not None:
            self.file = open(fpath, 'w')

    def __del__(self):
        self.close()

    def __enter__(self):
        pass

    def __exit__(self, *args):
        self.close()

    def write(self, msg):
        self.console.write(msg)
        if self.file is not None:
            self.file.write(msg)

    def flush(self):
        self.console.flush()
        if self.file is not None:
            self.file.flush()
            os.fsync(self.file.fileno())

    def close(self):
        self.console.close()
        if self.file is not None:
            self.file.close()


class TimeStamp():
    def __init__(self, print_log=True):
        self.prev = time.time()
        self.print_log = print_log
        self.times = {}

    def set(self):
        self.prev = time.time()

    def flush(self):
        if self.print_log:
            print("\n=========Summary=========")
            for key in self.times.keys():
                times = np.array(self.times[key])
                print(
                    f"{key}: {times.sum():.4f}s (avg {times.mean():.4f}s, std {times.std():.4f}, count {len(times)})"
                )
                self.times[key] = []

    def stamp(self, name=''):
        if self.print_log:
            spent = time.time() - self.prev
            # print(f"{name}: {spent:.4f}s")
            if name in self.times.keys():
                self.times[name].append(spent)
            else:
                self.times[name] = [spent]
            self.set()


def accuracy(output, target, topk=(1, )):
    """Computes the precision@k for the specified values of k"""

    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))

    return res

def Eval_auc(output, target):
    avg_tpr, avg_fpr, avg_auc, img_num = 0.0, 0.0, 0.0, 0.0
    with torch.no_grad():
        trans = transforms.Compose([transforms.ToTensor()])
        for pred, gt in output:
            if self.colabcuda:
                pred = trans(pred).colabcuda()
                #index = torch.arange(n).to(device)
                pred = (pred - torch.min(pred)) / (torch.max(pred) -
                                                    torch.min(pred) + 1e-20)
                gt = trans(gt).colabcuda()
            else:
                pred = trans(pred)
                pred = (pred - torch.min(pred)) / (torch.max(pred) -
                                                    torch.min(pred) + 1e-20)
                gt = trans(gt)
            TPR, FPR = self._eval_roc(pred, gt, 255)
            avg_tpr += TPR
            avg_fpr += FPR
            # img_num += 1.0
        avg_tpr = avg_tpr / img_num
        avg_fpr = avg_fpr / img_num

        sorted_idxes = torch.argsort(avg_fpr)
        avg_tpr = avg_tpr[sorted_idxes]
        avg_fpr = avg_fpr[sorted_idxes]
        avg_auc = torch.trapz(avg_tpr, avg_fpr)

        return avg_auc.item(), avg_tpr, avg_fpr

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class Plotter():
    def __init__(self, path, nepoch, idx=0):
        self.path = path
        self.data = {'epoch': [], 'acc_tr': [], 'acc_val': [], 'loss_tr': [], 'loss_val': []}
        self.nepoch = nepoch
        self.plot_freq = 10
        self.idx = idx

    def update(self, epoch, acc_tr, acc_val, loss_tr, loss_val):
        self.data['epoch'].append(epoch)
        self.data['acc_tr'].append(acc_tr)
        self.data['acc_val'].append(acc_val)
        self.data['loss_tr'].append(loss_tr)
        self.data['loss_val'].append(loss_val)

        if len(self.data['epoch']) % self.plot_freq == 0:
            self.plot()

    def plot(self, color='black'):
        fig, axes = plt.subplots(1, 4, figsize=(4 * 4, 3))
        fig.tight_layout(h_pad=3, w_pad=3)

        fig.suptitle(f"{self.path}", size=16, y=1.1)

        axes[0].plot(self.data['epoch'], self.data['acc_tr'], color, lw=0.8)
        axes[0].set_xlim([0, self.nepoch])
        axes[0].set_ylim([0, 100])
        axes[0].set_title('acc train')

        axes[1].plot(self.data['epoch'], self.data['acc_val'], color, lw=0.8)
        axes[1].set_xlim([0, self.nepoch])
        axes[1].set_ylim([0, 100])
        axes[1].set_title('acc val')

        axes[2].plot(self.data['epoch'], self.data['loss_tr'], color, lw=0.8)
        axes[2].set_xlim([0, self.nepoch])
        axes[2].set_ylim([0, 3])
        axes[2].set_title('loss train')

        axes[3].plot(self.data['epoch'], self.data['loss_val'], color, lw=0.8)
        axes[3].set_xlim([0, self.nepoch])
        axes[3].set_ylim([0, 3])
        axes[3].set_title('loss val')

        for ax in axes:
            ax.set_xlabel('epochs')

        plt.savefig(f'{self.path}/curve_{self.idx}.png', bbox_inches='tight')
        plt.close()


def random_indices(y, nclass=2, intraclass=False, device='cuda'):               #10
    n = len(y)
    if intraclass:
        index = torch.arange(n).to(device)
        for c in range(nclass):
            index_c = index[y == c]
            if len(index_c) > 0:
                randidx = torch.randperm(len(index_c))
                index[y == c] = index_c[randidx]
    else:
        index = torch.randperm(n).to(device)
    return index


def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2


class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string


class Lighting(object):
    """Lighting noise(AlexNet - style PCA - based noise)"""
    def __init__(self, alphastd, eigval, eigvec, device='cpu'):
        self.alphastd = alphastd
        self.eigval = torch.tensor(eigval, device=device)
        self.eigvec = torch.tensor(eigvec, device=device)

    def __call__(self, img):
        if self.alphastd == 0:
            return img

        alpha = img.new().resize_(3).normal_(0, self.alphastd)
        rgb = self.eigvec.type_as(img).clone() \
            .mul(alpha.view(1, 3).expand(3, 3)) \
            .mul(self.eigval.view(1, 3).expand(3, 3)) \
            .sum(1).squeeze()

        # make differentiable
        if len(img.shape) == 4:
            return img + rgb.view(1, 3, 1, 1).expand_as(img)
        else:
            return img + rgb.view(3, 1, 1).expand_as(img)


class Grayscale(object):
    def __call__(self, img):
        gs = img.clone()
        gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2])
        gs[1].copy_(gs[0])
        gs[2].copy_(gs[0])
        return gs


class Saturation(object):
    def __init__(self, var):
        self.var = var

    def __call__(self, img):
        gs = Grayscale()(img)
        alpha = random.uniform(-self.var, self.var)
        return img.lerp(gs, alpha)


class Brightness(object):
    def __init__(self, var):
        self.var = var

    def __call__(self, img):
        gs = img.new().resize_as_(img).zero_()
        alpha = random.uniform(-self.var, self.var)
        return img.lerp(gs, alpha)


class Contrast(object):
    def __init__(self, var):
        self.var = var

    def __call__(self, img):
        gs = Grayscale()(img)
        gs.fill_(gs.mean())
        alpha = random.uniform(-self.var, self.var)
        return img.lerp(gs, alpha)


class ColorJitter(object):
    def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4):
        self.brightness = brightness
        self.contrast = contrast
        self.saturation = saturation

    def __call__(self, img):
        self.transforms = []
        if self.brightness != 0:
            self.transforms.append(Brightness(self.brightness))
        if self.contrast != 0:
            self.transforms.append(Contrast(self.contrast))
        if self.saturation != 0:
            self.transforms.append(Saturation(self.saturation))

        random.shuffle(self.transforms)
        transform = Compose(self.transforms)
        # print(transform)
        return transform(img)


class CutOut():
    def __init__(self, ratio, device='cpu'):
        self.ratio = ratio
        self.device = device

    def __call__(self, x):
        n, _, h, w = x.shape
        cutout_size = [int(h * self.ratio + 0.5), int(w * self.ratio + 0.5)]
        offset_x = torch.randint(h + (1 - cutout_size[0] % 2), size=[1], device=self.device)[0]
        offset_y = torch.randint(w + (1 - cutout_size[1] % 2), size=[1], device=self.device)[0]

        grid_batch, grid_x, grid_y = torch.meshgrid(
            torch.arange(n, dtype=torch.long, device=self.device),
            torch.arange(cutout_size[0], dtype=torch.long, device=self.device),
            torch.arange(cutout_size[1], dtype=torch.long, device=self.device),
        )
        grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=h - 1)
        grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=w - 1)
        mask = torch.ones(n, h, w, dtype=x.dtype, device=self.device)
        mask[grid_batch, grid_x, grid_y] = 0

        x = x * mask.unsqueeze(1)
        return x


class Normalize():
    def __init__(self, mean, std, device='cpu'):
        self.mean = torch.tensor(mean, device=device).reshape(1, len(mean), 1, 1)
        self.std = torch.tensor(std, device=device).reshape(1, len(mean), 1, 1)

    def __call__(self, x, seed=-1):
        return (x - self.mean) / self.std

In [None]:
#convnet.py
import torch
import torch.nn as nn


class ConvNet(nn.Module):
    def __init__(self,
                 num_classes,
                 net_norm='instance',
                 net_depth=3,
                 net_width=128,
                 channel=3,
                 net_act='relu',
                 net_pooling='avgpooling',
                 im_size=(32, 32)):
        # print(f"Define Convnet (depth {net_depth}, width {net_width}, norm {net_norm})")
        super(ConvNet, self).__init__()
        if net_act == 'sigmoid':
            self.net_act = nn.Sigmoid()
        elif net_act == 'relu':
            self.net_act = nn.ReLU()
        elif net_act == 'leakyrelu':
            self.net_act = nn.LeakyReLU(negative_slope=0.01)
        else:
            exit('unknown activation function: %s' % net_act)

        if net_pooling == 'maxpooling':
            self.net_pooling = nn.MaxPool2d(kernel_size=2, stride=2)
        elif net_pooling == 'avgpooling':
            self.net_pooling = nn.AvgPool2d(kernel_size=2, stride=2)
        elif net_pooling == 'none':
            self.net_pooling = None
        else:
            exit('unknown net_pooling: %s' % net_pooling)

        self.depth = net_depth
        self.net_norm = net_norm

        self.layers, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm,
                                                    net_pooling, im_size)
        num_feat = shape_feat[0] * shape_feat[1] * shape_feat[2]
        self.classifier = nn.Linear(num_feat, num_classes)

    def forward(self, x, return_features=False):
        for d in range(self.depth):
            x = self.layers['conv'][d](x)
            if len(self.layers['norm']) > 0:
                x = self.layers['norm'][d](x)
            x = self.layers['act'][d](x)
            if len(self.layers['pool']) > 0:
                x = self.layers['pool'][d](x)

        # x = nn.functional.avg_pool2d(x, x.shape[-1])
        out = x.view(x.shape[0], -1)
        logit = self.classifier(out)

        if return_features:
            return logit, out
        else:
            return logit

    def get_feature(self, x, idx_from, idx_to=-1, return_prob=False, return_logit=False):
        if idx_to == -1:
            idx_to = idx_from
        features = []

        for d in range(self.depth):
            x = self.layers['conv'][d](x)
            if self.net_norm:
                x = self.layers['norm'][d](x)
            x = self.layers['act'][d](x)
            if self.net_pooling:
                x = self.layers['pool'][d](x)
            features.append(x)
            if idx_to < len(features):
                return features[idx_from:idx_to + 1]

        if return_prob:
            out = x.view(x.size(0), -1)
            logit = self.classifier(out)
            prob = torch.softmax(logit, dim=-1)
            return features, prob
        elif return_logit:
            out = x.view(x.size(0), -1)
            logit = self.classifier(out)
            return features, logit
        else:
            return features[idx_from:idx_to + 1]

    def _get_normlayer(self, net_norm, shape_feat):
        # shape_feat = (c * h * w)
        if net_norm == 'batch':
            norm = nn.BatchNorm2d(shape_feat[0], affine=True)
        elif net_norm == 'layer':
            norm = nn.LayerNorm(shape_feat, elementwise_affine=True)
        elif net_norm == 'instance':
            norm = nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True)
        elif net_norm == 'group':
            norm = nn.GroupNorm(4, shape_feat[0], affine=True)
        elif net_norm == 'none':
            norm = None
        else:
            norm = None
            exit('unknown net_norm: %s' % net_norm)
        return norm

    def _make_layers(self, channel, net_width, net_depth, net_norm, net_pooling, im_size):
        layers = {'conv': [], 'norm': [], 'act': [], 'pool': []}

        in_channels = channel
        if im_size[0] == 28:
            im_size = (32, 32)
        shape_feat = [in_channels, im_size[0], im_size[1]]

        for d in range(net_depth):
            layers['conv'] += [
                nn.Conv2d(in_channels,
                          net_width,
                          kernel_size=3,
                          padding=3 if channel == 1 and d == 0 else 1)
            ]
            shape_feat[0] = net_width
            if net_norm != 'none':
                layers['norm'] += [self._get_normlayer(net_norm, shape_feat)]
            layers['act'] += [self.net_act]
            in_channels = net_width
            if net_pooling != 'none':
                layers['pool'] += [self.net_pooling]
                shape_feat[1] //= 2
                shape_feat[2] //= 2

        layers['conv'] = nn.ModuleList(layers['conv'])
        layers['norm'] = nn.ModuleList(layers['norm'])
        layers['act'] = nn.ModuleList(layers['act'])
        layers['pool'] = nn.ModuleList(layers['pool'])
        layers = nn.ModuleDict(layers)

        return layers, shape_feat

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F


class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()
        self.pool = nn.MaxPool2d(kernel_size=4, stride=4, padding=0)
        self.fc1 = nn.Linear(196, 1)
        self.fc = nn.Linear(196, 2)

    def forward(self, x, print_size=False): #edit

        x = nn.Conv2d(3, 196, kernel_size=3, stride=1, padding=1)
        x = nn.LayerNorm(normalized_shape=[196, size, size])
        x = nn.LeakyReLU()
        x = nn.Conv2d(196, 196, kernel_size=3, stride=2, padding=1)
        x = nn.LayerNorm(normalized_shape=[196, size // 2, size // 2])
        x = nn.LeakyReLU()
        x = nn.Conv2d(196, 196, kernel_size=3, stride=1, padding=1)
        x = nn.LayerNorm(normalized_shape=[196, size // 2, size // 2])
        x = nn.LeakyReLU()
        x = nn.Conv2d(196, 196, kernel_size=3, stride=2, padding=1)
        x = nn.LayerNorm(normalized_shape=[196, size // 4, size // 4])
        x = nn.LeakyReLU()
        x = nn.Conv2d(196, 196, kernel_size=3, stride=1, padding=1)
        x = nn.LayerNorm(normalized_shape=[196, size // 4, size // 4])
        x = nn.LeakyReLU()
        x = nn.Conv2d(196, 196, kernel_size=3, stride=1, padding=1)
        x = nn.LayerNorm(normalized_shape=[196, size // 4, size // 4])
        x = nn.LeakyReLU()

        # if print_size:
        #     print(x.size())
        x = nn.Conv2d(196, 196, kernel_size=3, stride=2, padding=1)
        x = nn.LayerNorm(normalized_shape=[196, 4, 4])
        x = nn.LeakyReLU()
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        # if print_size:
        #     print(x.size())
        fc1 = self.fc1(x)#source
        fc = self.fc(x)#class

        return fc1,fc


class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()
        def noise_function(x):
          return torch.randn(64, x.view(x.size(0), -1))
        self.batch_size=64
        self.tanh = nn.Tanh()

    def forward(self, x, print_size=False):

        x = nn.Linear(100, 196*4*4)
        x = nn.ReLU(nn.BatchNorm1d(196*4*4))
        x = x.view(-1, 196, 4, 4)
        x = nn.ConvTranspose2d(196, 196, kernel_size=4, stride=2, padding=1)
        x = nn.BatchNorm2d(196)
        x = nn.ReLU()
        x = nn.Conv2d(196, 196, kernel_size=3, stride=1, padding=1)
        x = nn.BatchNorm2d(196)
        x = nn.ReLU()
        x = nn.Conv2d(196, 196, kernel_size=3, stride=1, padding=1)
        x = nn.BatchNorm2d(196)
        x = nn.ReLU()
        x = nn.Conv2d(196, 196, kernel_size=3, stride=1, padding=1)
        x = nn.BatchNorm2d(196)
        x = nn.ReLU()
        x = nn.ConvTranspose2d(196, 196, kernel_size=4, stride=2, padding=1)
        x = nn.BatchNorm2d(196)
        x = nn.ReLU()
        x = nn.deconv2d(x, [self.batch_size, 196, 3, self.noise_function(x)])
        # bn is not applied
        x = self.tanh(x)
        return x


In [None]:

ipc=50
batch_size=64
epochs=100
epochs_eval=1000
epochs_match=100
epochs_match_train=16
lr=5e-6
eval_lr=0.01
momentum=0.9
weight_decay=5e-4
match_coeff=0.001
match_model='convnet'
matchs='feat'
eval_model=['convnet']
dim_noise=100
num_workers=4
print_freq=50
eval_interval=10
test_interval=200
fix_disc=False

data='cifar10'
num-classes=2
data-dir='./data'
output-dir='./results/'
logs-dir='./logs/'
weight='./weight/'
match_aug=False
aug_type='color_crop_cutout'
mixup_net='cut'
metric='l1'
bias=False
fc=False
mix_p=-1.0
beta=1.0
tag='test'
seed=3407

In [None]:
import os
import sys
import time
import random
import argparse
import numpy as np
import colab
from colab import drive

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

import models.resnet as RN
import models.convnet as CN
import models.resnet_ap as RNAP
import models.densenet_cifar as DN
from gan_model import Generator, Discriminator
from utils import AverageMeter, accuracy, Normalize, Logger, rand_bbox
from augment import DiffAug
from sklearn.metrics import roc_auc_score


def str2bool(v):
    """Cast string to boolean
    """
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def transformation(pth):
  data_transforms= transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
  dt=datasets.ImageFolder(pth,transform=data_transforms)
  return dt

def load_data():
  img_dt='/content/drive/MyDrive/Projects/astar/isic'#'/content/drive/MyDrive/Projects/astar/32x32'
  img_dt2='/content/drive/MyDrive/Projects/astar/isic2'
  exl='/content/drive/MyDrive/Projects/astar/ISBI2016_GroundTruth.csv'
  trainset=transformation(drive.mount(img_dt))
  testset=transformation(drive.mount(img_dt2))
  trainloader = torch.utils.data.DataLoader(
      trainset, batch_size=batch_size, shuffle=True,
      num_workers=num_workers, drop_last=True
  )
  testloader = torch.utils.data.DataLoader(
      testset, batch_size=batch_size, shuffle=False,
      num_workers=num_workers
  )

  return trainloader, testloader


def define_model(num_classes, e_model=None):
    '''Obtain model for training, validating and matching
    With no 'e_model' specified, it returns a random model
    '''
    if e_model:
        model = e_model
    else:
        model_pool = ['convnet', 'resnet10', 'resnet18',
                      'resnet10_ap', 'resnet18_ap']
        model = random.choice(model_pool)
        print('Random model: {}'.format(model))

    if data == 'mnist' or data == 'fashion':
        nch = 1
    else:
        nch = 3

    if model == 'convnet':
        return CN.ConvNet(num_classes, channel=nch)
    # elif model == 'resnet10':
    #     return RN.ResNet(data, 10, num_classes, nch=nch)
    # elif model == 'resnet18':
    #     return RN.ResNet(data, 18, num_classes, nch=nch)
    # elif model == 'resnet34':
    #     return RN.ResNet(data, 34, num_classes, nch=nch)
    # elif model == 'resnet50':
    #     return RN.ResNet(data, 50, num_classes, nch=nch)
    # elif model == 'resnet101':
    #     return RN.ResNet(data, 101, num_classes, nch=nch)
    # elif model == 'resnet10_ap':
    #     return RNAP.ResNetAP(data, 10, num_classes, nch=nch)
    # elif model == 'resnet18_ap':
    #     return RNAP.ResNetAP(data, 18, num_classes, nch=nch)
    # elif model == 'resnet34_ap':
    #     return RNAP.ResNetAP(data, 34, num_classes, nch=nch)
    # elif model == 'resnet50_ap':
    #     return RNAP.ResNetAP(data, 50, num_classes, nch=nch)
    # elif model == 'resnet101_ap':
    #     return RNAP.ResNetAP(data, 101, num_classes, nch=nch)
    # elif model == 'densenet':
    #     return DN.densenet_cifar(num_classes)


def calc_gradient_penalty(discriminator, img_real, img_syn):
    ''' Gradient penalty from Wasserstein GAN
    '''
    LAMBDA = 10
    n_size = img_real.shape[-1]
    batch_size = img_real.shape[0]
    n_channels = img_real.shape[1]

    alpha = torch.rand(batch_size, 1)
    alpha = alpha.expand(batch_size, int(img_real.nelement() / batch_size)).contiguous()
    alpha = alpha.view(batch_size, n_channels, n_size, n_size)
    alpha = alpha.colabcuda()

    img_syn = img_syn.view(batch_size, n_channels, n_size, n_size)
    interpolates = alpha * img_real.detach() + ((1 - alpha) * img_syn.detach())

    interpolates = interpolates.colabcuda()
    interpolates.requires_grad_(True)

    disc_interpolates, _ = discriminator(interpolates)

    gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                    grad_outputs=torch.ones(disc_interpolates.size()).colabcuda(),
                                    create_graph=True, retain_graph=True, only_inputs=True)[0]

    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
    return gradient_penalty


def dist(x, y, method='mse'):
    """Distance objectives
    """
    if method == 'mse':
        dist_ = (x - y).pow(2).sum()
    elif method == 'l1':
        dist_ = (x - y).abs().sum()
    elif method == 'l1_mean':
        n_b = x.shape[0]
        dist_ = (x - y).abs().reshape(n_b, -1).mean(-1).sum()
    elif method == 'cos':
        x = x.reshape(x.shape[0], -1)
        y = y.reshape(y.shape[0], -1)
        dist_ = torch.sum(1 - torch.sum(x * y, dim=-1) /
                          (torch.norm(x, dim=-1) * torch.norm(y, dim=-1) + 1e-6))

    return dist_


def add_loss(loss_sum, loss):
    if loss_sum == None:
        return loss
    else:
        return loss_sum + loss


def matchloss(img_real, img_syn, lab_real, lab_syn, model):
    """Matching losses (feature or gradient)
    """
    loss = None

    if 'feat' in matchs:
        with torch.no_grad():
            feat_tg = model.get_feature(img_real, idx_from, idx_to)
        feat = model.get_feature(img_syn, idx_from, idx_to)

        for i in range(len(feat)):
            loss = add_loss(loss, dist(feat_tg[i].mean(0), feat[i].mean(0), method=metric) * 0.001)

    elif 'grad' in matchs:
        criterion = nn.CrossEntropyLoss()

        output_real = model(img_real)
        loss_real = criterion(output_real, lab_real)
        g_real = torch.autograd.grad(loss_real, model.parameters())
        g_real = list((g.detach() for g in g_real))

        output_syn = model(img_syn)
        loss_syn = criterion(output_syn, lab_syn)
        g_syn = torch.autograd.grad(loss_syn, model.parameters(), create_graph=True)

        for i in range(len(g_real)):
            if (len(g_real[i].shape) == 1) and not bias:  # bias, normliazation
                continue
            if (len(g_real[i].shape) == 2) and not fc:
                continue

            loss = add_loss(loss, dist(g_real[i], g_syn[i], method=metric) * 0.001)

    elif 'logit' in matchs:
        output_real = F.log_softmax(model(img_real), dim=1)
        output_syn = F.log_softmax(model(img_syn), dim=1)
        loss = add_loss(loss, ((output_real - output_syn) ** 2).mean() * 0.01)

    return loss


def remove_aug(augtype, remove_aug):
    aug_list = []
    for aug in augtype.split("_"):
        if aug not in remove_aug.split("_"):
            aug_list.append(aug)

    return "_".join(aug_list)


def diffaug(device='cuda'):
    """Differentiable augmentation for condensation
    """
    aug_type = aug_type
    if data == 'cifar10':
        normalize = Normalize((0.491, 0.482, 0.447), (0.202, 0.199, 0.201), device='cuda')
    elif data == 'svhn':
        normalize = Normalize((0.437, 0.444, 0.473), (0.198, 0.201, 0.197), device='cuda')
    elif data == 'fashion':
        normalize = Normalize((0.286,), (0.353,), device='cuda')
    elif data == 'mnist':
        normalize = Normalize((0.131,), (0.308,), device='cuda')
    print("Augmentataion Matching: ", aug_type)
    augment = DiffAug(strategy=aug_type, batch=True)
    aug_batch = transforms.Compose([normalize, augment])

    if mixup_net == 'cut':
        aug_type = remove_aug(aug_type, 'cutout')
    print("Augmentataion Net update: ", aug_type)
    augment_rand = DiffAug(strategy=aug_type, batch=False)
    aug_rand = transforms.Compose([normalize, augment_rand])

    return aug_batch, aug_rand


def train(epoch, generator, discriminator, optim_g, optim_d, trainloader, criterion, aug, aug_rand):
    '''The main training function for the generator
    '''
    generator.train()
    gen_losses = AverageMeter()
    disc_losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    model = define_model( num_classes).colabcuda()
    model.train()
    optim_model = torch.optim.SGD(model.parameters(), eval_lr, momentum=momentum,
                                  weight_decay=weight_decay)

    for batch_idx, (img_real, lab_real) in enumerate(trainloader):
        img_real = img_real.colabcuda()
        lab_real = lab_real.colabcuda()

        # train the generator
        discriminator.eval()
        optim_g.zero_grad()

        # obtain the noise with one-hot class labels
        noise = torch.normal(0, 1, (batch_size, dim_noise))
        lab_onehot = torch.zeros((batch_size, num_classes))
        lab_onehot[torch.arange(batch_size), lab_real] = 1
        noise[torch.arange(batch_size), :num_classes] = lab_onehot[torch.arange(batch_size)]
        noise = noise.colabcuda()

        img_syn = generator(noise)
        gen_source, gen_class = discriminator(img_syn)
        gen_source = gen_source.mean()
        gen_class = criterion(gen_class, lab_real)

        gen_loss = - gen_source + gen_class

        # update the match model to obtain more various matching signals
        train_match_model( model, optim_model, trainloader, criterion, aug_rand)
        # calculate the matching loss
        if match_aug:
            img_aug = aug(torch.cat([img_real, img_syn]))
            match_loss = matchloss( img_aug[:batch_size], img_aug[batch_size:], lab_real, lab_real, model)# * match_coeff
        else:
            match_loss = matchloss( img_real, img_syn, lab_real, lab_real, model)# * match_coeff
        gen_loss = gen_loss + match_loss

        gen_loss.backward()
        optim_g.step()

        # train the discriminator
        discriminator.train()
        optim_d.zero_grad()
        lab_syn = torch.randint(num_classes, (batch_size,))
        noise = torch.normal(0, 1, (batch_size, dim_noise))
        lab_onehot = torch.zeros((batch_size, num_classes))
        lab_onehot[torch.arange(batch_size), lab_syn] = 1
        noise[torch.arange(batch_size), :num_classes] = lab_onehot[torch.arange(batch_size)]
        noise = noise.colabcuda()
        lab_syn = lab_syn.colabcuda()

        with torch.no_grad():
            img_syn = generator(noise)

        disc_fake_source, disc_fake_class = discriminator(img_syn)
        disc_fake_source = disc_fake_source.mean()
        disc_fake_class = criterion(disc_fake_class, lab_syn)

        disc_real_source, disc_real_class = discriminator(img_real)
        acc1, acc5 = accuracy(disc_real_class.data, lab_real, topk=(1, 5))
        auc = Eval_AUC(disc_real_class.data, lab_real)
        disc_real_source = disc_real_source.mean()
        disc_real_class = criterion(disc_real_class, lab_real)

        gradient_penalty = calc_gradient_penalty( discriminator, img_real, img_syn)

        disc_loss = disc_fake_source - disc_real_source + disc_fake_class + disc_real_class + gradient_penalty
        disc_loss.backward()
        optim_d.step()

        gen_losses.update(gen_loss.item())
        disc_losses.update(disc_loss.item())
        top1.update(acc1.item())
        top5.update(acc5.item())

        if (batch_idx + 1) % print_freq == 0:
            print('[Train Epoch {} Iter {}] G Loss: {:.3f}({:.3f}) D Loss: {:.3f}({:.3f}) D Acc: {:.3f}({:.3f}) AUC: {:.3f}'.format(
                epoch, batch_idx + 1, gen_losses.val, gen_losses.avg, disc_losses.val, disc_losses.avg, top1.val, top1.avg, auc)
            )


def train_match_model(model, optim_model, trainloader, criterion, aug_rand):
    '''The training function for the match model
    '''
    for batch_idx, (img, lab) in enumerate(trainloader):
        if batch_idx == epochs_match_train:
            break

        img = img.colabcuda()
        lab = lab.colabcuda()

        output = model(aug_rand(img))
        loss = criterion(output, lab)

        optim_model.zero_grad()
        loss.backward()
        optim_model.step()


def test(model, testloader, criterion):
    '''Calculate accuracy
    '''
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    for batch_idx, (img, lab) in enumerate(testloader):
        img = img.colabcuda()
        lab = lab.colabcuda()

        with torch.no_grad():
            output = model(img)
        loss = criterion(output, lab)
        acc1, acc5 = accuracy(output.data, lab, topk=(1, 5))
        losses.update(loss.item(), output.shape[0])
        top1.update(acc1.item(), output.shape[0])
        top5.update(acc5.item(), output.shape[0])

    return top1.avg, top5.avg, losses.avg


def validate(generator, testloader, criterion, aug_rand):
    '''Validate the generator performance
    '''
    all_best_top1 = []
    all_best_top5 = []
    for e_model in eval_model:
        print('Evaluating {}'.format(e_model))
        model = define_model( num_classes, e_model).colabcuda()
        model.train()
        optim_model = torch.optim.SGD(model.parameters(), eval_lr, momentum=momentum,
                                      weight_decay=weight_decay)

        generator.eval()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        best_top1 = 0.0
        best_top5 = 0.0
        for epoch_idx in range(epochs_eval):
            for batch_idx in range(10 * ipc // batch_size + 1):
                # obtain pseudo samples with the generator
                lab_syn = torch.randint(num_classes, (batch_size,))
                noise = torch.normal(0, 1, (batch_size, dim_noise))
                lab_onehot = torch.zeros((batch_size, num_classes))
                lab_onehot[torch.arange(batch_size), lab_syn] = 1
                noise[torch.arange(batch_size), :num_classes] = lab_onehot[torch.arange(batch_size)]
                noise = noise.colabcuda()
                lab_syn = lab_syn.colabcuda()

                with torch.no_grad():
                    img_syn = generator(noise)
                    img_syn = aug_rand((img_syn + 1.0) / 2.0)

                if np.random.rand(1) < mix_p and mixup_net == 'cut':
                    lam = np.random.beta(beta, beta)
                    rand_index = torch.randperm(len(img_syn)).colabcuda()

                    lab_syn_b = lab_syn[rand_index]
                    bbx1, bby1, bbx2, bby2 = rand_bbox(img_syn.size(), lam)
                    img_syn[:, :, bbx1:bbx2, bby1:bby2] = img_syn[rand_index, :, bbx1:bbx2, bby1:bby2]
                    ratio = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (img_syn.size()[-1] * img_syn.size()[-2]))

                    output = model(img_syn)
                    loss = criterion(output, lab_syn) * ratio + criterion(output, lab_syn_b) * (1. - ratio)
                else:
                    output = model(img_syn)
                    loss = criterion(output, lab_syn)

                acc1, acc5 = accuracy(output.data, lab_syn, topk=(1, 5))

                losses.update(loss.item(), img_syn.shape[0])
                top1.update(acc1.item(), img_syn.shape[0])
                top5.update(acc5.item(), img_syn.shape[0])

                optim_model.zero_grad()
                loss.backward()
                optim_model.step()

            if (epoch_idx + 1) % test_interval == 0:
                test_top1, test_top5, test_loss = test( model, testloader, criterion)
                print('[Test Epoch {}] Top1: {:.3f} Top5: {:.3f}'.format(epoch_idx + 1, test_top1, test_top5))
                if test_top1 > best_top1:
                    best_top1 = test_top1
                    best_top5 = test_top5

        all_best_top1.append(best_top1)
        all_best_top5.append(best_top5)

    return all_best_top1, all_best_top5


#noise_multiplier=1.07

def dp_conv_hook(module, grad_input, grad_output):
    '''
    gradient modification + noise hook

    :param module:
    :param grad_input:
    :param grad_output:
    :return:
    '''
    CLIP_BOUND = 1.
    SENSITIVITY = 2.
    noise_multiplier=1.07

    #global noise_multiplier
    ### get grad wrt. input (image)
    grad_wrt_image = grad_input[0]
    grad_input_shape = grad_wrt_image.size()
    batchsize = grad_input_shape[0]
    clip_bound_ = CLIP_BOUND / batchsize

    grad_wrt_image = grad_wrt_image.view(batchsize, -1)
    grad_input_norm = torch.norm(grad_wrt_image, p=2, dim=1)

    ### clip
    clip_coef = clip_bound_ / (grad_input_norm + 1e-10)
    clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef))
    clip_coef = clip_coef.unsqueeze(-1)
    grad_wrt_image = clip_coef * grad_wrt_image

    ### add noise
    noise = clip_bound_ * noise_multiplier * SENSITIVITY * torch.randn_like(grad_wrt_image)
    grad_wrt_image = grad_wrt_image + noise
    grad_input_new = [grad_wrt_image.view(grad_input_shape)]
    for i in range(len(grad_input) - 1):
        grad_input_new.append(grad_input[i + 1])
    return tuple(grad_input_new)

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
gpu=colab.connect_sys("v100")

if not os.path.exists(output_dir):
    os.makedirs(drive.mount(output_dir))
output_dir = output_dir + tag
if not os.path.exists(output_dir):
    os.makedirs(drive.mount(output_dir))
if not os.path.exists(drive.mount(output_dir) + '/outputs'):
    os.makedirs(drive.mount(output_dir) + '/outputs')

if not os.path.exists(drive.mount(logs_dir)):
    os.makedirs(drive.mount(logs_dir))
logs_dir = logs_dir + tag
if not os.path.exists(drive.mount(logs_dir)):
    os.makedirs(drive.mount(logs_dir))
sys.stdout = Logger(os.path.join(logs_dir, 'logs.txt'))


trainloader, testloader = load_data()

generator = Generator().colabcuda()
discriminator = Discriminator().colabcuda()

optim_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0, 0.9))
optim_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0, 0.9))

colab.setsys("pro")
model_dict = torch.load(drive.load(weight))
generator.load_state_dict(model_dict['generator'])
discriminator.load_state_dict(model_dict['discriminator'])
optim_g.load_state_dict(model_dict['optim_g'])
optim_d.load_state_dict(model_dict['optim_d'])
for g in optim_g.param_groups:
    g['lr'] = lr
for g in optim_d.param_groups:
    g['lr'] = lr
criterion = nn.CrossEntropyLoss()

aug, aug_rand = diffaug()

best_top1s = np.zeros((len(eval_model),))
best_top5s = np.zeros((len(eval_model),))
best_epochs = np.zeros((len(eval_model),))
for epoch in range(epochs):
    colab.call_sys("active")
    generator.train()
    discriminator.train()
    train(, epoch, generator, discriminator, optim_g, optim_d, trainloader, criterion, aug, aug_rand)
    dynamic_hook_function = dp_conv_hook

    # save image for visualization
    generator.eval()
    test_label = torch.tensor(list(range(2)) *10)#*10
    #print(test_label)
    test_noise = torch.normal(0, 1, (100, 100))
    #print(test_noise)
    #print("hello")
    # noise = torch.randn(batchsize, z_dim).to(device0)
    # label = torch.randint(0, NUM_CLASSES, [batchsize]).to(device0)
    # noisev = autograd.Variable(noise)
    lab_onehot = torch.zeros((100, num_classes))
    #print(lab_onehot)
    lab_onehot[torch.arange(100), test_label] = 1
    test_noise[torch.arange(100), :num_classes] = lab_onehot[torch.arange(100)]
    # if epoch==1:
    #   print(test_noise)
    test_noise = test_noise.colabcuda()
    test_img_syn = (generator(test_noise) + 1.0) / 2.0
    test_img_syn = make_grid(test_img_syn, nrow=10)
    generator.train()

    if (epoch + 1) % eval_interval == 0:
        top1s, top5s = validate(, generator, testloader, criterion, aug_rand)
        for e_idx, e_model in enumerate(eval_model):
            if top1s[e_idx] > best_top1s[e_idx]:
                best_top1s[e_idx] = top1s[e_idx]
                best_top5s[e_idx] = top5s[e_idx]
                best_epochs[e_idx] = epoch

                model_dict = {'generator': generator.state_dict(),
                              'discriminator': discriminator.state_dict(),
                              'optim_g': optim_g.state_dict(),
                              'optim_d': optim_d.state_dict()}
                torch.save(
                    model_dict,
                    os.path.join(drive.mount(output_dir), 'model_dict_{}.pth'.format(e_model)))
                print('Save model for {}'.format(e_model))

            print('Current Best Epoch for {}: {}, Top1: {:.3f}, Top5: {:.3f}'.format(e_model, best_epochs[e_idx], best_top1s[e_idx], best_top5s[e_idx]))