# Image Style Transfer Using Convolutional Neural Networks

This notebook implements the algorithm found in [(Gatys
2016)](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Gatys_Image_Style_Transfer_CVPR_2016_paper.pdf).

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision.transforms as transforms
import torchvision.models as models
import torchvision.utils as utils

from PIL import Image

imsize = 128
loader = transforms.Compose([transforms.Resize(imsize), transforms.ToTensor()])

# Load `content_img` as a torch tensor of size 3 * `imsize` * `imsize`
image = Image.open("./data/images/dancing.jpg")
content_img = loader(image)

# Load `style_img` as a torch tensor of size 3 * `imsize` * `imsize`
image = Image.open("./data/images/mondrian.jpg")
style_img = loader(image)


## Feature extraction with VGG19
The next cell is a CNN based on VGG19 which extracts convolutional
features specified by `modules_indexes`. It is used to compute the
features of the content and style image. It is also used to
reconstruct the target image by backpropagation.

In [2]:
class VGG19Features(nn.Module):
    def __init__(self, modules_indexes):
        super(VGG19Features, self).__init__()

        # VGG19 pretrained model in evaluation mode
        self.vgg19 = models.vgg19(weights=models.VGG19_Weights.DEFAULT).eval()

        # Indexes of layers to remember
        self.modules_indexes = modules_indexes

    def forward(self, input):
        # Define a hardcoded `mean` and `std` of size 3 * 1 * 1
        mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)

        # First center and normalize `input` with `mean` and `std`
        # <answer>
        input_norm = (input - mean) / std
        # </answer>

        # Add a fake mini-batch dimension to `input_norm`
        # <answer>
        input_norm = input_norm.unsqueeze(0)
        # </answer>

        # Install hooks on specified modules to save their features
        features = []
        handles = []
        for module_index in self.modules_indexes:

            def hook(module, input, output):
                # `output` is of size (`batchsize` = 1) * `n_filters`
                # * `imsize` * `imsize`
                features.append(output)

            handle = self.vgg19.features[module_index].register_forward_hook(hook)
            handles.append(handle)

        # Forward propagate `input_norm`. This will trigger the hooks
        # set up above and populate `features`
        self.vgg19(input_norm)

        # Remove hooks
        [handle.remove() for handle in handles]

        # The output of our custom VGG19Features neural network is a
        # list of features of `input`
        return features


The next cell defines the convolutional layers we will use to
capture the style and the content. Look at the paper to see what are
those.

In [3]:

# Indexes of interesting features to extract

# Define `modules_indexes`
# <answer>
modules_indexes = [0, 2, 5, 7, 10]
# </answer>

vgg19 = VGG19Features(modules_indexes)
content_features = [f.detach() for f in vgg19.forward(content_img)]


## Style features as gram matrix of convolutional features

The next cell computes the gram matrix of `input`. We first need to
reshape `input` before computing the gram matrix.

In [4]:
def gram_matrix(input):
    batchsize, n_filters, width, height = input.size()

    # Reshape `input` into `n_filters` * `n_pixels`
    # <answer>
    features = input.view(n_filters, width*height)
    # </answer>

    # Compute the inner products between filters in `G`
    # <answer>
    G = torch.mm(features, features.t())
    # </answer>

    # We `normalize` the values of the gram matrix by dividing by the
    # number of element in each feature maps.
    return G.div(n_filters * width * height)


style_gram_features = [gram_matrix(f.detach()) for f in vgg19.forward(style_img)]

target = content_img.clone().requires_grad_(True)


## Optimizer

Look at the paper to see what is the algorithm they are using.
Remember that we are optimizing on a target image.

In [5]:

# Define `optimizer` to use L-BFGS algorithm to do gradient descent on
# `target`
# <answer>
optimizer = optim.LBFGS([target])
# </answer>


## The algorithm

From the paper, there are two different losses. The style loss and the
content loss.

Define `style_weight` the trade-off parameter between style and
content losses
<answer>
style_weight = 10**6
</answer>

In [6]:

for step in range(500):
    # To keep track of the losses in the closure
    losses = {}

    # Need to use a closure that computes the loss and gradients to allow the
    # optimizer to evaluate repeatedly at different locations
    def closure():
        optimizer.zero_grad()

        # First, forward propagate `target` through our VGG19Features neural
        # network and store its output as `target_features`
        # <answer>
        target_features = vgg19(target)
        # </answer>

        # Define `content_loss` on the first layer only
        # <answer>
        content_loss = torch.sum((target_features[0] - content_features[0])**2)
        # </answer>

        style_loss = 0
        for target_feature, style_gram_feature in zip(target_features, style_gram_features):
            # Compute Gram matrix
            # <answer>
            target_gram_feature = gram_matrix(target_feature)
            # </answer>

            # Add current loss to `style_loss`
            # <answer>
            style_loss += torch.sum((target_gram_feature - style_gram_feature)**2)
            # </answer>

        # Compute combined loss
        # <answer>
        style_weight = 10**6
        loss = content_loss + style_weight * style_loss
        # </answer>

        # Store the losses
        losses["loss"] = loss.item()
        losses["style_loss"] = style_loss.item()
        losses["content_loss"] = content_loss.item()

        # Backward propagation and return loss
        # <answer>
        loss.backward()
        return loss
        # </answer>

    # Gradient step : don't forget to pass the closure to the optimizer
    # <answer>
    optimizer.step(closure)
    # </answer>

    if step % 10 == 0:
        print("step {}:".format(step))
        print(
            "Style Loss: {:4f} Content Loss: {:4f} Overall: {:4f}".format(
                losses["style_loss"], losses["content_loss"], losses["loss"]
            )
        )
        img = target.clone().squeeze()
        img = img.clamp_(0, 1)
        utils.save_image(img, "output-{}.png".format(step))

step 0:
Style Loss: 10.415602 Content Loss: 99735.046875 Overall: 10515337.000000
step 10:
Style Loss: 0.736319 Content Loss: 337936.906250 Overall: 1074256.125000
step 20:
Style Loss: 0.436361 Content Loss: 388684.531250 Overall: 825046.000000
step 30:
Style Loss: 0.363572 Content Loss: 400106.750000 Overall: 763678.562500
step 40:
Style Loss: 0.334233 Content Loss: 403407.687500 Overall: 737640.312500
step 50:
Style Loss: 0.318921 Content Loss: 404121.593750 Overall: 723042.312500
step 60:
Style Loss: 0.308954 Content Loss: 403457.312500 Overall: 712411.625000
step 70:
Style Loss: 0.300849 Content Loss: 403455.500000 Overall: 704304.250000
step 80:
Style Loss: 0.294807 Content Loss: 403440.718750 Overall: 698248.125000
step 90:
Style Loss: 0.289846 Content Loss: 403200.343750 Overall: 693046.375000
step 100:
Style Loss: 0.285253 Content Loss: 403203.031250 Overall: 688456.375000
step 110:
Style Loss: 0.281025 Content Loss: 403289.125000 Overall: 684314.062500
step 120:
Style Loss: 0.

KeyboardInterrupt: 