In [1]:
#We use the pretrained VGG19 network
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image

In [2]:
model = models.vgg19(pretrained=True).features

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:07<00:00, 73.8MB/s]


In [3]:
print(model)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [4]:
class VGG(nn.Module):
  def __init__(self):
    super(VGG, self).__init__()
    self.chosen_features = ['0', '5', '10', '19', '28']
    self.model = models.vgg19(pretrained=True).features[:29]

  def forward(self, x):
    features = []
    for layer_num, layer in enumerate(self.model):
      x = layer(x)
      if str(layer_num) in self.chosen_features:
        features.append(x)

    return features

In [5]:
def load_image(image_name):
  image = Image.open(image_name)
  image = loader(image).unsqueeze(0)
  return image.to(device)

device = torch.device("cuda" if torch.cuda.is_available else "cpu")
image_size = 356

loader = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
    ]
)

In [14]:
original_image = load_image('annahathaway.png')

In [15]:
style_image = load_image('style.jpg')

In [16]:
model = VGG().to(device).eval()
generated = original_image.clone().requires_grad_(True)



In [17]:
total_steps = 6000
learning_rate = 0.001
alpha = 1
beta = 0.01


In [18]:
optimizer = optim.Adam([generated], lr=learning_rate)

In [22]:
for step in range(total_steps):
  generated_features = model(generated)
  original_img_features = model(original_image)
  style_features = model(style_image)
  style_loss = original_loss = 0

  for gen_feature, orig_feature, style_feature in zip(
      generated_features, original_img_features, style_features
  ):
    batch_size, channel, height, width = gen_feature.shape
    original_loss += torch.mean((gen_feature - orig_feature) ** 2)

    G = gen_feature.view(channel, height*width).mm(
        gen_feature.view(channel, height*width).t()
    )

    A = style_feature.view(channel, height*width).mm(
        style_feature.view(channel, height*width).t()
    )

    style_loss += torch.mean((G - A) ** 2)
  total_loss = alpha * original_loss + beta  * style_loss
  optimizer.zero_grad()
  total_loss.backward()
  optimizer.step()

  if step % 200 == 0:
    print(total_loss)
    save_image(generated, 'generated.png')

tensor(1579650.8750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(84283., device='cuda:0', grad_fn=<AddBackward0>)
tensor(39234.9766, device='cuda:0', grad_fn=<AddBackward0>)
tensor(28047.4609, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22963.9258, device='cuda:0', grad_fn=<AddBackward0>)
tensor(19655.9492, device='cuda:0', grad_fn=<AddBackward0>)
tensor(17174.5898, device='cuda:0', grad_fn=<AddBackward0>)
tensor(15166.9668, device='cuda:0', grad_fn=<AddBackward0>)
tensor(13460.9521, device='cuda:0', grad_fn=<AddBackward0>)
tensor(11982.3027, device='cuda:0', grad_fn=<AddBackward0>)
tensor(10696.8359, device='cuda:0', grad_fn=<AddBackward0>)
tensor(9583.5459, device='cuda:0', grad_fn=<AddBackward0>)
tensor(8621.6748, device='cuda:0', grad_fn=<AddBackward0>)
tensor(7799.9907, device='cuda:0', grad_fn=<AddBackward0>)
tensor(7095.5288, device='cuda:0', grad_fn=<AddBackward0>)
tensor(6490.7441, device='cuda:0', grad_fn=<AddBackward0>)
tensor(5964.9897, device='cuda:0', grad_fn=<Add