<a href="https://colab.research.google.com/github/antonrufino/colab-notebooks/blob/main/Neural_Style_Transfer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

An implementation of [Image Style Transfer](https://openaccess.thecvf.com/content_cvpr_2016/html/Gatys_Image_Style_Transfer_CVPR_2016_paper.html) in PyTorch. The implementation is based on the [tutorial](https://pytorch.org/tutorials/advanced/neural_style_tutorial.html) provided by PyTorch.

# Libraries

In [None]:
from google.colab import files

import torch
from torch import nn
import torch.nn.functional as F
from torch import optim

from torchvision import transforms
from torchvision.models import vgg19

from PIL import Image
import matplotlib.pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# Parameters

In [None]:
EPOCHS = 300
NUM_CHANNELS = 3
IMG_HEIGHT = 800
IMG_WIDTH = 800
ALPHA = 1e-3
BETA = 1

# Input

Use only square shaped photos i.e. photos with the same width and height.

In [None]:
def load_img():
  upload_dict = files.upload()
  img_path = list(upload_dict.keys())[0]
  return Image.open(img_path).convert('RGB')

## Style Image

In [None]:
#style_img = load_img()
style_img = Image.open('/path/to/style/img')

## Content Image

In [None]:
#content_img = load_img()
content_img = Image.open('/path/to/style/img')

# Preprocessing

In [None]:
pipeline = transforms.Compose([
  transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
  transforms.ToTensor()
])

style_tensor = pipeline(style_img).unsqueeze(0)
content_tensor = pipeline(content_img).unsqueeze(0)

print(style_tensor)
print(content_tensor)

# Loss Functions

## Style Loss

In [None]:
class StyleLoss(nn.Module):
  def __init__(self):
    super(StyleLoss, self).__init__()

  def gram_matrix(self, tensor):
    n, c, h, w = tensor.shape
    tensor = tensor.view(n * c, h * w)

    return torch.mm(tensor, tensor.t())

  def forward(self, input, target):
    n, c, h, w = input.shape

    input = self.gram_matrix(input)
    target = self.gram_matrix(target)
    return F.mse_loss(input, target) / 4.0

## Content Loss

In [None]:
class ContentLoss(nn.Module):
  def __init__(self):
    super(ContentLoss, self).__init__()

  def forward(self, input, target):
    return F.mse_loss(input, target)

# Model

In [None]:
vgg_model = vgg19(pretrained=True)
print(vgg_model)

class NeuralStyleTransfer(nn.Module):
  def __init__(self, vgg_model):
    super(NeuralStyleTransfer, self).__init__()

    feature_extractor = vgg_model.features

    # Replace max pooling with avg pooling

    feature_extractor[4] = nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=False)
    feature_extractor[9] = nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=False)
    feature_extractor[18] = nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=False)
    feature_extractor[27] = nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=False)

    # Isolate portions of network that produce needed activations

    self.conv1_1 = feature_extractor[:2]
    self.conv2_1 = feature_extractor[2:7]
    self.conv3_1 = feature_extractor[7:12]
    self.conv4_1 = feature_extractor[12:21]
    self.conv5_1 = feature_extractor[21:30]

  def forward(self, x):
    a = self.conv1_1(x)
    b = self.conv2_1(a)
    c = self.conv3_1(b)
    d = self.conv4_1(c)
    e = self.conv5_1(d)

    return [a, b, c, d, e]

nst_model = NeuralStyleTransfer(vgg_model).to(device)
print(nst_model)

# Training

In [None]:
with torch.no_grad():
  style_tensor = style_tensor.to(device)
  content_tensor = content_tensor.to(device)

  style_activations = nst_model(style_tensor)
  content_activations = nst_model(content_tensor)

white_noise_img = content_tensor.clone() #torch.rand(1, NUM_CHANNELS, IMG_HEIGHT, IMG_WIDTH)
white_noise_img = white_noise_img.to(device)
white_noise_img = white_noise_img.requires_grad_()

style_criterion = StyleLoss()
content_criterion = ContentLoss()
optimizer = optim.Adam([white_noise_img], lr=3e-2)

for i in range(EPOCHS):
  white_noise_img.data.clamp_(0, 1)
  optimizer.zero_grad()

  activations = nst_model(white_noise_img)
  
  style_loss = 0.0
  for j in range(len(activations)):
    style_loss += style_criterion(activations[j], style_activations[j])
  
  content_loss = content_criterion(activations[3], content_activations[3])

  loss = ALPHA * content_loss + BETA * style_loss
  loss.backward()
  optimizer.step()

  if ((i+1) % 50 == 0):
    print(loss.item())

white_noise_img = white_noise_img.cpu()
white_noise_img.clamp_(0, 1)
white_noise_img.squeeze_(0)

convert_to_img = transforms.ToPILImage()
final_img = convert_to_img(white_noise_img)
final_img