In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image

In [2]:
model = models.vgg19(pretrained = True).features

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:07<00:00, 80.1MB/s]


In [None]:
model

In [4]:
class VGG(nn.Module):
  def __init__(self):
    super(VGG, self).__init__()

    self.chosen_features = ['0','5','10','19','28']
    self.model = models.vgg19(pretrained=True).features[:29]

  def forward(self,x):
    features = []

    for layer_num, layer in enumerate(self.model):
      x = layer(x)

      if str(layer_num) in self.chosen_features:
        features.append(x)
    return features

In [6]:
def load_image(image_name):
  image = Image.open(image_name)
  image = loader(image).unsqueeze(0) #Adding additional dimension for batch size
  return image.to(device)

device = torch.device("cuda" if torch.cuda.is_available else "cpu")
image_size = 512

loader = transforms.Compose([transforms.Resize((image_size, image_size)),
                             transforms.ToTensor()])

In [9]:
Content_img = load_image("deer_image.jpg")
Style_img = load_image("art.jpg")
Generated = Content_img.clone().requires_grad_(True)

In [10]:
model = VGG().to(device).eval()



In [11]:
total_steps = 6000
learning_rate = 0.001
alpha = 1
beta = 0.01
optimizer = optim.Adam([Generated], lr = learning_rate)

In [12]:
for step in range(total_steps):
  Generated_features = model(Generated)
  Content_img_features = model(Content_img)
  Style_img_features = model(Style_img)

  style_loss = content_loss = 0

  for gen_feature, cont_feature, style_feature in zip(
      Generated_features, Content_img_features, Style_img_features
  ):
    batch_size, channel, height, width = gen_feature.shape
    content_loss += torch.mean((gen_feature - cont_feature)**2)

    # Computing Gram matrix now
    G = gen_feature.view(channel, height*width).mm( # Dot Product
        gen_feature.view(channel, height*width).t() # Transpose
    )
    A = style_feature.view(channel, height*width).mm( # Dot Product
        style_feature.view(channel, height*width).t() # Transpose
    )

    style_loss += torch.mean((G - A)**2)

  total_loss = alpha*content_loss + beta*style_loss # Loss = content loss + style loss
  optimizer.zero_grad()
  total_loss.backward()
  optimizer.step()
  if step % 400 == 0:
    print(total_loss)
    save_image(Generated, "Generated image.jpg")

tensor(1852350., device='cuda:0', grad_fn=<AddBackward0>)
tensor(71109.2266, device='cuda:0', grad_fn=<AddBackward0>)
tensor(13625.8105, device='cuda:0', grad_fn=<AddBackward0>)
tensor(4842.1021, device='cuda:0', grad_fn=<AddBackward0>)
tensor(3264.7683, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2592.8311, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2159.7310, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1852.1746, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1620.0619, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1433.1207, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1273.4374, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1132.0062, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1005.2955, device='cuda:0', grad_fn=<AddBackward0>)
tensor(890.1053, device='cuda:0', grad_fn=<AddBackward0>)
tensor(786.9357, device='cuda:0', grad_fn=<AddBackward0>)
