# Style Transfer (pretrained model)

#### Import dependencies

In [None]:
#!g1.1 #noqa
import os
import urllib.request

import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from torch import nn
from torch import optim
from torch.optim.lr_scheduler import StepLR
from torchsummary import summary
from torchvision import transforms
from torchvision.models import vgg19_bn, VGG19_BN_Weights
from torchvision.models.feature_extraction import create_feature_extractor

In [None]:
#!g1.1 #noqa
!nvidia-smi

In [None]:
#!g1.1 #noqa
MODEL_WEIGHTS_DIR = './models/weights/'
IMAGE_DIR = './image/'
STYLE_DIR = IMAGE_DIR + 'style/'
CONTENT_DIR = IMAGE_DIR + 'content/'

#### Setting seed and device

In [None]:
#!g1.1 #noqa
random_seed = 10
torch.manual_seed(random_seed)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

### Model

Get pretrained weights and save them localy (to exclude downloading every time)

In [None]:
#!g1.1 #noqa
weights_path = MODEL_WEIGHTS_DIR + 'vgg19_bn_weights.pt'

if not os.path.isfile(weights_path):
    if not os.path.exists(MODEL_WEIGHTS_DIR):
        os.makedirs(MODEL_WEIGHTS_DIR)
    weights = VGG19_BN_Weights.DEFAULT
    model = vgg19_bn(weights = weights).features
    torch.save(model.state_dict(), MODEL_WEIGHTS_DIR + 'vgg19_bn_weights.pt')

Loading convolutional part of the model architecture (without classifier)

In [None]:
#!g1.1 #noqa
model = vgg19_bn(weights=None).features.to(device)

Replace pooling layers with AvgPool (based on the article [How to Get Beautiful Results with Neural Style Transfer](https://towardsdatascience.com/how-to-get-beautiful-results-with-neural-style-transfer-75d0c05d6489))

In [None]:
#!g1.1 #noqa
for i in range(len(model)):
    if model[i].__class__.__name__ == 'MaxPool2d':
        model[i] = nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=False)

In [None]:
#!g1.1 #noqa
model.load_state_dict(torch.load(weights_path, map_location = device))

In [None]:
#!g1.1 #noqa
summary(model, (3, 224, 224), device = device)

### Data preparation

Getting style image and save it locally

In [None]:
#!g1.1 #noqa
urllib.request.urlretrieve('https://path_to_your_img.jpg', STYLE_DIR + 'style.jpg')

In [None]:
#!g1.1 #noqa
def load_image(path: str) -> torch.tensor:
    """Open, resize and normalize image."""
    img = Image.open(path)

    transformation = transforms.Compose([
        transforms.Resize([224, 224]),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    img = transformation(img)
    return img

In [None]:
#!g1.1 #noqa
content = load_image(CONTENT_DIR + 'test1.jpg').to(device)
style = load_image(STYLE_DIR + 'test_2.jpg').to(device)

In [None]:
#!g1.1 #noqa
def conv_to_img(tensor: torch.tensor) -> np.array:
    """Convert tensor back to image."""
    img = tensor.to('cpu').clone().detach()
    img = img.numpy().squeeze()
    img = img.transpose(1, 2, 0)
    img = img * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    img = img.clip(0, 1)
    return img

Let's look at our images

In [None]:
#!g1.1 #noqa
fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (20, 20))
ax1.imshow(conv_to_img(content))
ax1.set_title('Content image')
ax2.imshow(conv_to_img(style))
ax2.set_title('Style image')
plt.show()

 ### Getting intermediate nodes (features) from the given model

In [None]:
#!g1.1 #noqa
features = create_feature_extractor(model, return_nodes = ['2', '9', '12', '22', '29', '42'])

For better understanding what is important for the model at each picked layer let's have a look at some of them.

In [None]:
#!g1.1 #noqa
style_features = features(style.unsqueeze(0).detach())

In [None]:
#!g1.1 #noqa
style_features['2'].shape

In [None]:
#!g1.1 #noqa
fig, ax = plt.subplots(8, 8, figsize=(30, 30))

for i, f in enumerate(style_features['2'].squeeze()):

    ax[i // 8][i % 8].imshow(f.detach().cpu().numpy())

plt.show()

In [None]:
#!g1.1 #noqa
with torch.no_grad():
    content_features = features(content.unsqueeze(0))

In [None]:
#!g1.1 #noqa
fig, ax = plt.subplots(8, 8, figsize=(30, 30))

for i, f in enumerate(content_features['2'].squeeze()[:64]):

    ax[i // 8][i % 8].imshow(f.detach().cpu().numpy())

plt.show()

### Metrics


In [None]:
#!g1.1 #noqa
def gram_matrix(tensor: torch.tensor) -> torch.tensor:
    """Calculate the gram matrix."""
    _, d, h, w = tensor.size()  # first parameter is batch size, we don't need it
    tensor = tensor.view(d, h * w)
    gram = torch.mm(tensor, tensor.t())

    return gram.div(d * h * w)

As style image is not changing during the training process, we can calculate its gram matrix just once and use it during the train.

In [None]:
#!g1.1 #noqa
style_grams = {layer: gram_matrix(style_features[layer]).detach() for layer in style_features}

In [None]:
#!g1.1 #noqa
# Use content image as a target image
target = content.clone().requires_grad_(True).to(device)

# Use random noise as a target image
# target = torch.rand(content.shape, requires_grad=True, device='cuda')

In [None]:
#!g1.1 #noqa
# Sets influence of each style layer to final loss
style_weights = {'2': 1,
                 '9': 0.9,
                 '12': 0.75,
                 '22': 0.2,
                 '29': 0.5,
                 '42': 0.2}

content_weight = 1  # alpha
style_weight = 1e9  # beta

### Setup training process  

In [None]:
#!g1.1 #noqa
optimizer = optim.Adam([target], lr=0.05)
scheduler = StepLR(optimizer, step_size = 1000, gamma = 0.5)
loss_func = nn.L1Loss()
epochs = 2000

In [None]:
#!g1.1 #noqa
for eposh in range(epochs):

    # gets the features from the target image
    target_img_features = features(target.unsqueeze(0))

    # Calculates the content loss
    content_loss = loss_func(target_img_features['2'], content_features['2'])

    # gets the style loss
    style_loss = 0
    for layer in style_weights:
        # gets the target image style representation at that layer
        target_img_feature = target_img_features[layer]

        # calculates gram matrix for target image features at that layer
        target_img_gram = gram_matrix(target_img_feature)

        # gets value of gram matrix for style image features at that layer
        style_gram = style_grams[layer]

        # calculates weighted style loss for that layer
        style_loss += style_weights[layer] * loss_func(target_img_gram, style_gram)

    # calculates the total loss
    loss = content_weight * content_loss + style_weight * style_loss

    # updates target image
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()

    # displays intermediate images and the loss
    if (eposh + 1) % 1000 == 0:
        plt.title('Loss: {:.3f}'.format(loss.item()))
        plt.imshow(conv_to_img(target))
        plt.show()

### Compare the results

In [None]:
#!g1.1 #noqa
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize = (25, 25))

ax1.imshow(conv_to_img(content))
ax1.set_title('Content image')
ax2.imshow(conv_to_img(style))
ax2.set_title('Style image')
ax3.imshow(conv_to_img(target))
ax3.set_title('Target')
plt.show()

In [None]:
#!g1.1 #noqa
