In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np
from PIL import Image, ImageDraw

from losses import PerceptualLoss, TotalVariationLoss, MarkovRandomFieldLoss, Extractor

In [None]:
DEVICE = torch.device('cuda:0')

In [None]:
FINAL_SIZE = (600, 400)

CONTENT = Image.open('dog.jpg').resize(FINAL_SIZE, Image.LANCZOS)
STYLE = Image.open('cat.jpg')

ANGLES = [-45, 0, 45]
SCALES = [0.8, 1.0, 1.2]

# Augment style

In [None]:
width, height = STYLE.size

STYLES = []
for a in ANGLES:
    for s in SCALES:
        w, h = int(width * s), int(height * s)
        resized = STYLE.resize((w, h), Image.LANCZOS)
        rotated = resized.rotate(a, Image.BICUBIC)
        box = (0.2 * w, 0.2 * h, 0.8 * w, 0.8 * h)
        cropped = rotated.crop(box) if a != 0 else rotated
        STYLES.append(cropped)

width = max(s.size[0] for s in STYLES)
height = sum(s.size[1] for s in STYLES)
background = Image.new('RGB', (width, height), (255, 255, 255))
draw = ImageDraw.Draw(background, 'RGB')

offset = 0
for s in STYLES:
    _, h = s.size
    background.paste(s, (0, offset))
    offset += h
    
background

# Define loss

In [None]:
def to_tensor(x):
    x = np.array(x)
    x = torch.FloatTensor(x)
    # convert to the NCHW format and the [0, 1] range
    return x.permute(2, 0, 1).unsqueeze(0)/255.0


class Loss(nn.Module):

    def __init__(self, content, styles, initial=None):
        """
        Arguments:
            content: an instance of PIL image.
            styles: a list of PIL images.
            initial: an instance of PIL image or None.
        """
        super(Loss, self).__init__()
        
        if initial is None:
            mean, std = 0.5, 1e-3
            w, h = content.size
            initial = mean + std * torch.randn(1, 3, h, w)
        else:
            assert initial.size == content.size
            initial = to_tensor(initial)

        # images
        content = to_tensor(content)
        styles = [to_tensor(s) for s in styles]
        self.x = nn.Parameter(data=initial, requires_grad=True)

        # features
        self.vgg = Extractor()
        cf = self.vgg(content)
        sf = [self.vgg(s) for s in styles]

        # names of features to use
        content_layers = ['relu4_2']
        style_layers = ['relu3_1', 'relu4_1']
        num_styles = len(styles)

        # create losses
        self.content = nn.ModuleDict({
            n: PerceptualLoss(cf[n]) 
            for n in content_layers
        })
        self.style = nn.ModuleDict({
            n: MarkovRandomFieldLoss(
                [sf[i][n] for i in range(num_styles)],
                size=3, stride=1, threshold=1e-2
            ) 
            for n in style_layers
        })
        self.tv = TotalVariationLoss()

    def forward(self):

        f = self.vgg(self.x)
        content_loss = torch.tensor(0.0, device=self.x.device)
        style_loss = torch.tensor(0.0, device=self.x.device)
        tv_loss = self.tv(self.x)
            
        for n, m in self.content.items():
            content_loss += m(f[n])
            
        for n, m in self.style.items():
            style_loss += m(f[n])
    
        return content_loss, style_loss, tv_loss

# Optimize with Adam

In [None]:
def synthesize(content, initial):
    
    objective = Loss(content, styles=STYLES, initial=initial).to(DEVICE)
    params = filter(lambda x: x.requires_grad, objective.parameters())

    NUM_STEPS = 500
    optimizer = optim.Adam(params, lr=1e-2)

    text = 'i:{0},total:{1:.2f},content:{2:.3f},style:{3:.6f},tv:{4:.4f}'
    for i in range(NUM_STEPS):

        objective.x.data.clamp_(0, 1)
        optimizer.zero_grad()

        content_loss, style_loss, tv_loss = objective()
        total_loss = 2 * content_loss + style_loss + 1000 * tv_loss
        total_loss.backward()

        print(text.format(i, total_loss.item(), content_loss.item(), style_loss.item(), tv_loss.item()))
        optimizer.step()
        
    result = 255 * objective.x.clamp(0, 1).detach().permute(0, 2, 3, 1)[0].cpu().numpy()
    return Image.fromarray(result.astype('uint8'))

In [None]:
s = 8
assert s % 2 == 0
num_upsamplings = int(np.log2(s))

w, h = CONTENT.size
x = synthesize(CONTENT.resize((w // s, h // s)), initial=None)

results = [x]
for _ in range(num_upsamplings):

    w, h = x.size
    initial = x.resize((w * 2, h * 2))

    x = synthesize(CONTENT.resize((w * 2, h * 2)), initial)
    results.append(x)

In [None]:
results[0]

In [None]:
results[1]

In [None]:
results[2]

In [None]:
results[3]

In [None]:
result = 255 * objective.x.clamp(0, 1).detach().permute(0, 2, 3, 1)[0].cpu().numpy()
Image.fromarray(result.astype('uint8'))

In [None]:
result = objective.x.detach().permute(0, 2, 3, 1)[0].cpu().numpy()
result = 255*(result - result.min())/(result.max() - result.min())
Image.fromarray(result.astype('uint8'))

# Optimize with L-BFGS


In [None]:
objective = Loss(content, style).to(DEVICE)
params = filter(lambda x: x.requires_grad, objective.parameters())

optimizer = optim.LBFGS(
    params=params, lr=0.1, max_iter=300, 
    tolerance_grad=-1, tolerance_change=-1
)

text = 'total:{0:.2f},content:{1:.3f},style:{2:.6f},tv:{3:.4f}'
def closure():

    objective.x.data.clamp_(0, 1)
    optimizer.zero_grad()

    content_loss, style_loss, tv_loss = objective()
    total_loss = content_loss + 100 * style_loss + 1000 * tv_loss
    total_loss.backward()

    print(text.format(total_loss.item(), content_loss.item(), style_loss.item(), tv_loss.item()))
    return total_loss

optimizer.step(closure)

In [None]:
result = 255 * objective.x.clamp(0, 1).detach().permute(0, 2, 3, 1)[0].cpu().numpy()
Image.fromarray(result.astype('uint8'))

In [None]:
result = objective.x.detach().permute(0, 2, 3, 1)[0].cpu().numpy()
result = 255*(result - result.min())/(result.max() - result.min())
Image.fromarray(result.astype('uint8'))