In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.utils import save_image
from PIL import Image

In [2]:
model = models.vgg19(pretrained=True).features
model



Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [3]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        self.chosen_features = [0, 5, 10, 19, 28]
        self.model = models.vgg19(pretrained=True).features[:29]

    def forward(self, x):
        features = []
        for layer_num, layer in enumerate(self.model):
            x = layer(x)
            if layer_num in self.chosen_features:
                features.append(x)

        return features

In [10]:
def load_image(image_name):
    image = Image.open(image_name)
    image = loader(image).unsqueeze(0)
    return image.to(device)

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
imsize = 356

In [12]:
loader = transforms.Compose(
    [
        transforms.Resize((imsize, imsize)),
        transforms.ToTensor(),
    ]
)

In [15]:
original_img = load_image("MIT_dome.jpg")
style_img = load_image("Starry_Night.jpg")
generated = original_img.clone().requires_grad_(True)
model = VGG().to(device).eval()

In [16]:
total_steps = 6000
learning_rate = 0.001
alpha = 1
beta = 0.01
optimizer = optim.Adam([generated], lr=learning_rate)

In [17]:
for step in range(total_steps):
    generated_features = model(generated)
    original_img_features = model(original_img)
    style_features = model(style_img)

    style_loss = original_loss = 0

    for gen_feature, orig_feature, style_feature in zip(
        generated_features, original_img_features, style_features
    ):

        batch_size, channel, height, width = gen_feature.shape
        original_loss += torch.mean((gen_feature - orig_feature) ** 2)

        G = gen_feature.view(channel, height * width).mm(
            gen_feature.view(channel, height * width).t()
        )

        A = style_feature.view(channel, height * width).mm(
            style_feature.view(channel, height * width).t()
        )
        style_loss += torch.mean((G - A) ** 2)

    total_loss = alpha * original_loss + beta * style_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    if step % 200 == 0:
        print(total_loss)
        save_image(generated, "generated.png")

tensor(653701.2500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(16626.0195, device='cuda:0', grad_fn=<AddBackward0>)
tensor(8407.8271, device='cuda:0', grad_fn=<AddBackward0>)
tensor(5603.6367, device='cuda:0', grad_fn=<AddBackward0>)
tensor(4188.3604, device='cuda:0', grad_fn=<AddBackward0>)
tensor(3355.8149, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2816.4109, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2438.6121, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2158.1675, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1941.0804, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1767.9083, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1626.4784, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1507.6978, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1404.3403, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1313.9401, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1234.0369, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1162.1794, device='cuda:0', grad_fn=<AddBackwa