In [1]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
from torchvision.models import VGG19_Weights
from torchvision.utils import save_image

In [2]:
model = models.vgg19(weights=VGG19_Weights.DEFAULT).features

device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')

In [3]:
def image_loader(path):
    image = Image.open(path)
    loader = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor()])
    image = loader(image).unsqueeze(0)
    return image.to(device, torch.float)


original_image = image_loader('images/city.jpg')
style_image = image_loader('images/vangogh.jpg')

generated_image = original_image.clone().requires_grad_(True)

In [4]:
models.vgg19(weights=VGG19_Weights.DEFAULT).features

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [5]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        self.req_features = [0, 5, 10, 19, 28]
        self.model = models.vgg19(weights=VGG19_Weights.DEFAULT).features.to(device).eval()

    def forward(self, x):
        features = []
        for layer_num, layer in enumerate(self.model):
            x = layer(x)
            if layer_num in self.req_features:
                features.append(x)
        return features

In [6]:
def calc_content_loss(gen_feat, orig_feat):
    content_l = torch.mean((gen_feat - orig_feat) ** 2)
    return content_l

In [7]:
def calc_style_loss(gen, style):
    batch_size, channel, height, width = gen.shape

    G = torch.mm(gen.view(channel, height * width), gen.view(channel, height * width).t())
    A = torch.mm(style.view(channel, height * width), style.view(channel, height * width).t())

    style_l = torch.mean((G - A) ** 2)
    return style_l

In [8]:
def calculate_loss(gen_features, orig_features, style_features, content_weight=1, style_weight=1000):
    style_loss = content_loss = 0
    for gen, cont, style in zip(gen_features, orig_features, style_features):
        content_loss += calc_content_loss(gen, cont)
        style_loss += calc_style_loss(gen, style)

    total_loss = content_weight * content_loss + style_weight * style_loss
    return total_loss

In [9]:
def get_input_optimizer(input_img):
    return optim.LBFGS([input_img])

In [10]:
model = VGG().to(device).eval()

epoch = 500

optimizer = get_input_optimizer(generated_image)

In [None]:
for e in range(epoch):
    def closure():
        optimizer.zero_grad()
        gen_features = model(generated_image)
        orig_features = model(original_image)
        style_features = model(style_image)

        total_loss = calculate_loss(gen_features, orig_features, style_features)
        total_loss.backward()
        if e / 100 == 0:
            print(total_loss)
            save_image(generated_image, 'images/gen_v2.png')
        return total_loss


    optimizer.step(closure)