In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load in 

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the "../input/" directory.
# For example, running this (by clicking run or pressing Shift+Enter) will list the files in the input directory

import os
print(os.listdir("../input/pics-test/"))

# Any results you write to the current directory are saved as output.

In [None]:
#Importing libraires
import torch
from torch import nn
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms,models

In [None]:
def load_image(img_path,max_size=400,shape = None):
    image = Image.open(img_path).convert('RGB')
    if max(image.size) > max_size:
        size = max_size
    else:
        size = max(image.size)
    if shape is not None:
        size = shape
        
    pic_transforms = transforms.Compose([transforms.Resize(size),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
    image = pic_transforms(image).unsqueeze(0)
    print(image.shape)
    return image


In [None]:
#Reading content images
content_image = load_image("../input/pics-test/Aditya.jpg")
_,_,h,w = content_image.size()
style_image = load_image("../input/pics-test/Fire.jpg",shape = (h,w))

In [None]:
def im_convert(image):
    image = image.clone().detach().numpy().squeeze()
    image = image.transpose(1,2,0)
    image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5))
    image = image.clip(0, 1)
    return image

In [None]:
#plot both images
fig,(ax1,ax2) = plt.subplots(1,2,figsize = (20,20))
ax1.imshow(im_convert(content_image))
ax1.axis('off')
ax2.imshow(im_convert(style_image))
ax2.axis('off')
plt.show()

In [None]:
#using VGG19 model to create feature maps from each of the image
vgg = models.vgg19(pretrained=True).features

In [None]:
for params in vgg.parameters():
    params.requires_grad_(False)

In [None]:
def generate_features(image,model):
    layers = {'0': 'conv1_1',
              '5': 'conv2_1', 
              '10': 'conv3_1', 
              '19': 'conv4_1',
              '21': 'conv4_2',  # Content Extraction
              '28': 'conv5_1'}
    features = {}
    for name, layer in model._modules.items():
        image = layer(image)
        
        if name in layers:
            features[layers[name]] = image
    return features

content_features = generate_features(content_image,vgg)
style_features = generate_features(style_image,vgg)

In [None]:
def gram_matrix(tensor):
    _,d,h,w = tensor.size()
    tensor = tensor.view(d,h*w)
    gram = torch.mm(tensor, tensor.t())
    return gram

style_grams = {layer : gram_matrix(style_features[layer]) for layer in style_features}

In [None]:
style_weights = {'conv1_1': 1.,
                 'conv2_1': 0.75,
                 'conv3_1': 0.2,
                 'conv4_1': 0.2,
                 'conv5_1': 0.2}
content_weight = 0.5
style_weight = 0.5

In [None]:
steps = 2000
#target = content_image.clone().requires_grad_(True)
target = torch.zeros(content_image.shape,dtype=content_image.dtype).requires_grad_(True)
optimizer = torch.optim.Adam([target],lr = 0.003)

for i in range(1,steps+1):
    target_features = generate_features(target,vgg)
    content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
    target_features = generate_features(target, vgg)
    content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
    style_loss = 0
  
    for layer in style_weights:
        target_feature = target_features[layer]
        target_gram = gram_matrix(target_feature)
        style_gram = style_grams[layer]
        layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
        _, d, h, w = target_feature.shape
        style_loss += layer_style_loss / (d * h * w)

    total_loss = content_weight * content_loss + style_weight * style_loss
  
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    if  i % 50 == 0:
        print('Total loss: ', total_loss.item())
        print('Iteration: ', i)
        plt.imshow(im_convert(target))
        plt.axis("off")
        plt.show()

    
    