In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np
from PIL import Image, ImageDraw

from model import Loss

In [None]:
DEVICE = torch.device('cuda:0')

In [None]:
FINAL_SIZE = (600, 400)

CONTENT = Image.open('dog.jpg').resize(FINAL_SIZE, Image.LANCZOS)
STYLE = Image.open('Vasily Kandinsky Small Worlds I.jpg')
print('size of the style image', STYLE.size)

ANGLES = [-45, 0, 45]
SCALES = [0.5]

# Utils

In [None]:
def show(images):
    """
    Shows a list of images.
    Images can be of different sizes.
    """

    width = max(i.size[0] for i in images)
    height = sum(i.size[1] for i in images)
    background = Image.new('RGB', (width, height), (255, 255, 255))
    draw = ImageDraw.Draw(background, 'RGB')

    offset = 0
    for i in images:
        _, h = i.size
        background.paste(i, (0, offset))
        offset += h

    return background

# Augment style

In [None]:
width, height = STYLE.size

STYLES = []
for a in ANGLES:
    for s in SCALES:
        w, h = int(width * s), int(height * s)
        resized = STYLE.resize((w, h), Image.LANCZOS)
        rotated = resized.rotate(a, Image.BICUBIC)
        box = (0.2 * w, 0.2 * h, 0.8 * w, 0.8 * h)
        cropped = rotated.crop(box)# if a != 0 else rotated
        STYLES.append(cropped)

show(STYLES)

# Optimize with Adam

In [None]:
def synthesize1(content, initial):
    
    objective = Loss(content, styles=STYLES, initial=initial).to(DEVICE)
    params = filter(lambda x: x.requires_grad, objective.parameters())

    NUM_STEPS = 10000
    optimizer = optim.Adam(params, lr=1e-4)

    text = 'i:{0}, total:{1:.2f}, content:{2:.3f}, style:{3:.6f}, tv:{4:.4f}'
    for i in range(1, NUM_STEPS + 1):

        objective.x.data.clamp_(0, 1)
        optimizer.zero_grad()

        content_loss, style_loss, tv_loss = objective()
        total_loss = 2 * content_loss + 1 * style_loss + 6000 * tv_loss
        total_loss.backward()
    
        if i % 100 == 0 or i == 1:
            print(text.format(i, total_loss.item(), content_loss.item(), style_loss.item(), tv_loss.item()))
        optimizer.step()
        
    result = 255 * objective.x.clamp(0, 1).detach().permute(0, 2, 3, 1)[0].cpu().numpy()
    result = Image.fromarray(result.astype('uint8'))

    del objective
    torch.cuda.empty_cache()

    return result

In [None]:
num_upsamplings = 2  # non negative integer

width, height = CONTENT.size
s = 2**num_upsamplings
size = (width // s, height // s)
print('synthesizing image of size', size)
x = synthesize1(CONTENT.resize(size, Image.LANCZOS), initial=CONTENT.resize(size, Image.LANCZOS))

results = [x]
for _ in range(num_upsamplings):

    width, height = x.size
    size = (width * 2, height * 2)
    print('\nsynthesizing image of size', size)

    initial = x.resize(size, Image.LANCZOS)
    x = synthesize1(CONTENT.resize(size, Image.LANCZOS), initial)
    results.append(x)

In [None]:
show(results)

# Optimize with L-BFGS


In [None]:
def synthesize2(content, initial):
    
    objective = Loss(content, styles=STYLES, initial=initial).to(DEVICE)
    params = filter(lambda x: x.requires_grad, objective.parameters())

    NUM_STEPS = 3000
    optimizer = optim.LBFGS(
        params=params, lr=0.1, max_iter=NUM_STEPS, 
        tolerance_grad=-1, tolerance_change=-1
    )
    
    i = [1]
    text = 'i:{0}, total:{1:.2f}, content:{2:.3f}, style:{3:.6f}, tv:{4:.4f}'
    def closure():

        objective.x.data.clamp_(0, 1)
        optimizer.zero_grad()

        content_loss, style_loss, tv_loss = objective()
        total_loss = 1 * content_loss + 2 * style_loss + 7000 * tv_loss
        total_loss.backward()
        
        if i[0] % 100 == 0 or i[0] == 1:
            print(text.format(i[0], total_loss.item(), content_loss.item(), style_loss.item(), tv_loss.item()))
        
        i[0] += 1
        return total_loss

    optimizer.step(closure)
    result = 255 * objective.x.clamp(0, 1).detach().permute(0, 2, 3, 1)[0].cpu().numpy()
    result = Image.fromarray(result.astype('uint8'))

    del objective
    torch.cuda.empty_cache()

    return result

In [None]:
num_upsamplings = 2  # non negative integer

width, height = CONTENT.size
s = 2**num_upsamplings
size = (width // s, height // s)
print('synthesizing image of size', size)
x = synthesize2(CONTENT.resize(size, Image.LANCZOS), initial=CONTENT.resize(size, Image.LANCZOS))

results = [x]
for _ in range(num_upsamplings):

    width, height = x.size
    size = (width * 2, height * 2)
    print('\nsynthesizing image of size', size)

    initial = x.resize(size, Image.LANCZOS)
    x = synthesize2(CONTENT.resize(size, Image.LANCZOS), initial)
    results.append(x)

In [None]:
show(results)