In [None]:
import torch
import torchvision
from torchvision.models import vgg19
from torchvision import  transforms
from torchvision.transforms import functional as TF
from PIL import Image

In [None]:
class Style_Content_Extractor(torch.nn.Module):
  def __init__(self,content_layer,style_layers):

    super(Style_Content_Extractor, self).__init__()

    # content layers and style layers
    self.content_layer = content_layer
    self.style_layers = style_layers
    
    # maximum number of layers in the model
    max_layer = max(content_layer,max(style_layers))
    features = list(vgg19(pretrained = True).features)[:max_layer+1]

    self.features = torch.nn.ModuleList(features).eval()

    
  def forward(self, x):
    x = TF.normalize(x,mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    x = x.unsqueeze(0)

    # forward propagation
    style = []
    content = None
    for ii,model in enumerate(self.features):
      x = model(x)
      if ii in self.style_layers:
        style.append(x)
      if ii == self.content_layer:
        content = x
    
    return content,style

In [None]:
# selected layers
extractor = Style_Content_Extractor(18,[4,9,11,15,18,24,27]).cuda()
extractor.eval()
extractor.requires_grad_(False)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/checkpoints/vgg19-dcbb9e9d.pth


HBox(children=(FloatProgress(value=0.0, max=574673361.0), HTML(value='')))




Style_Content_Extractor(
  (features): ModuleList(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3),

In [None]:
# method for loading images
def image_loader(image_name):
    image = Image.open(image_name)
    image = TF.resize(image,(500))
    image = TF.to_tensor(image)
    return image


In [None]:
# method for computing styling descriptions as gram matrix
def gram_matrix(style_tensor):
  return torch.einsum('bcij,bdij->bcd', style_tensor, style_tensor)/(style_tensor.shape[0]*style_tensor.shape[1]*style_tensor.shape[2])


In [None]:
# method for computing content loss
def content_loss_fn(content_op,content_input,weight):
  return weight*((content_input - content_op)**2).sum()

In [None]:
# method for total variation loss, used for smoothing
def total_variation_loss(y,variance_weight):
    g variance_weight*(torch.sum(torch.abs(y[:, :, :-1] - y[ :, :, 1:])) + torch.sum(torch.abs(y[ :, :-1, :] - y[ :, 1:, :])))/y.shape[0]

In [None]:
# method for style loss
def style_loss_fn(style_op,style_input,weights):
  loss = 0
  for style_i,style_o in zip(style_op,style_input):
    gram_i = gram_matrix(style_i)
    gram_o = gram_matrix(style_o)
    loss += ((gram_i - gram_o)**2).sum()
  return loss*weights

In [None]:
import seaborn as sns
sns.set(rc={'figure.figsize':(12,7)})

import numpy as np
import matplotlib.pyplot as plt

In [None]:
def get_stylized_image(content,style,extractor,epochs,content_weight,style_weight,total_variation_loss_weight,learning_rate):

  # extracting style description and content description from given images 
  stylized_output = content.clone()
  stylized_output.requires_grad_(True)
  content_input,_ = extractor(content.cuda())
  _,style_s = extractor(style.cuda())


  optimizer = torch.optim.Adam(params = [stylized_output],lr=learning_rate)
  loss = np.zeros((epochs))
  
  # optimizing input image for style
  for e in range(epochs):
    optimizer.zero_grad()
    content_o,style_o = extractor(stylized_output.cuda())

    # loss computation
    content_loss = content_loss_fn(content_o,content_input,content_weight)
    style_loss = style_loss_fn(style_o,style_s,style_weights)
    variance_loss = total_variation_loss(stylized_output.cuda(),total_variation_loss_weight)
    total_loss = style_loss + content_loss + variance_loss

    # backprop
    total_loss.backward()
    optimizer.step()

    loss[e] = total_loss.item()
    
    stylized_output.data.clamp_(0, 1)
    if e%50==0:
      print(f'Epoch {e}: Loss:{total_loss:.2f}. Style loss:{style_loss:.2f}. Variance Loss:{variance_loss:.2f}. Content Loss: {content_loss:.2f}')
      display(TF.to_pil_image(stylized_output))
    
  ax =  sns.lineplot(data = loss)
  ax.set(ylabel = 'Training Loss')
  plt.show()
  return stylized_output




In [None]:
content = image_loader('/content/20170110004814_IMG_5719.JPG')
style = image_loader('/content/style_4.jpg')

In [None]:
epochs = 2000
content_weight = 0.000001
style_weight = 0.1
variance_weight = 0.00005


In [None]:
output = get_stylized_image(content,
                   style,
                   extractor,
                   epochs,
                   content_weight,
                   style_weight,
                   variance_weight,
                   learning_rate=0.004)

In [None]:
content = image_loader('/content/noisy_mountains.jpg')
output = get_stylized_image(content,
                   style,
                   extractor,
                   epochs,
                   content_weight,
                   style_weight,
                   variance_weight,
                   learning_rate=0.004)


In [None]:
content = image_loader('/content/Scattered_billiards_balls.jpg')
output = get_stylized_image(content,
                   style,
                   extractor,
                   2500,
                   content_weight,
                   style_weight,
                   variance_weight,
                   learning_rate=0.004)


In [None]:
content = image_loader('/content/IMG20200413154412.jpg')
output = get_stylized_image(content,
                   style,
                   extractor,
                   2500,
                   content_weight,
                   style_weight,
                   variance_weight,
                   learning_rate=0.004)
