In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image
from torchvision.models import VGG19_Weights

In [2]:
model = models.vgg19(weights=VGG19_Weights.DEFAULT).features
# print(model)

In [3]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()

        self.chosen_features = ['0', '5', '10', '19', '28']
        self.model = models.vgg19(weights=VGG19_Weights.DEFAULT).features[:29]

    def forward(self, x):
        features = []

        for layer_num, layer in enumerate(self.model):
            x = layer(x)

            if str(layer_num) in self.chosen_features:
                features.append(x)

        return features

In [4]:
device = torch.device("cuda" if torch.cuda.is_available else "cpu")
image_size = 356

In [5]:
loader = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[], std=[])
    ]
)

In [6]:
def load_image(image_name):
    image = Image.open(image_name)
    image = loader(image).unsqueeze(0)
    return image.to(device)

In [7]:
original_img = load_image("png/annahathaway.png")
style_img = load_image("png/style.png")

In [8]:
model = VGG().to(device).eval() # to freeze the weights

# generated = torch.randn(original_img.shape, device=device, requires_grad=True)

# doing a copy of the original is better than above
generated = original_img.clone().requires_grad_(True)

In [9]:
# Hyperparameters
total_steps = 6000
learning_rate = 0.001
alpha = 1
beta = 0.01
optimizer = optim.Adam([generated], lr=learning_rate)

In [None]:
for step in range(total_steps):
    generated_features = model(generated)
    original_img_features = model(original_img)
    style_features = model(style_img)

    style_loss = original_loss = 0

    for gen_feature, orig_feature, style_feature in zip(
        generated_features, original_img_features, style_features
    ):
        batch_size, channel, height, width = gen_feature.shape
        original_loss += torch.mean((gen_feature - orig_feature) ** 2)

        # Compute Gram Matrix
        G = gen_feature.view(channel, height*width).mm(
            gen_feature.view(channel, height*width).t()
        )

        A = style_feature.view(channel, height*width).mm(
            style_feature.view(channel, height*width).t()
        )

        style_loss += torch.mean((G - A)**2)

    total_loss = alpha*original_loss + beta * style_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    if step % 200 == 0:
        print(total_loss)
        save_image(generated, "png/generated.png")

tensor(3996007.2500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(167669.3438, device='cuda:0', grad_fn=<AddBackward0>)
tensor(83794.9297, device='cuda:0', grad_fn=<AddBackward0>)
tensor(55577.1406, device='cuda:0', grad_fn=<AddBackward0>)
tensor(41845.3438, device='cuda:0', grad_fn=<AddBackward0>)
tensor(33773.0156, device='cuda:0', grad_fn=<AddBackward0>)
tensor(28332.3516, device='cuda:0', grad_fn=<AddBackward0>)
tensor(24449.9551, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21566.1230, device='cuda:0', grad_fn=<AddBackward0>)
