In [1]:
import sys
sys.path.append('..')
import torch
import os
import torch.nn as nn
import torchvision
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import random
from models.CycleGAN import *
from datasets.UnalignedDataset import UnalignedDataset
from utils.utils import ImageBuffer, set_requires_grad, tensor_to_image, save_cyclegan_model, get_activation, get_norm_module

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [2]:
class ContentEncoder(nn.Module):
    
    def __init__(self, num_channels, num_blocks, content_dim):
        super(ContentEncoder, self).__init__()
        
        self.conv1 = ConvNormRelu(in_channels=num_channels, out_channels=64,
                                  kernel_size=7, padding=(3, "zeros"), leaky=False, norm='instance')
        self.conv2 = ConvNormRelu(in_channels=64, out_channels=128, 
                                  kernel_size=4, padding=(1, "zeros"), stride=2, leaky=False, norm='instance')
        self.conv3 = ConvNormRelu(in_channels=128, out_channels=256, 
                                 kernel_size=4, padding=(1, "zeros"), stride=2, leaky=False, norm='instance')
        self.blocks = nn.ModuleList()
        for _ in range(num_blocks):
            self.blocks.append(ResBlock(in_planes=content_dim, kernel_size=3, padding=(1, "reflection"), norm="instance"))
    
    def forward(self, inputs):
        out = self.conv1(inputs)
        out = self.conv2(out)
        out = self.conv3(out)
        for block in self.blocks:
            out = block(out)
        return out

In [3]:
class StyleEncoder(nn.Module):
    
    def __init__(self, num_channels, style_dims):
        super(StyleEncoder, self).__init__()
        
        self.conv1 = ConvNormRelu(in_channels=num_channels, out_channels=64,
                                  kernel_size=7, padding=(3, "zeros"), leaky=False, norm='instance')
        
        self.convs = nn.ModuleList()
        dims = 64
        prev_dims = 0
        n_convs = 4
        
        for _ in range(n_convs):
            prev_dims = dims
            dims = min(dims * 2, 256)
            
            self.convs.append(ConvNormRelu(in_channels=prev_dims, out_channels=dims, 
                                  kernel_size=4, padding=(1, "zeros"), stride=2, leaky=False, norm='instance'))
        
        self.conv_fc = nn.Conv2d(dims, style_dims, kernel_size=1, stride=1, padding=0)  
            
    def forward(self, inputs):
        out = self.conv1(inputs)
        for conv in self.convs:
            out = conv(out)
            
        #Fastest version of Global Average Pooling
        out = torch.mean(out.view(out.size(0), out.size(1), -1), dim=2)
        out = self.conv_fc(out)
        return out

In [4]:
#Nvidia implementation
class AdaptiveInstanceNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(AdaptiveInstanceNorm2d, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        # weight and bias are dynamically assigned
        self.weight = None
        self.bias = None
        # just dummy buffers, not used
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def forward(self, x):
        assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!"
        b, c = x.size(0), x.size(1)
        running_mean = self.running_mean.repeat(b)
        running_var = self.running_var.repeat(b)

        # Apply instance norm
        x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])

        out = F.batch_norm(
            x_reshaped, running_mean, running_var, self.weight, self.bias,
            True, self.momentum, self.eps)

        return out.view(b, c, *x.size()[2:])

    def __repr__(self):
        return self.__class__.__name__ + '(' + str(self.num_features) + ')'

In [5]:
class Decoder(nn.Module):
    
    def __init__(self, in_channels):
        super(Decoder, self).__init__()
        
        self.blocks = nn.ModuleList()
        n_blocks = 4
        for _ in range(n_blocks):
            self.blocks.append(ResBlock(in_planes=in_channels, kernel_size=3,
                                        padding=(1, "reflection"), norm="adain"))
        n_blocks = 2
        self.upsample_blocks = nn.ModuleList()
        prev_dims = 0
        dims = 256
        for _ in range(n_blocks):
            prev_dims = dims
            dims = dims // 2
            self.upsample_blocks.append(nn.Upsample(scale_factor=2))
            self.upsample_blocks.append(ConvNormRelu(in_channels=prev_dims, out_channels=dims,
                                                     kernel_size=5, padding=(2, "reflection"), stride=1, norm="ln"))
            
        self.last_layer = ConvNormRelu(in_channels=dims, out_channels=3, kernel_size=7,
                                      padding=(3, "reflection"), stride=1, norm=None)
        
    def forward(self, inputs):
        out = inputs
        for block in self.blocks:
            out = block(out)
        for block in self.upsample_blocks:
            out = block(out)
        return F.tanh(self.last_layer(out))

In [6]:
class LinearNormAct(nn.Module):
    
    def __init__(self, in_channels, out_channels, norm='batch', activation="relu"):
        super(LinearNormAct, self).__init__()
        
        self.fc = nn.Linear(in_channels, out_channels)
        self.norm = get_norm_module(norm)(out_channels)
        self.leaky = leaky
        self.activation = get_activation(name)
    
    def forward(self, inputs):
        out = self.fc(inputs)
        if self.norm is not None:
            out = self.norm(out)
        if self.activation:
            return self.activation(out)
        else:
            return out

In [7]:
class MLP(nn.Module):
    
    def __init__(self, in_channels, out_channels, hidden_dim, num_blocks):
        
        self.fc1 = LinearNormAct(in_channels=in_channels, out_channels=hidden_dim, norm="none")
        
        self.blocks = nn.ModuleList()
        for _ in range(num_blocks - 2):
            self.blocks.append(LinearNormAct(in_channels=hidden_dim, out_channels=hidden_dim, norm="none"))
        
        self.last_fc = LinearNormAct(in_channels=hidden_dim, out_channels=out_channels, norm="none", activation="none")
        
    def forward(self, inputs):
        out = self.fc1(inputs)
        for block in self.blocks:
            out = block(out)
        return self.last_fc(out)
    

In [9]:
class MUnitAutoencoder(nn.Module):
    
    def __init__(self, in_channels, mlp_hidden_dim, mlp_num_blocks, enc_style_dims, enc_cont_num_blocks, enc_cont_dim=256):
        super(MUnitAutoencoder, self).__init__()
        
        
        self.enc_cont = ContentEncoder(num_channels=in_channels, num_blocks=enc_cont_num_blocks, content_dim=enc_cont_dim)
        self.enc_style = StyleEncoder(num_channels=in_channels, style_dims=enc_style_dims)
        
        self.decoder = Decoder(in_channels=enc_cont_dim)
        self.mlp = MLP(in_channels=enc_style_dims, out_channels=get_num_adain_params,
                       hidden_dim=mlp_hidden_dim, num_blocks=mlp_num_blocks)
        
    
    
    
    def encode(self, inputs):
        enc_cont = self.enc_cont(inputs)
        enc_style = self.enc_style(inputs)
        return enc_cont, enc_style
    
    def decode(self, enc_cont, enc_style):
        features = self.mlp(enc_style)
        assign_adain_params(features, self.decoder)
        return self.decoder(enc_cont)
    
    def forward(self, inputs):
        enc_cont, enc_style = self.encode(inputs)
        rec_inputs = self.decode(enc_cont, enc_style)
        return rec_inputs
    
    #Nvidia
    def assign_adain_params(self, adain_params, model):
        # assign the adain_params to the AdaIN layers in model
        for m in model.modules():
            if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
                mean = adain_params[:, :m.num_features]
                std = adain_params[:, m.num_features:2*m.num_features]
                m.bias = mean.contiguous().view(-1)
                m.weight = std.contiguous().view(-1)
                if adain_params.size(1) > 2*m.num_features:
                    adain_params = adain_params[:, 2*m.num_features:]
    
    #Nvidia
    def get_num_adain_params(self, model):
        # return the number of AdaIN parameters needed by the model
        num_adain_params = 0
        for m in model.modules():
            if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
                num_adain_params += 2*m.num_features
        return num_adain_params

In [None]:
class MSDiscriminator(nn.Module):
    
    def __init__(self, in_channels, num_scales):
        super(MSDiscriminator, self).__init__()
        
        self.discrs = nn.ModuleList()
        self.in_channels = in_channels
        self.num_scales = num_scales
        for _ in range(self.num_scales):
            self.discrs.append(self.create_discr(self.in_channels))
        
        self.downsample = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False)
        
        
    
    def create_discr(self, in_channels):
        prev_dims = 0
        dims = 64
        self.discr = []
        n_blocks = 3
        
        self.model += [ConvNormRelu(in_channels=in_channels, out_channels=dims, 
                                  kernel_size=4, padding=(1, "zeros"), stride=2, leaky=True, norm='none')]
        
        for _ in range(n_blocks):
            prev_dims = dims
            dims = dims * 2
            self.model += [ConvNormRelu(in_channels=prev_dims, out_channels=dims, 
                                  kernel_size=4, padding=(1, "zeros"), stride=2, leaky=True, norm='instance')]
        
        self.model += [nn.Conv2d(dims, out_channels=1, kernel_size=1, padding=0)]
        return nn.Sequential(*self.model)
        
    
    def forward(self, inputs):
        outputs = []
        for discr in self.discrs:
            outputs.append(discr(inputs))
            inputs = self.downsample(inputs)
        return outputs
    
    def discr_loss(self, real, fake):
        outputs_real = self.forward(real)
        outputs_fake = self.forward(fake)
        loss_fake = 0
        loss_real = 0
        for i in range(self.num_scales):
            loss_fake += calc_mse_loss(outputs_fake[i], 1)
            loss_real += calc_mse_loss(outputs_real[i], 0)
        
        loss = (loss_fake + loss_real) * (1/2)
        #loss.backward()
        
        return loss
    
    
    def gen_loss(self, fake):
        loss = 0
        outputs_fake = self.forward(fake)
        for i in range(self.num_scales):
            loss += calc_mse_loss(outputs_fake[i], 1) 
        return loss

In [6]:
a = torch.randn(1, 2, 3)

In [11]:
print(torch.randn(a.shape).shape)

torch.Size([1, 2, 3])


In [None]:
        cont_A, style_A = G1.encode(inputs)
        rec_img_A = G1.decode(cont_A, style_A)
        
        fake_style = torch.randn(*style.shape)
        
        cont_B, fake_style_B = G2.decode(cont_A, fake_style)
        
        
        img_rec_loss = F.l1_loss(rec_image, inputs)
        
        cont_loss = F.l1_loss(cont_B, cont_A)
        
        style_loss = F.l1_loss(fake_style_B, fake_style)
        
        adv_loss = 