In [1]:
import sys
sys.path.append('..')
import torch
import os
import torch.nn as nn
import torchvision
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import random
from models.CycleGAN import *
from datasets.UnalignedDataset import UnalignedDataset
from utils.utils import ImageBuffer, set_requires_grad, tensor_to_image, save_cyclegan_model, get_activation, get_norm_module, Identity

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [2]:
torch.__version__

'1.1.0'

In [9]:
a = torch.rand(1, 3, 256, 256)
torch.mean(a.view(a.size(0), a.size(1), -1), 2).unsqueeze(2).unsqueeze(3)

torch.Size([1, 3, 1, 1])


In [2]:
class AdaLIN(nn.Module):
    
    def __init__(self, num_features, eps=1e-5):
        super(AdaLIN, self).__init__()
        
        self.num_features = num_features
        self.eps = eps
        self.rho = nn.Parameter(torch.empty(1, num_features, 1, 1))
        self.rho.data.fill_(0.9)
    
    def forward(self, inputs, gamma, beta):
        #inputs_in_mean, inputs_in_var = torch.mean(inputs, dim=[2, 3], keep_dim=True), torch.var(inputs, dim=[2, 3], keep_dim=True)
        inputs_in_mean = torch.mean(inputs.view(inputs.size(0), inputs.size(1), -1), 2).unsqueeze(2).unsqueeze(3)
        inputs_in_var = torch.var(inputs.view(inputs.size(0), inputs.size(1), -1), 2).unsqueeze(2).unsqueeze(3)
        inputs_in = (inputs - inputs_in_mean) / torch.sqrt((inputs_in_var + self.eps))
        #inputs_ln_mean, inputs_ln_var = torch.mean(inputs, dim[1, 2, 3], keep_dim=True), torch.var(inputs, dim=[1, 2, 3], keep_dim=True)
        inputs_ln_mean = torch.mean(inputs.view(inputs.size(0), -1), 1).unsqueeze(1).unsqueeze(2).unsqueeze(3)
        inputs_ln_var = torch.var(inputs.view(inputs.size(0), -1), 1).unsqueeze(1).unsqueeze(2).unsqueeze(3)
        inputs_ln = (inputs - inputs_ln_mean) / torch.sqrt((inputs_ln_var + self.eps))
        
        out = self.rho.expand(inputs.shape[0], -1, -1, -1) * inputs_in + \
              ((1 - self.rho).expand(inputs.shape[0], -1, -1, -1) * inputs_ln)
        
        return gamma.unsqueeze(2).unsqueeze(3) * out + beta.unsqueeze(2).unsqueeze(3)

In [3]:
class LIN(nn.Module):
    
    def __init__(self, num_features, eps=1e-5):
        super(LIN, self).__init__()
        
        self.num_features = num_features
        self.eps = eps
        self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
        self.gamma = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
        self.beta = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
        self.rho.data.fill_(0.0)
        self.gamma.data.fill_(1.0)
        self.beta.data.fill_(0.0)
    
    def forward(self, inputs):
        #inputs_in_mean, inputs_in_var = torch.mean(inputs, dim=[2, 3], keep_dim=True), torch.var(inputs, dim=[2, 3], keep_dim=True)
        inputs_in_mean = torch.mean(inputs.view(inputs.size(0), inputs.size(1), -1), 2).unsqueeze(2).unsqueeze(3)
        inputs_in_var = torch.var(inputs.view(inputs.size(0), inputs.size(1), -1), 2).unsqueeze(2).unsqueeze(3)
        inputs_in = (inputs - inputs_in_mean) / torch.sqrt((inputs_in_var + self.eps))
        #inputs_ln_mean, inputs_ln_var = torch.mean(inputs, dim[1, 2, 3], keep_dim=True), torch.var(inputs, dim=[1, 2, 3], keep_dim=True)
        inputs_ln_mean = torch.mean(inputs.view(inputs.size(0), -1), 1).unsqueeze(1).unsqueeze(2).unsqueeze(3)
        inputs_ln_var = torch.var(inputs.view(inputs.size(0), -1), 1).unsqueeze(1).unsqueeze(2).unsqueeze(3)
        inputs_ln = (inputs - inputs_ln_mean) / torch.sqrt((inputs_ln_var + self.eps))
        
        out = self.rho.expand(inputs.shape[0], -1, -1, -1) * inputs_in + \
                ((1 - self.rho).expand(inputs.shape[0], -1, -1, -1) * inputs_ln)
        
        return gamma.expand(inputs.shape[0], -1, -1, -1) * out + beta.expand(inputs.shape[0], -1, -1, -1)

In [4]:
class AdaLINResBlock(nn.Module):
    
    def __init__(self, in_channels, kernel_size, activation, pad_type):
        super(AdaLINResBlock, self).__init__()
        
        if pad_type == "reflection":
            self.pad1 = nn.ReflectionPad2d(kernel_size // 2)
            self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size)
            self.pad2 = nn.ReflectionPad2d(kernel_size // 2)
            self.conv2 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size)
        elif pad_type == "zeros":
            self.pad1 = None
            self.pad2 = None
            self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels,
                                   kernel_size=kernel_size, padding=kernel_size//2)
            self.conv2 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels,
                                   kernel_size=kernel_size, padding=kernel_size//2)
            
        self.norm1 = AdaLIN(in_channels)
        self.norm2 = AdaLIN(in_channels)
        self.act = get_activation(activation)
    
    def forward(self, inputs, gamma, beta):
        out = inputs
        if self.pad1 is not None:
            out = self.pad1(out)
        out = self.conv1(out)
        out = self.act(self.norm1(out, gamma, beta))
        if self.pad2 is not None:
            out = self.pad2(out)
        out = self.conv2(out)
        out = self.norm2(out, gamma, beta)
        return out + inputs
        
    

In [5]:
class UpsampleConvBlock(nn.Module):
    
    def __init__(self, in_channels, kernel_size, activation, norm_type, pad_type):
        super(UpsampleConvBlock, self).__init__()
        
        if pad_type == "reflection":
            self.pad = nn.ReflectionPad2d(kernel_size // 2)
            self.conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels//2,
                              kernel_size=kernel_size)
        elif pad_type == "zeros":
            self.pad = None
            self.conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels//2,
                              kernel_size=kernel_size, padding=kernel_size//2)
        
        self.act = get_activation(activation)
        self.upsample = nn.Upsample(scale_factor=2)
        self.norm = get_norm_module(norm_type)(in_channels//2)
    
    def forward(self, inputs):
        out = self.upsample(inputs)
        if self.pad is not None:
            out = self.pad(out)
        out = self.conv(out)
        out = self.norm(out)
        out = self.act(out)
        return out 

In [6]:
class Generator(nn.Module):
    
    def __init__(self, in_channels, img_size, num_enc_blocks, num_enc_res_blocks, num_dec_upsample_blocks,
                 num_dec_res_blocks, norm_type, pad_type):
        super(Generator, self).__init__()
        
        dims = 64
        self.conv1 = ConvNormRelu(in_channels=in_channels, out_channels=dims, kernel_size=7, padding=(3, pad_type),
                                  norm=norm_type, leaky=False)
        
        self.convs = nn.ModuleList()
        
        for _ in range(num_enc_blocks - 1):
            prev_dims = dims
            dims = min(dims * 2, 256)
            self.convs.append(ConvNormRelu(in_channels=prev_dims, out_channels=dims, kernel_size=3,
                                           padding=(1, pad_type), norm=norm_type, stride=2))
            
        self.res_blocks = nn.ModuleList()
        
        for _ in range(num_enc_res_blocks):
            self.res_blocks.append(ResBlock(in_planes=dims, kernel_size=3, padding=(1, pad_type), norm=norm_type))
            
        
        self.gap_fc = nn.Linear(dims, 1, bias=False)
        self.gmp_fc = nn.Linear(dims, 1, bias=False)
        self.conv2 = nn.Conv2d(in_channels=dims * 2, out_channels=dims, kernel_size=1, stride=1)
        
        
        MLP = [nn.Linear(in_features=dims * (img_size // 2 ** (num_enc_blocks - 1)) ** 2, out_features=dims, bias=False),
               nn.ReLU(True),
               nn.Linear(in_features=dims, out_features=dims, bias=False),
               nn.ReLU(True)]
        
        self.mlp = nn.Sequential(*MLP)
        self.gamma = nn.Linear(in_features=dims, out_features=dims, bias=False)
        self.beta = nn.Linear(in_features=dims, out_features=dims, bias=False)
        
        #Decoder
        self.decoder_res_blocks = nn.ModuleList()
        for _ in range(num_dec_res_blocks):
            self.decoder_res_blocks.append(AdaLINResBlock(in_channels=dims, kernel_size=3,
                                                      activation='relu', pad_type="reflection"))
        
        self.upsample_blocks = nn.ModuleList()
        for _ in range(num_dec_upsample_blocks):
            self.upsample_blocks.append(UpsampleConvBlock(in_channels=dims, kernel_size=3, activation="relu",
                                                         norm_type="lin", pad_type="reflection"))
            dims = dims // 2
        
        self.pad_last_conv = nn.ReflectionPad2d(3)
        self.last_conv = nn.Conv2d(in_channels=dims, out_channels=3, kernel_size=7, stride=1, bias=False)
        
    
    def forward(self, inputs):
        #Extracting features
        out = self.conv1(inputs)
        
        for conv in self.convs:
            out = conv(out)
            
        for res_block in self.res_blocks:
            out = res_block(out)
            
            
        #Global Average and Max Pooling
        gap = F.adaptive_avg_pool2d(out, 1)
        gap_logits = self.gap_fc(gap.view(out.shape[0], -1))
        gap_fc_weigths = list(self.gap_fc.parameters())[0]
        gap = gap_fc_weigths.view(gap_fc_weigths.size(0), gap_fc_weigths.size(1), 1, 1) * out
        
        
        gmp = F.adaptive_max_pool2d(out, 1)
        gmp_logits = self.gmp_fc(gmp.view(out.shape[0], -1))
        gmp_fc_weights = list(self.gmp_fc.parameters())[0]
        gmp = gmp_fc_weights.view(gmp_fc_weights.size(0), gmp_fc_weights.size(1), 1, 1) * out
        
        gmp_gap_logits = torch.cat([gmp_logits, gap_logits], 1)
        
        gmp_gap = torch.cat([gmp, gap], 1)
        
        out = F.relu(self.conv2(gmp_gap))
        cam = out
        
        #Calculating beta and gamma for AdaLIN
        out = self.mlp(out.view(out.size(0), -1))
        
        gamma = self.gamma(out)
        beta = self.beta(out)
        
        out = cam
        for res_block in self.decoder_res_blocks:
            out = res_block(out, gamma, beta)
        
        for upsample_block in self.upsample_blocks:
            out = upsample_block(out)
        
        out = self.pad_last_conv(out)
        return F.tanh(self.last_conv(out))

In [7]:
model = Generator(3, 256, 3, 4, 2, 4, "instance", "reflection").cuda()
inputs = torch.randn(1, 3, 256, 256).cuda()
print(model(inputs).shape)

torch.Size([1, 3, 256, 256])




In [8]:
torch.__version__

'0.4.1.post2'

In [8]:
class PatchGan(nn.Module):
    
    def __init__(self, in_channels, num_downsample, norm_type, pad_type):
        super(PatchGan, self).__init__()
        
        dims = 64
        self.conv1 = ConvNormRelu(in_channels=in_channels, out_channels=dims, kernel_size=4,
                                  padding=(1, pad_type), stride=2, leaky=True, norm="sn")
        
        self.convs = nn.ModuleList()
        for _ in range(num_downsample):
            prev_dims = dims
            dims = dims * 2
            self.convs.append(ConvNormRelu(in_channels=prev_dims, out_channels=dims, kernel_size=4,
                                  padding=(1, pad_type), stride=2, leaky=True, norm="sn"))
        
        prev_dims = dims
        dims = dims * 2
        
        self.conv2 = ConvNormRelu(in_channels=prev_dims, out_channels=dims, kernel_size=4,
                                  padding=(1, pad_type), stride=1, leaky=True, norm="sn")
        
        self.gap_fc = nn.Linear(dims, 1)
        self.gmp_fc = nn.Linear(dims, 1)
        self.conv3 = nn.Conv2d(in_channels=dims * 2, out_channels=dims, kernel_size=1, stride=1, padding=0)
        
        self.pad_last_conv = nn.ReflectionPad2d(1)
        self.last_conv = nn.Conv2d(in_channels=dims, out_channels=1, kernel_size=4, stride=1)
        self.norm_last_conv = get_norm_module("sn")(1)
    
    
    def forward(self, inputs):
        out = self.conv1(inputs)
        for conv in self.convs:
            out = conv(out)
        out = self.conv2(out)
        
        gap = F.adaptive_avg_pool2d(out, 1)
        gap_logits = self.gap_fc(gap.view(gap.size(0), -1))
        gap_fc_weights = list(self.gap_fc.parameters())[0]
        gap = out * gap_fc_weights.view(gap_fc_weights.size(0), gap_fc_weights.size(1), 1, 1)
        

        gmp = F.adaptive_max_pool2d(out, 1)
        gmp_logits = self.gmp_fc(gmp.view(gmp.size(0), -1))
        gmp_fc_weights = list(self.gmp_fc.parameters())[0]
        gmp = out * gmp_fc_weights.view(gmp_fc_weights.size(0), gmp_fc_weights.size(1), 1, 1)
        
        
        gap_gmp_logits = torch.cat([gap, gmp], 1)
        gap_gmp = torch.cat([gap, gmp], 1)
        
        out = F.leaky_relu(self.conv3(gap_gmp), negative_slope=0.2)
        
        out = self.pad_last_conv(out)
        out = self.last_conv(out)
        out = self.norm_last_conv(out)
        
        return out, gap_gmp_logits