In [13]:
# Packages
import numpy as np
import torch
import torch.optim as optim
from PIL import Image
from skimage.io import imsave
from torchvision.utils import save_image
from utils import compute_gt_gradient, make_canvas_mask, numpy2tensor, laplacian_filter_tensor, \
                  MeanShift, Vgg16, gram_matrix
import argparse
import pdb
import os
import imageio.v2 as iio
import torch.nn.functional as F

In [14]:

###################################
########### First Pass ###########
###################################

# Inputs
source_file = "data/rsz_ad_jotform.png"
mask_file = "data/rsz_ad_mask_jotform.png"
target_file = "data/avatar_700_bg.png"

# Outputs
output_dir = "results/1/"

# Hyperparameter Inputs
gpu_id = 1
num_steps = 500
ts = 1024
ss = 50

# blending operation
x_start = 512
y_start = 512

device = 'mps'

num_steps = 1000

save_video = False

# Default weights for loss functions in the first pass
grad_weight = 1e4; style_weight = 1e4; content_weight = 1; tv_weight = 1e-6

In [15]:
os.makedirs(output_dir, exist_ok = True)

In [16]:
# Load Images
source_img = np.array(Image.open(source_file).convert('RGB').resize((ss, ss)))
target_img = np.array(Image.open(target_file).convert('RGB').resize((ts, ts)))
mask_img = np.array(Image.open(mask_file).convert('L').resize((ss, ss)))
mask_img[mask_img>0] = 1

# Make Canvas Mask
canvas_mask = make_canvas_mask(x_start, y_start, target_img, mask_img)
canvas_mask = numpy2tensor(canvas_mask, device)
canvas_mask = canvas_mask.squeeze(0).repeat(3,1).view(3,ts,ts).unsqueeze(0)

# Compute Ground-Truth Gradients
gt_gradient = compute_gt_gradient(x_start, y_start, source_img, target_img, mask_img, device)

# Convert Numpy Images Into Tensors
#source_img = torch.from_numpy(source_img).unsqueeze(0).transpose(1,3).transpose(2,3).float().to(gpu_id)
#target_img = torch.from_numpy(target_img).unsqueeze(0).transpose(1,3).transpose(2,3).float().to(gpu_id)
#input_img = torch.randn(target_img.shape).to(gpu_id)

source_img = torch.from_numpy(source_img).unsqueeze(0).transpose(1,3).transpose(2,3).float().to(device)
target_img = torch.from_numpy(target_img).unsqueeze(0).transpose(1,3).transpose(2,3).float().to(device)
input_img = torch.randn(target_img.shape).to(device)

mask_img = numpy2tensor(mask_img, device)
mask_img = mask_img.squeeze(0).repeat(3,1).view(3,ss,ss).unsqueeze(0)

# Define LBFGS optimizer
def get_input_optimizer(input_img):
    optimizer = optim.LBFGS([input_img.requires_grad_()])
    return optimizer
optimizer = get_input_optimizer(input_img)

# Define Loss Functions
mse = torch.nn.MSELoss()

# Import VGG network for computing style and content loss
mean_shift = MeanShift(device)
#vgg = Vgg16().to(gpu_id)
vgg = Vgg16().to(device)

# Save reconstruction process in a video
if save_video:
    recon_process_video = iio.get_writer(os.path.join(output_dir, 'recon_process.mp4'), format='FFMPEG', mode='I', fps=400)

run = [0]
while run[0] <= num_steps:
    
    def closure():
        # Composite Foreground and Background to Make Blended Image

        #blend_img = torch.zeros(target_img.shape).to(gpu_id)
        blend_img = torch.zeros(target_img.shape).to(device)
        blend_img = input_img*canvas_mask + target_img*(canvas_mask-1)*(-1) 
        
        # Compute Laplacian Gradient of Blended Image
        pred_gradient = laplacian_filter_tensor(blend_img, device)
        
        # Compute Gradient Loss
        grad_loss = 0
        for c in range(len(pred_gradient)):
            grad_loss += mse(pred_gradient[c], gt_gradient[c])
        grad_loss /= len(pred_gradient)
        grad_loss *= grad_weight
        
        # Compute Style Loss
        target_features_style = vgg(mean_shift(target_img))
        target_gram_style = [gram_matrix(y) for y in target_features_style]
        
        blend_features_style = vgg(mean_shift(input_img))
        blend_gram_style = [gram_matrix(y) for y in blend_features_style]
        
        style_loss = 0
        for layer in range(len(blend_gram_style)):
            style_loss += mse(blend_gram_style[layer], target_gram_style[layer])
        style_loss /= len(blend_gram_style)  
        style_loss *= style_weight           

        
        # Compute Content Loss
        blend_obj = blend_img[:,:,int(x_start-source_img.shape[2]*0.5):int(x_start+source_img.shape[2]*0.5), int(y_start-source_img.shape[3]*0.5):int(y_start+source_img.shape[3]*0.5)]
        source_object_features = vgg(mean_shift(source_img*mask_img))
        blend_object_features = vgg(mean_shift(blend_obj*mask_img))
        content_loss = content_weight * mse(blend_object_features.relu2_2, source_object_features.relu2_2)
        content_loss *= content_weight
        
        # Compute TV Reg Loss
        tv_loss = torch.sum(torch.abs(blend_img[:, :, :, :-1] - blend_img[:, :, :, 1:])) + \
                   torch.sum(torch.abs(blend_img[:, :, :-1, :] - blend_img[:, :, 1:, :]))
        tv_loss *= tv_weight
        
        # Compute Total Loss and Update Image
        loss = grad_loss + style_loss + content_loss + tv_loss
        optimizer.zero_grad()
        loss.backward()

        # Write to output to a reconstruction video 
        if save_video:
            foreground = input_img*canvas_mask
            foreground = (foreground - foreground.min()) / (foreground.max() - foreground.min())
            background = target_img*(canvas_mask-1)*(-1)
            background = background / 255.0
            final_blend_img =  + foreground + background
            if run[0] < 200:
                # more frames for early optimization by repeatedly appending the frames
                for _ in range(10):
                    recon_process_video.append_data(final_blend_img[0].transpose(0,2).transpose(0,1).cpu().data.numpy())
            else:
                recon_process_video.append_data(final_blend_img[0].transpose(0,2).transpose(0,1).cpu().data.numpy())
        
        # Print Loss
        if run[0] % 1 == 0:
            print("run {}:".format(run))
            print('grad : {:4f}, style : {:4f}, content: {:4f}, tv: {:4f}'.format(\
                          grad_loss.item(), \
                          style_loss.item(), \
                          content_loss.item(), \
                          tv_loss.item()
                          ))
            print()
        
        run[0] += 1
        return loss
    
    optimizer.step(closure)

# clamp the pixels range into 0 ~ 255
input_img.data.clamp_(0, 255)

# Make the Final Blended Image
#blend_img = torch.zeros(target_img.shape).to(gpu_id)
blend_img = torch.zeros(target_img.shape).to(device)
blend_img = input_img*canvas_mask + target_img*(canvas_mask-1)*(-1) 
blend_img_np = blend_img.transpose(1,3).transpose(1,2).cpu().data.numpy()[0]

# Save image from the first pass
first_pass_img_file = os.path.join(output_dir, 'first_pass.png')
imsave(first_pass_img_file, blend_img_np.astype(np.uint8))

###################################
########### Second Pass ###########
###################################

# Default weights for loss functions in the second pass
style_weight = 1e7; content_weight = 1; tv_weight = 1e-6
#ss = 512; ts = 512
ss = ts;
num_steps = num_steps

first_pass_img = np.array(Image.open(first_pass_img_file).convert('RGB').resize((ss, ss)))
target_img = np.array(Image.open(target_file).convert('RGB').resize((ts, ts)))

#first_pass_img = torch.from_numpy(first_pass_img).unsqueeze(0).transpose(1,3).transpose(2,3).float().to(gpu_id)
#target_img = torch.from_numpy(target_img).unsqueeze(0).transpose(1,3).transpose(2,3).float().to(gpu_id)
first_pass_img = torch.from_numpy(first_pass_img).unsqueeze(0).transpose(1,3).transpose(2,3).float().to(device)
target_img = torch.from_numpy(target_img).unsqueeze(0).transpose(1,3).transpose(2,3).float().to(device)

first_pass_img = first_pass_img.contiguous()
target_img = target_img.contiguous()

# Define LBFGS optimizer
def get_input_optimizer(first_pass_img):
    optimizer = optim.LBFGS([first_pass_img.requires_grad_()])
    return optimizer

optimizer = get_input_optimizer(first_pass_img)

print('Optimizing...')
run = [0]
while run[0] <= num_steps:
    
    def closure():
        
        # Compute Loss Loss    
        target_features_style = vgg(mean_shift(target_img))
        target_gram_style = [gram_matrix(y) for y in target_features_style]
        blend_features_style = vgg(mean_shift(first_pass_img))
        blend_gram_style = [gram_matrix(y) for y in blend_features_style]
        style_loss = 0
        for layer in range(len(blend_gram_style)):
            style_loss += mse(blend_gram_style[layer], target_gram_style[layer])
        style_loss /= len(blend_gram_style)  
        style_loss *= style_weight        
        
        # Compute Content Loss
        content_features = vgg(mean_shift(first_pass_img))
        content_loss = content_weight * mse(blend_features_style.relu2_2, content_features.relu2_2)
        
        # Compute Total Loss and Update Image
        loss = style_loss + content_loss
        optimizer.zero_grad()
        loss.backward()

        # Write to output to a reconstruction video 
        if save_video:
            foreground = first_pass_img*canvas_mask
            foreground = (foreground - foreground.min()) / (foreground.max() - foreground.min())
            background = target_img*(canvas_mask-1)*(-1)
            background = background / 255.0
            final_blend_img =  + foreground + background
            recon_process_video.append_data(final_blend_img[0].transpose(0,2).transpose(0,1).cpu().data.numpy())
        
        # Print Loss
        if run[0] % 1 == 0:
            print("run {}:".format(run))
            print(' style : {:4f}, content: {:4f}'.format(\
                          style_loss.item(), \
                          content_loss.item()
                          ))
            print()
        
        run[0] += 1
        return loss
    
    optimizer.step(closure)

# clamp the pixels range into 0 ~ 255
first_pass_img.data.clamp_(0, 255)

# Make the Final Blended Image
input_img_np = first_pass_img.transpose(1,3).transpose(1,2).cpu().data.numpy()[0]

# Save image from the second pass
imsave(os.path.join(output_dir, 'second_pass.png'), input_img_np.astype(np.uint8))

# Save recon process video
if save_video:
    recon_process_video.close()

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


run [0]:
grad : 278022.000000, style : 15176670.000000, content: 73046.148438, tv: 13.910302

run [1]:
grad : 278020.531250, style : 15176670.000000, content: 73046.109375, tv: 13.910302

run [2]:
grad : 170658.015625, style : 15157843.000000, content: 58511.398438, tv: 14.031243

run [3]:
grad : 86077.148438, style : 15151732.000000, content: 51578.355469, tv: 14.004866

run [4]:
grad : 65655.554688, style : 15129552.000000, content: 43953.585938, tv: 14.029720

run [5]:
grad : 82639.726562, style : 15046895.000000, content: 41416.332031, tv: 14.091745

run [6]:
grad : 225467.171875, style : 14715626.000000, content: 87867.062500, tv: 14.240730

run [7]:
grad : 1958781.000000, style : 12265021.000000, content: 869668.125000, tv: 15.011994

run [8]:
grad : 614841.437500, style : 13593513.000000, content: 278864.593750, tv: 14.537506

run [9]:
grad : 650040.000000, style : 12819596.000000, content: 326062.843750, tv: 14.599748

run [10]:
grad : 7779883.000000, style : 71872600.000000, c

KeyboardInterrupt: 