# **Import Packages**

In [None]:
import torch
import torch.optim as optim
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# **Transfer model**

In [None]:
vgg = models.vgg19(pretrained = True).features

for param in vgg.parameters():
  param.requires_grad_(False)

In [None]:
vgg.to(device)

# **Load image**

In [None]:
def load_image(img_path, max_size=400, img_shape=None):

  image = Image.open(img_path).convert('RGB')

  if max(image.size) > max_size:
    size = max_size
  else:
    size = max(image.size)

  if img_shape is not None:
    size = img_shape

  transform = transforms.Compose([
                                  transforms.Resize(size),
                                  transforms.ToTensor(),
                                  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  ])

  image = transform(image).unsqueeze(0)

  return image

In [None]:
content = load_image('../MonaLisa.jpg').to(device)
style = load_image('../StarryNight.jpg', img_shape=content.shape[-2:]).to(device)

# **Conversion of tensor to numpy**

In [None]:
def im_convert(tensor):
  img = tensor.to("cpu").clone().detach()
  img = img.numpy().squeeze()
  img = img.transpose(1,2,0)
  img = img * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5))
  img = img.clip(0, 1)

  return img

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20,10))
ax1.imshow(im_convert(content))
ax1.axis("off")
ax2.imshow(im_convert(style))
ax2.axis("off")

# **Extract features**

In [None]:
def feature_extract(image, model):
  
  layers = {'0': 'conv1_1',
            '5': 'conv2_1', 
            '10': 'conv3_1', 
            '19': 'conv4_1',
            '21': 'conv4_2',  # Content Extraction
            '28': 'conv5_1'}
  
  features = {}

  for name, layer in model._modules.items():
    image = layer(image)
    if name in layers:
      features[layers[name]] = image
  print(image.shape)
  return features

In [None]:
print(content.shape)
content_features = feature_extract(content, vgg)
style_features = feature_extract(style, vgg)

# **Gram Matrix for extracting style**

In [None]:
def gram_matrix(tensor):
    _, d, h, w = tensor.size()
    tensor = tensor.view(d, h * w)
    gram = torch.mm(tensor, tensor.t())
    return gram

In [None]:
style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}

# **Defining weights for loss**

In [None]:
style_weights = {'conv1_1': 1.,
                 'conv2_1': 0.75,
                 'conv3_1': 0.2,
                 'conv4_1': 0.2,
                 'conv5_1': 0.2}

content_weight = 1  # alpha
style_weight = 1e6  # beta

In [None]:
target = content.clone().requires_grad_(True).to(device)

# **Specifications for training**

In [None]:
show_every = 300
optimizer = optim.Adam([target], lr=0.003)
steps = 6000

height, width, channels = im_convert(target).shape
image_array = np.empty(shape=(300, height, width, channels))
capture_frame = steps/300
counter = 0

# **Training**

In [None]:
for ii in range(1, steps+1):
  target_features = feature_extract(target, vgg)
  content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
  style_loss = 0
  
  for layer in style_weights:
    target_feature = target_features[layer]
    target_gram = gram_matrix(target_feature)
    style_gram = style_grams[layer]
    layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
    _, d, h, w = target_feature.shape
    style_loss += layer_style_loss / (d * h * w)
  
  total_loss = content_weight * content_loss + style_weight * style_loss
  
  optimizer.zero_grad()
  total_loss.backward()
  optimizer.step()
  
  if  ii % show_every == 0:
    print('Total loss: ', total_loss.item())
    print('Iteration: ', ii)
    plt.imshow(im_convert(target))
    plt.axis("off")
    plt.show()
    
  if ii % capture_frame == 0:
    image_array[counter] = im_convert(target)
    counter = counter + 1

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 10))
ax1.imshow(im_convert(content))
ax1.axis('off')
ax2.imshow(im_convert(style))
ax2.axis('off')
ax3.imshow(im_convert(target))
ax3.axis('off')

In [None]:
import cv2 

frame_height, frame_width, _ = im_convert(target).shape
vid = cv2.VideoWriter('output.mp4', cv2.VideoWriter_fourcc(*'XVID'), 30, (frame_width, frame_height))

for i in range(0, 300):
  img = image_array[i]
  img = img*255
  img = np.array(img, dtype = np.uint8)
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  vid.write(img)

vid.release()

In [None]:
from google.colab import files
files.download('output.mp4')