# Image Style Transfer Using Convolutional Neural Networks

This notebook implements the algorithm found in [(Gatys
2016)](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Gatys_Image_Style_Transfer_CVPR_2016_paper.pdf).

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision.utils as utils

from PIL import Image

# Size of images given to the network. You can reduce it if the learning is too
# slow.
imsize = 200
loader = transforms.Compose([transforms.Resize(imsize), transforms.ToTensor()])

# Load `content_img` as a torch tensor of size 3 * `imsize` * `imsize`
image = Image.open("./data/images/dancing.jpg")
content_img = loader(image)

# Load `style_img` as a torch tensor of size 3 * `imsize` * `imsize`
image = Image.open("./data/images/picasso.jpg")
style_img = loader(image)

## Feature extraction with VGG19

The next cell is a model that extracts convolutional features specified
by `modules_indexes` from a VGG19 pretrained model. It is used to
compute the features of the content and style image. It is also used to
reconstruct the target image by backpropagation.

In [2]:
class VGG19Features(nn.Module):
    def __init__(self, modules_indexes=[]):
        super(VGG19Features, self).__init__()

        # VGG19 pretrained model in evaluation mode
        self.vgg19 = models.vgg19(weights=models.VGG19_Weights.DEFAULT).eval()

        # Indexes of convolutional layers to remember during forward phase
        self.modules_indexes = modules_indexes

    def forward(self, input):
        # Define a hardcoded `mean` and `std` of size 3 * 1 * 1
        mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)

        # First center and normalize `input` with `mean` and `std`
        # <answer>
        input_norm = (input - mean) / std
        # </answer>

        # Add a fake mini-batch dimension to `input_norm`
        # <answer>
        input_norm = input_norm.unsqueeze(0)
        # </answer>

        # Install hooks on specified modules to save their features
        features = []
        handles = []
        for module_index in self.modules_indexes:

            def hook(module, input, output):
                # `output` is of size (`batchsize` = 1) * `n_filters`
                # * `imsize` * `imsize`
                features.append(output)

            handle = self.vgg19.features[module_index].register_forward_hook(hook)
            handles.append(handle)

        # Forward propagate `input_norm`. This will trigger the hooks
        # set up above and populate `features`
        self.vgg19(input_norm)

        # Remove hooks
        [handle.remove() for handle in handles]

        # The output of our custom VGG19Features neural network is a
        # list of features of `input`
        return features

The next cell defines the convolutional layers we will use to capture
the style and the content. From the paper, the layers `conv1_1`,
`conv2_1`, `conv3_1`, `conv4_1`, `conv4_2` and `conv5_1` are used.

In [3]:
# Names of convolutional layers to extract
modules_names = ["conv1_1", "conv2_1", "conv3_1", "conv4_1", "conv4_2", "conv5_1"]

# To have access to them we need to specify their corresponding index in the
# vgg19 network. Print VGG19Features().vgg19.features.

# <answer>
modules_indexes = [0, 5, 10, 19, 21, 28]
# </answer>

# Define out feature extractor model
vgg19_feat = VGG19Features(modules_indexes)

# Extract features of the content image
content_features = [f.detach() for f in vgg19_feat.forward(content_img)]

## Style features as gram matrix of convolutional features

The next cell computes the gram matrix of `input`. We first need to
reshape `input` before computing the gram matrix.

In [4]:
def gram_matrix(input):
    batchsize, n_filters, width, height = input.size()

    # Reshape `input` into `n_filters` * `n_pixels`
    # <answer>
    features = input.view(n_filters, width * height)
    # </answer>

    # Compute the inner products between filters in `G`
    # <answer>
    G = torch.mm(features, features.t())
    # </answer>

    return G


# Extract gram features of the style image
style_gram_features = [gram_matrix(f.detach()) for f in vgg19_feat.forward(style_img)]

## Optimizer

Look at the paper to see what is the algorithm they are using. Remember
that we are optimizing on a target image.

In [5]:
target = content_img.clone().requires_grad_(True)

# Define `optimizer` to use L-BFGS algorithm to do gradient descent on
# `target`
# <answer>
optimizer = optim.LBFGS([target])
# </answer>

## The algorithm

From the paper, there are two different losses. The style loss and the
content loss.

Define `style_weight` the trade-off parameter between style and content
losses $\alpha/\beta$ in the paper). It depends on chosen content and
style images. A good choice for `dancing.jpg` and `picasso.jpg` is
$1e-5$.

In [6]:
# <answer>
content_weight = 1e-5
# </answer>

In [7]:
for step in range(1):
    # To keep track of the losses in the closure
    losses = {}

    # Need to use a closure that computes the loss and gradients to allow the
    # optimizer to evaluate repeatedly at different locations
    def closure():
        optimizer.zero_grad()

        # First, forward propagate `target` through `vgg19_feat` and store its
        # output as `target_features`.
        # <answer>
        target_features = vgg19_feat(target)
        # </answer>

        # Define `content_loss` on layer "conv4_2" (equation (1))
        conv4_2_index = modules_names.index("conv4_2")
        # <answer>
        content_loss = 0.5 * torch.sum(
            (target_features[conv4_2_index] - content_features[conv4_2_index]) ** 2
        )
        # </answer>

        # Define `style_loss` (equations (4) and (5))
        style_loss = 0
        for target_feature, style_gram_feature, module_name in zip(
            target_features, style_gram_features, modules_names
        ):
            if module_name not in [
                "conv1_1",
                "conv2_1",
                "conv3_1",
                "conv4_1",
                "conv5_1",
            ]:
                continue

            # Compute Gram matrix
            # <answer>
            target_gram_feature = gram_matrix(target_feature)
            # </answer>

            # Compute style loss for current feature
            # <answer>
            current_style_loss = (
                1
                / 4
                * torch.sum((target_gram_feature - style_gram_feature) ** 2)
                / target_gram_feature.numel() ** 2
            )
            # </answer>

            # Add current loss to `style_loss`
            # <answer>
            style_loss += 1 / 5 * current_style_loss
            # </answer>

        # Compute combined loss (equation (7))
        # <answer>
        loss = content_weight * content_loss + style_loss
        # </answer>

        # Store the losses
        losses["loss"] = loss.item()
        losses["style_loss"] = style_loss.item()
        losses["content_loss"] = content_loss.item()

        # Backward propagation and return loss
        # <answer>
        loss.backward()
        return loss
        # </answer>

    # Gradient step : don't forget to pass the closure to the optimizer
    # <answer>
    optimizer.step(closure)
    # </answer>

    print("step {}:".format(step))
    print(
        "Style Loss: {:4f} Content Loss: {:4f} Overall: {:4f}".format(
            losses["style_loss"], losses["content_loss"], losses["loss"]
        )
    )
    img = target.clone().squeeze()
    img = img.clamp_(0, 1)
    utils.save_image(img, "output-{:03}.png".format(step))

step 0:
Style Loss: 2.237663 Content Loss: 95853.664062 Overall: 3.196200
