In [None]:
## results are in the subfolder named "saved" of the drive folder :- https://drive.google.com/drive/folders/1_Nt3xnZJpVlI0jRJQM4QJ4vdMKjSSmZn?usp=sharing 
## detailed result analysis :- https://docs.google.com/document/d/1nRHeqZVKCQ7sC9tGEDbVhdMrRtnR3Gv627Lfv6HEB2s/edit?usp=sharing 
## output of training done for sketch synthesis is after the cell :- https://colab.research.google.com/drive/1u3LctUtFMHnXZlOUF-n2gN7yaUnZCht8#scrollTo=si5UIMg8hu_M&line=1&uniqifier=1 
## sketch generation is after this cell :- https://colab.research.google.com/drive/1u3LctUtFMHnXZlOUF-n2gN7yaUnZCht8#scrollTo=zRAcYGwxi5pn&line=1&uniqifier=1 
## training AutoEncoder for sketch-to-image synthesis is after this cell :- https://colab.research.google.com/drive/1u3LctUtFMHnXZlOUF-n2gN7yaUnZCht8#scrollTo=caaYS8KiygTo&line=1&uniqifier=1 


In [None]:
import os
import time
import torch
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as Dataset
import torchvision.utils as vutils
from torch import nn
from torch.utils.data import DataLoader
import argparse
import tqdm

In [None]:
from google.colab import drive
drive.mount('/content/My_drive')

Mounted at /content/My_drive


In [None]:
import functools
from math import sqrt
import torch
import torch.nn as nn
import torch.optim as optim
from torch import cat, sigmoid
from torch.autograd import Variable
from torch.nn import Parameter, init
from torch.nn.utils import spectral_norm
import torch.nn.functional as F
from torch.jit import ScriptModule, script_method, trace

In [None]:
def calc_mean_std(feat, eps=1e-5):
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std


def adain(content_feat, style_feat):
    assert (content_feat.size()[:2] == style_feat.size()[:2])
    size = content_feat.size()
    style_mean, style_std = calc_mean_std(style_feat)
    content_mean, content_std = calc_mean_std(content_feat)

    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)


def get_batched_gram_matrix(input):
    a, b, c, d = input.size()
    features = input.view(a, b, c * d)
    G = torch.bmm(features, features.transpose(2,1)) 
    return G.div(b * c * d)
    
class Adaptive_pool(nn.Module):
    def __init__(self, channel_out, hw_out):
        super().__init__()
        self.channel_out = channel_out
        self.hw_out = hw_out
        self.pool = nn.AdaptiveAvgPool2d((channel_out, hw_out**2))
    def forward(self, input):
        if len(input.shape) == 3:
            input.unsqueeze_(1)
        return self.pool(input).view(-1, self.channel_out, self.hw_out, self.hw_out)

class VGGSimple(nn.Module):
    def __init__(self):
        super(VGGSimple, self).__init__()

        self.features = self.make_layers()
        
        self.norm_mean = torch.Tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
        self.norm_std = torch.Tensor([0.229, 0.224, 0.225]).view(1,3,1,1)

    def forward(self, img, after_relu=True, base=4):
        feat = (((img+1)*0.5) - self.norm_mean.to(img.device)) / self.norm_std.to(img.device)
        cut_points = [2, 7, 14, 21, 28]
        if after_relu:
            cut_points = [c+2 for c in cut_points]
        for i in range(31):
            feat = self.features[i](feat)
            if i == cut_points[0]:
                feat_64 = F.adaptive_avg_pool2d(feat, base*16)
            if i == cut_points[1]:
                feat_32 = F.adaptive_avg_pool2d(feat, base*8)
            if i == cut_points[2]:
                feat_16 = F.adaptive_avg_pool2d(feat, base*4)
            if i == cut_points[3]:
                feat_8 = F.adaptive_avg_pool2d(feat, base*2)
            if i == cut_points[4]:
                feat_4 = F.adaptive_avg_pool2d(feat, base)
        
        return feat_64, feat_32, feat_16, feat_8, feat_4

    def make_layers(self, cfg="D", batch_norm=False):
        cfg_dic = {
            'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
            'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
            'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
            'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
        }
        cfg = cfg_dic[cfg]
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                if batch_norm:
                    layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=False)]
                else:
                    layers += [conv2d, nn.ReLU(inplace=False)]
                in_channels = v
        return nn.Sequential(*layers)


class VGG_3label(nn.Module):
    def __init__(self, nclass_artist=1117, nclass_style=55, nclass_genre=26):
        super(VGG_3label, self).__init__()
        self.features = self.make_layers()
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        
        self.classifier_feat = self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096, 512))

        self.classifier_style = nn.Sequential(nn.ReLU(), nn.Dropout(), nn.Linear(512, nclass_style))
        self.classifier_genre = nn.Sequential(nn.ReLU(), nn.Dropout(), nn.Linear(512, nclass_genre))
        self.classifier_artist = nn.Sequential(nn.ReLU(), nn.Dropout(), nn.Linear(512, nclass_artist))

        self.norm_mean = torch.Tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
        self.norm_std = torch.Tensor([0.229, 0.224, 0.225]).view(1,3,1,1)
    
        self.avgpool_4 = nn.AdaptiveAvgPool2d((4, 4))
        self.avgpool_8 = nn.AdaptiveAvgPool2d((8, 8))
        self.avgpool_16 = nn.AdaptiveAvgPool2d((16, 16))
    
    def get_features(self, img, after_relu=True, base=4):
        feat = (((img+1)*0.5) - self.norm_mean.to(img.device)) / self.norm_std.to(img.device)
        cut_points = [2, 7, 14, 21, 28]
        if after_relu:
            cut_points = [4, 9, 16, 23, 30]
        for i in range(31):
            feat = self.features[i](feat)
            if i == cut_points[0]:
                feat_64 = F.adaptive_avg_pool2d(feat, base*16)
            if i == cut_points[1]:
                feat_32 = F.adaptive_avg_pool2d(feat, base*8)
            if i == cut_points[2]:
                feat_16 = F.adaptive_avg_pool2d(feat, base*4)
            if i == cut_points[3]:
                feat_8 = F.adaptive_avg_pool2d(feat, base*2)
            if i == cut_points[4]:
                feat_4 = F.adaptive_avg_pool2d(feat, base)
        return feat_64, feat_32, feat_16, feat_8, feat_4


    def load_pretrain_weights(self):
        pretrained_vgg16 = vgg.vgg16(pretrained=True)
        self.features.load_state_dict(pretrained_vgg16.features.state_dict())
        self.classifier_feat[0] = pretrained_vgg16.classifier[0] 
        self.classifier_feat[3] = pretrained_vgg16.classifier[3] 
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def make_layers(self, cfg="D", batch_norm=False):
        cfg_dic = {
            'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
            'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
            'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
            'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
        }
        cfg = cfg_dic[cfg]
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                if batch_norm:
                    layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=False)]
                else:
                    layers += [conv2d, nn.ReLU(inplace=False)]
                in_channels = v
        return nn.Sequential(*layers)

    def forward(self, img):
        feature = self.classifier_feat( self.avgpool(self.features(img)).view(img.size(0), -1) )
        pred_style = self.classifier_style(feature)
        pred_genre = self.classifier_genre(feature)
        pred_artist = self.classifier_artist(feature)
        return pred_style, pred_genre, pred_artist


class UnFlatten(nn.Module):
    def __init__(self, block_size):
        super(UnFlatten, self).__init__()
        self.block_size = block_size

    def forward(self, x):
        return x.view(x.size(0), -1, self.block_size, self.block_size)


class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view(x.size(0), -1)


class UpConvBlock(nn.Module):
    def __init__(self, in_channel, out_channel, norm_layer=nn.BatchNorm2d):
        super().__init__()

        self.main = nn.Sequential(
            nn.ReflectionPad2d(1),
            spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 0, bias=True)),
            norm_layer(out_channel), 
            nn.LeakyReLU(0.01), 
            )

    def forward(self, x):
        y = F.interpolate(x, scale_factor=2)
        return self.main(y)


class DownConvBlock(nn.Module):
    def __init__(self, in_channel, out_channel, norm_layer=nn.BatchNorm2d, down=True):
        super().__init__()

        m = [   spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1, bias=True)),
                norm_layer(out_channel), 
                nn.LeakyReLU(0.1) ]
        if down:
            m.append(nn.AvgPool2d(2, 2))
        self.main = nn.Sequential(*m)

    def forward(self, x):
        return self.main(x)




class Generator(nn.Module):
    def __init__(self, infc=512, nfc=64, nc_out=3):
        super(Generator, self).__init__()

        self.decode_32 = UpConvBlock(infc, nfc*4)	
        self.decode_64 = UpConvBlock(nfc*4, nfc*4)    
        self.decode_128 = UpConvBlock(nfc*4, nfc*2)    

        self.final = nn.Sequential(
            spectral_norm( nn.Conv2d(nfc*2, nc_out, 3, 1, 1, bias=True) ),
            nn.Tanh())

    def forward(self, input):

        decode_32 = self.decode_32(input)
        decode_64 = self.decode_64(decode_32)
        decode_128 = self.decode_128(decode_64)

        output = self.final(decode_128)
        return output


class Discriminator(nn.Module):
    def __init__(self, nfc=512, norm_layer=nn.InstanceNorm2d):
        super(Discriminator, self).__init__()

        self.main = nn.Sequential(
            DownConvBlock(nfc, nfc//2, norm_layer=norm_layer, down=False),
            DownConvBlock(nfc//2, nfc//4, norm_layer=norm_layer),
            spectral_norm( nn.Conv2d(nfc//4, 1, 4, 2, 0) )
        )
	
    def forward(self, input):
        out = self.main(input)
        return out.view(-1)


In [None]:
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms, utils
import subprocess as sp
from PIL import Image
import time
import torch.utils.data as data

from skimage.color import hsv2rgb
import torch.nn.functional as F
import torch.nn as nn

In [None]:
eps = 1e-7
class HSV_Loss(nn.Module):
    def __init__(self, h=0, s=1, v=0.7):
        super(HSV_Loss, self).__init__()
        self.hsv = [h, s, v]
        self.l1 = nn.L1Loss()
        self.mse = nn.MSELoss()
        
    @staticmethod
    def get_h(im):
        img = im * 0.5 + 0.5
        b, c, h, w = img.shape
        hue = torch.Tensor(im.shape[0], im.shape[2], im.shape[3]).to(im.device)
        hue[img[:,2]==img.max(1)[0]] = 4.0+((img[:,0]-img[:,1])/(img.max(1)[0] - img.min(1)[0]))[img[:,2]==img.max(1)[0]]
        hue[img[:,1]==img.max(1)[0]] = 2.0+((img[:,2]-img[:,0])/(img.max(1)[0] - img.min(1)[0]))[img[:,1]==img.max(1)[0]]
        hue[img[:,0]==img.max(1)[0]] = ((img[:,1]-img[:,2])/(img.max(1)[0] - img.min(1)[0]))[img[:,0]==img.max(1)[0]]
        hue = (hue/6.0) % 1.0
        hue[img.min(1)[0]==img.max(1)[0]] = 0.0
        return hue 

    @staticmethod
    def get_v(im):
        img = im * 0.5 + 0.5
        b, c, h, w = img.shape
        it = img.transpose(1,2).transpose(2,3).contiguous().view(b, -1, c)        
        value = F.max_pool1d(it, c).view(b, h, w)
        return value 

    @staticmethod
    def get_s(im):
        img = im * 0.5 + 0.5
        b, c, h, w = img.shape
        it = img.transpose(1,2).transpose(2,3).contiguous().view(b, -1, c)        
        max_v = F.max_pool1d(it, c).view(b, h, w)
        min_v = F.max_pool1d(it*-1, c).view(b, h, w)
        satur = (max_v + min_v) / (max_v+eps)
        return satur

    def forward(self, input):
        h = self.get_h(input)
        s = self.get_s(input)
        v = self.get_v(input)
        target_h = torch.Tensor(h.shape).fill_(self.hsv[0]).to(input.device).type_as(h)
        target_s = torch.Tensor(s.shape).fill_(self.hsv[1]).to(input.device)
        target_v = torch.Tensor(v.shape).fill_(self.hsv[2]).to(input.device)
        return self.mse(h, target_h) 

def InfiniteSampler(n):
    i = n - 1
    order = np.random.permutation(n)
    while True:
        yield order[i]
        i += 1
        if i >= n:
            np.random.seed()
            order = np.random.permutation(n)
            i = 0

class InfiniteSamplerWrapper(data.sampler.Sampler):
    def __init__(self, data_source):
        self.num_samples = len(data_source)

    def __iter__(self):
        return iter(InfiniteSampler(self.num_samples))

    def __len__(self):
        return 2 ** 31


def _rescale(img):
    return img * 2.0 - 1.0

def trans_maker(size=256):
	trans = transforms.Compose([ 
					transforms.Resize((size+36, size+36)),
					transforms.RandomHorizontalFlip(),
					transforms.RandomCrop((size, size)),
					transforms.ToTensor(),
					_rescale
					])
	return trans

def trans_maker_testing(size=256):
	trans = transforms.Compose([ 
					transforms.Resize((size, size)),
					transforms.ToTensor(),
					_rescale
					])
	return trans
transform_gan = trans_maker(size=128)

import torchvision.utils as vutils
import logging
logger = logging.getLogger(__name__)


def save_image(net, dataloader_A, device, cur_iter, trial, save_path):
    logger.info('Saving gan epoch {} images: {}'.format(cur_iter, save_path))
    net.eval()
    for p in net.parameters():
        data_type = p.type()
        break
    with torch.no_grad():
        for itx, data in enumerate(dataloader_A):
            g_img = net.gen_a2b(data[0].to(device).type(data_type))
            for i in range(g_img.size(0)):
                vutils.save_image(
                    g_img.cpu().float().add_(1).mul_(0.5),
                    os.path.join(save_path, "{}_gan_epoch_{}_iter_{}_{}.jpg".format(trial, cur_iter, itx, i)),)
    
    net.train()
    return save_path

def save_model(net, save_folder, cuda_device, if_multi_gpu, trial, cur_iter):
    save_name = "{}_gan_epoch_{}.pth".format(trial, cur_iter)
    save_path = os.path.join(save_folder, save_name)
    logger.info('Saving gan model: {}'.format(save_path))

    net.save(save_path)

    for fname in os.listdir(save_folder):
        if fname.endswith('.pth') and fname != save_name:
            delete_path = os.path.join(save_folder, fname)
            os.remove(delete_path)
            logger.info('Deleted previous gan model: {}'.format(delete_path))

    return save_path

In [None]:
torch.backends.cudnn.benchmark = True

def creat_folder(save_folder, trial_name):
    saved_model_folder = os.path.join(save_folder, 'train_results/%s/models'%trial_name)
    saved_image_folder = os.path.join(save_folder, 'train_results/%s/images'%trial_name)
    folders = [os.path.join(save_folder, 'train_results'), os.path.join(save_folder, 'train_results/%s'%trial_name), 
    os.path.join(save_folder, 'train_results/%s/images'%trial_name), os.path.join(save_folder, 'train_results/%s/models'%trial_name)]

    for folder in folders:
        if not os.path.exists(folder):
            os.mkdir(folder)
    return saved_model_folder, saved_image_folder

def train_d(net, data, label="real"):
    pred = net(data)
    if label=="real":
        err = F.relu(1-pred).mean()
    else:
        err = F.relu(1+pred).mean()

    err.backward()
    return torch.sigmoid(pred).mean().item()

def gram_matrix(input):
    a, b, c, d = input.size()  
    features = input.view(a * b, c * d)  
    G = torch.mm(features, features.t())  
    return G.div(a * b * c * d)

def gram_loss(input, target):
    in_gram = gram_matrix(input)
    tar_gram = gram_matrix(target.detach())
    return F.mse_loss(in_gram, tar_gram)

def save_image(net_g, dataloader, saved_image_folder, n_iter, vgg, device, base):
    net_g.eval()
    with torch.no_grad():
        imgs = []
        real = []
        for i, d in enumerate(dataloader):
            if i < 2:
                f_3 = vgg(d[0].to(device), base=base)[2]
                imgs.append(net_g(f_3).cpu())
                real.append(d[0])
            else:
                break
        imgs = torch.cat(imgs, dim=0)
        real = torch.cat(real, dim=0)
        sss = torch.cat([imgs, real], dim=0)
        
        vutils.save_image( sss, "%s/iter_%d.jpg"%(saved_image_folder, n_iter), range=(-1,1), normalize=True)
        del imgs
    net_g.train()

def train(net_g, net_d_style, max_iteration, save_folder, trial_name, dataloader_A_fixed, dataloader_B, dataloader_A, base, vgg, device, gram_reshape, optDS, optG):
    print('training begin ... ')
    titles = ['D_r', 'D_f', 'G', 'G_rec']
    losses = {title: 0.0 for title in titles}
    mse_weight = 0.2
    gram_weight = 1
    log_interval = 100

    saved_model_folder, saved_image_folder = creat_folder(save_folder, trial_name)
    
    for n_iter in tqdm.tqdm(range(max_iteration+1)):
        if (n_iter+1)%(100)==0:
            try:
                model_dict = {'g': net_g.state_dict(), 'ds':net_d_style.state_dict()}
                torch.save(model_dict, os.path.join(saved_model_folder, '%d_model.pth'%(n_iter)))
                opt_dict = {'g': optG.state_dict(), 'ds':optDS.state_dict()}
                torch.save(opt_dict, os.path.join(saved_model_folder, '%d_opt.pth'%(n_iter)))
            except:
                print("models not properly saved")
        if n_iter%100==0:
            save_image(net_g, dataloader_A_fixed, saved_image_folder, n_iter, vgg, device, base)
        
        
        real_style = next(dataloader_B)[0].to(device)
        real_content = next(dataloader_A)[0].to(device)
     
        cf_1, cf_2, cf_3, cf_4, cf_5 = vgg(real_content, base=base)
        sf_1, sf_2, sf_3, sf_4, sf_5 = vgg(real_style, base=base)

        fake_img = net_g(cf_3)
        tf_1, tf_2, tf_3, tf_4, tf_5 = vgg(fake_img, base=base)

        target_3 = adain(cf_3, sf_3)

        gram_sf_4 = gram_reshape(get_batched_gram_matrix(sf_4))
        gram_sf_3 = gram_reshape(get_batched_gram_matrix(sf_3))
        gram_sf_2 = gram_reshape(get_batched_gram_matrix(sf_2))
        real_style_sample = torch.cat([gram_sf_2, gram_sf_3, gram_sf_4], dim=1)

        gram_tf_4 = gram_reshape(get_batched_gram_matrix(tf_4))
        gram_tf_3 = gram_reshape(get_batched_gram_matrix(tf_3))
        gram_tf_2 = gram_reshape(get_batched_gram_matrix(tf_2))
        fake_style_sample = torch.cat([gram_tf_2, gram_tf_3, gram_tf_4], dim=1)

        
        net_d_style.zero_grad()
        D_R = train_d(net_d_style, real_style_sample, label="real")
        D_F = train_d(net_d_style, fake_style_sample.detach(), label="fake")
        optDS.step()
        
        net_g.zero_grad()
        pred_gs = net_d_style(fake_style_sample)
        err_gs = -pred_gs.mean()
   
        G_B = torch.sigmoid(pred_gs).mean().item() 

        err_rec = F.mse_loss(tf_3, target_3)
        err_gram = 2000*(
            gram_loss(tf_4, sf_4) + \
                gram_loss(tf_3, sf_3) + \
                    gram_loss(tf_2, sf_2))

        G_rec = err_gram.item()

        err = err_gs + mse_weight*err_rec + gram_weight*err_gram
        err.backward()

        optG.step()
        loss_values = [D_R, D_F, G_B, G_rec]
        for i, term in enumerate(titles):
            losses[term] += loss_values[i]
        
        if n_iter > 0 and n_iter % log_interval == 0:
            log_line = ""
            for key, value in losses.items():
                log_line += "%s: %.5f  "%(key, value/log_interval)
                losses[key] = 0
            print(log_line)



In [None]:
def executing(path_a, path_b):

    checkpt = 'None'
    trial_name = "test1"
    data_root_A = path_a
    data_root_B = path_b
    mse_weight = 0.2
    gram_weight = 1
    max_iteration = 1000
    device = torch.device("cuda:%d"%(0))
    batch_size = 4
    lr_ = 2e-4

    im_size = 256
    if im_size == 128:
        base = 4
    elif im_size == 256:
        base = 8
    elif im_size == 512:
        base = 16
    if im_size not in [128, 256, 512]:
        print("the size must be in [128, 256, 512]")
  

    log_interval = 100
    save_folder = '/content/My_drive/MyDrive/rajat-Inspiron-3576/SEM7/DEEP_LEARNING/Project/Files/saved/'
    number_model_to_save = 30

    vgg = VGGSimple()
    vgg.load_state_dict(torch.load('/content/My_drive/MyDrive/rajat-Inspiron-3576/SEM7/DEEP_LEARNING/Project/Files/vgg-feature-weights.pth', map_location=lambda a,b:a))
    vgg.to(device)
    vgg.eval()
    for p in vgg.parameters():
        p.requires_grad = False

    dataset_A = Dataset.ImageFolder(root=data_root_A, transform=trans_maker(im_size)) 
    dataloader_A_fixed = DataLoader(dataset_A, 8, shuffle=False, num_workers=4)
    dataloader_A = iter(DataLoader(dataset_A, batch_size, shuffle=False,\
            sampler=InfiniteSamplerWrapper(dataset_A), num_workers=4, pin_memory=True))

    dataset_B = Dataset.ImageFolder(root=data_root_B, transform=trans_maker(im_size)) 
    dataloader_B = iter(DataLoader(dataset_B, batch_size, shuffle=False,\
            sampler=InfiniteSamplerWrapper(dataset_B), num_workers=4, pin_memory=True))

    net_g = Generator(infc=256, nfc=128)

    net_d_style = Discriminator(nfc=128*3, norm_layer=nn.BatchNorm2d)
    gram_reshape = Adaptive_pool(128, 16)
    
    if checkpt is not 'None':
        checkpoint = torch.load(checkpt, map_location=lambda storage, loc: storage)
        net_g.load_state_dict(checkpoint['g'])
        net_d_style.load_state_dict(checkpoint['ds'])
        print("saved model loaded")

    net_d_style.to(device)
    net_g.to(device)   

    optG = optim.Adam(net_g.parameters(), lr=lr_, betas=(0.5, 0.99))
    optDS = optim.Adam(net_d_style.parameters(), lr=lr_, betas=(0.5, 0.99))
    
    if checkpt is not 'None':
        opt_path = checkpt.replace("_model.pth", "_opt.pth")
        try:
            opt_weights = torch.load(opt_path, map_location=lambda a, b: a)
            optG.load_state_dict(opt_weights['g'])
            optDS.load_state_dict(opt_weights['ds'])
            print("saved optimizer loaded")
        except:
            print("no optimizer weights detected, resuming a training without optimizer weights may not let the model converge as desired")
            pass
    
    train(net_g, net_d_style, max_iteration, save_folder, trial_name, dataloader_A_fixed, dataloader_B, dataloader_A, base, vgg, device, gram_reshape, optDS, optG)

In [None]:
## training for sketch synthesis

In [None]:
img_path = '/content/My_drive/MyDrive/rajat-Inspiron-3576/SEM7/DEEP_LEARNING/Project/Files/images/'
skc_path = '/content/My_drive/MyDrive/rajat-Inspiron-3576/SEM7/DEEP_LEARNING/Project/Files/sketch/'

executing(img_path, skc_path)

  cpuset_checked))


training begin ... 


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
 10%|█         | 101/1001 [03:05<35:21,  2.36s/it]

D_r: 0.89513  D_f: 0.11602  G: 0.10068  G_rec: 1.81017  


 20%|██        | 201/1001 [05:35<28:44,  2.16s/it]

D_r: 0.87728  D_f: 0.12672  G: 0.11242  G_rec: 0.90363  


 30%|███       | 301/1001 [08:05<25:15,  2.16s/it]

D_r: 0.86820  D_f: 0.12276  G: 0.11374  G_rec: 0.75287  


 40%|████      | 401/1001 [10:34<21:38,  2.16s/it]

D_r: 0.86533  D_f: 0.13738  G: 0.12805  G_rec: 0.65702  


 50%|█████     | 501/1001 [13:05<17:58,  2.16s/it]

D_r: 0.82999  D_f: 0.16613  G: 0.14405  G_rec: 0.55781  


 60%|██████    | 601/1001 [15:35<14:41,  2.20s/it]

D_r: 0.84595  D_f: 0.15380  G: 0.12750  G_rec: 0.60710  


 70%|███████   | 701/1001 [18:05<10:48,  2.16s/it]

D_r: 0.83098  D_f: 0.16384  G: 0.13540  G_rec: 0.56120  


 80%|████████  | 801/1001 [20:35<07:12,  2.16s/it]

D_r: 0.82504  D_f: 0.15991  G: 0.13128  G_rec: 0.55975  


 90%|█████████ | 901/1001 [23:05<03:36,  2.16s/it]

D_r: 0.81668  D_f: 0.17630  G: 0.14758  G_rec: 0.52363  


100%|██████████| 1001/1001 [25:35<00:00,  1.53s/it]

D_r: 0.83098  D_f: 0.16481  G: 0.13122  G_rec: 0.60554  





In [None]:
import os
import torch
import torchvision.datasets as Dataset
import torchvision.utils as vutils
from torch import nn
import argparse

In [None]:
def synthesizing(path_content, path_result, checkpoint):

    device = torch.device("cuda:%d"%(0))

    im_size = 256
    if im_size == 128:
        base = 4
    elif im_size == 256:
        base = 8
    elif im_size == 512:
        base = 16
    elif im_size == 1024:
        base = 32
    if im_size not in [128, 256, 512, 1024]:
        print("the size must be in [128, 256, 512, 1024]")
  
    vgg = VGGSimple()
    vgg.load_state_dict(torch.load('/content/My_drive/MyDrive/rajat-Inspiron-3576/SEM7/DEEP_LEARNING/Project/Files/vgg-feature-weights.pth', map_location=lambda a,b:a))
    vgg.to(device)
    vgg.eval()
    for p in vgg.parameters():
        p.requires_grad = False

    dataset = Dataset.ImageFolder(root=path_content, transform=trans_maker_testing(size=im_size)) 
    
    net_g = Generator(infc=256, nfc=128)
    
    if checkpoint is not 'None':
        checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
        net_g.load_state_dict(checkpoint['g'])
        print("saved model loaded")

    net_g.to(device)   
    net_g.eval()

    dist_path = path_result
    if not os.path.exists(dist_path):
        os.mkdir(dist_path)


    print("begin generating images ...")
    with torch.no_grad():
        for i in range(len(dataset)):
            print("generating the %dth image"%(i))
            img = dataset[i][0].to(device)
            feat = vgg(img, base=base)[2]
            g_img = net_g(feat)

            g_img = g_img.mean(1).unsqueeze(1).detach().add(1).mul(0.5)
            g_img = (g_img > 0.7).float()
            vutils.save_image(g_img, os.path.join(dist_path, '%d.jpg'%(i)))

In [None]:
## sketch generation

In [None]:
path_content = "/content/My_drive/MyDrive/rajat-Inspiron-3576/SEM7/DEEP_LEARNING/Project/Files/images/"
path_result = "/content/My_drive/MyDrive/rajat-Inspiron-3576/SEM7/DEEP_LEARNING/Project/Files/saved/train_results/synthesized_sketches"
checkpoint = "/content/My_drive/MyDrive/rajat-Inspiron-3576/SEM7/DEEP_LEARNING/Project/Files/saved/train_results/test1/models/899_model.pth"
synthesizing(path_content, path_result, checkpoint)

saved model loaded
begin generating images ...
generating the 0th image


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


generating the 1th image
generating the 2th image
generating the 3th image
generating the 4th image
generating the 5th image
generating the 6th image
generating the 7th image
generating the 8th image
generating the 9th image
generating the 10th image
generating the 11th image
generating the 12th image
generating the 13th image
generating the 14th image
generating the 15th image
generating the 16th image
generating the 17th image
generating the 18th image
generating the 19th image
generating the 20th image
generating the 21th image
generating the 22th image
generating the 23th image
generating the 24th image
generating the 25th image
generating the 26th image
generating the 27th image
generating the 28th image
generating the 29th image
generating the 30th image
generating the 31th image
generating the 32th image
generating the 33th image
generating the 34th image
generating the 35th image
generating the 36th image
generating the 37th image
generating the 38th image
generating the 39th i

In [None]:
path_content = "/content/My_drive/MyDrive/rajat-Inspiron-3576/SEM7/DEEP_LEARNING/Fun/Img"
path_result = "/content/My_drive/MyDrive/rajat-Inspiron-3576/SEM7/DEEP_LEARNING/Fun/Sketch"
checkpoint = "/content/My_drive/MyDrive/rajat-Inspiron-3576/SEM7/DEEP_LEARNING/Project/Files/saved/train_results/test1/models/899_model.pth"
synthesizing(path_content, path_result, checkpoint)

saved model loaded
begin generating images ...
generating the 0th image
generating the 1th image
generating the 2th image
generating the 3th image
generating the 4th image
generating the 5th image
generating the 6th image
generating the 7th image
generating the 8th image
generating the 9th image


In [None]:
import os
import torch
from copy import deepcopy
from random import shuffle
import torch.nn.functional as F

def d_hinge_loss(real_pred, fake_pred):
    real_loss = F.relu(1-real_pred)
    fake_loss = F.relu(1+fake_pred)

    return real_loss.mean() + fake_loss.mean()


def g_hinge_loss(pred):
    return -pred.mean()


class AverageMeter(object):

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def true_randperm(size, device='cuda'):
    def unmatched_randperm(size):
        l1 = [i for i in range(size)]
        l2 = []
        for j in range(size):
            deleted = False
            if j in l1:
                deleted = True
                del l1[l1.index(j)]
            shuffle(l1)
            if len(l1) == 0:
                return 0, False
            l2.append(l1[0])
            del l1[0]
            if deleted:
                l1.append(j)
        return l2, True
    flag = False
    l = torch.zeros(size).long()
    while not flag:
        l, flag = unmatched_randperm(size)
    return torch.LongTensor(l).to(device)


def copy_G_params(model):
    flatten = deepcopy(list(p.data for p in model.parameters()))
    return flatten


def load_params(model, new_param):
    for p, new_p in zip(model.parameters(), new_param):
        p.data.copy_(new_p)


def make_folders(save_folder, trial_name):
    saved_model_folder = os.path.join(save_folder, 'train_results/%s/models'%trial_name)
    saved_image_folder = os.path.join(save_folder, 'train_results/%s/images'%trial_name)
    folders = [os.path.join(save_folder, 'train_results'), 
               os.path.join(save_folder, 'train_results/%s'%trial_name), 
               os.path.join(save_folder, 'train_results/%s/images'%trial_name), 
               os.path.join(save_folder, 'train_results/%s/models'%trial_name)]
    for folder in folders:
        if not os.path.exists(folder):
            os.mkdir(folder)
    
    from shutil import copy
    try:
        for f in os.listdir('.'):
            if '.py' in f:
                copy(f, os.path.join(save_folder, 'train_results/%s'%trial_name)+'/'+f)
    except:
        pass
    return saved_image_folder, saved_model_folder 



import cv2
import numpy as np
import math

def warp(img, mag=10, freq=100):
    rows, cols = img.shape

    img_output = np.zeros(img.shape, dtype=img.dtype)

    for i in range(rows):
        for j in range(cols):
            offset_x = int(mag * math.sin(2 * 3.14 * i / freq))
            offset_y = int(mag * math.cos(2 * 3.14 * j / freq))
            if i+offset_y < rows and j+offset_x < cols:
                img_output[i,j] = img[(i+offset_y)%rows,(j+offset_x)%cols]
            else:
                img_output[i,j] = 0

    return img_output

In [None]:
import datetime

DATA_NAME = 'face'

DATALOADER_WORKERS = 4
NBR_CLS = 25

EPOCH_GAN = 20

SAVE_IMAGE_INTERVAL = 10
SAVE_MODEL_INTERVAL = 25
LOG_INTERVAL = 10
FID_INTERVAL = 25
FID_BATCH_NBR = 10

ITERATION_AE = 250

NFC=32
MULTI_GPU = False


IM_SIZE_GAN = 1024
BATCH_SIZE_GAN = 4

IM_SIZE_AE = 512
BATCH_SIZE_AE = 8

ct = datetime.datetime.now()  
TRIAL_NAME = 'trial-pr-%s-%d-%d-%d-%d'%(DATA_NAME, ct.month, ct.day, ct.hour, ct.minute)
SAVE_FOLDER = '/content/My_drive/MyDrive/rajat-Inspiron-3576/SEM7/DEEP_LEARNING/Project/Files/saved'

PRETRAINED_AE_PATH = None 
PRETRAINED_AE_ITER = 12000

GAN_CKECKPOINT =None

TRAIN_AE_ONLY = False
TRAIN_GAN_ONLY = False

data_root_colorful = '/content/My_drive/MyDrive/rajat-Inspiron-3576/SEM7/DEEP_LEARNING/Project/Files/images/rgb_images'
data_root_sketch_1 = '/content/My_drive/MyDrive/rajat-Inspiron-3576/SEM7/DEEP_LEARNING/Project/Files/sketch/sketch-1'
data_root_sketch_2 = '/content/My_drive/MyDrive/rajat-Inspiron-3576/SEM7/DEEP_LEARNING/Project/Files/sketch/sketch-2'
data_root_sketch_3 = '/content/My_drive/MyDrive/rajat-Inspiron-3576/SEM7/DEEP_LEARNING/Project/Files/saved/train_results/synthesized_sketches'

In [None]:
import torchvision.transforms.functional as F

In [None]:
F.crop()

<module 'torchvision.transforms.functional' from '/usr/local/lib/python3.7/dist-packages/torchvision/transforms/functional.py'>

In [None]:
import os
import random
import numpy as np
from PIL import Image, ImageFilter
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

import torch
from torch.utils.data import Dataset
from torchvision import transforms
import torchvision.transforms.functional as Fff
import torch.utils.data as data

normalize_imagenet = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

def _noise_adder(img):
    return torch.empty_like(img, dtype=img.dtype).uniform_(0.0, 1/128.0) + img


def _rescale(img):
    return img * 2.0 - 1.0


def trans_maker(size=512):
    trans = transforms.Compose([
                    transforms.Resize((size, size)),
                    transforms.ToTensor(),
                    _rescale
                    ])
    return trans


def trans_maker_augment(size=256):
    trans = transforms.Compose([ 
                    transforms.Resize((int(size*1.1),int(size*1.1))),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop((size, size)),
                    transforms.ToTensor(),
                    _rescale
                    ])
    return trans


class SelfSupervisedDataset(Dataset):
    def __init__(self, data_root, data_root_2, im_size=512, nbr_cls=2000, rand_crop=True):
        super(SelfSupervisedDataset, self).__init__()
        self.root = data_root
        self.skt_root = data_root_2

        self.frame = self._parse_frame()
        random.shuffle(self.frame)

        self.nbr_cls = nbr_cls
        self.set_offset = 0

        self.im_size = im_size
        self.transform_rd = transforms.Compose([ 
                            transforms.Resize((int(im_size*1.3), int(im_size*1.3))),
                            transforms.RandomCrop( (int(im_size), int(im_size)) ),
                            transforms.RandomRotation( 30 ),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            _rescale])

        self.crop = rand_crop
        if self.crop:
            self.transform_1 = transforms.Resize((int(im_size*1.1), int(im_size*1.1)))
            self.transform_2 = transforms.Compose([ transforms.ToTensor(),
                                                    _rescale
                                                    ])
            self.rand_range = int(self.im_size * 0.1)
        else:
            self.transform_normal = trans_maker(size=im_size)
    
        self.transform_flip = transforms.RandomHorizontalFlip(p=1)

        self.transform_erase = transforms.Compose([
                        transforms.RandomErasing(p=0.8, scale=(0.02, 0.1), value=1),
                        transforms.RandomErasing(p=0.8, scale=(0.02, 0.1), value=1),
                                ])

    def _parse_frame(self):
        frame = []
        img_names = os.listdir(self.root)
        img_names.sort()
        for i in range(len(img_names)):
            img_name = '%d.jpg'%(i)
            img_name = img_names[i]
            image_path = os.path.join(self.root, img_names[i])
            skt_path = os.path.join( self.skt_root,  img_name)
            if os.path.exists(image_path) and os.path.exists(skt_path): 
                frame.append( (image_path, skt_path) )
        return frame

    def __len__(self):
        return self.nbr_cls

    def _next_set(self):
        self.set_offset += self.nbr_cls
        if self.set_offset > ( len(self.frame) - self.nbr_cls ):
            random.shuffle(self.frame)
            self.set_offset = 0

    def __getitem__(self, idx):
        file, skt_path = self.frame[idx+self.set_offset]
        img = Image.open(file).convert('RGB')
        skt = Image.open(skt_path).convert('L')
        
        bold_factor = 3 
        skt_bold = skt.filter( ImageFilter.MinFilter(size=bold_factor) )

        if random.randint(0, 1) == 1:
            img = self.transform_flip(img)
            skt = self.transform_flip(skt)
            skt_bold = self.transform_flip(skt_bold)

        img_rd = self.transform_rd(img) 

        if self.crop:
            img_normal = self.transform_1(img) 
            skt_normal = self.transform_1(skt) 
            skt_bold = self.transform_1(skt_bold) 

            i = random.randint(0, self.rand_range) 
            j = random.randint(0, self.rand_range) 

            img_normal = Fff.crop(img_normal, i, j, self.im_size, self.im_size)
            skt_normal = Fff.crop(skt_normal, i, j, self.im_size, self.im_size)
            skt_bold = Fff.crop(skt_bold, i, j, self.im_size, self.im_size)

            img_normal = self.transform_2(img_normal) 
            skt_normal = self.transform_2(skt_normal) 
            skt_bold = self.transform_2(skt_bold) 
        else:
            img_normal = self.transform_normal(img)
            skt_normal = self.transform_normal(skt)
            skt_bold = self.transform_normal(skt_bold)

        skt_erased = self.transform_erase(skt_normal)
        skt_erased_bold = self.transform_erase(skt_bold)
        return img_rd, img_normal, skt_normal, skt_bold, skt_erased, skt_erased_bold, idx


class PairedMultiDataset(Dataset):
    def __init__(self, data_root_1, data_root_2, data_root_3, data_root_4, rand_crop=True, im_size=512):
        super(PairedMultiDataset, self).__init__()
        self.root_a = data_root_1
        self.root_b = data_root_2
        self.root_c = data_root_3
        self.root_d = data_root_4

        self.frame = self._parse_frame()
        
        self.crop = rand_crop
        self.im_size = im_size
        if self.crop:
            self.transform_1 = transforms.Resize((int(im_size*1.1), int(im_size*1.1)))
            self.transform_2 = transforms.Compose([ transforms.ToTensor(),
                                                    _rescale
                                                    ])
            self.rand_range = int(self.im_size * 0.1)
        else:
            self.transform = trans_maker( int( im_size ) )

    def _parse_frame(self):
        frame = []

        img_names = os.listdir(self.root_a)
        img_names.sort()
        for i in range(len(img_names)):
            img_name = '%d.jpg'%(i)
            img_name = img_names[i]
            image_a_path = os.path.join(self.root_a, img_names[i])
            if os.path.exists(image_a_path): 
                image_b_path = os.path.join(self.root_b, img_name)
                image_c_path = os.path.join(self.root_c, img_name)
                image_d_path = os.path.join(self.root_d, img_name)
                if os.path.exists(image_b_path) and os.path.exists(image_c_path) and os.path.exists(image_d_path):
                    frame.append( (image_a_path, image_b_path, image_c_path, image_d_path) )
                else:
                    print('2', image_a_path, image_b_path)
            else:
                print("1", image_a_path)
        return frame

    def __len__(self):
        return len(self.frame)

    def __getitem__(self, idx):
        file_a, file_b, file_c, file_d = self.frame[idx]
        img_a = Image.open(file_a).convert('RGB')
        img_b = Image.open(file_b).convert('L')
        img_c = Image.open(file_c).convert('L')
        img_d = Image.open(file_d).convert('L')
            
        if self.crop:
            img_a = self.transform_1(img_a) 
            img_b = self.transform_1(img_b) 
            img_c = self.transform_1(img_c) 
            img_d = self.transform_1(img_d) 

            i = random.randint(0, self.rand_range) 
            j = random.randint(0, self.rand_range) 
            img_a = Fff.crop(img_a, i, j, self.im_size, self.im_size)
            img_b = Fff.crop(img_b, i, j, self.im_size, self.im_size)
            img_c = Fff.crop(img_c, i, j, self.im_size, self.im_size)
            img_d = Fff.crop(img_d, i, j, self.im_size, self.im_size)

            img_a = self.transform_2(img_a) 
            img_b = self.transform_2(img_b) 
            img_c = self.transform_2(img_c) 
            img_d = self.transform_2(img_d) 
        else:
            img_a = self.transform(img_a) 
            img_b = self.transform(img_b) 
            img_c = self.transform(img_c) 
            img_d = self.transform(img_d) 


        return (img_a, img_b, img_c, img_d)


class PairedDataset(Dataset):
    def __init__(self, data_root_1, data_root_2, transform=trans_maker(512)):
        super(PairedDataset, self).__init__()
        self.root_a = data_root_1
        self.root_b = data_root_2

        self.frame = self._parse_frame()
        self.transform = transform


    def _parse_frame(self):
        frame = []
        img_names = os.listdir(self.root_a)
        img_names.sort()
        for i in range(len(img_names)):
            img_name = '%d.jpg'%(i)
            if DATA_NAME == 'shoe':
                img_name = img_names[i]
            image_a_path = os.path.join(self.root_a, img_names[i])
            if ('.jpg' in image_a_path) or ('.png' in image_a_path): 
                image_b_path = os.path.join(self.root_b, img_name)
                if os.path.exists(image_b_path):
                    frame.append( (image_a_path, image_b_path) )

        return frame

    def __len__(self):
        return len(self.frame)

    def __getitem__(self, idx):
        file_a, file_b = self.frame[idx]
        img_a = Image.open(file_a).convert('RGB')
        img_b = Image.open(file_b).convert('L')
            
        if self.transform:
            img_a = self.transform(img_a) 
            img_b = self.transform(img_b) 

        return (img_a, img_b)


class  ImageFolder(Dataset):
    def __init__(self, data_root, transform=trans_maker(512)):
        super( ImageFolder, self).__init__()
        self.root = data_root

        self.frame = self._parse_frame()
        self.transform = transform

    def _parse_frame(self):
        frame = []
        img_names = os.listdir(self.root)
        img_names.sort()
        for i in range(len(img_names)):
            image_path = os.path.join(self.root, img_names[i])
            if ('.jpg' in image_path) or ('.png' in image_path): 
                frame.append(image_path)

        return frame

    def __len__(self):
        return len(self.frame)

    def __getitem__(self, idx):
        file = self.frame[idx]
        img = Image.open(file).convert('RGB')
            
        if self.transform:
            img = self.transform(img) 
        return img




def InfiniteSampler(n):
    i = n - 1
    order = np.random.permutation(n)
    while True:
        yield order[i]
        i += 1
        if i >= n:
            np.random.seed()
            order = np.random.permutation(n)
            i = 0


class InfiniteSamplerWrapper(data.sampler.Sampler):
    def __init__(self, data_source):
        self.num_samples = len(data_source)

    def __iter__(self):
        return iter(InfiniteSampler(self.num_samples))

    def __len__(self):
        return 2 ** 31

In [None]:
!pip uninstall torch_dwconv

Found existing installation: torch-dwconv 0.1.0
Uninstalling torch-dwconv-0.1.0:
  Would remove:
    /usr/local/lib/python3.7/dist-packages/torch_dwconv-0.1.0.dist-info/*
    /usr/local/lib/python3.7/dist-packages/torch_dwconv/*
Proceed (y/n)? y
  Successfully uninstalled torch-dwconv-0.1.0


In [None]:
!pip install torch_dwconv

Collecting torch_dwconv
  Downloading torch_dwconv-0.1.0.tar.gz (249 kB)
[?25l[K     |█▎                              | 10 kB 17.2 MB/s eta 0:00:01[K     |██▋                             | 20 kB 17.5 MB/s eta 0:00:01[K     |████                            | 30 kB 19.9 MB/s eta 0:00:01[K     |█████▎                          | 40 kB 21.1 MB/s eta 0:00:01[K     |██████▌                         | 51 kB 16.3 MB/s eta 0:00:01[K     |███████▉                        | 61 kB 11.3 MB/s eta 0:00:01[K     |█████████▏                      | 71 kB 11.3 MB/s eta 0:00:01[K     |██████████▌                     | 81 kB 12.2 MB/s eta 0:00:01[K     |███████████▉                    | 92 kB 12.0 MB/s eta 0:00:01[K     |█████████████                   | 102 kB 12.8 MB/s eta 0:00:01[K     |██████████████▍                 | 112 kB 12.8 MB/s eta 0:00:01[K     |███████████████▊                | 122 kB 12.8 MB/s eta 0:00:01[K     |█████████████████               | 133 kB 12.8 MB/s eta 0

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from torch_dwconv import DepthwiseConv2d

import math
import random


def weights_init(m):
    classname = m.__class__.__name__
    try:
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)
    except:
        pass


class DMI(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        
        self.weight_a = nn.Parameter(torch.ones(1, in_channels, 1, 1)*1.01)
        self.weight_b = nn.Parameter(torch.ones(1, in_channels, 1, 1)*0.99)

        self.bias_a = nn.Parameter(torch.zeros(1, in_channels, 1, 1)+0.01)
        self.bias_b = nn.Parameter(torch.zeros(1, in_channels, 1, 1)-0.01)

    def forward(self, feat, mask):
        if feat.shape[1] > mask.shape[1]:
            channel_scale = feat.shape[1] // mask.shape[1]
            mask = mask.repeat(1, channel_scale, 1, 1)
        
        mask = F.interpolate(mask, size=feat.shape[2])
        feat_a = self.weight_a * feat * mask + self.bias_a
        feat_b = self.weight_b * feat * (1-mask) + self.bias_b
        return feat_a + feat_b


class Swish(nn.Module):
    def forward(self, feat):
        return feat * torch.sigmoid(feat)


class Squeeze(nn.Module):
    def forward(self, feat):
        return feat.squeeze(-1).squeeze(-1)


class UnSqueeze(nn.Module):
    def forward(self, feat):
        return feat.unsqueeze(-1).unsqueeze(-1)


class GLU(nn.Module):
    def __init__(self):
        super(GLU, self).__init__()

    def forward(self, x):
        nc = x.size(1)
        assert nc % 2 == 0, 'channels dont divide 2!'
        nc = int(nc/2)
        return x[:, :nc] * torch.sigmoid(x[:, nc:])


class NoiseInjection(nn.Module):
    def __init__(self, ch):
        super().__init__()

        self.weight = nn.Parameter(torch.zeros(1, ch, 1, 1), requires_grad=True)

    def forward(self, feat, noise=None):
        if noise is None:
            batch, _, height, width = feat.shape
            noise = torch.randn(batch, 1, height, width).to(feat.device)

        return feat + self.weight * noise


class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            spectral_norm( nn.Linear(channel, channel // reduction, bias=False) ),
            nn.ReLU(inplace=True),
            spectral_norm( nn.Linear(channel // reduction, channel, bias=False) ),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)


class ResBlkG(nn.Module):
    def __init__(self, ch, ch_m=4):
        super().__init__()
        self.main = nn.Sequential(spectral_norm( nn.BatchNorm2d(ch) ),
                            spectral_norm( nn.Conv2d(ch, ch*ch_m, 1, 1, 0, bias=False) ),
                            spectral_norm( nn.BatchNorm2d(ch*ch_m) ), Swish(),
                            spectral_norm( DepthwiseConv2d(ch*ch_m, ch*ch_m, 5, 1, 2) ),
                            spectral_norm( nn.BatchNorm2d(ch*ch_m) ), Swish(),
                            spectral_norm( nn.Conv2d(ch*ch_m, ch, 1, 1, 0, bias=False) ),
                            spectral_norm( nn.BatchNorm2d(ch) ),
                            SELayer(ch))
    def forward(self, feat):
        return feat + self.main(feat)


class ResBlkE(nn.Module):
    def __init__(self, ch):
        super().__init__()

        self.main = nn.Sequential(
                            spectral_norm( nn.BatchNorm2d(ch) ), Swish(),
                            spectral_norm( nn.Conv2d(ch, ch, 3, 1, 1, bias=False) ),
                            spectral_norm( nn.BatchNorm2d(ch) ), Swish(),
                            spectral_norm( nn.Conv2d(ch, ch, 3, 1, 1, bias=False) ),
                            SELayer(ch))

    def forward(self, feat):
        return feat + self.main(feat)


class StyleEncoder(nn.Module):
    def __init__(self, nfc=64, nbr_cls=500):
        super().__init__()

        self.nfc = nfc

        self.sf_256 = nn.Sequential(nn.Conv2d(3, nfc//4, 4, 2, 1, bias=False),nn.LeakyReLU(0.2,inplace=True))
        self.sf_128 = nn.Sequential(nn.Conv2d(nfc//4, nfc//2, 4, 2, 1, bias=False),nn.BatchNorm2d(nfc//2),nn.LeakyReLU(0.1,inplace=True)) 
        self.sf_64 = nn.Sequential(nn.Conv2d(nfc//2, nfc, 4, 2, 1, bias=False),nn.BatchNorm2d(nfc),nn.LeakyReLU(0.1,inplace=True)) 
        
        self.sf_32 = nn.Sequential(nn.Conv2d(nfc, nfc*2, 4, 2, 1, bias=False), ResBlkE(nfc*2))
        self.sf_16 = nn.Sequential(nn.LeakyReLU(0.1,inplace=True), nn.Conv2d(nfc*2, nfc*4, 4, 2, 1, bias=False), ResBlkE(nfc*4))
        self.sf_8 = nn.Sequential(nn.LeakyReLU(0.1,inplace=True), nn.Conv2d(nfc*4, nfc*8, 4, 2, 1, bias=False), ResBlkE(nfc*8))
        
        self.sfv_32 = nn.Sequential( nn.AdaptiveAvgPool2d(output_size=4), nn.Conv2d(nfc*2, nfc*2, 4, 1, 0, bias=False), Squeeze() )
        self.sfv_16 = nn.Sequential( nn.AdaptiveAvgPool2d(output_size=4), nn.Conv2d(nfc*4, nfc*4, 4, 1, 0, bias=False), Squeeze() )
        self.sfv_8 = nn.Sequential( nn.AdaptiveAvgPool2d(output_size=4), nn.Conv2d(nfc*8, nfc*8, 4, 1, 0, bias=False), Squeeze() )

        self.nbr_cls = nbr_cls
        self.final_cls = None

    def reset_cls(self):
        if self.final_cls is None:
            self.final_cls = nn.Sequential(nn.LeakyReLU(0.1), nn.Linear(self.nfc*8, self.nbr_cls))
        stdv = 1. / math.sqrt(self.final_cls[1].weight.size(1))
        self.final_cls[1].weight.data.uniform_(-stdv, stdv)
        if self.final_cls[1].bias is not None:
            self.final_cls[1].bias.data.uniform_(-0.1*stdv, 0.1*stdv)

    def get_feats(self, image):
        feat = self.sf_256(image)
        feat = self.sf_128(feat)
        feat = self.sf_64(feat)
        feat_32 = self.sf_32(feat)
        feat_16 = self.sf_16(feat_32)
        feat_8 = self.sf_8(feat_16)
        
        feat_32 = self.sfv_32(feat_32)
        feat_16 = self.sfv_16(feat_16)
        feat_8 = self.sfv_8(feat_8)

        return feat_32, feat_16, feat_8

    def forward(self, image):
        feat_32, feat_16, feat_8 = self.get_feats(image)

        pred_cls = self.final_cls(feat_8)
        return [feat_32, feat_16, feat_8], pred_cls


class ContentEncoder(nn.Module):
    def __init__(self, nfc=64):
        super().__init__()

        self.cf_256 = nn.Sequential(nn.Conv2d(1, nfc//4, 4, 2, 1, bias=False),nn.LeakyReLU(0.2,inplace=True))
        self.cf_128 = nn.Sequential(nn.Conv2d(nfc//4, nfc//2, 4, 2, 1, bias=False),nn.BatchNorm2d( nfc//2),nn.LeakyReLU(0.1,inplace=True)) 
        self.cf_64 = nn.Sequential(nn.Conv2d( nfc//2, nfc, 4, 2, 1, bias=False),nn.BatchNorm2d(nfc),nn.LeakyReLU(0.1,inplace=True)) 
        
        self.cf_32 = nn.Sequential(nn.Conv2d(nfc, nfc*2, 4, 2, 1, bias=False), ResBlkE(nfc*2))
        self.cf_16 = nn.Sequential(nn.LeakyReLU(0.1,inplace=True), nn.Conv2d(nfc*2, nfc*4, 4, 2, 1, bias=False), ResBlkE(nfc*4))
        self.cf_8 = nn.Sequential(nn.LeakyReLU(0.1,inplace=True), nn.Conv2d(nfc*4, nfc*8, 4, 2, 1, bias=False), ResBlkE(nfc*8))
        
    def get_feats(self, image):
        feat = self.cf_256(image)
        feat = self.cf_128(feat)
        feat = self.cf_64(feat)
        feat_32 = self.cf_32(feat)
        feat_16 = self.cf_16(feat_32)
        feat_8 = self.cf_8(feat_16)

        return feat_32, feat_16, feat_8

    def forward(self, image):
        feat_32, feat_16, feat_8 = self.get_feats(image)
        return [feat_32, feat_16, feat_8]


def up_decoder(ch_in, ch_out):
    return nn.Sequential(
        nn.UpsamplingNearest2d(scale_factor=2),
        nn.Conv2d(ch_in, ch_out*2, 3, 1, 1, bias=False),
        nn.InstanceNorm2d( ch_out*2 ), GLU())


class Decoder(nn.Module):
    def __init__(self, nfc=64):
        super().__init__()
 
        self.base_feat = nn.Parameter(torch.randn(1, nfc*8, 8, 8).normal_(0, 1), requires_grad=True)
        
        self.dmi_8 = DMI(nfc*8)
        self.dmi_16 = DMI(nfc*4)

        self.feat_8_1 = nn.Sequential( ResBlkG(nfc*16), nn.LeakyReLU(0.1,inplace=True), nn.Conv2d(nfc*16, nfc*8, 3, 1, 1, bias=False), nn.InstanceNorm2d(nfc*8) )
        self.feat_8_2 = nn.Sequential( nn.LeakyReLU(0.1,inplace=True), ResBlkG(nfc*8) )
        self.feat_16  = nn.Sequential( nn.LeakyReLU(0.1,inplace=True), nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(nfc*8, nfc*4, 3, 1, 1, bias=False), ResBlkG(nfc*4) )
        self.feat_32  = nn.Sequential( nn.LeakyReLU(0.1,inplace=True), nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(nfc*8, nfc*2, 3, 1, 1, bias=False), ResBlkG(nfc*2) )
        self.feat_64  = nn.Sequential( nn.LeakyReLU(0.1,inplace=True), up_decoder(nfc*4, nfc) ) 
        self.feat_128 = up_decoder(nfc*1, nfc//2)
        self.feat_256 = up_decoder(nfc//2, nfc//4)
        self.feat_512 = up_decoder(nfc//4, nfc//8)
        
        self.to_rgb = nn.Sequential( nn.Conv2d(nfc//8, 3, 3, 1, 1, bias=False), nn.Tanh() )
        
        self.style_8 = nn.Sequential( nn.Linear(nfc*8, nfc*8), nn.ReLU(), nn.Linear(nfc*8, nfc*8), nn.BatchNorm1d(nfc*8), UnSqueeze() )
        self.style_64 = nn.Sequential( nn.Linear(nfc*8, nfc), nn.ReLU(), nn.Linear(nfc, nfc), nn.Sigmoid() , UnSqueeze())
        self.style_128 = nn.Sequential( nn.Linear(nfc*4, nfc//2), nn.ReLU(), nn.Linear(nfc//2, nfc//2), nn.Sigmoid() , UnSqueeze())
        self.style_256 = nn.Sequential( nn.Linear(nfc*2, nfc//4), nn.ReLU(), nn.Linear(nfc//4, nfc//4), nn.Sigmoid() , UnSqueeze())

    def forward(self, content_feats, style_vectors):

        feat_8 = self.feat_8_1( torch.cat( [content_feats[2], self.base_feat.repeat(style_vectors[0].shape[0], 1, 1, 1)], dim=1 ) )            
        feat_8 = self.dmi_8(feat_8, content_feats[2])

        bs = feat_8.shape[0]

        feat_8 = feat_8 * self.style_8( style_vectors[2] )
        feat_8 = self.feat_8_2(feat_8)

        feat_16 = self.feat_16(feat_8) 
        feat_16 = self.dmi_16(feat_16, content_feats[1])
        feat_16 = torch.cat([feat_16, content_feats[1]], dim=1)

        feat_32 = self.feat_32(feat_16) 
        feat_32 = torch.cat([feat_32, content_feats[0]], dim=1)

        feat_64 = self.feat_64(feat_32) * self.style_64(style_vectors[2]) 
        feat_128 = self.feat_128(feat_64) * self.style_128(style_vectors[1]) 
        feat_256 = self.feat_256(feat_128) * self.style_256(style_vectors[0]) 
        feat_512 = self.feat_512(feat_256) 

        return self.to_rgb(feat_512)


class AE(nn.Module):
    def __init__(self, nfc, nbr_cls=500):
        super().__init__()  

        self.style_encoder = StyleEncoder(nfc, nbr_cls=nbr_cls)
        self.content_encoder = ContentEncoder(nfc)
        self.decoder = Decoder(nfc)

    @torch.no_grad()
    def forward(self, skt_img, style_img):
        style_feats = self.style_encoder.get_feats( F.interpolate(style_img, size=512) )
        content_feats = self.content_encoder( F.interpolate( skt_img , size=512) )
        gimg = self.decoder(content_feats, style_feats)
        return gimg, style_feats

    def load_state_dicts(self, path):
        ckpt = torch.load(path)
        self.style_encoder.reset_cls()
        self.style_encoder.load_state_dict(ckpt['s'])
        self.content_encoder.load_state_dict(ckpt['c'])
        self.decoder.load_state_dict(ckpt['d'])
        print('AE load success')

def down_gan(ch_in, ch_out):
    return nn.Sequential(
        spectral_norm(nn.Conv2d(ch_in, ch_out, 4, 2, 1, bias=False)),
        nn.BatchNorm2d(ch_out), nn.LeakyReLU(0.2, inplace=True))


def up_gan(ch_in, ch_out):
    return nn.Sequential(
        nn.UpsamplingNearest2d(scale_factor=2),
        spectral_norm( nn.Conv2d(ch_in, ch_out*2, 3, 1, 1, bias=False) ),
        nn.BatchNorm2d( ch_out*2 ), NoiseInjection(ch_out*2), GLU())


def repeat_upscale(feat, scale_factor=2):
    feat = feat.repeat(1,1,scale_factor,scale_factor)
    return feat


class RefineGenerator_art(nn.Module):
    def __init__(self, nfc=64, im_size=512):
        super().__init__()  

        self.im_size = im_size

        d16, d32, d64, d128, d256, d512 = nfc*8, nfc*8, nfc*4, nfc*2, nfc, nfc//2 

        self.from_noise_32 = nn.Sequential( UnSqueeze(),
            spectral_norm(nn.ConvTranspose2d(nfc*8, nfc*8, 4, 1, 0, bias=False)), #4
            nn.BatchNorm2d(nfc*8), GLU(), up_gan(nfc*4, nfc*2),  up_gan(nfc*2, nfc*2), up_gan(nfc*2, nfc*1)) #32 

        self.from_style = nn.Sequential( UnSqueeze(),
            spectral_norm(nn.ConvTranspose2d(nfc*(8+4+2), nfc*16, 4, 1, 0, bias=False)), #4
            nn.BatchNorm2d(nfc*16), GLU(), up_gan(nfc*8, nfc*4) )
        
        self.encode_256 = nn.Sequential( spectral_norm(nn.Conv2d(3, d256, 4, 2, 1, bias=False)),nn.LeakyReLU(0.2,inplace=True))
        self.encode_128 = down_gan(d256, d128)
        self.encode_64 = down_gan(d128, d64)
        self.encode_32 = down_gan(d64, d32)
        self.encode_16 = down_gan(d32, d16)

        self.residual_16 = nn.Sequential( ResBlkG(d16+nfc*4), Swish(), ResBlkG(d16+nfc*4), Swish() )

        self.decode_32  = up_gan(d16+nfc*4, d32)
        self.decode_64  = up_gan(d32+nfc, d64) 
        self.decode_128 = up_gan(d64, d128)
        self.decode_256 = up_gan(d128, d256)
        self.decode_512 = up_gan(d256, d512)
        if im_size == 1024:
            self.decode_1024 = up_gan(d512, nfc//4)

        self.style_64  =  nn.Sequential( spectral_norm( nn.Linear(nfc*8, d64) ), nn.ReLU(), nn.Linear(d64, d64),  nn.Sigmoid(), UnSqueeze())
        self.style_128 =  nn.Sequential( spectral_norm( nn.Linear(nfc*8, d128)), nn.ReLU(), nn.Linear(d128, d128),nn.Sigmoid(), UnSqueeze())
        self.style_256 =  nn.Sequential( spectral_norm( nn.Linear(nfc*4, d256)), nn.ReLU(), nn.Linear(d256, d256),nn.Sigmoid(), UnSqueeze())
        self.style_512 =  nn.Sequential( spectral_norm( nn.Linear(nfc*2, d512)), nn.ReLU(), nn.Linear(d512, d512),nn.Sigmoid(), UnSqueeze())
        
        self.to_rgb = nn.Sequential( nn.Conv2d(nfc//2, 3, 3, 1, 1, bias=False), nn.Tanh() )
        if im_size == 1024:
            self.to_rgb = nn.Sequential( nn.Conv2d(nfc//4, 3, 3, 1, 1, bias=False), nn.Tanh() )
        
        if DATA_NAME=='shoe':
            self.bs_0 = nn.Parameter(torch.randn(1, nfc*2))
            self.bs_1 = nn.Parameter(torch.randn(1, nfc*4))
            self.bs_2 = nn.Parameter(torch.randn(1, nfc*8))

    def forward(self, image, style_vectors):
         
        s_16 = repeat_upscale( self.from_style(torch.cat(style_vectors,1)), scale_factor=2 )
        if DATA_NAME=='shoe':  
            s_16 = torch.zeros_like(s_16)
            
        n_32 = self.from_noise_32(torch.randn_like(style_vectors[2]))

        e_256 = self.encode_256( image )
        e_128 = self.encode_128( e_256 )
        e_64 = self.encode_64( e_128 )
        e_32 = self.encode_32( e_64 )
        e_16 = self.encode_16(e_32)

        e_16 = self.residual_16( torch.cat([e_16, s_16],dim=1) )
        
        d_32 = self.decode_32( e_16 )
        d_64 = self.decode_64( torch.cat([d_32, n_32], dim=1) ) 
        if DATA_NAME!='shoe':
            d_64 *= self.style_64(style_vectors[2])
        d_128 = self.decode_128( d_64 + e_64 ) 
        if DATA_NAME!='shoe':
            d_128 *= self.style_128(style_vectors[2])
        d_256 = self.decode_256( d_128 + e_128 )
        if DATA_NAME!='shoe':
            d_256 *= self.style_256(style_vectors[1])
        d_512 = self.decode_512( d_256 + e_256 ) 
        if DATA_NAME!='shoe':
            d_512 *= self.style_512(style_vectors[0])
        
        if self.im_size == 1024:
            d_final = self.decode_1024(d_512)
        else:
            d_final = d_512
        return self.to_rgb(d_final)


class RefineGenerator_face(nn.Module):
    def __init__(self, nfc, im_size):
        super().__init__()  

        self.im_size = im_size

        e1, e2, e3, e4 = 16, 32, 64, 128
        self.encode_1 = down_gan(3, e1)      
        self.encode_2 = down_gan(e1, e2)     
        self.encode_3 = down_gan(e2, e3)     
        self.encode_4 = down_gan(e3, e4)    

        s1, s2, s3, s4 = 256, 128, 128, 64
        self.style = nn.Sequential(nn.Linear(nfc*(8+4+2), 512), nn.LeakyReLU(0.1))
        self.from_style_32 = nn.Sequential(
            spectral_norm(nn.ConvTranspose2d(512, s1, 4, 1, 0, bias=False)), 
            nn.BatchNorm2d(s1), GLU(), up_gan(s1//2, s2), up_gan(s2, s3), up_gan(s3, s4)) 

        d1, d2, d3, d4, d5 = 256, 128, 64, 32, 16
        self.decode_64 = up_gan( e4 + s4 , d1)
        self.decode_128 = up_gan(d1+e3, d2)
        self.decode_256 = up_gan(d2+e2, d3)
        self.decode_512 = up_gan(d3+e1, d4)
        if im_size == 1024:
            self.decode_1024 = up_gan(d4, d5)

        self.style_blocks = nn.ModuleList()
        
        chs = [d1, d2, d3, d4]
        if im_size == 1024:
            chs.append(d5)
        for i in range(len(chs)):
            self.style_blocks.append(nn.Sequential( 
                    nn.Linear(512, chs[i]), nn.ReLU(), nn.Linear(chs[i], chs[i]), nn.Sigmoid() ))

        self.final = nn.Sequential( spectral_norm( 
                            nn.Conv2d(d4, 3, 3, 1, 1, bias=False) ), nn.Tanh() )
        if im_size == 1024:
            self.final = nn.Sequential( spectral_norm( 
                            nn.Conv2d(d5, 3, 3, 1, 1, bias=False) ), nn.Tanh() )
        
    def forward(self, image, style):
        e_256 = self.encode_1( image )
        e_128 = self.encode_2( e_256 )
        e_64 = self.encode_3( e_128 )
        e_32 = self.encode_4( e_64 )

        style = self.style(torch.cat(style, dim=1))
        s_32 = self.from_style_32( style.unsqueeze(-1).unsqueeze(-1) )
        
        if random.randint(0, 1) == 1:
            s_32 = s_32.flip(2)
        if random.randint(0, 1) == 1:
            s_32 = s_32.flip(3)
        
        feat_64 = self.decode_64( torch.cat([e_32, s_32], dim=1) ) * self.style_blocks[0](style).unsqueeze(-1).unsqueeze(-1)
        feat_128 = self.decode_128( torch.cat([e_64, feat_64], dim=1) ) * self.style_blocks[1](style).unsqueeze(-1).unsqueeze(-1)
        feat_256 = self.decode_256( torch.cat([e_128, feat_128], dim=1) ) * self.style_blocks[2](style).unsqueeze(-1).unsqueeze(-1)
        feat_512 = self.decode_512( torch.cat([e_256, feat_256], dim=1) ) * self.style_blocks[3](style).unsqueeze(-1).unsqueeze(-1)
        if self.im_size == 1024:
            feat_1024 = self.decode_1024( feat_512 ) * self.style_blocks[4](style).unsqueeze(-1).unsqueeze(-1)
            return self.final(feat_1024)
        else:
            return self.final(feat_512)


class DownBlock(nn.Module):
    def __init__(self, ch_in, ch_out, ch_skip=0):
        super().__init__()

        self.ch_out = ch_out
        self.down_main = nn.Sequential(
                spectral_norm(nn.Conv2d(ch_in, ch_out, 3, 2, 1, bias=False)),
                nn.BatchNorm2d(ch_out),
                nn.LeakyReLU(0.2, inplace=True),
                spectral_norm(nn.Conv2d(ch_out, ch_out, 3, 1, 1, bias=False)),
                nn.BatchNorm2d(ch_out),
                nn.LeakyReLU(0.2, inplace=True)
                )

        self.skip = False
        
        if ch_skip > 0:  
            self.skip = True 
            self.skip_conv = nn.Sequential(
                nn.AdaptiveAvgPool2d(4),
                spectral_norm( nn.Conv2d(ch_skip, ch_out, 4, 1, 0, bias=False) ),
                nn.ReLU(),
                spectral_norm( nn.Conv2d(ch_out, ch_out*2, 1, 1, 0, bias=False) ),
            ) 

    def forward(self, feat, skip_feat=None):
        feat_out = self.down_main(feat) 
        if skip_feat is not None and self.skip:
            addon = self.skip_conv(skip_feat)
            feat_out = feat_out * torch.sigmoid(addon[:,:self.ch_out]) + torch.tanh(addon[:,self.ch_out:])        
        
        return feat_out


class Discriminator(nn.Module):
    def __init__(self, ndf=64, nc=3, im_size=512):
        super(Discriminator, self).__init__()
        self.ndf = ndf
        self.im_size = im_size

        modules = [
            nn.Sequential(spectral_norm(nn.Conv2d(nc, ndf//4, 4, 2, 1, bias=False)),
                          nn.LeakyReLU(0.2, inplace=True)),
            DownBlock(ndf//4, ndf//2),
            DownBlock(ndf//2, ndf*1),
            DownBlock(ndf*1,  ndf*2),
            DownBlock(ndf*2,  ndf*4, ch_skip=ndf//4),
            ]

        if im_size == 512:
            modules.append(
                DownBlock(ndf*4,  ndf*16, ch_skip=ndf//2),
            )
        elif im_size == 1024:
            modules.append(
                DownBlock(ndf*4,  ndf*8, ch_skip=ndf//2))
            modules.append(
                DownBlock(ndf*8,  ndf*16, ch_skip=ndf*1),
            )
        modules.append(
                        nn.Sequential(
                            spectral_norm(nn.Conv2d(ndf*16, ndf*16, 1, 1, 0, bias=False)),
                            nn.BatchNorm2d(ndf*16),
                            nn.LeakyReLU(0.2, inplace=True),
                            spectral_norm(nn.Conv2d(ndf*16, 1, 4, 1, 0, bias=False)))
                       )

        self.main = nn.ModuleList(modules)
        
        self.apply(weights_init)


    def forward(self, x):
        feat_256 = self.main[0](x)
        feat_128 = self.main[1](feat_256)
        feat_64 = self.main[2](feat_128)
        feat_32 = self.main[3](feat_64)

        feat_16 = self.main[4](feat_32, feat_256)
        feat_8 = self.main[5](feat_16, feat_128)
        if self.im_size == 1024:
            feat_last = self.main[6](feat_8, feat_64)
        else:
            feat_last = feat_8

        return self.main[-1](feat_last)

In [None]:
import torch
import torchvision
print(torch.__version__)
print(torchvision.__version__)

1.9.0+cu102
0.10.0+cu102


In [None]:
!pip3 install torch==1.9.0 torchvision==0.10.0

Collecting torch==1.9.0
  Downloading torch-1.9.0-cp37-cp37m-manylinux1_x86_64.whl (831.4 MB)
[K     |████████████████████████████████| 831.4 MB 2.6 kB/s 
[?25hCollecting torchvision==0.10.0
  Downloading torchvision-0.10.0-cp37-cp37m-manylinux1_x86_64.whl (22.1 MB)
[K     |████████████████████████████████| 22.1 MB 1.5 MB/s 
Installing collected packages: torch, torchvision
  Attempting uninstall: torch
    Found existing installation: torch 1.7.0
    Uninstalling torch-1.7.0:
      Successfully uninstalled torch-1.7.0
  Attempting uninstall: torchvision
    Found existing installation: torchvision 0.8.1
    Uninstalling torchvision-0.8.1:
      Successfully uninstalled torchvision-0.8.1
Successfully installed torch-1.9.0 torchvision-0.10.0


In [None]:
import os
import torch
from torch.autograd import Variable
from pdb import set_trace as st
from IPython import embed

class BaseModel():
    def __init__(self):
        pass;
        
    def name(self):
        return 'BaseModel'

    def initialize(self, use_gpu=True, gpu_ids=[0]):
        self.use_gpu = use_gpu
        self.gpu_ids = gpu_ids

    def forward(self):
        pass

    def get_image_paths(self):
        pass

    def optimize_parameters(self):
        pass

    def get_current_visuals(self):
        return self.input

    def get_current_errors(self):
        return {}

    def save(self, label):
        pass

    def save_network(self, network, path, network_label, epoch_label):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(path, save_filename)
        torch.save(network.state_dict(), save_path)

    def load_network(self, network, network_label, epoch_label):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        print('Loading network from %s'%save_path)
        network.load_state_dict(torch.load(save_path))

    def update_learning_rate():
        pass

    def get_image_paths(self):
        return self.image_paths

    def save_done(self, flag=False):
        np.save(os.path.join(self.save_dir, 'done_flag'),flag)
        np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')

In [None]:
from collections import namedtuple
import torch
from torchvision import models as tv
from IPython import embed

class squeezenet(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True):
        super(squeezenet, self).__init__()
        pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.slice6 = torch.nn.Sequential()
        self.slice7 = torch.nn.Sequential()
        self.N_slices = 7
        for x in range(2):
            self.slice1.add_module(str(x), pretrained_features[x])
        for x in range(2,5):
            self.slice2.add_module(str(x), pretrained_features[x])
        for x in range(5, 8):
            self.slice3.add_module(str(x), pretrained_features[x])
        for x in range(8, 10):
            self.slice4.add_module(str(x), pretrained_features[x])
        for x in range(10, 11):
            self.slice5.add_module(str(x), pretrained_features[x])
        for x in range(11, 12):
            self.slice6.add_module(str(x), pretrained_features[x])
        for x in range(12, 13):
            self.slice7.add_module(str(x), pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1 = h
        h = self.slice2(h)
        h_relu2 = h
        h = self.slice3(h)
        h_relu3 = h
        h = self.slice4(h)
        h_relu4 = h
        h = self.slice5(h)
        h_relu5 = h
        h = self.slice6(h)
        h_relu6 = h
        h = self.slice7(h)
        h_relu7 = h
        vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
        out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)

        return out


class alexnet(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True):
        super(alexnet, self).__init__()
        alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.N_slices = 5
        for x in range(2):
            self.slice1.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(2, 5):
            self.slice2.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(5, 8):
            self.slice3.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(8, 10):
            self.slice4.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(10, 12):
            self.slice5.add_module(str(x), alexnet_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1 = h
        h = self.slice2(h)
        h_relu2 = h
        h = self.slice3(h)
        h_relu3 = h
        h = self.slice4(h)
        h_relu4 = h
        h = self.slice5(h)
        h_relu5 = h
        alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
        out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)

        return out

class vgg16(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True):
        super(vgg16, self).__init__()
        vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.N_slices = 5
        for x in range(4):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(23, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        h = self.slice5(h)
        h_relu5_3 = h
        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)

        return out



class resnet(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True, num=18):
        super(resnet, self).__init__()
        if(num==18):
            self.net = tv.resnet18(pretrained=pretrained)
        elif(num==34):
            self.net = tv.resnet34(pretrained=pretrained)
        elif(num==50):
            self.net = tv.resnet50(pretrained=pretrained)
        elif(num==101):
            self.net = tv.resnet101(pretrained=pretrained)
        elif(num==152):
            self.net = tv.resnet152(pretrained=pretrained)
        self.N_slices = 5

        self.conv1 = self.net.conv1
        self.bn1 = self.net.bn1
        self.relu = self.net.relu
        self.maxpool = self.net.maxpool
        self.layer1 = self.net.layer1
        self.layer2 = self.net.layer2
        self.layer3 = self.net.layer3
        self.layer4 = self.net.layer4

    def forward(self, X):
        h = self.conv1(X)
        h = self.bn1(h)
        h = self.relu(h)
        h_relu1 = h
        h = self.maxpool(h)
        h = self.layer1(h)
        h_conv2 = h
        h = self.layer2(h)
        h_conv3 = h
        h = self.layer3(h)
        h_conv4 = h
        h = self.layer4(h)
        h_conv5 = h

        outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
        out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)

        return out

In [None]:
!pip install lpips

Collecting lpips
  Downloading lpips-0.1.4-py3-none-any.whl (53 kB)
[?25l[K     |██████                          | 10 kB 21.7 MB/s eta 0:00:01[K     |████████████▏                   | 20 kB 26.4 MB/s eta 0:00:01[K     |██████████████████▎             | 30 kB 27.3 MB/s eta 0:00:01[K     |████████████████████████▍       | 40 kB 20.9 MB/s eta 0:00:01[K     |██████████████████████████████▌ | 51 kB 15.4 MB/s eta 0:00:01[K     |████████████████████████████████| 53 kB 2.1 MB/s 
Installing collected packages: lpips
Successfully installed lpips-0.1.4


In [None]:
from __future__ import absolute_import

import sys
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable
import numpy as np
from pdb import set_trace as st
from skimage import color
from IPython import embed

import lpips as util

def spatial_average(in_tens, keepdim=True):
    return in_tens.mean([2,3],keepdim=keepdim)

def upsample(in_tens, out_H=64):
    in_H = in_tens.shape[2]
    scale_factor = 1.*out_H/in_H

    return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)

class PNetLin(nn.Module):
    def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
        super(PNetLin, self).__init__()

        self.pnet_type = pnet_type
        self.pnet_tune = pnet_tune
        self.pnet_rand = pnet_rand
        self.spatial = spatial
        self.lpips = lpips
        self.version = version
        self.scaling_layer = ScalingLayer()

        if(self.pnet_type in ['vgg','vgg16']):
            net_type = vgg16
            self.chns = [64,128,256,512,512]
        elif(self.pnet_type=='alex'):
            net_type = alexnet
            self.chns = [64,192,384,256,256]
        elif(self.pnet_type=='squeeze'):
            net_type = squeezenet
            self.chns = [64,128,256,384,384,512,512]
        self.L = len(self.chns)

        self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)

        if(lpips):
            self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
            self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
            self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
            self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
            self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
            self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
            if(self.pnet_type=='squeeze'):
                self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
                self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
                self.lins+=[self.lin5,self.lin6]

    def forward(self, in0, in1, retPerLayer=False):
        in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
        outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
        feats0, feats1, diffs = {}, {}, {}

        for kk in range(self.L):
            feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
            diffs[kk] = (feats0[kk]-feats1[kk])**2

        if(self.lpips):
            if(self.spatial):
                res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
            else:
                res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
        else:
            if(self.spatial):
                res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
            else:
                res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]

        val = res[0]
        for l in range(1,self.L):
            val += res[l]
        
        if(retPerLayer):
            return (val, res)
        else:
            return val

class ScalingLayer(nn.Module):
    def __init__(self):
        super(ScalingLayer, self).__init__()
        self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
        self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])

    def forward(self, inp):
        return (inp - self.shift) / self.scale


class NetLinLayer(nn.Module):
    def __init__(self, chn_in, chn_out=1, use_dropout=False):
        super(NetLinLayer, self).__init__()

        layers = [nn.Dropout(),] if(use_dropout) else []
        layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
        self.model = nn.Sequential(*layers)


class Dist2LogitLayer(nn.Module):
    def __init__(self, chn_mid=32, use_sigmoid=True):
        super(Dist2LogitLayer, self).__init__()

        layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
        layers += [nn.LeakyReLU(0.2,True),]
        layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
        layers += [nn.LeakyReLU(0.2,True),]
        layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
        if(use_sigmoid):
            layers += [nn.Sigmoid(),]
        self.model = nn.Sequential(*layers)

    def forward(self,d0,d1,eps=0.1):
        return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))

class BCERankingLoss(nn.Module):
    def __init__(self, chn_mid=32):
        super(BCERankingLoss, self).__init__()
        self.net = Dist2LogitLayer(chn_mid=chn_mid)
        self.loss = torch.nn.BCELoss()

    def forward(self, d0, d1, judge):
        per = (judge+1.)/2.
        self.logit = self.net.forward(d0,d1)
        return self.loss(self.logit, per)

class FakeNet(nn.Module):
    def __init__(self, use_gpu=True, colorspace='Lab'):
        super(FakeNet, self).__init__()
        self.use_gpu = use_gpu
        self.colorspace=colorspace

class L2(FakeNet):

    def forward(self, in0, in1, retPerLayer=None):
        assert(in0.size()[0]==1)

        if(self.colorspace=='RGB'):
            (N,C,X,Y) = in0.size()
            value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
            return value
        elif(self.colorspace=='Lab'):
            value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 
                util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
            ret_var = Variable( torch.Tensor((value,) ) )
            if(self.use_gpu):
                ret_var = ret_var.cuda()
            return ret_var

class DSSIM(FakeNet):

    def forward(self, in0, in1, retPerLayer=None):
        assert(in0.size()[0]==1)

        if(self.colorspace=='RGB'):
            value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
        elif(self.colorspace=='Lab'):
            value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 
                util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
        ret_var = Variable( torch.Tensor((value,) ) )
        if(self.use_gpu):
            ret_var = ret_var.cuda()
        return ret_var

def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print('Network',net)
    print('Total number of parameters: %d' % num_params)

In [None]:
from __future__ import absolute_import
import sys
import numpy as np
import torch
from torch import nn
import os
from collections import OrderedDict
from torch.autograd import Variable
import itertools
from scipy.ndimage import zoom
import fractions
import functools
import skimage.transform
from tqdm import tqdm

from IPython import embed
import lpips as util

class DistModel(BaseModel):
    def name(self):
        return self.model_name

    def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
            use_gpu=True, printNet=False, spatial=False, 
            is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
        
        BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)

        self.model = model
        self.net = net
        self.is_train = is_train
        self.spatial = spatial
        self.gpu_ids = gpu_ids
        self.model_name = '%s [%s]'%(model,net)

        if(self.model == 'net-lin'):
            self.net = PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
                use_dropout=True, spatial=spatial, version=version, lpips=True)
            kw = {}
            if not use_gpu:
                kw['map_location'] = 'cpu'
            if(model_path is None):
                import inspect
                model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net)))

            if(not is_train):
                print('Loading model from: %s'%model_path)
                self.net.load_state_dict(torch.load(model_path, **kw), strict=False)

        elif(self.model=='net'):
            self.net = PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
        elif(self.model in ['L2','l2']):
            self.net = L2(use_gpu=use_gpu,colorspace=colorspace)
            self.model_name = 'L2'
        elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
            self.net = DSSIM(use_gpu=use_gpu,colorspace=colorspace)
            self.model_name = 'SSIM'
        else:
            raise ValueError("Model [%s] not recognized." % self.model)

        self.parameters = list(self.net.parameters())

        if self.is_train:
            self.rankLoss = BCERankingLoss()
            self.parameters += list(self.rankLoss.net.parameters())
            self.lr = lr
            self.old_lr = lr
            self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
        else:
            self.net.eval()

        if(use_gpu):
            self.net.to(gpu_ids[0])
            self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
            if(self.is_train):
                self.rankLoss = self.rankLoss.to(device=gpu_ids[0])

        if(printNet):
            print('---------- Networks initialized -------------')
            print_network(self.net)
            print('-----------------------------------------------')

    def forward(self, in0, in1, retPerLayer=False):

        return self.net.forward(in0, in1, retPerLayer=retPerLayer)

    def optimize_parameters(self):
        self.forward_train()
        self.optimizer_net.zero_grad()
        self.backward_train()
        self.optimizer_net.step()
        self.clamp_weights()

    def clamp_weights(self):
        for module in self.net.modules():
            if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
                module.weight.data = torch.clamp(module.weight.data,min=0)

    def set_input(self, data):
        self.input_ref = data['ref']
        self.input_p0 = data['p0']
        self.input_p1 = data['p1']
        self.input_judge = data['judge']

        if(self.use_gpu):
            self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
            self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
            self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
            self.input_judge = self.input_judge.to(device=self.gpu_ids[0])

        self.var_ref = Variable(self.input_ref,requires_grad=True)
        self.var_p0 = Variable(self.input_p0,requires_grad=True)
        self.var_p1 = Variable(self.input_p1,requires_grad=True)

    def forward_train(self):
        self.d0 = self.forward(self.var_ref, self.var_p0)
        self.d1 = self.forward(self.var_ref, self.var_p1)
        self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)

        self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())

        self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)

        return self.loss_total

    def backward_train(self):
        torch.mean(self.loss_total).backward()

    def compute_accuracy(self,d0,d1,judge):
        d1_lt_d0 = (d1<d0).cpu().data.numpy().flatten()
        judge_per = judge.cpu().numpy().flatten()
        return d1_lt_d0*judge_per + (1-d1_lt_d0)*(1-judge_per)

    def get_current_errors(self):
        retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
                            ('acc_r', self.acc_r)])

        for key in retDict.keys():
            retDict[key] = np.mean(retDict[key])

        return retDict

    def get_current_visuals(self):
        zoom_factor = 256/self.var_ref.data.size()[2]

        ref_img = util.tensor2im(self.var_ref.data)
        p0_img = util.tensor2im(self.var_p0.data)
        p1_img = util.tensor2im(self.var_p1.data)

        ref_img_vis = zoom(ref_img,[zoom_factor, zoom_factor, 1],order=0)
        p0_img_vis = zoom(p0_img,[zoom_factor, zoom_factor, 1],order=0)
        p1_img_vis = zoom(p1_img,[zoom_factor, zoom_factor, 1],order=0)

        return OrderedDict([('ref', ref_img_vis),
                            ('p0', p0_img_vis),
                            ('p1', p1_img_vis)])

    def save(self, path, label):
        if(self.use_gpu):
            self.save_network(self.net.module, path, '', label)
        else:
            self.save_network(self.net, path, '', label)
        self.save_network(self.rankLoss.net, path, 'rank', label)

    def update_learning_rate(self,nepoch_decay):
        lrd = self.lr / nepoch_decay
        lr = self.old_lr - lrd

        for param_group in self.optimizer_net.param_groups:
            param_group['lr'] = lr

        print('update lr [%s] decay: %f -> %f' % (type,self.old_lr, lr))
        self.old_lr = lr

def score_2afc_dataset(data_loader, func, name=''):

    d0s = []
    d1s = []
    gts = []

    for data in tqdm(data_loader.load_data(), desc=name):
        d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist()
        d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist()
        gts+=data['judge'].cpu().numpy().flatten().tolist()

    d0s = np.array(d0s)
    d1s = np.array(d1s)
    gts = np.array(gts)
    scores = (d0s<d1s)*(1.-gts) + (d1s<d0s)*gts + (d1s==d0s)*.5

    return(np.mean(scores), dict(d0s=d0s,d1s=d1s,gts=gts,scores=scores))

def score_jnd_dataset(data_loader, func, name=''):

    ds = []
    gts = []

    for data in tqdm(data_loader.load_data(), desc=name):
        ds+=func(data['p0'],data['p1']).data.cpu().numpy().tolist()
        gts+=data['same'].cpu().numpy().flatten().tolist()

    sames = np.array(gts)
    ds = np.array(ds)

    sorted_inds = np.argsort(ds)
    ds_sorted = ds[sorted_inds]
    sames_sorted = sames[sorted_inds]

    TPs = np.cumsum(sames_sorted)
    FPs = np.cumsum(1-sames_sorted)
    FNs = np.sum(sames_sorted)-TPs

    precs = TPs/(TPs+FPs)
    recs = TPs/(TPs+FNs)
    score = util.voc_ap(recs,precs)

    return(score, dict(ds=ds,sames=sames))

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from skimage.measure import compare_ssim
import torch
from torch.autograd import Variable

class PerceptualLoss(torch.nn.Module):
    def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]):
        super(PerceptualLoss, self).__init__()
        print('Setting up Perceptual loss...')
        self.use_gpu = use_gpu
        self.spatial = spatial
        self.gpu_ids = gpu_ids
        self.model = DistModel()
        self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
        print('...[%s] initialized'%self.model.name())
        print('...Done')

    def forward(self, pred, target, normalize=False):

        if normalize:
            target = 2 * target  - 1
            pred = 2 * pred  - 1

        return self.model.forward(target, pred)

def normalize_tensor(in_feat,eps=1e-10):
    norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
    return in_feat/(norm_factor+eps)

def l2(p0, p1, range=255.):
    return .5*np.mean((p0 / range - p1 / range)**2)

def psnr(p0, p1, peak=255.):
    return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))

def dssim(p0, p1, range=255.):
    return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.

def rgb2lab(in_img,mean_cent=False):
    from skimage import color
    img_lab = color.rgb2lab(in_img)
    if(mean_cent):
        img_lab[:,:,0] = img_lab[:,:,0]-50
    return img_lab

def tensor2np(tensor_obj):
    return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))

def np2tensor(np_obj):
    return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))

def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
    from skimage import color

    img = tensor2im(image_tensor)
    img_lab = color.rgb2lab(img)
    if(mc_only):
        img_lab[:,:,0] = img_lab[:,:,0]-50
    if(to_norm and not mc_only):
        img_lab[:,:,0] = img_lab[:,:,0]-50
        img_lab = img_lab/100.

    return np2tensor(img_lab)

def tensorlab2tensor(lab_tensor,return_inbnd=False):
    from skimage import color
    import warnings
    warnings.filterwarnings("ignore")

    lab = tensor2np(lab_tensor)*100.
    lab[:,:,0] = lab[:,:,0]+50

    rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
    if(return_inbnd):
        lab_back = color.rgb2lab(rgb_back.astype('uint8'))
        mask = 1.*np.isclose(lab_back,lab,atol=2.)
        mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
        return (im2tensor(rgb_back),mask)
    else:
        return im2tensor(rgb_back)

def rgb2lab(input):
    from skimage import color
    return color.rgb2lab(input / 255.)

def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
    image_numpy = image_tensor[0].cpu().float().numpy()
    image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
    return image_numpy.astype(imtype)

def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
    return torch.Tensor((image / factor - cent)
                        [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))

def tensor2vec(vector_tensor):
    return vector_tensor.data.cpu().numpy()[:, :, 0, 0]

def voc_ap(rec, prec, use_07_metric=False):
    if use_07_metric:
        ap = 0.
        for t in np.arange(0., 1.1, 0.1):
            if np.sum(rec >= t) == 0:
                p = 0
            else:
                p = np.max(prec[rec >= t])
            ap = ap + p / 11.
    else:
        mrec = np.concatenate(([0.], rec, [1.]))
        mpre = np.concatenate(([0.], prec, [0.]))

        for i in range(mpre.size - 1, 0, -1):
            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

        i = np.where(mrec[1:] != mrec[:-1])[0]
        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap

def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
    image_numpy = image_tensor[0].cpu().float().numpy()
    image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
    return image_numpy.astype(imtype)

def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
    return torch.Tensor((image / factor - cent)
                        [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))

In [None]:
import torch
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from torchvision import utils as vutils

import random
from tqdm import tqdm

def loss_for_list(loss, fl1, fl2, detach_second=True):
    result_loss = 0
    for f_idx in range(len(fl1)):
        if detach_second:
            result_loss += loss( fl1[f_idx] , fl2[f_idx].detach() )
        else:
            result_loss += loss( fl1[f_idx] , fl2[f_idx] )
    return result_loss


def loss_for_list_perm(loss, fl1, fl2, detach_second=True):
    result_loss = 0
    for f_idx in range(len(fl1)):
        perm = true_randperm(fl1[0].shape[0], fl1[0].device)
        if detach_second:
            result_loss += F.relu( 2 + loss( fl1[f_idx] , fl2[f_idx].detach() ) - loss( fl1[f_idx][perm] , fl2[f_idx].detach() ))
        else:
            result_loss += F.relu( 2 + loss( fl1[f_idx] , fl2[f_idx] ) - loss( fl1[f_idx][perm] , fl2[f_idx] ))
    return result_loss


def loss_for_list_mean(feat_list):
    loss = 0
    for feat in feat_list:
        if len(feat.shape) == 4:
            feat = feat.mean(dim=[2,3])
            loss += F.l1_loss( feat, torch.ones_like(feat) )
        else:
            loss += F.l1_loss( feat, torch.zeros_like(feat) )
    return loss


def train_ae():

    dataset = PairedMultiDataset(data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3, im_size=IM_SIZE_AE, rand_crop=True)
    print(len(dataset))
    dataloader = iter(DataLoader(dataset, BATCH_SIZE_AE, \
        sampler=InfiniteSamplerWrapper(dataset), num_workers=DATALOADER_WORKERS, pin_memory=True))


    dataset_ss = SelfSupervisedDataset(data_root_colorful, data_root_sketch_3, im_size=IM_SIZE_AE, nbr_cls=NBR_CLS, rand_crop=True)
    print(len(dataset_ss), len(dataset_ss.frame))
    dataloader_ss = iter(DataLoader(dataset_ss, BATCH_SIZE_AE, \
        sampler=InfiniteSamplerWrapper(dataset_ss), num_workers=DATALOADER_WORKERS, pin_memory=True))


    style_encoder = StyleEncoder(nfc=NFC, nbr_cls=NBR_CLS).cuda()
    content_encoder = ContentEncoder(nfc=NFC).cuda()
    decoder = Decoder(nfc=NFC).cuda()

    opt_c = optim.Adam(content_encoder.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_s = optim.Adam( style_encoder.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_d = optim.Adam(decoder.parameters(), lr=2e-4, betas=(0.5, 0.999))

    style_encoder.reset_cls()
    style_encoder.final_cls.cuda()

    PRETRAINED_AE_PATH = None
    PRETRAINED_AE_ITER = 12000

    if PRETRAINED_AE_PATH is not None:
        PRETRAINED_AE_PATH = PRETRAINED_AE_PATH + '/models/%d.pth'%PRETRAINED_AE_ITER 
        ckpt = torch.load(PRETRAINED_AE_PATH)
        
        print(PRETRAINED_AE_PATH)
        
        style_encoder.load_state_dict(ckpt['s'])
        content_encoder.load_state_dict(ckpt['c'])
        decoder.load_state_dict(ckpt['d'])

        opt_c.load_state_dict(ckpt['opt_c'])
        opt_s.load_state_dict(ckpt['opt_s'])
        opt_d.load_state_dict(ckpt['opt_d'])
        print('loaded pre-trained AE')
    
    style_encoder.reset_cls()
    style_encoder.final_cls.cuda()
    opt_s_cls = optim.Adam( style_encoder.final_cls.parameters(), lr=2e-4, betas=(0.5, 0.999))


    saved_image_folder, saved_model_folder = make_folders(SAVE_FOLDER, 'AE_'+TRIAL_NAME)
    log_file_path = saved_image_folder+'/../ae_log.txt'
    log_file = open(log_file_path, 'w')
    log_file.close()
    losses_sf_consist = AverageMeter()
    losses_cf_consist = AverageMeter()
    losses_cls = AverageMeter()
    losses_rec_rd = AverageMeter()
    losses_rec_org = AverageMeter()
    losses_rec_grey = AverageMeter()

    percept = PerceptualLoss(model='net-lin', net='vgg', use_gpu=True)

    for iteration in tqdm(range(ITERATION_AE)):
        
        if iteration%( (NBR_CLS*100)//BATCH_SIZE_AE )==0 and iteration>1:
            dataset_ss._next_set()
            dataloader_ss = iter(DataLoader(dataset_ss, BATCH_SIZE_AE, sampler=InfiniteSamplerWrapper(dataset_ss), num_workers=DATALOADER_WORKERS, pin_memory=True))
            style_encoder.reset_cls()
            opt_s_cls = optim.Adam( style_encoder.final_cls.parameters(), lr=2e-4, betas=(0.5, 0.999))
            
            opt_s.param_groups[0]['lr'] = 1e-4
            opt_d.param_groups[0]['lr'] = 1e-4

        
        rgb_img_rd, rgb_img_org, skt_org, skt_bold, skt_erased, skt_erased_bold, img_idx = next(dataloader_ss)
        rgb_img_rd = rgb_img_rd.cuda()
        rgb_img_org = rgb_img_org.cuda()
        img_idx = img_idx.cuda()

        skt_org = F.interpolate( skt_org , size=512 ).cuda()
        skt_bold = F.interpolate( skt_bold , size=512 ).cuda()
        skt_erased = F.interpolate( skt_erased , size=512 ).cuda()
        skt_erased_bold = F.interpolate( skt_erased_bold , size=512 ).cuda()

        style_encoder.zero_grad()
        decoder.zero_grad()
        content_encoder.zero_grad()

        style_vector_rd, pred_cls_rd = style_encoder(rgb_img_rd)
        style_vector_org, pred_cls_org = style_encoder(rgb_img_org)
        
        content_feats = content_encoder(skt_org)
        content_feats_bold = content_encoder(skt_bold)
        content_feats_erased = content_encoder(skt_erased)
        content_feats_eb = content_encoder(skt_erased_bold)
        
        rd = random.randint(0, 3)
        gimg_rd = None
        if rd==0:
            gimg_rd = decoder(content_feats, style_vector_rd)
        elif rd==1:
            gimg_rd = decoder(content_feats_bold, style_vector_rd)
        elif rd==2:
            gimg_rd = decoder(content_feats_erased, style_vector_rd)
        elif rd==3:
            gimg_rd = decoder(content_feats_eb, style_vector_rd)


        loss_cf_consist = loss_for_list_perm(F.mse_loss, content_feats_bold, content_feats) +\
                            loss_for_list_perm(F.mse_loss, content_feats_erased, content_feats) +\
                                loss_for_list_perm(F.mse_loss, content_feats_eb, content_feats)

        loss_sf_consist = 0
        for loss_idx in range(3):
            loss_sf_consist += -F.cosine_similarity(style_vector_rd[loss_idx], style_vector_org[loss_idx].detach()).mean() + \
                                    F.cosine_similarity(style_vector_rd[loss_idx], style_vector_org[loss_idx][torch.randperm(BATCH_SIZE_AE)].detach()).mean()
        
        loss_cls = F.cross_entropy(pred_cls_rd, img_idx) + F.cross_entropy(pred_cls_org, img_idx)
        loss_rec_rd = F.mse_loss(gimg_rd, rgb_img_org)
        if DATA_NAME != 'shoe':
            loss_rec_rd += percept( F.adaptive_avg_pool2d(gimg_rd, output_size=256), F.adaptive_avg_pool2d(rgb_img_org, output_size=256)).sum()                
        else:
            loss_rec_rd += F.l1_loss(gimg_rd, rgb_img_org)
        
        loss_total = loss_cls + loss_sf_consist + loss_rec_rd + loss_cf_consist 
        loss_total.backward()

        opt_s.step()
        opt_s_cls.step()
        opt_c.step()
        opt_d.step()
        
        rgb_img, skt_img_1, skt_img_2, skt_img_3 = next(dataloader)
            
        rgb_img = rgb_img.cuda()

        rd = random.randint(0, 3) 
        if rd == 0:
            skt_img = skt_img_1
        elif rd == 1:
            skt_img = skt_img_2
        else:
            skt_img = skt_img_3

        skt_img = F.interpolate(skt_img, size=512).cuda()

        style_encoder.zero_grad()
        decoder.zero_grad()
        content_encoder.zero_grad()

        style_vector, _ = style_encoder(rgb_img)
        content_feats = content_encoder(skt_img)
        gimg = decoder(content_feats, style_vector)

        loss_rec_org = F.mse_loss(gimg, rgb_img)
        if DATA_NAME != 'shoe':
            loss_rec_org += percept( F.adaptive_avg_pool2d(gimg, output_size=256), 
                                F.adaptive_avg_pool2d(rgb_img, output_size=256)).sum()
            
        loss_rec = loss_rec_org 
        if DATA_NAME == 'shoe':
            perm = true_randperm(BATCH_SIZE_AE)
            gimg_perm = decoder(content_feats, [s[perm] for s in style_vector])
            gimg_grey = gimg_perm.mean(dim=1, keepdim=True)
            real_grey = rgb_img.mean(dim=1, keepdim=True)
            loss_rec_grey = F.mse_loss( gimg_grey , real_grey )
            loss_rec += loss_rec_grey 
        loss_rec.backward()

        opt_s.step()
        opt_d.step()
        opt_c.step()

        losses_cf_consist.update(loss_cf_consist.mean().item(), BATCH_SIZE_AE)
        losses_sf_consist.update(loss_sf_consist.mean().item(), BATCH_SIZE_AE)
        losses_cls.update(loss_cls.mean().item(), BATCH_SIZE_AE)
        losses_rec_rd.update(loss_rec_rd.item(), BATCH_SIZE_AE)
        losses_rec_org.update(loss_rec_org.item(), BATCH_SIZE_AE)
        if DATA_NAME=='shoe':
            losses_rec_grey.update(loss_rec_grey.item(), BATCH_SIZE_AE)


        if iteration%LOG_INTERVAL==0:
            log_msg = 'Train Stage 1: AE: \nrec_rd: %.4f  rec_org: %.4f  cls: %.4f  style_consist: %.4f  content_consist: %.4f  rec_grey: %.4f'%(losses_rec_rd.avg, \
                    losses_rec_org.avg, losses_cls.avg, losses_sf_consist.avg, losses_cf_consist.avg, losses_rec_grey.avg)
            
            print(log_msg)

            if log_file_path is not None:
                log_file = open(log_file_path, 'a')
                log_file.write(log_msg+'\n')
                log_file.close()

            losses_sf_consist.reset()
            losses_cls.reset()
            losses_rec_rd.reset()
            losses_rec_org.reset()
            losses_cf_consist.reset()
            losses_rec_grey.reset()

        if iteration%SAVE_IMAGE_INTERVAL==0:
            vutils.save_image( torch.cat([rgb_img_rd, F.interpolate(skt_org.repeat(1,3,1,1), size=512) , gimg_rd]), '%s/rd_%d.jpg'%(saved_image_folder, iteration), normalize=True, range=(-1,1) )
            if DATA_NAME != 'shoe':
                with torch.no_grad():
                    perm = true_randperm(BATCH_SIZE_AE)
                    gimg_perm = decoder([c for c in content_feats], [s[perm] for s in style_vector])
            vutils.save_image( torch.cat([rgb_img, F.interpolate(skt_img.repeat(1,3,1,1), size=512), gimg, gimg_perm]), '%s/org_%d.jpg'%(saved_image_folder, iteration), normalize=True, range=(-1,1) )

        if iteration%SAVE_MODEL_INTERVAL==0:
            print('Saving history model')
            torch.save( {'s': style_encoder.state_dict(),
                        'd': decoder.state_dict(),
                        'c': content_encoder.state_dict(),
                        'opt_c': opt_c.state_dict(),
                        'opt_s_cls': opt_s_cls.state_dict(),
                        'opt_s': opt_s.state_dict(), 
                        'opt_d': opt_d.state_dict(),
                            }, '%s/%d.pth'%(saved_model_folder, iteration))
    
    torch.save( {'s': style_encoder.state_dict(),
                        'd': decoder.state_dict(),
                        'c': content_encoder.state_dict(),
                        'opt_c': opt_c.state_dict(),
                        'opt_s_cls': opt_s_cls.state_dict(),
                        'opt_s': opt_s.state_dict(), 
                        'opt_d': opt_d.state_dict(),
                            }, '%s/%d.pth'%(saved_model_folder, ITERATION_AE))




In [None]:
!ls

My_drive  sample_data


In [None]:
!mkdir weights

In [None]:
!mkdir weights/v0.0

In [None]:
!mkdir weights/v0.1

In [None]:
train_ae()

621


  cpuset_checked))


25 621
Setting up Perceptual loss...


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


  0%|          | 0.00/528M [00:00<?, ?B/s]

Loading model from: /content/weights/v0.1/vgg.pth
...[net-lin [vgg]] initialized
...Done


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Train Stage 1: AE: 
rec_rd: 6.8075  rec_org: 6.9349  cls: 6.4574  style_consist: -0.2744  content_consist: 17.7025  rec_grey: 0.0000




Saving history model


  4%|▍         | 10/250 [00:28<11:58,  2.99s/it]

Train Stage 1: AE: 
rec_rd: 6.2104  rec_org: 6.1814  cls: 6.3548  style_consist: -1.3483  content_consist: 12.5245  rec_grey: 0.0000


  8%|▊         | 20/250 [01:07<12:46,  3.33s/it]

Train Stage 1: AE: 
rec_rd: 5.5885  rec_org: 5.5994  cls: 5.5251  style_consist: -1.7936  content_consist: 0.5564  rec_grey: 0.0000


 10%|█         | 25/250 [01:31<16:27,  4.39s/it]

Saving history model


 12%|█▏        | 30/250 [01:50<14:53,  4.06s/it]

Train Stage 1: AE: 
rec_rd: 5.2456  rec_org: 5.3962  cls: 3.9595  style_consist: -2.0250  content_consist: 0.0000  rec_grey: 0.0000


 16%|█▌        | 40/250 [02:26<11:37,  3.32s/it]

Train Stage 1: AE: 
rec_rd: 5.0454  rec_org: 5.1442  cls: 2.4294  style_consist: -2.0636  content_consist: 0.0277  rec_grey: 0.0000


 20%|██        | 50/250 [03:09<13:50,  4.15s/it]

Train Stage 1: AE: 
rec_rd: 4.9796  rec_org: 5.2178  cls: 1.4169  style_consist: -2.2404  content_consist: 0.0000  rec_grey: 0.0000
Saving history model


 24%|██▍       | 60/250 [03:46<10:22,  3.28s/it]

Train Stage 1: AE: 
rec_rd: 4.8234  rec_org: 4.9810  cls: 1.0604  style_consist: -2.3526  content_consist: 0.0000  rec_grey: 0.0000


 28%|██▊       | 70/250 [04:29<13:06,  4.37s/it]

Train Stage 1: AE: 
rec_rd: 4.7502  rec_org: 4.8336  cls: 0.8187  style_consist: -2.2954  content_consist: 0.0000  rec_grey: 0.0000


 30%|███       | 75/250 [04:48<11:13,  3.85s/it]

Saving history model


 32%|███▏      | 80/250 [05:04<08:52,  3.13s/it]

Train Stage 1: AE: 
rec_rd: 4.6044  rec_org: 4.8657  cls: 1.2272  style_consist: -2.4151  content_consist: 0.0054  rec_grey: 0.0000


 36%|███▌      | 90/250 [05:30<06:38,  2.49s/it]

Train Stage 1: AE: 
rec_rd: 4.5464  rec_org: 4.6606  cls: 0.8970  style_consist: -2.4180  content_consist: 0.0000  rec_grey: 0.0000


 40%|████      | 100/250 [05:55<06:12,  2.49s/it]

Train Stage 1: AE: 
rec_rd: 4.4639  rec_org: 4.6732  cls: 0.3976  style_consist: -2.4249  content_consist: 0.0000  rec_grey: 0.0000
Saving history model


 44%|████▍     | 110/250 [06:22<05:49,  2.50s/it]

Train Stage 1: AE: 
rec_rd: 4.3206  rec_org: 4.5696  cls: 0.7144  style_consist: -2.2830  content_consist: 0.0000  rec_grey: 0.0000


 48%|████▊     | 120/250 [06:47<05:21,  2.47s/it]

Train Stage 1: AE: 
rec_rd: 4.3936  rec_org: 4.5237  cls: 0.6654  style_consist: -2.4463  content_consist: 0.0290  rec_grey: 0.0000


 50%|█████     | 125/250 [07:01<05:17,  2.54s/it]

Saving history model


 52%|█████▏    | 130/250 [07:14<05:04,  2.54s/it]

Train Stage 1: AE: 
rec_rd: 4.2233  rec_org: 4.4986  cls: 0.6299  style_consist: -2.2808  content_consist: 0.0000  rec_grey: 0.0000


 56%|█████▌    | 140/250 [07:39<04:31,  2.47s/it]

Train Stage 1: AE: 
rec_rd: 4.2685  rec_org: 4.5046  cls: 0.4144  style_consist: -2.3868  content_consist: 0.0655  rec_grey: 0.0000


 60%|██████    | 150/250 [08:05<04:08,  2.48s/it]

Train Stage 1: AE: 
rec_rd: 4.1558  rec_org: 4.4580  cls: 0.4255  style_consist: -2.3294  content_consist: 0.0000  rec_grey: 0.0000
Saving history model


 64%|██████▍   | 160/250 [08:32<03:43,  2.49s/it]

Train Stage 1: AE: 
rec_rd: 4.1014  rec_org: 4.6007  cls: 0.1780  style_consist: -2.3593  content_consist: 0.0000  rec_grey: 0.0000


 68%|██████▊   | 170/250 [08:57<03:17,  2.47s/it]

Train Stage 1: AE: 
rec_rd: 4.1129  rec_org: 4.4137  cls: 0.4441  style_consist: -2.3759  content_consist: 0.0000  rec_grey: 0.0000


 70%|███████   | 175/250 [09:10<03:10,  2.54s/it]

Saving history model


 72%|███████▏  | 180/250 [09:24<02:57,  2.53s/it]

Train Stage 1: AE: 
rec_rd: 4.0884  rec_org: 4.2982  cls: 0.4297  style_consist: -2.3657  content_consist: 0.0000  rec_grey: 0.0000


 76%|███████▌  | 190/250 [09:49<02:28,  2.48s/it]

Train Stage 1: AE: 
rec_rd: 4.0136  rec_org: 4.2589  cls: 0.8495  style_consist: -2.2947  content_consist: 0.0000  rec_grey: 0.0000


 80%|████████  | 200/250 [10:15<02:04,  2.48s/it]

Train Stage 1: AE: 
rec_rd: 4.0515  rec_org: 4.4541  cls: 0.9142  style_consist: -2.3074  content_consist: 0.0583  rec_grey: 0.0000
Saving history model


 84%|████████▍ | 210/250 [10:41<01:39,  2.49s/it]

Train Stage 1: AE: 
rec_rd: 4.0042  rec_org: 4.1851  cls: 0.2444  style_consist: -2.1925  content_consist: 0.0000  rec_grey: 0.0000


 88%|████████▊ | 220/250 [11:07<01:14,  2.47s/it]

Train Stage 1: AE: 
rec_rd: 3.9685  rec_org: 4.4719  cls: 0.2516  style_consist: -2.2200  content_consist: 0.0000  rec_grey: 0.0000


 90%|█████████ | 225/250 [11:20<01:03,  2.53s/it]

Saving history model


 92%|█████████▏| 230/250 [11:33<00:50,  2.53s/it]

Train Stage 1: AE: 
rec_rd: 3.9538  rec_org: 4.1825  cls: 0.1346  style_consist: -2.3445  content_consist: 0.0000  rec_grey: 0.0000


 96%|█████████▌| 240/250 [11:59<00:24,  2.48s/it]

Train Stage 1: AE: 
rec_rd: 3.9313  rec_org: 4.5582  cls: 0.2181  style_consist: -2.3423  content_consist: 0.0000  rec_grey: 0.0000


100%|██████████| 250/250 [12:24<00:00,  2.98s/it]


In [None]:
!python -V

Python 3.7.12


In [None]:
print(torch.__version__)

1.9.0+cu111
