In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory
import matplotlib.pyplot as plt
%matplotlib inline
import torch
import os
from PIL import Image 
from torchvision.transforms.functional import to_pil_image
from torchvision.transforms import Resize ,Compose,Normalize,ToTensor
import torch.nn.functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'

path2data = '../input/styledata/'
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
path2content = os.path.join(path2data,'content.jpg')
path2style = os.path.join(path2data,'style.jpeg')
content_img = Image.open(path2content)
style_img  = Image.open(path2style)

h, w = 256, 384
VGG19_mean = torch.tensor([0.485, 0.456, 0.406])
VGG19_std  = torch.tensor([0.229, 0.224, 0.225])
image_transform = Compose([
    ToTensor(),
    Resize(size=(h,w)),
    Normalize(VGG19_mean,VGG19_std),
])
style_img = image_transform(style_img)
content_img = image_transform(content_img)

In [None]:
def tensorToPil(img):
    image = img.clone().detach()
    image *= VGG19_std.view(3,1,1)
    image += VGG19_mean.view(3,1,1)
    image = image.clamp(0,1)
    return to_pil_image(image)
for i,title,img in zip(range(1,3,1),['style','content'],[style_img,content_img]):
    ax = plt.subplot(1,2,i)
    ax.imshow(tensorToPil(img))
    ax.set_title(title)
    

In [None]:
import torchvision.models as models
model = models.vgg19(pretrained=True).to(device)
model_vgg = model.features
for parameters in model_vgg.parameters():
    parameters.requires_grad_(False)

In [None]:
def get_features(x,model,layers):
    features = {}
    for name,layer in enumerate(model.children()):
        x = layer(x)
        if name in layers:
            features[int(name)] = x
    return features
dic=get_features(content_img.unsqueeze(0),model_vgg,[0,5,11,19,28])
print(type(dic))
for x in dic:
    print(dic[x].shape)

In [None]:
def gram_matrix(x):
    n,c,h,w = x.size()
    x = x.view(c,h*w)
    x = torch.mm(x,x.t())
    return x
gram_matrix(style_img.unsqueeze(0))

In [None]:
def content_loss(pred_features,target_features,layer):
   
    pred = gram_matrix(pred_features[layer])
    target = gram_matrix(target_features[layer])
    loss = F.mse_loss(pred,target)
    return loss
a = get_features(style_img.unsqueeze(0),model_vgg,layers = [0,5,10,19,21,28])


In [None]:
def style_loss(pred_f,target_f,layers_dic):
    loss = 0
    for layer,weight in zip(layers_dic,[0.75,0.5,.25,.25,.25]):
        n,c,h,w = pred_f[layer].shape
        pred = gram_matrix(pred_f[layer])
        target = gram_matrix(target_f[layer])
        loss += weight*(F.mse_loss(pred,target)/(n*c*h*w))
    return loss

In [None]:
content_features = get_features(content_img.unsqueeze(0),model_vgg,layers_)
for key in content_features.keys():
    print(content_features[key].shape)

In [None]:
from torch.optim import Adam
input_tensor = content_img.clone().requires_grad_(True)
optim = Adam([input_tensor], lr=0.01)
layers_ = [0,5,10,19,21,28]
number_of_epoch = 300
for i in range(number_of_epoch):
    optim.zero_grad()
    content_features = get_features(content_img.unsqueeze(0),model_vgg,layers_)
    style_features = get_features(style_img.unsqueeze(0),model_vgg,layers_)
    input_features = get_features(input_tensor.unsqueeze(0),model_vgg,layers_)
    
    style_s = style_loss(input_features,style_features,layers_)
    content_s = content_loss(input_features,content_features,28)
    print(content_s)
    print(style_s)
    loss = 1e4*style_s + 1e1*content_s
    loss.backward(retain_graph=True)
    optim.step()

diff = input_tensor-content_img
print(torch.sum(diff))
plt.imshow(tensorToPil(input_tensor))
