In [1]:
import argparse
from collections import namedtuple
import os
from time import clock
from random import choices
import numpy as np

from PIL import Image
from PIL import ImageFile
from tqdm import tqdm_notebook as tqdm

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
from tensorboardX import SummaryWriter
from torchvision import transforms
from torchvision.utils import save_image
from torch.distributions.normal import Normal
import torchvision.models as models 

<h2> Steps & Ideas</h2>
<ul>
    <li> Use pretrained encoder, train decoder</li>
    <li> But first AdaIn's decoder</li>
</ul>

In [2]:
decoder = nn.Sequential(
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 256, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 128, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 64, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 3, (3, 3)),
)

encoder = nn.Sequential(
    nn.Conv2d(3, 3, (1, 1)),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(3, 64, (3, 3)),
    nn.ReLU(),  # relu1-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),  # relu1-2
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 128, (3, 3)),
    nn.ReLU(),  # relu2-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),  # relu2-2
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 256, (3, 3)),
    nn.ReLU(),  # relu3-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-4
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 512, (3, 3)),
    nn.ReLU(),  # relu4-1, this is the last layer used
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-4
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU()  # relu5-4
)

In [3]:
class Vgg16(torch.nn.Module):
    def __init__(self,requires_grad=False):
        super(Vgg16,self).__init__()
        vgg_features = models.vgg16(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        for x in range(4):
            self.slice1.add_module(str(x),vgg_features[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x),vgg_features[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg_features[x])
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg_features[x])
            
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False
                
    def forward(self, X):
        h = self.slice1(X)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3)
        return out

In [4]:
def gram_matrix(y):
    (b, ch, h, w) = y.size()
    features = y.view(b, ch, w * h)
    features_t = features.transpose(1, 2)
    gram = features.bmm(features_t) / (ch * h * w)
    return gram

def normalize_batch(batch):
    # normalize using imagenet mean and std
    mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
    std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
    batch = batch.div_(255.0)
    return (batch - mean) / std

def styleGram(style_image_path,style_transform,vgg):
    style = load_image(style_image_path, size=512)
    style = style_transform(style)
    style = style.repeat(4, 1, 1, 1).to("cuda")
    print(vgg.training)
    with torch.no_grad():
        features_style = vgg(normalize_batch(style))
    gram_style = [gram_matrix(y) for y in features_style] #style layer no X batch size
    return gram_style

In [5]:
trans = transforms.Compose([
        transforms.Resize(size=(512, 512)),
        transforms.ToTensor()])

In [6]:
device= torch.device('cuda')

decoder.eval()
encoder.eval()

decoder.load_state_dict(torch.load("d:/AdaIn/models/decoder.pth"))
encoder.load_state_dict(torch.load("d:/AdaIn/models/vgg_normalised.pth"))
encoder = nn.Sequential(*list(encoder.children())[:31])
for param in encoder.parameters():
    param.requires_grad = False
for param in decoder.parameters():
    param.requires_grad = False

encoder.to(device)
decoder.to(device);

In [7]:
vgg = Vgg16(requires_grad=False).to(device)

In [8]:
Ic = trans(Image.open("d:/Images/dancing.jpg")).unsqueeze(0).cuda()
Is = trans(Image.open("d:/Images/picasso.jpg")).unsqueeze(0).cuda()
GS = []
with torch.no_grad():
    fc = encoder(Ic)
    Vc = vgg(Ic)
    Vs = vgg(Is)
    for f in Vs:
        GS.append(gram_matrix(f))
fs = Normal(fc.mean(),fc.var()).sample(fc.size()).cuda().requires_grad_(True)

In [9]:
optimizer = optim.Adam([fs.requires_grad_(True)])
mse_loss = torch.nn.MSELoss()
writer = SummaryWriter("d:/visual")

In [None]:
for it in tqdm(range(10001)):
    f = .5*fc+.5*fs 
    I = decoder(f)
    Vi = vgg(I)
    contL = mse_loss(Vc.relu2_2, Vi.relu2_2) / int(np.prod( Vc.relu2_2.size()))
    styleL = 0.
    for i in range(len(GS)):
        gm_y = gram_matrix(Vi[i])
        styleL += mse_loss(gm_y,GS[i])
    
    styleL *= 5e6
    contL *= 10
    
    total_loss = contL + styleL
    writer.add_scalar("loss",total_loss.item(),it)
    writer.add_scalar("Sloss",styleL.item(),it)
    writer.add_scalar("Closs",contL.item(),it)
    optimizer.zero_grad()
    total_loss.backward()
    if it%1000==0:
        print(fs.grad.mean().item(),fs.grad.var().item(),fs.mean().item(),fs.var().item())
        save_image(I[0],f"d:/foo{it//100}.jpg")
    optimizer.step()
    
    

HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))



-1.023144477585447e-06 1.0949904094559315e-08 0.5842747688293457 0.9976884722709656
-4.991998281411725e-08 1.095843381315742e-11 0.5893219709396362 1.0308541059494019


In [None]:
save_image(I[0],"d:/doo.jpg")

In [None]:
save_image(I[0],"d:/foo.jpg")

In [None]:
l.backward()

In [None]:
 Vc.relu2_2.size()

In [None]:
np.prod( Vc.relu2_2.size())