My Implementation of Neural Style Transfer

"A Neural Algorithm of Artistic Style"
https://arxiv.org/abs/1508.06576

Gholamreza Dar 2022

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_size = 356
loader = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
    ]
)

In [3]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        self.chosen_features =  ['0', '5', '10', '19', '28']
        self.model = models.vgg19(weights="DEFAULT").features[:29]

    def forward(self, x):
        features = []
        for layer_num, layer in enumerate(self.model):
            x = layer(x)
            if str(layer_num) in self.chosen_features:
                features.append(x)

        return features

In [4]:
def load_image(image_name):
    image = Image.open(image_name)
    image = loader(image).unsqueeze(0)
    return image.to(device)

In [5]:
!wget https://www.vincentvangogh.org/assets/img/paintings/the-starry-night-over-the-rhone.jpg -O "image_b.jpg"
!wget https://bonnierpublications.com/app/uploads/2022/05/woman-1.jpg -O "image_a.jpg"

--2024-12-15 22:57:01--  https://www.vincentvangogh.org/assets/img/paintings/the-starry-night-over-the-rhone.jpg
Resolving www.vincentvangogh.org (www.vincentvangogh.org)... 104.21.62.178, 172.67.137.232, 2606:4700:3033::ac43:89e8, ...
Connecting to www.vincentvangogh.org (www.vincentvangogh.org)|104.21.62.178|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 306388 (299K) [image/jpeg]
Saving to: ‘image_b.jpg’


2024-12-15 22:57:02 (693 KB/s) - ‘image_b.jpg’ saved [306388/306388]

--2024-12-15 22:57:02--  https://bonnierpublications.com/app/uploads/2022/05/woman-1.jpg
Resolving bonnierpublications.com (bonnierpublications.com)... 104.21.59.52, 172.67.214.143, 2606:4700:3035::ac43:d68f, ...
Connecting to bonnierpublications.com (bonnierpublications.com)|104.21.59.52|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 316396 (309K) [image/jpeg]
Saving to: ‘image_a.jpg’


2024-12-15 22:57:03 (8.29 MB/s) - ‘image_a.jpg’ saved [316396/316396]



In [6]:
original_image = load_image("image_a.jpg")
style_image = load_image("image_b.jpg")
model = VGG().to(device).eval()
generated = original_image.clone().requires_grad_(True)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:06<00:00, 85.1MB/s]


In [None]:
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

total_steps = 3000
learning_rate = 0.001
alpha = 1
beta = 0.01

# notice how we put generated image as the "parameters"
optimizer = optim.Adam([generated], lr=learning_rate)

for step in tqdm(range(total_steps)):
    generated_features = model(generated)
    original_img_features = model(original_image)
    style_features = model(style_image)

    style_loss = 0
    original_loss = 0

    for gen_feature, orig_feature, style_feature in zip(
        generated_features, original_img_features, style_features
    ):
        batch_size, channel, height, width = gen_feature.shape
        original_loss += torch.mean((gen_feature - orig_feature) ** 2)

        # compute the gram matrix (gen_feature X gen_feature.T)
        G = gen_feature.view(channel, height * width).mm(
            gen_feature.view(channel, height * width).t()
        )

        A = style_feature.view(channel, height * width).mm(
            style_feature.view(channel, height * width).t()
        )

        style_loss += torch.mean((G - A) ** 2)

    total_loss = alpha * original_loss + beta * style_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    if step % 200 == 0:
        save_image(generated, f"generated_{step:<4}.png")
        plt.title(f"Step: {step}")
        plt.imshow(generated.detach().cpu().squeeze(0).permute(1, 2, 0))
        plt.show()