In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
from torchvision import transforms
from torchvision.models import VGG19_Weights
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm


class VGGNet(nn.Module):
    def __init__(self):
        super().__init__()
        # content on conv4_2
        # style on conv1_1, conv2_1, conv3_1, conv4_1, conv5_1
        self.select = ['0', '5', '10', '19', '28', '21']
        vgg = models.vgg19(weights=VGG19_Weights.DEFAULT).features  # pretrained VGG19 model
        for param in vgg.parameters():
            param.requires_grad_(False)
        self.vgg = nn.Sequential(*[self.replace_layers(layer) for layer in vgg[:29]])
    
    def replace_layers(self, layer):
        if isinstance(layer, nn.MaxPool2d):
            return nn.AvgPool2d(kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding)
        return layer

    def forward(self, x):
        features = []
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.select:
                features.append(x)
        return features


def preprocess(
        image_path,
        transform=None,
        device=None
):
    """
    read an image
    process it to tensor available for vgg
    """
    image = Image.open(image_path)
    if transform:
        image = transform(image).unsqueeze(0)
    return image.to(device, torch.float)


def get_features(image, model):
    """
    get feature vectors
    """
    layers = model(image)
    features = []
    for layer in layers:
        feature = layer.reshape(layer.shape[1], -1)
        features.append(feature)
    return features


def gram_matrix(tensor):
    c , hw = tensor.size()
    gram = torch.mm(tensor, tensor.t())
    return gram / hw


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    image_size = 224
    img_path = './img/'
    save_path = './results/'
    content_path = f'{img_path}content1.jpg'
    style_path = f'{img_path}style1.jpg'

    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x[torch.LongTensor([2, 1, 0])]),
        transforms.Normalize(
            (0.485, 0.456, 0.406),
            (1, 1, 1)
        ),
        transforms.Lambda(lambda x: x.mul_(255))
    ])

    postprocess = transforms.Compose([
        transforms.Lambda(lambda x: x.mul_(1./255)),
        transforms.Normalize(
            mean=[-0.485, -0.456, -0.406],
            std=[1, 1, 1]
        ),
        transforms.Lambda(lambda x: x[torch.LongTensor([2, 1, 0])]),
        transforms.Lambda(lambda x: torch.clamp(x, 0, 1)),
        transforms.ToPILImage()
    ])

    # load content and style image
    content = preprocess(
        content_path,
        transform=transform,
        device=device
    )
    style = preprocess(
        style_path,
        transform=transform,
        device=device
    )

    model = VGGNet().to(device).eval()

    # get content and style features
    content_features = get_features(content, model)[-1]
    style_features = get_features(style, model)[:-1]
    style_grams = [gram_matrix(feature) for feature in style_features]

    # white noise image
    target = torch.randn_like(content).requires_grad_(True).to(device)
    #n target = content.clone().requires_grad_(True).to(device)

    content_weight = 1
    style_weights = [1e3 / n ** 2 for n in [64, 128, 256, 512, 512]]
    optimizer = optim.LBFGS([target])
    steps = 500
    show_every = 10

    def closure():
        optimizer.zero_grad()
        target_features = get_features(target, model)
        content_loss = F.mse_loss(target_features[-1], content_features) * content_weight
        style_loss = 0
        for target_style, style_gram, style_weight in zip(
                target_features[:-1],
                style_grams,
                style_weights
        ):
            target_gram = gram_matrix(target_style)
            style_loss += F.mse_loss(target_gram, style_gram) * style_weight
        total_loss = content_loss + style_loss
        total_loss.backward()
        return total_loss

    for i in tqdm(range(1, steps + 1)):
        optimizer.step(closure)
        if i % show_every == 0:
            loss = closure()
            plt.imshow(postprocess(target[0].cpu().detach()))
            plt.axis("off")
            plt.title(f"Iter {i}, Loss: {loss.item()}")
            plt.show()

    # save image
    plt.imshow(postprocess(target[0].cpu().detach()))
    plt.axis("off")
    content_name = content_path.split('/')[-1].split('.')[0]
    style_name = style_path.split('/')[-1].split('.')[0]
    plt.savefig(f'{save_path}{content_name}_{style_name}.png')
    plt.show()
