In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np
from PIL import Image

from contextual_loss import ContextualLoss
from other_losses import TotalVariationLoss
from extractor import Extractor

In [None]:
DEVICE = torch.device('cuda:0')

# Get images

In [None]:
content = Image.open('dog.jpg').resize((300, 200))
style = Image.open('The_Starry_Night.jpg').resize((600, 500))

In [None]:
def to_tensor(x):
    x = np.array(x)
    x = torch.FloatTensor(x)
    # convert to the NCHW format and the [0, 1] range
    return x.permute(2, 0, 1).unsqueeze(0)/255.0

# Use contextual loss

In [None]:
class Loss(nn.Module):

    def __init__(self, content, style):
        super(Loss, self).__init__()

        # images
        c = to_tensor(content)
        s = to_tensor(style)
        mean, std = 0.5, 1e-3
        self.x = nn.Parameter(data=mean + std * torch.randn(c.size()), requires_grad=True)

        # features
        vgg = Extractor()
        cf = vgg(c)
        sf = vgg(s)

        # names of features to use
        content_layers = ['conv4_1', 'conv5_1']
        style_layers = ['conv4_1', 'conv5_1']

        # create losses
        self.content = nn.ModuleDict({
            n: ContextualLoss(cf[n], size=5, stride=2, h=0.1) 
            for n in content_layers
        })
        self.style = nn.ModuleDict({
            n: ContextualLoss(sf[n], size=5, stride=2, h=0.2) 
            for n in style_layers
        })
        self.tv = TotalVariationLoss()
        self.vgg = vgg

    def forward(self):

        f = self.vgg(self.x)
        content_loss = torch.tensor(0.0, device=self.x.device)
        style_loss = torch.tensor(0.0, device=self.x.device)
        tv_loss = self.tv(self.x)
            
        for n, m in self.content.items():
            content_loss += m(f[n])
            
        for n, m in self.style.items():
            style_loss += m(f[n])
    
        return content_loss, style_loss, tv_loss

# Optimize with Adam

In [None]:
objective = Loss(content, style).to(DEVICE)
params = filter(lambda x: x.requires_grad, objective.parameters())

NUM_STEPS = 2000
optimizer = optim.Adam(params, lr=10.0)

text = 'total:{0:.2f},content:{1:.3f},style:{2:.6f},tv:{3:.4f}'
for i in range(NUM_STEPS):
    
    objective.x.data.clamp_(0, 1)
    optimizer.zero_grad()
    
    content_loss, style_loss, tv_loss = objective()
    total_loss = 100 * content_loss + style_loss + 10000 * tv_loss
    total_loss.backward()

    print(text.format(total_loss.item(), content_loss.item(), style_loss.item(), tv_loss.item()))
    optimizer.step()

In [None]:
result = 255 * objective.x.clamp(0, 1).detach().permute(0, 2, 3, 1)[0].cpu().numpy()
Image.fromarray(result.astype('uint8'))