In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image

In [11]:
model = models.vgg19(pretrained=True).features
print(model)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [12]:
class VGG(nn.Module):

    def __init__(self):
        super(VGG, self).__init__()

        self.chosen_features = ['0', '5', '10', '19', '28']
        self.model = models.vgg19(pretrained=True).features[:29]

    def forward(self, x):
        features = []

        for layer_n, layer in enumerate(self.model):
            x = layer(x)

            if str(layer_n) in self.chosen_features:
                features.append(x)

        return features

In [13]:
def load_image(image_name):
    image = Image.open(image_name)
    image = loader(image).unsqueeze(0)
    return image.to(device)

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_size = 356

In [15]:
loader = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])

In [16]:
original_img = load_image("original.jpg")
style_img = load_image("style.jpg")
#generated = torch.randn(original_img.shape, device=device, requires_grad=True)
generated = original_img.clone().requires_grad(True)

In [17]:
total_steps = 6000
learning_rate = 0.001
#Not the same value as the original paper
alpha = 1
beta = 0.01
optimizer = optim.Adam([generated], lr=learning_rate)

In [18]:
model = VGG().to(device).eval() #freeze weights

In [19]:
for step in range(total_steps):
    generated_features = model(generated)
    original_img_features = model(original_img)
    style_features = model(style_img)

    style_loss = content_loss = 0

    for gen_feature, orig_feature, style_feature in zip(generated_features, original_img_features, style_features):
        batch_size, channel, height, width = gen_feature.shape
        content_loss += torch.mean((gen_feature - orig_feature)**2)

        G = gen_feature.view(channel, height*width).mm(gen_feature.view(channel, height*width).t())
        A = style_feature.view(channel, height*width).mm(style_feature.view(channel, height*width).t())

        style_loss += torch.mean((G - A)**2)

    total_loss = alpha * content_loss + beta * style_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    if step % 200 == 0:
        print(total_loss)
        save_image(generated, "generated.png")

tensor(5.6667e+08, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.9346e+08, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.0101e+08, device='cuda:0', grad_fn=<AddBackward0>)
tensor(62773408., device='cuda:0', grad_fn=<AddBackward0>)
tensor(43015628., device='cuda:0', grad_fn=<AddBackward0>)
tensor(31378500., device='cuda:0', grad_fn=<AddBackward0>)
tensor(23910148., device='cuda:0', grad_fn=<AddBackward0>)
tensor(18816808., device='cuda:0', grad_fn=<AddBackward0>)
tensor(15179368., device='cuda:0', grad_fn=<AddBackward0>)
tensor(12464425., device='cuda:0', grad_fn=<AddBackward0>)
tensor(10346301., device='cuda:0', grad_fn=<AddBackward0>)
tensor(8649906., device='cuda:0', grad_fn=<AddBackward0>)
tensor(7315893.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(6258705.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(5403838., device='cuda:0', grad_fn=<AddBackward0>)
tensor(4696684., device='cuda:0', grad_fn=<AddBackward0>)
tensor(4104842.5000, device='cuda:0', grad_fn=<Add