In [13]:
import torch 
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms #convert image to tensor
import torchvision.models as models  #to use a pre-trained model
from torchvision.utils import save_image #storing the generated image

from PIL import Image #load the image

In [14]:
model = models.vgg19(pretrained=True).features
print(model)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [25]:
class VGG(nn.Module):
  def __init__(self):
    super(VGG, self).__init__()
    self.chosen_features = ['0', '5', '10', '19', '28']
    self.model = models.vgg19(pretrained=True).features[:29]

  def forward(self, x):
    features = []
    for layer_num, layer in enumerate(self.model):
      x = layer(x)
      if str(layer_num) in self.chosen_features:
        features.append(x)
    return features

In [26]:
def load_img(img_name):
  img = Image.open(img_name)
  img = loader(img).unsqueeze(0) #we need to add additional dimension
  return img.to(device)

In [27]:
device = torch.device("cuda" if torch.cuda.is_available else "cpu")
device

device(type='cuda')

In [28]:
img_size = 200

In [29]:
loader = transforms.Compose(
    [
      transforms.Resize((img_size, img_size)),
      transforms.ToTensor()
    ]
)

In [30]:
original_img = load_img('naruto.jpg')
style_img = load_img('style.jpg')

In [31]:
model = VGG().to(device).eval() #freeze weights

In [32]:
generated = original_img.clone().requires_grad_(True) 

In [33]:
#hyperparameters
total_steps = 6000
learning_rate = 0.001
alpha = 1
beta = 0.01
optimizer = optim.Adam([generated], lr= learning_rate) 

In [36]:
for step in range(total_steps):
  generated_features = model(generated)
  original_features = model(original_img)
  style_features = model(style_img)

  style_loss = original_loss = 0

  for gen_feature, original_feature, style_feature in zip(generated_features, original_features, style_features):
    batch_size, channel, height, width = gen_feature.shape
    original_loss +=torch.mean((gen_feature - original_feature)**2)

    #compute gram matrix
    G = gen_feature.view(channel, height*width).mm(gen_feature.view(channel, height*width).t())
    S = style_feature.view(channel, height*width).mm(style_feature.view(channel, height*width).t())

    style_loss+= torch.mean((G-S)**2)

  total_loss = alpha*original_loss + beta*style_loss

  optimizer.zero_grad()
  total_loss.backward()
  optimizer.step()
  if step%200 == 0:
    print(total_loss)
    save_image(generated, "gen.jpg")

tensor(119825.3359, device='cuda:0', grad_fn=<AddBackward0>)
tensor(28667.5918, device='cuda:0', grad_fn=<AddBackward0>)
tensor(8485.2061, device='cuda:0', grad_fn=<AddBackward0>)
tensor(4171.2158, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2627.1440, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2100.1914, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1804.5573, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1569.4468, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1380.1438, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1230.0936, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1107.3485, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1010.5127, device='cuda:0', grad_fn=<AddBackward0>)
tensor(933.8871, device='cuda:0', grad_fn=<AddBackward0>)
tensor(872.4640, device='cuda:0', grad_fn=<AddBackward0>)
tensor(821.5547, device='cuda:0', grad_fn=<AddBackward0>)
tensor(778.1966, device='cuda:0', grad_fn=<AddBackward0>)
tensor(739.0781, device='cuda:0', grad_fn=<AddBackward0>)