In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import numpy as np
from PIL import Image

from contextual_loss import ContextualLoss
from other_losses import TotalVariationLoss
from extractor import Extractor
from utils import show

In [None]:
DEVICE = torch.device('cuda:1')

# Get images

In [None]:
CONTENT = Image.open('dog.jpg').resize((2 * 320, 2 * 224))
STYLE = Image.open('cat.jpg').resize((2 * 320, 2 * 224))
show([CONTENT, STYLE])

# Define the loss

In [None]:
# names of features to use
CONTENT_LAYERS = ['conv4_2']
STYLE_LAYERS = ['conv3_1', 'conv4_1', 'conv5_1']


class Loss(nn.Module):

    def __init__(self, content, style, initial=None):
        """
        Arguments:
            content: an instance of PIL image.
            style: an instance of PIL image.
            initial: an instance of PIL image or None.
        """
        super(Loss, self).__init__()

        # image to start optimization from
        if initial is None:
            mean, std = 0.5, 1e-3
            w, h = content.size
            initial = mean + std * torch.randn(1, 3, h, w)
        else:
            initial = to_tensor(initial)

        # images
        content = to_tensor(content)
        style = to_tensor(style)
        self.x = nn.Parameter(data=initial, requires_grad=True)

        # features
        feature_names = CONTENT_LAYERS + STYLE_LAYERS
        self.vgg = Extractor(feature_names)
        cf = self.vgg(content)
        sf = self.vgg(style)

        # create losses
        self.content = nn.ModuleDict({
            n: ContextualLoss(cf[n], size=5, stride=2, h=0.1) 
            for n in CONTENT_LAYERS
        })
        self.style = ContextualLoss(concat_styles(sf), size=5, stride=2, h=0.2)
        self.tv = TotalVariationLoss()

    def forward(self):

        f = self.vgg(self.x)
        content_loss = torch.tensor(0.0, device=self.x.device)        

        for n, m in self.content.items():
            content_loss += m(f[n])

        style_loss = self.style(concat_styles(f))
        tv_loss = self.tv(self.x)
        return content_loss, style_loss, tv_loss


def to_tensor(x):
    """
    Arguments:
        x: an instance of PIL image.
    Returns:
        a float tensor with shape [3, h, w],
        it represents a RGB image with
        pixel values in [0, 1] range.
    """
    x = np.array(x)
    x = torch.FloatTensor(x)
    return x.permute(2, 0, 1).unsqueeze(0).div(255.0)


def concat_styles(d):
    """
    Arguments:
        d: a dict with float tensors.
    Returns:
        a float tensor with shape [1, c, h, w].
    """
    style_features = [d[n] for n in STYLE_LAYERS]
    h, w = max(style_features, key=lambda f: f.size(2) * f.size(3)).size()[2:]

    result = []
    for f in style_features:
        result.append(F.interpolate(f, size=(h, w), mode='bilinear').div(f.size(1)))

    return torch.cat(result, dim=1)

# Optimize with Adam

In [None]:
objective = Loss(CONTENT, STYLE, initial=CONTENT).to(DEVICE)
params = filter(lambda x: x.requires_grad, objective.parameters())

NUM_STEPS = 4000
optimizer = optim.Adam(params, lr=6e-3)

text = 'i:{0},total:{1:.2f},content:{2:.3f},style:{3:.6f},tv:{4:.4f}'
for i in range(NUM_STEPS):
    
    objective.x.data.clamp_(0, 1)
    optimizer.zero_grad()
    
    content_loss, style_loss, tv_loss = objective()
    total_loss = 1.5 * content_loss + 1 * style_loss + 100 * tv_loss
    total_loss.backward()

    print(text.format(i, total_loss.item(), content_loss.item(), style_loss.item(), tv_loss.item()))
    optimizer.step()

In [None]:
result = 255 * objective.x.clamp(0, 1).detach().permute(0, 2, 3, 1)[0].cpu().numpy()
Image.fromarray(result.astype('uint8'))

In [None]:
result = objective.x.detach().permute(0, 2, 3, 1)[0].cpu().numpy()
result = 255*(result - result.min())/(result.max() - result.min())
Image.fromarray(result.astype('uint8'))

# Optimize with L-BFGS


In [None]:
objective = Loss(CONTENT, STYLE, initial=CONTENT).to(DEVICE)
params = filter(lambda x: x.requires_grad, objective.parameters())

optimizer = optim.LBFGS(
    params=params, lr=1.0, max_iter=4000, 
    tolerance_grad=-1, tolerance_change=-1
)

i = [0]
text = 'i:{0},total:{1:.2f},content:{2:.3f},style:{3:.6f},tv:{4:.4f}'
def closure():

    objective.x.data.clamp_(0, 1)
    optimizer.zero_grad()

    content_loss, style_loss, tv_loss = objective()
    total_loss = 1.5 * content_loss + 1 * style_loss + 100 * tv_loss
    total_loss.backward()
    
    i[0] += 1
    print(text.format(i[0], total_loss.item(), content_loss.item(), style_loss.item(), tv_loss.item()))
    return total_loss

optimizer.step(closure)

In [None]:
result = 255 * objective.x.clamp(0, 1).detach().permute(0, 2, 3, 1)[0].cpu().numpy()
Image.fromarray(result.astype('uint8'))

In [None]:
result = objective.x.detach().permute(0, 2, 3, 1)[0].cpu().numpy()
result = 255*(result - result.min())/(result.max() - result.min())
Image.fromarray(result.astype('uint8'))