In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from scipy import ndimage

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torchvision.models import vgg16
from torchvision import transforms
import torchvision.transforms.functional as TVF

In [None]:
img_size = (800, 600)
seed = 42
device = "cuda"

In [None]:
content_img = Image.open("data/content.jpg")
content_img = content_img.resize(img_size)
plt.imshow(content_img)
plt.show()

In [None]:
style_img = Image.open("data/style.jpg")
style_img = style_img.resize(img_size)
plt.imshow(style_img)
plt.show()

In [None]:
# Need to reverse the size due to PIL convention
np.random.seed(seed)
opt_image = np.random.uniform(0, 1, size=(img_size[1], img_size[0], 3))

# Filtering image can help convergence
opt_image = ndimage.median_filter(opt_image, (8, 8, 1))

In [None]:
plt.imshow(opt_image)
plt.show()

In [None]:
model = vgg16(pretrained='imagenet')
model = model.features.to(device).eval()

for layer in model:
    layer.requires_grad = False

In [None]:
# Conv layer before MaxPool (including Relu)
[idx - 1 for idx, layer in enumerate(model)
    if isinstance(layer, nn.MaxPool2d)]

In [None]:
model[22]

In [None]:
class SaveFeatures:
    features = None

    def __init__(self, layer):
        self.hook = layer.register_forward_hook(self.hook_fn)
    
    def hook_fn(self, module, input, output):
        self.features = output
    
    def close(self):
        self.hook.remove()

In [None]:
sf = SaveFeatures(model[22])

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(mean=(0.485, 0.456, 0.406),
                          std=(0.229, 0.224, 0.225))
    ])

In [None]:
content_tensor = transform(content_img)
content_tensor = content_tensor.unsqueeze(0).float().to(device)

style_tensor = transform(style_img)
#style_tensor = style_tensor.unsqueeze(0).float().to(device)

In [None]:
#opt_tensor = (opt_image - np.mean(opt_image)) / np.sqrt(np.var(opt_image))
opt_tensor = opt_image / 2
opt_tensor = TVF.to_tensor(opt_image)
opt_tensor = opt_tensor.unsqueeze(0).float().to(device)
opt_tensor.requires_grad = True

In [None]:
lr = 1e-3
iter = 100

#optimizer = optim.Adam([opt_tensor], lr=lr)
optimizer = optim.LBFGS([opt_tensor], lr=0.5)

In [None]:
_ = model(content_tensor)
content_hidden = sf.features.clone()

In [None]:
torch.manual_seed(seed)

run = [0]
for it in range(1, iter+1):
    def lbfgs_step():
        optimizer.zero_grad()

        _ = model(opt_tensor)
        opt_hidden = sf.features

        loss = F.mse_loss(opt_hidden, content_hidden)
        loss.backward(retain_graph=True)

        if run[0] % 10 == 0:
            print(loss)

        run[0] += 1
        return loss
    
    #optimizer.zero_grad()
    
    optimizer.step(lbfgs_step)
    
    #if it % 10 == 0:
    #    print(loss)
    #    print(opt_tensor[0, :, 0, 0])

In [None]:
opt_img = opt_tensor.squeeze(0).detach().cpu()
opt_img = TVF.to_pil_image(opt_img)
plt.imshow(opt_img)
plt.show()