In [1]:
import sys
sys.path.append('..')
import torch
import os
import torch.nn as nn
import torchvision
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import random
from models.CycleGAN import *
from datasets.UnalignedDataset import UnalignedDataset
from utils.utils import ImageBuffer, set_requires_grad, tensor_to_image, save_cyclegan_model

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [2]:
class ContentEncoder(nn.Module):
    
    def __init__(self, num_channels, num_blocks):
        super(ContentEncoder, self).__init__()
        
        self.conv1 = ConvNormRelu(in_channels=num_channels, out_channels=64,
                                  kernel_size=7, padding=(3, "zeros"), leaky=False, norm='instance')
        self.conv2 = ConvNormRelu(in_channels=64, out_channels=128, 
                                  kernel_size=4, padding=(1, "zeros"), stride=2, leaky=False, norm='instance')
        self.conv3 = ConvNormRelu(in_channels=128, out_channels=256, 
                                 kernel_size=4, padding=(1, "zeros"), stride=2, leaky=False, norm='instance')
        self.blocks = nn.ModuleList()
        dims = 256
        for _ in range(num_blocks):
            self.blocks.append(ResBlock(in_planes=dims, kernel_size=3, padding=(1, "reflection"), norm="instance"))
    
    def forward(self, inputs):
        out = self.conv1(inputs)
        out = self.conv2(out)
        out = self.conv3(out)
        for block in self.blocks:
            out = block(out)
        return out

In [4]:
class StyleEncoder(nn.Module):
    
    def __init__(self, num_channels, style_dims):
        super(StyleEncoder, self).__init__()
        
        self.conv1 = ConvNormRelu(in_channels=num_channels, out_channels=64,
                                  kernel_size=7, padding=(3, "zeros"), leaky=False, norm='instance')
        
        self.convs = nn.ModuleList()
        dims = 64
        prev_dims = 0
        n_convs = 4
        
        for _ in range(n_convs):
            prev_dims = dims
            dims = min(dims * 2, 256)
            
            self.convs.append(ConvNormRelu(in_channels=prev_dims, out_channels=dims, 
                                  kernel_size=4, padding=(1, "zeros"), stride=2, leaky=False, norm='instance'))
        
        self.conv_fc = nn.Conv2d(dims, style_dims, kernel_size=1, stride=1, padding=0)  
            
    def forward(self, inputs):
        out = self.conv1(inputs)
        for conv in self.convs:
            out = conv(out)
            
        #Fastest version of Global Average Pooling
        out = torch.mean(out.view(out.size(0), out.size(1), -1), dim=2)
        out = self.conv_fc(out)
        return out

In [None]:
#Nvidia implementation
class AdaptiveInstanceNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(AdaptiveInstanceNorm2d, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        # weight and bias are dynamically assigned
        self.weight = None
        self.bias = None
        # just dummy buffers, not used
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def forward(self, x):
        assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!"
        b, c = x.size(0), x.size(1)
        running_mean = self.running_mean.repeat(b)
        running_var = self.running_var.repeat(b)

        # Apply instance norm
        x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])

        out = F.batch_norm(
            x_reshaped, running_mean, running_var, self.weight, self.bias,
            True, self.momentum, self.eps)

        return out.view(b, c, *x.size()[2:])

    def __repr__(self):
        return self.__class__.__name__ + '(' + str(self.num_features) + ')'

In [5]:
class Decoder(nn.Module):
    
    def __init__(self, in_channels):
        super(Decoder, self).__init__()
        
        self.blocks = nn.ModuleList()
        n_blocks = 4
        for _ in range(n_blocks):
            self.blocks.append(ResBlock(in_planes=in_channels, kernel_size=3,
                                        padding=(1, "reflection"), norm="adain"))
        n_blocks = 2
        self.upsample_blocks = nn.ModuleList()
        prev_dims = 0
        dims = 256
        for _ in range(n_blocks):
            prev_dims = dims
            dims = dims // 2
            self.upsample_blocks.append(nn.Upsample(scale_factor=2))
            self.upsample_blocks.append(ConvNormRelu(in_channels=prev_dims, out_channels=dims,
                                                     kernel_size=5, padding=(2, "reflection"), stride=1, norm="ln"))
            
        self.last_layer = ConvNormRelu(in_channels=dims, out_channels=3, kernel_size=7,
                                      padding=(3, "reflection"), stride=1, norm=None)
        
    def forward(self, inputs):
        out = inputs
        for block in self.blocks:
            out = block(out)
        for block in self.upsample_blocks:
            out = block(out)
        return F.tanh(self.last_layer(out))