In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np
from PIL import Image

from contextual_loss import ContextualLoss
from other_losses import TotalVariationLoss
from extractor import Extractor

In [None]:
DEVICE = torch.device('cuda:0')

# Get images

In [None]:
content = Image.open('dog.jpg').resize((320, 224))
style = Image.open('cat.jpg').resize((320, 224))

In [None]:
def to_tensor(x):
    x = np.array(x)
    x = torch.FloatTensor(x)
    # convert to the NCHW format and the [0, 1] range
    return x.permute(2, 0, 1).unsqueeze(0)/255.0

# Use contextual loss

In [None]:
class Loss(nn.Module):

    def __init__(self, content, style):
        super(Loss, self).__init__()

        # images
        c = to_tensor(content)
        s = to_tensor(style)
        self.x = nn.Parameter(data=c, requires_grad=True)

        # features
        vgg = Extractor()
        cf = vgg(c)
        sf = vgg(s)

        # names of features to use
        content_layers = ['conv4_1']
        style_layers = ['conv4_1']

        # create losses
        self.content = nn.ModuleDict({
            n: ContextualLoss(cf[n], size=3, stride=1, h=0.1) 
            for n in content_layers
        })
        self.style = nn.ModuleDict({
            n: ContextualLoss(sf[n], size=3, stride=1, h=0.1) 
            for n in style_layers
        })
        self.tv = TotalVariationLoss()
        self.vgg = vgg

    def forward(self):

        f = self.vgg(self.x)
        content_loss = torch.tensor(0.0, device=self.x.device)
        style_loss = torch.tensor(0.0, device=self.x.device)
        tv_loss = self.tv(self.x)
            
        for n, m in self.content.items():
            content_loss += m(f[n])
            
        for n, m in self.style.items():
            style_loss += m(f[n])
    
        return content_loss, style_loss, tv_loss

# Optimize with Adam

In [None]:
objective = Loss(content, style).to(DEVICE)
params = filter(lambda x: x.requires_grad, objective.parameters())

NUM_STEPS = 300
optimizer = optim.Adam(params, lr=3e-3)

text = 'i:{0},total:{1:.2f},content:{2:.3f},style:{3:.6f},tv:{4:.4f}'
for i in range(NUM_STEPS):
    
    objective.x.data.clamp_(0, 1)
    optimizer.zero_grad()
    
    content_loss, style_loss, tv_loss = objective()
    total_loss = content_loss + 2000 * style_loss + 100 * tv_loss
    total_loss.backward()

    print(text.format(i, total_loss.item(), content_loss.item(), style_loss.item(), tv_loss.item()))
    optimizer.step()

In [None]:
result = 255 * objective.x.clamp(0, 1).detach().permute(0, 2, 3, 1)[0].cpu().numpy()
Image.fromarray(result.astype('uint8'))

In [None]:
result = objective.x.detach().permute(0, 2, 3, 1)[0].cpu().numpy()
result = 255*(result - result.min())/(result.max() - result.min())
Image.fromarray(result.astype('uint8'))

# Optimize with L-BFGS


In [None]:
objective = Loss(content, style).to(DEVICE)
params = filter(lambda x: x.requires_grad, objective.parameters())

optimizer = optim.LBFGS(
    params=params, lr=0.1, max_iter=300, 
    tolerance_grad=-1, tolerance_change=-1
)

text = 'total:{0:.2f},content:{1:.3f},style:{2:.6f},tv:{3:.4f}'
def closure():

    objective.x.data.clamp_(0, 1)
    optimizer.zero_grad()

    content_loss, style_loss, tv_loss = objective()
    total_loss = content_loss + 100 * style_loss + 1000 * tv_loss
    total_loss.backward()

    print(text.format(total_loss.item(), content_loss.item(), style_loss.item(), tv_loss.item()))
    return total_loss

optimizer.step(closure)

In [None]:
result = 255 * objective.x.clamp(0, 1).detach().permute(0, 2, 3, 1)[0].cpu().numpy()
Image.fromarray(result.astype('uint8'))

In [None]:
result = objective.x.detach().permute(0, 2, 3, 1)[0].cpu().numpy()
result = 255*(result - result.min())/(result.max() - result.min())
Image.fromarray(result.astype('uint8'))