In [None]:
%env CUDA_VISIBLE_DEVICES 0
import torch
from torch import nn

In [None]:
def np2pt(arr):
    return torch.tensor(arr).permute(2, 0, 1).unsqueeze(0).contiguous()

def pt2np(tensor):
    return to_numpy(tensor.squeeze(0).permute(1, 2, 0))

def imread2pt(path, d=None, **kwargs):
    image = imread(path, **kwargs)
    image = np2pt(image)
    # image = torch.tensor(imread(path, **kwargs)).permute(2, 0, 1).unsqueeze(0)
    if d is not None:
        image = shave_image(image, d)
    return image.contiguous()


In [None]:
import resize_right as Resizer

def get_pyramids_aux(img_torch, num_layers,ratio, print_):
    max_layer = num_layers
    pyr = []
    ratio = ratio
    if isinstance(ratio, float) or isinstance(ratio, int):
        ratio = (ratio, ratio)
    curr_img = img_torch
    pyr.append(curr_img)
    for j in range(max_layer):
        if print_:
            print(curr_img.shape)
            imshow(curr_img)
        rsizer = Resizer.Resizer(curr_img.shape, ratio, [1,img_torch.shape[1], np.ceil(img_torch.shape[2]*(ratio[0])**(j+1)), np.ceil(img_torch.shape[3]*(ratio[1])**(j+1))]).cuda()
        curr_img = rsizer(curr_img)
        pyr.append(curr_img)
        
    return pyr

def get_pyramid(file, degradation, sigma, num_layers, ratio, print_=True):

    gt = np.expand_dims(imread(file),0)
    gt = torch.tensor(gt).contiguous().permute(0,3,1,2).detach().cuda()
            
    return get_pyramids_aux(gt, num_layers,ratio, print_)

In [None]:
DIV = 100

def patch2im(input, w, h, patch_size=7):
    out = input.transpose(1,2)
    normalize = torch.ones_like(out)
    normalize = normalize.transpose(1,2)
    fold = torch.nn.functional.fold(
        input, 
        output_size=(w, h), 
        kernel_size=(patch_size, patch_size)
    )

    norm = torch.nn.functional.fold(
        normalize,
        output_size=(w, h),
        kernel_size=(patch_size, patch_size),
    )

    return fold/norm


def _calc_dist_l2(X, Y):
    Y = Y.transpose(0, 1)
    X2 = X.pow(2).sum(1, keepdim=True)
    Y2 = Y.pow(2).sum(0, keepdim=True)
    XY = X @ Y
    return X2 - (2 * XY) + Y2


def build_image(input_img, index_imgs, ref_imgs, patch_size=7):
    unfold = torch.nn.Unfold(kernel_size=(patch_size,patch_size))
    in_patches = unfold(input_img)
    for i, (ind_img, ref_img) in enumerate(zip(index_imgs, ref_imgs)):
        if i == 0:
            index_patches = unfold(ind_img)
            ref_patches = unfold(ref_img)
        else:
            index_patches = torch.cat([index_patches, unfold(ind_img)],dim=-1)
            ref_patches = torch.cat([ref_patches, unfold(ref_img)], dim=-1)
            
    if in_patches.shape[-1]*ref_patches.shape[-1] > 14000**2:
        for j in range(DIV):
            start_patch = j*(math.ceil(in_patches.shape[-1]/DIV))
            end_patch = min((j+1)*(math.ceil(in_patches.shape[-1]/DIV)), in_patches.shape[-1])
            dist_mat = _calc_dist_l2(in_patches[:,:,start_patch:end_patch].permute(0,2,1).squeeze(0), index_patches.squeeze(0).permute(1,0))
#                 print(start_patch, end_patch, dist_mat.shape)
            if j == 0:
                ind = dist_mat.argmin(1)
            else:
                ind = torch.cat([ind, dist_mat.argmin(1)])

    else:
        dist_mat = _calc_dist_l2(in_patches.permute(0,2,1).squeeze(0), index_patches.squeeze(0).permute(1,0))
        ind = dist_mat.argmin(1)
    out_patches = F.embedding(ind, ref_patches.squeeze(0).permute(1,0))
    return patch2im(out_patches.unsqueeze(0).permute(0,2,1), input_img.shape[-2], input_img.shape[-1], patch_size)

In [None]:
# Image Generation

def new_image_generation(src_pyr, dst_pyr, patch_size=7, top_level=9):
    start = time.time()
    new_im = dst_pyr[top_level] + torch.randn_like(dst_pyr[top_level])*0.75#new_pyr[9].clone()#
    for l in range(top_level,-1,-1): 
        num_iters = 1 if l == top_level else 10
        for k in range(num_iters):
            rsizer = Resizer.Resizer(src_pyr[l+1].shape, 4/3, src_pyr[l].shape).cuda()
            if l == top_level:
                new_im = build_image(new_im, [src_pyr[l]] ,[src_pyr[l]], patch_size=7)
            else:
                new_im = build_image(new_im, [rsizer(src_pyr[l+1])] ,[src_pyr[l]], patch_size=7)
        if l > 0:
            rsizer = Resizer.Resizer(dst_pyr[l].shape, 4/3, dst_pyr[l-1].shape).cuda()
            new_im = rsizer(new_im)
    imshow(new_im)
    print('Total time: %.2f[s]' % (time.time() - start,))
    return new_im


def harmonization(src_pyr, dst_pyr, patch_size=7, top_level=1):
    start = time.time()
    new_im = dst_pyr[top_level] #new_pyr[9].clone()#
    for l in range(top_level,-1,-1): 
        num_iters = 1
        for k in range(num_iters):
            rsizer = Resizer.Resizer(src_pyr[l+1].shape, 4/3, src_pyr[l].shape).cuda()
            if l == top_level:
                new_im = build_image(new_im, [src_pyr[l]] ,[src_pyr[l]], patch_size=7)
            else:
                new_im = build_image(new_im, [rsizer(src_pyr[l+1])] ,[src_pyr[l]], patch_size=7)
        if l > 0:
            rsizer = Resizer.Resizer(dst_pyr[l].shape, 4/3, dst_pyr[l-1].shape).cuda()
            new_im = rsizer(new_im)
    imshow(new_im)
    print('Total time: %.2f[s]' % (time.time() - start,))
    return new_im

def edit(src_pyr, dst_pyr, mask_size, patch_size=7, top_level_max=9):
    start = time.time()
    top_level = int(np.floor(np.log(patch_size/mask_size)/np.log(3/4)))
    top_level = min(top_level_max, top_level)
    new_im = dst_pyr[top_level] #new_pyr[9].clone()#
    for l in range(top_level,-1,-1): 
        num_iters = 1 if l == top_level else 10
        for k in range(num_iters):
            rsizer = Resizer.Resizer(src_pyr[l+1].shape, 4/3, src_pyr[l].shape).cuda()
            if l == top_level:
                new_im = build_image(new_im, [src_pyr[l]] ,[src_pyr[l]], patch_size=7)
            else:
                new_im = build_image(new_im, [rsizer(src_pyr[l+1])] ,[src_pyr[l]], patch_size=7)
        if l > 0:
            rsizer = Resizer.Resizer(dst_pyr[l].shape, 4/3, dst_pyr[l-1].shape).cuda()
            new_im = rsizer(new_im)
    imshow(new_im)
    print('Total time: %.2f[s]' % (time.time() - start,))
    return new_im


In [None]:
# Mounatins3 Comparison
for ratio in ([[1,1],[1,1/3],[1,5/4],[4/5,1],[1,2]]):
    rsizer = Resizer.Resizer(pyr[0].shape, ratio).cuda()
    dst_pyr = get_pyramids_aux(rsizer(pyr[0]), num_layers=15, ratio=3/4, print_=False)
    for k in range(2):
        im = new_image_generation(pyr, dst_pyr)
#         imwrite('/home/nivg/data/DropGAN/image_generation/' + IM_NAME + '_' + str(ratio[0]) + '_' + str(ratio[1]) +  '_im' + str(k) + '.png', im.squeeze(0).detach().cpu())

In [None]:
# Balloons Comparison
for ratio in ([[1,1],[1,1/3],[1,5/4],[5/4,1],[1,2]]):
    rsizer = Resizer.Resizer(pyr[0].shape, ratio).cuda()
    dst_pyr = get_pyramids_aux(rsizer(pyr[0]), num_layers=15, ratio=3/4, print_=False)
    for k in range(2):
        im = new_image_generation(pyr, dst_pyr)
#         imwrite('/home/nivg/data/DropGAN/image_generation/' + IM_NAME + '_' + str(ratio[0]) + '_' + str(ratio[1]) +  '_im' + str(k) + '.png', im.squeeze(0).detach().cpu())

In [None]:
IM_PATH = '/home/nivg/data/seascape.png'
HARM_IM_PATH = '/home/nivg/data/seascape_naive.jpg'
IM_NAME = 'seascape'
pyr = get_pyramid(IM_PATH, None, 0, 15, 0.75, print_=True) # Summary
pyr_harm = get_pyramid(HARM_IM_PATH, None, 0, 15, 0.75, print_=True) # Summary

In [None]:
im = harmonization(pyr, pyr_harm, top_level=0)
# imwrite('/home/nivg/data/DropGAN/image_harmonization/' + IM_NAME + '_harmonized.png', im.squeeze(0).detach().cpu())

In [None]:
im = edit(pyr, pyr_edit, mask_size=100, patch_size=7, top_level_max=9)
# imwrite('/home/nivg/data/DropGAN/image_editing/' + IM_NAME + '_edited.png', im.squeeze(0).detach().cpu())