In [1]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))  # 0 corresponds to the first GPU

True
NVIDIA GeForce RTX 3050 Ti Laptop GPU


In [2]:
%matplotlib inline
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import os
from zipfile import ZipFile
from PIL import Image

torch.manual_seed(1)

<torch._C.Generator at 0x775b743b7990>

In [5]:
'''
Encoder is a pretrained VGG up to relu4_1 as in the original paper (see 6.1 paper)
'''
class VGG_Encoder(torch.nn.Module):
    def __init__(self):
        super(VGG_Encoder, self).__init__()
        pretrained = torchvision.models.vgg19(pretrained=True)
        
        f = torch.nn.Sequential(*list(pretrained.features.children())[:21]).eval()
        
        # Splitting the network so we can get output of different layers
        # TODO: ADD REFLECTION PADDING LAYERS
        self.relu1_1 = torch.nn.Sequential(*f[:2],)
        self.relu2_1 = torch.nn.Sequential(*f[2:5], *f[5:7])
        self.relu3_1 = torch.nn.Sequential(*f[7:10],*f[10:12])
        self.relu4_1 = torch.nn.Sequential(*f[12:14],
                                          *f[14:16],
                                          *f[16:19],
                                           *f[19:21])

    def forward(self, x):
        out_1 = self.relu1_1(x)
        out_2 = self.relu2_1(out_1)
        out_3 = self.relu3_1(out_2)
        result = self.relu4_1(out_3)
        return out_1, out_2, out_3, result

def mean_and_std(x):
    x = x.view(x.shape[0], x.shape[1], -1)
    mean = x.mean(dim=2) + 0.00005
    std = x.var(dim=2).sqrt()
    return mean.view(mean.shape[0], mean.shape[1], 1, 1), std.view(std.shape[0], std.shape[1], 1, 1)

In [9]:
''' decoder is just the second part of an Unet'''
class Decoder(torch.nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.decode = torch.nn.Sequential(
            torch.nn.ReflectionPad2d((1, 1, 1, 1)),
            torch.nn.Conv2d(512, 256, (3, 3)),
            torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2, mode='nearest'),
            torch.nn.ReflectionPad2d((1, 1, 1, 1)),
            torch.nn.Conv2d(256, 256, (3, 3)),
            torch.nn.ReLU(),
            torch.nn.ReflectionPad2d((1, 1, 1, 1)),
            torch.nn.Conv2d(256, 256, (3, 3)),
            torch.nn.ReLU(),
            torch.nn.ReflectionPad2d((1, 1, 1, 1)),
            torch.nn.Conv2d(256, 256, (3, 3)),
            torch.nn.ReLU(),
            torch.nn.ReflectionPad2d((1, 1, 1, 1)),
            torch.nn.Conv2d(256, 128, (3, 3)),
            torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2, mode='nearest'),
            torch.nn.ReflectionPad2d((1, 1, 1, 1)),
            torch.nn.Conv2d(128, 128, (3, 3)),
            torch.nn.ReLU(),
            torch.nn.ReflectionPad2d((1, 1, 1, 1)),
            torch.nn.Conv2d(128, 64, (3, 3)),
            torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2, mode='nearest'),
            torch.nn.ReflectionPad2d((1, 1, 1, 1)),
            torch.nn.Conv2d(64, 64, (3, 3)),
            torch.nn.ReLU(),
            torch.nn.ReflectionPad2d((1, 1, 1, 1)),
            torch.nn.Conv2d(64, 3, (3, 3)),
        )
    def forward(self, x):
        result = self.decode(x)
        return result
"""
decode = Decoder()
img = decode(t)
concat_img((img[:12]).detach().cpu())
"""

'\ndecode = Decoder()\nimg = decode(t)\nconcat_img((img[:12]).detach().cpu())\n'

In [19]:
## try adain before skip connections or without to see if it makes a difference
encoder = VGG_Encoder()
decoder = Decoder()
## generate random tensor of at least 4 dimensions
random_tensor = torch.rand((12, 16, 26, 26))
style_image  = torch.rand((12, 16, 26, 26))
adain = AdaIN()
random_tensor = adain(random_tensor, style_image)
print(random_tensor.shape)



torch.Size([12, 16, 26, 26])




In [20]:
x = encoder(torch.rand(1, 3, 256, 256))
print(x[3].shape)

through_adain = adain(x[3], x[3])

output = decoder(through_adain)
print(output.shape)

torch.Size([1, 512, 32, 32])
torch.Size([1, 3, 256, 256])


In [3]:
## AdaIN implementation
## TODO: see if the output size is the same as input size
class AdaIN(torch.nn.Module):
    def __init__(self):
        super(AdaIN, self).__init__()
        self.IN = torch.nn.InstanceNorm2d(512)
    
    def forward(self, x, y):
        size = x.size()
        
        x = self.IN(x)
        
        #mean_x, std_x = mean_and_std(x)
        mean_y, std_y = mean_and_std(y)
        #x = (x - mean_x.expand(size)) / std_x.expand(size)
        x = x * std_y.expand(size) + mean_y.expand(size)
        return x
""""
print(style.shape)
mean, std = mean_and_std(style)
print(mean.shape)
print(std.shape)
Ada = AdaIN()
t = Ada(vgg(content)[3], vgg(style)[3])
"""


'"\nprint(style.shape)\nmean, std = mean_and_std(style)\nprint(mean.shape)\nprint(std.shape)\nAda = AdaIN()\nt = Ada(vgg(content)[3], vgg(style)[3])\n'

torch.Size([12, 16, 26, 26])




In [15]:
a = torch.nn.Conv2d(5, 20, kernel_size=1)
len(a.weight)

20