In [2]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
from torchvision.utils import save_image

import torchvision.transforms as transforms

In [3]:
device="cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [4]:
model = models.vgg19(pretrained=True).features.to(device)



In [5]:
print(model)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [6]:
# ['0','5','10','19','28'] are the interested layers and not interested in layers after 28

In [7]:
class VGG(nn.Module):
  def __init__(self):
    super(VGG,self).__init__()
    self.choosen_features=['0','5','10','19','28']
    self.model=models.vgg19(pretrained=True).features[:29]
  def forward(self,x):
    features=[]
    for layer_num,layer in enumerate(self.model):
      x=layer(x)
      if str(layer_num) in self.choosen_features:
        features.append(x)
    return(features)

In [8]:
model=VGG().to(device).eval()

In [9]:
IMAGE_SIZE=512
loader=transforms.Compose([
    transforms.Resize((IMAGE_SIZE,IMAGE_SIZE)),
    transforms.ToTensor()
])

In [10]:
def load_image(image_path):
  image=Image.open(image_path).convert('RGB')
  image=loader(image).unsqueeze(0)
  return image.to(device)

In [12]:
original_image=load_image('content.png')
style_image=load_image('style1.jpg')

In [13]:
# generated_image=torch.randn(original_image.shape,device=device,requires_grad=True)
generated_image=original_image.clone().requires_grad_(True) #gives better results

In [14]:
total_steps=2000
learning_rate=0.003
# alpha=1
# beta=90
alpha=1
beta=0.001
optimizer=optim.Adam([generated_image],lr=learning_rate)

In [None]:
for step in range(total_steps):
  generated_features=model(generated_image)
  original_features=model(original_image)
  style_features=model(style_image)
  style_loss = 0
  original_loss = 0
    
  for gen_feature,orig_feature,style_feature in zip(generated_features,original_features,style_features):
    batch_size,channel,height,width=gen_feature.shape
    original_loss+=torch.mean((gen_feature-orig_feature)**2)

    #Gram Matrix
    G=gen_feature.view(channel,height*width).mm(gen_feature.view(channel,height*width).t())
    A=style_feature.view(channel,height*width).mm(style_feature.view(channel,height*width).t())

    style_loss+=torch.mean((G-A)**2)
  total_loss=alpha*original_loss+beta*style_loss
  optimizer.zero_grad()
  total_loss.backward()
  optimizer.step()
  if step%200==0:
    print(total_loss)
  if step==(total_steps-1):
    save_image(generated_image,"generated"+str(step)+".png")

tensor(316617.2500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(6467.0444, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2372.6353, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1357.5688, device='cuda:0', grad_fn=<AddBackward0>)
tensor(954.9714, device='cuda:0', grad_fn=<AddBackward0>)
tensor(736.1094, device='cuda:0', grad_fn=<AddBackward0>)
tensor(595.2604, device='cuda:0', grad_fn=<AddBackward0>)
tensor(496.4590, device='cuda:0', grad_fn=<AddBackward0>)
tensor(429.7458, device='cuda:0', grad_fn=<AddBackward0>)
tensor(371.3081, device='cuda:0', grad_fn=<AddBackward0>)
