In [1]:
# %load_ext autoreload
# %autoreload 2

In [1]:
import numpy as np
import torch
import torch.optim as optim
from PIL import Image
import matplotlib.pyplot as plt
from skimage.io import imsave
from utils import compute_gt_gradient, make_canvas_mask, numpy2tensor, laplacian_filter_tensor, \
                  MeanShift, gram_matrix
from torchvision import models
from collections import namedtuple

ModuleNotFoundError: No module named 'cv2'

In [None]:
source_file = "images/5_source.png"
mask_file = "images/5_source.png"
target_file = "images/5_source.png"

In [4]:
x_start = 250
y_start = 150

In [5]:
num_steps = 1000
ss = 300 # source image size
ts = 512 # target image size

In [6]:
grad_weight = 1e4
style_weight = 1e4
content_weight = 1
tv_weight = 1e-6

In [7]:
source_img = np.array(Image.open(source_file).convert('RGB').resize((ss, ss)))
target_img = np.array(Image.open(target_file).convert('RGB').resize((ts, ts)))
mask_img = np.array(Image.open(mask_file).convert('L').resize((ss, ss)))
mask_img[mask_img>0] = 1

In [8]:
canvas_mask = make_canvas_mask(x_start, y_start, target_img, mask_img)
canvas_mask = numpy2tensor(canvas_mask)
canvas_mask = canvas_mask.squeeze(0).repeat(3,1).view(3,ts,ts).unsqueeze(0)

In [9]:
gt_gradient = compute_gt_gradient(x_start, y_start, source_img, target_img, mask_img)

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [11]:
source_img = torch.from_numpy(source_img).unsqueeze(0).transpose(1,3).transpose(2,3).float().to(device)
target_img = torch.from_numpy(target_img).unsqueeze(0).transpose(1,3).transpose(2,3).float().to(device)
input_img = torch.randn(target_img.shape).to(device)

In [12]:
mask_img = numpy2tensor(mask_img)
mask_img = mask_img.squeeze(0).repeat(3,1).view(3,ss,ss).unsqueeze(0)

In [13]:
def get_input_optimizer(input_img):
    optimizer = optim.LBFGS([input_img.requires_grad_()])
    return optimizer
optimizer = get_input_optimizer(input_img)

In [14]:
mse = torch.nn.MSELoss()

In [15]:
mean_shift = MeanShift()

In [16]:
class Vgg16(torch.nn.Module):
    def __init__(self, requires_grad=False):
        super(Vgg16, self).__init__()
        vgg_pretrained_features = models.vgg16(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        for x in range(4):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3)
        return out

In [17]:
vgg = Vgg16().to(device)

In [18]:
run = [0]

In [18]:
def closure1():
    blend_img = torch.zeros(target_img.shape).to(device)
    blend_img = input_img*canvas_mask + target_img*(canvas_mask-1)*(-1) 
    
    pred_gradient = laplacian_filter_tensor(blend_img)
    
    grad_loss = 0
    for c in range(len(pred_gradient)):
        grad_loss += mse(pred_gradient[c], gt_gradient[c])
    grad_loss /= len(pred_gradient)
    grad_loss *= grad_weight
    
    target_features_style = vgg(mean_shift(target_img))
    target_gram_style = [gram_matrix(y) for y in target_features_style]

    blend_features_style = vgg(mean_shift(input_img))
    blend_gram_style = [gram_matrix(y) for y in blend_features_style]

    style_loss = 0
    for layer in range(len(blend_gram_style)):
        style_loss += mse(blend_gram_style[layer], target_gram_style[layer])
    style_loss /= len(blend_gram_style)  
    style_loss *= style_weight    
    
    blend_obj = blend_img[:,:,int(x_start-source_img.shape[2]*0.5):int(x_start+source_img.shape[2]*0.5), int(y_start-source_img.shape[3]*0.5):int(y_start+source_img.shape[3]*0.5)]
    source_object_features = vgg(mean_shift(source_img*mask_img))
    blend_object_features = vgg(mean_shift(blend_obj*mask_img))
    content_loss = content_weight * mse(blend_object_features.relu2_2, source_object_features.relu2_2)
    content_loss *= content_weight

    # Compute TV Reg Loss
    tv_loss = torch.sum(torch.abs(blend_img[:, :, :, :-1] - blend_img[:, :, :, 1:])) + \
               torch.sum(torch.abs(blend_img[:, :, :-1, :] - blend_img[:, :, 1:, :]))
    tv_loss *= tv_weight

    # Compute Total Loss and Update Image
    loss = grad_loss + style_loss + content_loss + tv_loss
    optimizer.zero_grad()
    loss.backward()
    
    if run[0] % 50 == 0:
        print("run {}:".format(run))
        print('grad : {:4f}, style : {:4f}, content: {:4f}, tv: {:4f}'.format(\
                      grad_loss.item(), \
                      style_loss.item(), \
                      content_loss.item(), \
                      tv_loss.item()
                      ))
        print()

    run[0] += 1
    return loss

In [19]:
%%time
while run[0] <= num_steps:
    optimizer.step(closure1)

  This is separate from the ipykernel package so we can avoid doing imports until
  This is separate from the ipykernel package so we can avoid doing imports until


run [0]:
grad : 8817536.000000, style : 564255360.000000, content: 67266.132812, tv: 3.884570

run [50]:
grad : 6504623.500000, style : 230097008.000000, content: 72144.539062, tv: 12.650730



KeyboardInterrupt: 

In [30]:
input_img.data.clamp_(0, 255)

tensor([[[[112.8883,  69.5797, 123.5955,  ...,  88.1446,  47.8517,  87.4343],
          [ 94.7460,  84.2942, 100.5765,  ..., 107.0337,   0.0000,  24.1371],
          [146.3399, 135.2570, 115.0161,  ..., 236.9713,  32.9627,  26.1542],
          ...,
          [ 39.9588,  41.4356,  63.0307,  ..., 255.0000, 255.0000, 229.9171],
          [ 44.2169,  25.0721,  62.2607,  ..., 255.0000, 199.6423, 150.1129],
          [ 46.2332,  40.4184,  37.2825,  ..., 237.4586, 172.6453, 168.8631]],

         [[118.1725,  68.1907,  93.9025,  ...,  51.9844,  56.8669,  13.3262],
          [137.9009, 119.0598,  99.5933,  ..., 111.5192, 156.4389,  64.9105],
          [127.1557,  76.0841,  79.3315,  ..., 132.2823, 172.9769,   8.6906],
          ...,
          [ 14.4702,  45.2283,  45.9954,  ..., 189.7735, 252.7822,  90.2092],
          [ 25.5272,  17.5611,  35.3226,  ..., 194.7985, 255.0000, 165.8821],
          [ 45.1715,  48.2228,  30.7250,  ..., 202.2029, 255.0000, 231.2118]],

         [[ 75.0083,  60.0057,

In [31]:
blend_img = torch.zeros(target_img.shape).to(device)
blend_img = input_img*canvas_mask + target_img*(canvas_mask-1)*(-1) 
blend_img_np = blend_img.transpose(1,3).transpose(1,2).cpu().data.numpy()[0]

# Save image from the first pass
name = source_file.split('/')[1].split('_')[0]
imsave('results/'+str(name)+'_first_pass.png', blend_img_np.astype(np.uint8))