### Based on: https://youtu.be/imX4kSKDY7s

### Imports

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image

In [2]:
model = models.vgg19(pretrained=True).features
model

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [4]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()

        self.chosen_features = ['0', '5', '10', '19', '28']
        self.model = models.vgg19(pretrained=True).features[:29]

    def forward(self, x):
        features = []

        for layer_num, layer in enumerate(self.model):
            x = layer(x)

            if str(layer_num) in self.chosen_features:
                features.append(x)

        return features

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
IMAGE_SIZE = 356

In [7]:
loader = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor()
    ]
)

In [60]:
def load_image(image_name):
    image = Image.open(image_name)
    image = image.convert('RGB')
    image = loader(image).unsqueeze(0)
    return image.to(device)

In [9]:
original_image = load_image('img/original.png')
style_image = load_image('img/style.png')

In [20]:
generated = torch.clone(original_image)
generated.requires_grad_(True)

tensor([[[[0.9569, 0.9569, 0.9569,  ..., 0.9686, 0.9686, 0.9647],
          [0.9569, 0.9569, 0.9569,  ..., 0.9686, 0.9686, 0.9647],
          [0.9569, 0.9529, 0.9569,  ..., 0.9686, 0.9725, 0.9686],
          ...,
          [0.9333, 0.9333, 0.9333,  ..., 0.9765, 0.9686, 0.9765],
          [0.9333, 0.9333, 0.9333,  ..., 0.9725, 0.9765, 0.9804],
          [0.9333, 0.9333, 0.9333,  ..., 0.9686, 0.9804, 0.9843]],

         [[0.9255, 0.9255, 0.9216,  ..., 0.9255, 0.9255, 0.9294],
          [0.9255, 0.9216, 0.9216,  ..., 0.9255, 0.9255, 0.9255],
          [0.9255, 0.9216, 0.9255,  ..., 0.9255, 0.9216, 0.9255],
          ...,
          [0.8745, 0.8745, 0.8745,  ..., 0.9529, 0.9490, 0.9490],
          [0.8745, 0.8745, 0.8745,  ..., 0.9451, 0.9490, 0.9490],
          [0.8745, 0.8745, 0.8745,  ..., 0.9451, 0.9451, 0.9490]],

         [[0.8745, 0.8745, 0.8745,  ..., 0.8784, 0.8784, 0.8784],
          [0.8745, 0.8784, 0.8745,  ..., 0.8784, 0.8784, 0.8784],
          [0.8745, 0.8706, 0.8745,  ..., 0

In [21]:
TOTAL_STEPS = 6000
LR = 0.001
ALPHA = 1
BETA = 0.01

In [22]:
optimizer = optim.Adam([generated], lr=LR)

In [23]:
model = VGG().to(device).eval()

In [24]:
for step in range(TOTAL_STEPS):
    generated_features = model(generated)
    original_features = model(original_image)
    style_features = model(style_image)

    style_loss = original_loss = 0

    for g_feature, o_feature, s_feature in zip(generated_features, original_features, style_features):
        batch_size, channel, height, width = g_feature.shape
        original_loss += torch.mean((g_feature - o_feature) ** 2)

        # Gram Matrix
        G = g_feature.view(channel, height * width).mm(g_feature.view(channel, height * width).t())
        A = s_feature.view(channel, height * width).mm(s_feature.view(channel, height * width).t())

        style_loss += torch.mean((G - A) ** 2)

    total_loss = ALPHA * original_loss + BETA * style_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    if step % 200 == 0:
        print(f'Total loss: {total_loss}')
        save_image(generated, f'results/generated_{step}.png')

Total loss: 3996007.25
Total loss: 167721.734375
Total loss: 83721.90625
Total loss: 55577.49609375
Total loss: 41924.4921875
Total loss: 33845.6796875
Total loss: 28376.2265625
Total loss: 24444.33984375
Total loss: 21557.361328125
Total loss: 19315.96875
Total loss: 17522.27734375
Total loss: 16049.3115234375
Total loss: 14804.6982421875
Total loss: 13745.90625
Total loss: 12824.0166015625
Total loss: 11998.876953125
Total loss: 11255.3671875
Total loss: 10578.796875
Total loss: 9966.2021484375
Total loss: 9400.29296875
Total loss: 8880.0693359375
Total loss: 8396.2724609375
Total loss: 7948.07568359375
Total loss: 7530.85791015625
Total loss: 7150.970703125
Total loss: 6803.1767578125
Total loss: 6485.16259765625
Total loss: 6195.2216796875
Total loss: 5924.0810546875
Total loss: 5673.20556640625


In [25]:
torch.save(model.state_dict(), 'model.tar')

### Trying to reuse the model on another image

In [61]:
original_image = load_image('img/original2.png')
style_image = load_image('img/style.png')

In [62]:
generated = torch.clone(original_image)
generated.requires_grad_(True)

tensor([[[[0.9686, 0.9686, 0.9686,  ..., 0.9686, 0.9686, 0.9686],
          [0.9686, 0.9686, 0.9686,  ..., 0.9686, 0.9686, 0.9686],
          [0.9686, 0.9686, 0.9686,  ..., 0.9686, 0.9686, 0.9686],
          ...,
          [0.9686, 0.9686, 0.9686,  ..., 0.9686, 0.9686, 0.9686],
          [0.9686, 0.9686, 0.9686,  ..., 0.9686, 0.9686, 0.9686],
          [0.9686, 0.9686, 0.9686,  ..., 0.9686, 0.9686, 0.9686]],

         [[0.9686, 0.9686, 0.9686,  ..., 0.9686, 0.9686, 0.9686],
          [0.9686, 0.9686, 0.9686,  ..., 0.9686, 0.9686, 0.9686],
          [0.9686, 0.9686, 0.9686,  ..., 0.9686, 0.9686, 0.9686],
          ...,
          [0.9686, 0.9686, 0.9686,  ..., 0.9686, 0.9686, 0.9686],
          [0.9686, 0.9686, 0.9686,  ..., 0.9686, 0.9686, 0.9686],
          [0.9686, 0.9686, 0.9686,  ..., 0.9686, 0.9686, 0.9686]],

         [[0.9686, 0.9686, 0.9686,  ..., 0.9686, 0.9686, 0.9686],
          [0.9686, 0.9686, 0.9686,  ..., 0.9686, 0.9686, 0.9686],
          [0.9686, 0.9686, 0.9686,  ..., 0

In [63]:
optimizer = optim.Adam([generated], lr=LR)

In [64]:
model = VGG().to(device)
model.load_state_dict(torch.load('model.tar'))
model.eval()

VGG(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding

In [65]:
for step in range(400):
    generated_features = model(generated)
    original_features = model(original_image)
    style_features = model(style_image)

    style_loss = original_loss = 0

    for g_feature, o_feature, s_feature in zip(generated_features, original_features, style_features):
        batch_size, channel, height, width = g_feature.shape
        original_loss += torch.mean((g_feature - o_feature) ** 2)

        # Gram Matrix
        G = g_feature.view(channel, height * width).mm(g_feature.view(channel, height * width).t())
        A = s_feature.view(channel, height * width).mm(s_feature.view(channel, height * width).t())

        style_loss += torch.mean((G - A) ** 2)

    total_loss = ALPHA * original_loss + BETA * style_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    if step % 200 == 0:
        print(f'Total loss: {total_loss}')
        save_image(generated, f'results/generated2_{step}.png')

Total loss: 4164718.0
Total loss: 245776.671875


KeyboardInterrupt: 