In [1]:
import sys
import time

import numpy as np
from imageio import imread
from scipy.ndimage.filters import gaussian_filter
from PIL import Image

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import MSELoss
from torch.optim import Adam, LBFGS
from torchvision.transforms import ToTensor, Normalize, Compose

import pylab as plt
plt.ion()

from tqdm import tqdm
tqdm.monitor_interval = 0

from models import Vgg19, gram_matrix, patch_match, downsampling
from data_utils import read_img

In [2]:
style_fn = 'data/0_target.jpg'
naive_fn = 'data/0_naive.jpg'
mask_fn = 'data/0_c_mask_dilated.jpg'
tmask_fn = 'data/0_c_mask.jpg'

In [3]:
torch.manual_seed(316)
torch.cuda.manual_seed_all(316)

transform = Compose([ToTensor(), Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])])

style_img = transform(Image.open(style_fn)).unsqueeze(0).cuda()
naive_img = transform(Image.open(naive_fn)).unsqueeze(0).cuda()
mask_img = imread(mask_fn).astype(np.float32)
tmask_img = imread(tmask_fn).astype(np.float32)
if mask_img.shape[-1] == 3:
    mask_img = mask_img[..., 0]
if tmask_img.shape[-1] == 3:
    tmask_img = tmask_img[..., 0]
tmask_img = gaussian_filter(tmask_img, sigma = 3)
tmask_img = torch.from_numpy(tmask_img).unsqueeze(0).cuda() / 255.0
mask_img = torch.from_numpy(mask_img).unsqueeze(0).cuda() / 255.0

print(style_img.shape)
print(naive_img.shape)
print(mask_img.shape)
print(tmask_img.shape)

naive_img_original = naive_img.clone()
naive_img = naive_img.requires_grad_(True)

torch.Size([1, 3, 682, 700])
torch.Size([1, 3, 682, 700])
torch.Size([1, 682, 700])
torch.Size([1, 682, 700])


In [None]:
net = Vgg19().cuda()
#optimizer = Adam([naive_img], lr = 1e1)
optimizer = LBFGS([naive_img], max_iter = 100, lr = 1)

In [None]:
features_style = net(style_img)
features_naive_original = net(naive_img_original)

layers_content = ['relu4_1']
layers_style = ['relu3_1', 'relu4_1', 'relu5_1']

features_style_nearest = {}
for l in layers_style:
    features_style_nearest[l] = patch_match(features_naive_original[l], features_style[l], patch_size = 3)
    #features_style_nearest[l] = features_style[l]

In [None]:
from models import cosine_similarity
for l in layers_style:
    print('no matching: ', torch.mean(cosine_similarity(features_naive_original[l], features_style[l])))
    print('matched: ', torch.mean(cosine_similarity(features_naive_original[l], features_style_nearest[l])))

In [None]:
for i in tqdm(range(10)):
    def closure(split = False):
        features_naive = net(naive_img)

        loss_content = 0
        mask = mask_img.unsqueeze(0)
        j = 0
        for l in layers_content:
            while (mask.size(2) != features_naive[l].size(2)):
                if type(net.features[j]) == nn.Conv2d:
                    mask = F.avg_pool2d(mask, 3, stride = 1, padding = 1)
                elif type(net.features[j]) == nn.MaxPool2d:
                    mask = downsampling(mask, scale_factor = 0.5)
                j += 1
            loss_content += torch.mean(mask * (features_naive[l] - features_naive_original[l].detach()) ** 2)

        loss_style = 0
        mask = mask_img.unsqueeze(0)
        j = 0
        for l in layers_style:
            while (mask.size(2) != features_naive[l].size(2)):
                if type(net.features[j]) == nn.Conv2d:
                    mask = F.avg_pool2d(mask, 3, stride = 1, padding = 1)
                elif type(net.features[j]) == nn.MaxPool2d:
                    mask = downsampling(mask, scale_factor = 0.5)
                j += 1
            gram_naive = gram_matrix(mask * features_naive[l]) / torch.sum(mask)
            gram_style = gram_matrix(mask * features_style_nearest[l]) / torch.sum(mask)
            loss_style += torch.mean((gram_naive - gram_style.detach()) ** 2)

        #loss_variation = 0
        #loss_variation += torch.sum((naive_img[:, :, :-1, :-1] - naive_img[:, :, :-1, 1:]) ** 2)
        #loss_variation += torch.sum((naive_img[:, :, :-1, :-1] - naive_img[:, :, 1:, :-1]) ** 2)

        loss_content = loss_content * 5
        loss_style = loss_style * 100
        loss = loss_content  + loss_style# + 1e-3 * loss_variation
        optimizer.zero_grad()

        loss.backward()
        #naive_img.grad.data *= mask_img
        if split:
            return loss, loss_content, loss_style
        else:
            return loss

    optimizer.step(closure)
    print(closure(split = True))

    out_img = tmask_img * naive_img + (1 - tmask_img) * style_img
    out = np.transpose(out_img.detach().squeeze().cpu().numpy(), (1, 2, 0))
    out *= np.array([0.229, 0.224, 0.225], dtype = np.float32)
    out += np.array([0.485, 0.456, 0.406], dtype = np.float32)
    # out = (out - np.min(out)) / (np.max(out) - np.min(out))
    out = np.clip(out, 0, 1)
    plt.figure(figsize = (8, 8))
    plt.imshow(out)

In [None]:
out = np.transpose(out_img.detach().squeeze().cpu().numpy(), (1, 2, 0))
out *= np.array([0.229, 0.224, 0.225], dtype = np.float32)
out += np.array([0.485, 0.456, 0.406], dtype = np.float32)
print(np.max(out), np.min(out))
print(np.mean(out))
out = np.clip(out, 0, 1)
plt.figure(figsize = (16, 16))
plt.imshow(out)

In [None]:
out = tmask_img.detach().squeeze().cpu().numpy()
plt.figure(figsize = (16, 16))
plt.imshow(out)
print(np.max(out), np.min(out))