In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image

In [2]:
model = models.vgg19(pretrained=True).features


Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


  0%|          | 0.00/548M [00:00<?, ?B/s]

In [3]:
print(model)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [5]:
from google.colab import files

uploaded = files.upload()

for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))

Saving lake.jpg to lake.jpg
Saving mount.jpg to mount.jpg
User uploaded file "lake.jpg" with length 13269 bytes
User uploaded file "mount.jpg" with length 11888 bytes


In [8]:
class VGG(nn.Module):
  def __init__(self):
    super(VGG, self).__init__()
    self.chosen_features = ['0','5','10','19','28']
    self.model = models.vgg19(pretrained=False).features[:29]

  def forward(self,x):
    features=[]

    for layer_num , layer in enumerate(self.model):
      x = layer(x)

      if str(layer_num) in self.chosen_features:
        features.append(x)
    return features

def load_image(image_name):
  image = Image.open(image_name)
  image = loader(image).unsqueeze(0)
  return image.to(device)

device = torch.device("cuda" if torch.cuda.is_available else "cpu")
image_size = 356

loader = transforms.Compose(
    [
     transforms.Resize((image_size, image_size)),
     transforms.ToTensor(),
    ]
)

original_img = load_image('mount.jpg')
style_img = load_image('lake.jpg')

model = VGG().to(device).eval()

generated = original_img.clone().requires_grad_(True)

# hyperparameters
total_steps = 6000
learning_rate = 0.001
alpha = 1
beta = 0.01
optimizer = optim.Adam([generated], lr=learning_rate)


for step in range(total_steps):
  generated_features = model(generated)
  original_img_features = model(original_img)
  style_features = model(style_img)

  style_loss = original_loss = 0

  for gen_feature, orig_feature, style_feature in zip(generated_features,original_img_features,style_features):
    batch_size , channel, height, width = gen_feature.shape
    original_loss += torch.mean((gen_feature - orig_feature)**2)

    # Compute Gram matrix
    G = gen_feature.view(channel, height*width).mm(
        gen_feature.view(channel, height*width).t()
    )

    A = style_feature.view(channel, height*width).mm(
        style_feature.view(channel, height*width).t()
    )

    style_loss += torch.mean((G-A)**2)

  total_loss = alpha*original_loss + beta * style_loss
  optimizer.zero_grad()
  total_loss.backward()
  optimizer.step()

  if step % 200 == 0:
    print(total_loss)
    save_image(generated, 'generated.png')

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


tensor(6339.1768, device='cuda:0', grad_fn=<AddBackward0>)
tensor(847.7491, device='cuda:0', grad_fn=<AddBackward0>)
tensor(228.7175, device='cuda:0', grad_fn=<AddBackward0>)
tensor(75.6339, device='cuda:0', grad_fn=<AddBackward0>)
tensor(23.2090, device='cuda:0', grad_fn=<AddBackward0>)
tensor(7.8980, device='cuda:0', grad_fn=<AddBackward0>)
tensor(3.7221, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.3968, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.8167, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.4666, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.2154, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.0216, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.8668, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.7410, device='cuda:0', grad_fn=<AddBackward0>)


KeyboardInterrupt: ignored