In [None]:
!gpustat -cu

In [None]:
%env CUDA_VISIBLE_DEVICES 1

In [None]:
import os
# import random
import time
from math import ceil

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair
# import faiss
# import faiss.contrib.torch_utils

from image import imread, imwrite, imshow
from fold import fold2d, unfold2d
from nnlookup import nn_lookup2d, nn_lookup_soft2d, l2_dist
from utils import view_as_2d, view_2d_as
from Resizer import Resizer

# from image import to_numpy, np2pt, pt2np
# from nnlookup import nn_lookup, nn_lookup_soft, nn_lookup_soft2d
# from nnlookup import l2_dist, inner_prod_dist
# from utils import view_3d_as_6d, view_6d_as_3d

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)

In [None]:
@torch.no_grad()
def get_pyramid(image, depth, ratio, verbose=False):
    device = image.device
    max_layer = depth
    pyramid = []
    ratio = _pair(ratio)
    curr = image
    pyramid.append(curr)
    for j in range(max_layer):
        if verbose:
            print(curr.shape)
            imshow(curr)
        shape = [1, image.shape[1], ceil(image.shape[2] * (ratio[0])**(j+1)), ceil(image.shape[3] * (ratio[1])**(j+1))]
        resizer = Resizer(curr.shape, ratio, shape).to(device=device)
        curr = resizer(curr)
        pyramid.append(curr)
    return pyramid


def get_pyramid_from_file(fname, *args, device=None, **kwargs):
    image = imread(file, pt=True).to(device=device)            
    return get_pyramids_aux(image, *args, **kwargs)

In [None]:
# NN_DTYPE = torch.float32
SOFTMIN_DTYPE = torch.float32

def extract_kvs(keys, values, patch_size, use_padding):
    if isinstance(keys, torch.Tensor):
        keys = [keys]
    if isinstance(values, torch.Tensor):
        values = [values]
    assert len(keys) == len(values)
    assert all(k.shape == v.shape for k, v in zip(keys, values))
    key = torch.cat([view_as_2d(unfold2d(k, patch_size, use_padding=use_padding))[0] for k in keys], dim=0).contiguous()
    value = torch.cat([view_as_2d(unfold2d(v, patch_size, use_padding=use_padding))[0] for v in values], dim=0).contiguous()
#     return key.to(dtype=NN_DTYPE), value.to(dtype=NN_DTYPE)
    return key, value


@torch.no_grad()
def torch_patch_weight_min_dist(query, key, dist_fn, alpha, batch_size):
    kbatch = (batch_size + query.shape[0] - 1) // query.shape[0]
    min_dists = []
    for j in range((key.shape[0] + kbatch - 1) // kbatch):
        idx = nn_lookup2d(key[j*kbatch:(j+1)*kbatch], query, dist_fn=dist_fn)
        match_query = torch.index_select(query, 0, idx)
        min_dists.append((key[j*kbatch:(j+1)*kbatch] - match_query).pow(2).mean(1))
    min_dist = torch.cat(min_dists, dim=0)
    weight = 1 / (alpha + min_dist)
    return weight


@torch.no_grad()
def torch_patch_nn(query, key, value, patch_size, dist_fn, bidirectional=False, alpha=1., temperature=None, reduce='mean', use_padding=False, batch_size=2**28):
    # reshape input
    # dtype = query.dtype  # XXX
    query = unfold2d(query, patch_size, use_padding=use_padding)
    query, size, ndim = view_as_2d(query)
    # query = query.to(dtype=NN_DTYPE)
    # print(query.shape, key.shape)
    # perform nn-search
    idxs = []
    qbatch = (batch_size + key.shape[0] - 1) // key.shape[0]
    kbatch = (batch_size + query.shape[0] - 1) // query.shape[0]
    if bidirectional:
        weight = torch_patch_weight_min_dist(query, key, dist_fn, alpha, batch_size).view(1, -1)
    else:
        weight = None
    for i in range((query.shape[0] + qbatch - 1) // qbatch):
        if temperature is None:
            idxs.append(nn_lookup2d(query[i*qbatch:(i+1)*qbatch], key, weight=weight, dist_fn=dist_fn))
        else:
            idxs.append(nn_lookup_soft2d(query[i*qbatch:(i+1)*qbatch], key, weight=weight, dist_fn=dist_fn, temperature=temperature, dtype=SOFTMIN_DTYPE))
    idx = torch.cat(idxs, dim=0)
    result = torch.index_select(value, 0, idx)

    # reshape output
    result = view_2d_as(result, size, ndim)
    # result = result.to(dtype=dtype)
    output = fold2d(result, reduce=reduce, use_padding=use_padding)
    return output


class TorchPatchNN(nn.Module):
    def __init__(self, patch_size, dist_fn, bidirectional=False, alpha=1., temperature=None, reduce='mean', use_padding=False, batch_size=2**28):
        super().__init__()
        self._patch_size = _pair(patch_size)
        self._dist_fn = dist_fn
        self._bidirectional = bidirectional
        self._alpha = alpha
        self._temperature = temperature
        self._reduce = reduce
        self._use_padding = use_padding
        self._batch_size = batch_size
    
    def forward(self, query, key, value):
        return torch_patch_nn(
            query=query,
            key=key,
            value=value,
            patch_size=self._patch_size,
            dist_fn=self._dist_fn,
            bidirectional=self._bidirectional,
            alpha=self._alpha,
            temperature=self._temperature,
            reduce=self._reduce,
            use_padding=self._use_padding,
            batch_size=self._batch_size)
    
    def extract_kvs(self, keys, values):
        key, value = extract_kvs(keys, values, self._patch_size, use_padding=self._use_padding)
        return key, value
    
    def __repr__(self):
        return '{}(patch_size={}, dist_fn={}, bidirectional={}, alpha={}, reduce="{}", use_padding={}, batch_size={})'.format(
            self.__class__.__name__,
            self._patch_size,
            self._dist_fn.__name__,
            self._bidirectional,
            self._alpha,
            self._reduce,
            self._use_padding,
            self._batch_size)

In [None]:
@torch.no_grad()
def new_image_generation(pnn, src_pyramid, dst_pyramid, ratio=4/3, noise_std=0.75, noise_decay=None, num_iters=10, top_level=9, verbose=False):
    to = {'device': src_pyramid[0].device, 'dtype': src_pyramid[0].dtype}
    start = time.time()
    new_im = dst_pyramid[top_level]
    new_im = new_im + noise_std * torch.randn_like(new_im)
    if noise_decay is None:
        noise_decay = 0.
    for l in range(top_level, -1, -1):
        start = time.time()
        resizer = Resizer(src_pyramid[l + 1].shape, ratio, src_pyramid[l].shape).to(**to)
        if l == top_level:
            curr = src_pyramid[l]
            prev = src_pyramid[l]
        else:
            curr = src_pyramid[l]
            prev = resizer(src_pyramid[l + 1])
        # key, value = extract_kvs(curr, prev, pnn._patch_size, use_padding=pnn._use_padding)
        key_index, value = pnn.extract_kvs(keys=prev, values=curr) # TODO: change keys to prev! # create_index_l2(key)
        for k in range(num_iters if l != top_level else 1):
            # new_im = pnn(new_im, keys=prev, values=curr)
            start = time.time()
            new_im = pnn(new_im, key_index, value)
            # print(l, k, '%.2fms' % (1000 * (time.time() - start),))
#         print(new_im.shape)
#         imshow(new_im)
        if l > 0:
            resizer = Resizer(dst_pyramid[l].shape, ratio, dst_pyramid[l - 1].shape).to(**to)
            new_im = resizer(new_im)
            new_im = new_im + (noise_std * noise_decay ** (top_level + 1 - l)) * torch.randn_like(new_im)
    
    if verbose:
        print('Total time: %.2f[s]' % (time.time() - start,))
        imshow(new_im)
    
    return new_im

@torch.no_grad()
def structural_analogy(pnn, path_a, path_b, num_iters=10, top_level=9, verbose=False):
#     to = {'device': src_pyramid[0].device, 'dtype': src_pyramid[0].dtype}
    start = time.time()
    pyr = get_pyramid(imread(path_a).cuda(), depth=20, ratio=0.8, verbose=False)
    pyr2 = get_pyramid(imread(path_b).cuda(), depth=20, ratio=0.8, verbose=False)
    new_im = pyr2[0].clone()
    for i in range(2):
        new_pyr = get_pyramid(new_im, depth=20, ratio=0.8, verbose=False)
        new_im = new_pyr[top_level].clone()
        for l in range(top_level, -1, -1):
            resizer = Resizer(pyr[l + 1].shape, 5/4, pyr[l].shape).cuda()
            curr = pyr[l]
            prev = resizer(pyr[l + 1])
            # key, value = extract_kvs(curr, prev, pnn._patch_size, use_padding=pnn._use_padding)
            key_index, value = pnn.extract_kvs(keys=[prev,curr], values=[curr,curr])  # create_index_l2(key)
            for k in range(num_iters):
                start = time.time()
                new_im = pnn(new_im, key_index, value)
            imshow(torch.cat([new_im, pyr2[l]], dim=-1))
            if l > 0:
                resizer = Resizer(new_pyr[l].shape, 5/4, new_pyr[l - 1].shape).cuda()
                new_im = resizer(new_im)    
    if verbose:
        print('Total time: %.2f[s]' % (time.time() - start,))
        imshow(new_im)
    
    return new_im


In [None]:
# patch_size = (7, 7)
# im = gt.to(device=device)
# imshow(im)

# # pyramid = [level.to(device=device) for level in pyramid]
# for i, level in enumerate(pyramid):
#     print(i, level.shape)
#     imshow(level)

In [None]:
gt = imread('/home/nivg/data/balloons.png', pt=True).to(device=device)

In [None]:
patch_size = (7, 7)
# pnn = FaissPatchNNL2(patch_size=patch_size, reduce='weighted_mean', use_padding=False)
pnn = TorchPatchNN(patch_size=patch_size, dist_fn=l2_dist, batch_size=2**28, bidirectional=False, alpha=1, reduce='weighted_mean')
pnn = pnn.to(device=device)
print(pnn)

In [None]:
im = gt.to(device=device)
dtype = torch.float32
orig_ratio = 3 / 4
noise_std = 0.75
noise_decay = 0.0
depth = 15
pyr = get_pyramid(im, depth=depth, ratio=orig_ratio, verbose=False)
pyr = [lvl.to(dtype=dtype) for lvl in pyr]
pyr = pyr[0:]
imshow(pyr[0])
pnn._temperature = None
SOFTMIN_DTPYE = torch.float64

for ratio in [(1, 1)]:#, (1, 4/5), (1, 3/4), (1, 1/3), (1, 5/4), (1, 2)]:
    rsizer = Resizer(pyr[0].shape, ratio).to(device=device)
    dst_pyr = get_pyramid(rsizer(pyr[0]), depth=len(pyr) + 1, ratio=orig_ratio, verbose=False)
    dst_pyr = [lvl.to(dtype=dtype) for lvl in dst_pyr]
    for k in range(10):
        start = time.time()
#         out = new_image_generation_old(pyr, dst_pyr, patch_size)
        out = new_image_generation(pnn, pyr, dst_pyr, top_level=9, ratio=1/orig_ratio, noise_std=noise_std, noise_decay=noise_decay, verbose=False)
        print('ratio=%s' % (ratio,), 'k=%d' % (k,), 'time=%.2fms' % (1000 * (time.time() - start),))
        imshow(out)

In [None]:
# 4_real_a, 4_real_b - 6, 1e-2
# 4_real_b, 4_real_a - 6, 5e-2

# 10_real_a, 10_real_b - 6, 5e-3
# 10_real_b, 10_real_a - 7, 5e-3

# factory_real_a, factory_real_b - 8, 5e-2
# factory_real_b, factory_real_a - 8, 5e-3

# flowers_real_a, flowers_real_b - 7, 5e-3
# flowers_real_b, flowers_real_a - 7, 5e-3

# flowers2_real_b, flowers2_real_a - 11, 5e-3
# flowers2_real_a, flowers2_real_b - 10, 5e-3

# knit_real_a, knit_real_b - 3/4, 5e-4
# knit_real_b, knit_real_a - 6, 5e-3

# perspective_real_a, perspective_real_b - 4, 5e-3
# perspective_real_b, perspective_real_a - 8, 5e-2

# snow_real_a (1), snow_real_b (2) - 5, 5e-3
# snow_real_b (2), snow_real_a (1) - 4, 5e-2

# marble_real_a, marble_real_b - 8, 5e-3
# marble_real_b, marble_real_a - 9, 5e-4

# 0_real_a, 0_real_b - 8, 5e-4
# 0_real_b, 0_real_a - 8, 5e-4

# sky_real_a, sky_real_b - 6, 5e-4
# sky_real_b, sky_real_a - 6, 5e-4

# oranges_real_a, oranges_real_b - 6, 5e-4
# oranges_real_b, oranges_real_a - 7, 5e-4

# tennis_real_a (1), tennis_real_b - 6, 5e-5
# tennis_real_b, tennis_real_a (1) - 9, 5e-3

# dirtroads_real_a, dirtroads_real_b - 7, 5e-3
# dirtroads_real_b, dirtroads_real_a - 8, 5e-4

# snow2ice_real_a, snow2ice_real_b - 10, 5e-3
# snow2ice_real_b, snow2ice_real_a - 11, 5e-2

# 22_real_a, 22_real_b - 8, 5e-4 / 10,5e-4
# 22_real_b, 22_real_a - 7, 5e-4

# 206, 207 - 10, 5e-3
# 207, 206 - 11, 5e-1
# 204, 205 - 4, 5e-3
# 205, 204 - 4, 5e-3
# 202, 203 - 2, 5e-4
# 203, 202 - 2, 5e-4
# 208_resized, 209 - 10, 5e-3
# 209, 208_resized - 6, 5e-3
# 10, 11 - 6, 5e-3
# 108, 109 - 4, 5e-3
# 109, 108 - 6, 5e-3
# trees_paint, trees - 3, 5e-2
# cows_paint, cows - 3, 5e-4
# 309_resized, 308 - 6, 5e-3


pnn = TorchPatchNN(patch_size=7, dist_fn=l2_dist, batch_size=2**28, bidirectional=True, alpha=5e-3, reduce='mean')
pnn = pnn.to(device=device)

path_a = '/home/nivg/data/DropGAN/Analogies/for_assaf/marble_real_b.jpg'
path_b = '/home/nivg/data/DropGAN/Analogies/for_assaf/marble_real_a.jpg'

structural_analogy(pnn, path_a, path_b, num_iters=10, top_level=8, verbose=False)

In [None]:
# @torch.no_grad()
# def torch_patch_weight_min_dist(query, key, dist_fn, batch_size, alpha=1):
#     kbatch = (batch_size + query.shape[0] - 1) // query.shape[0]
#     min_dists = []
#     for j in range((key.shape[0] + kbatch - 1) // kbatch):
#         idx = nn_lookup2d(key[j*kbatch:(j+1)*kbatch], query, dist_fn=dist_fn)
#         match_query = torch.index_select(query, 0, idx)
#         min_dists.append((key[j*kbatch:(j+1)*kbatch] - match_query).pow(2).sum(1))
#     min_dist = torch.cat(min_dists, dim=0)
#     weight = 1 / (alpha + min_dist)
#     return weight


# @torch.no_grad()
# def torch_patch_nn(query, key, value, patch_size, dist_fn, bidirectional=False, reduce='mean', use_padding=False, batch_size=2**28):
#     # reshape input
#     query = unfold2d(query, patch_size, use_padding=use_padding)
#     query, size, ndim = view_as_2d(query)
    
#     # perform nn-search
#     idxs = []
#     qbatch = (batch_size + key.shape[0] - 1) // key.shape[0]
#     kbatch = (batch_size + query.shape[0] - 1) // query.shape[0]
#     if bidirectional:
#         weight = torch_patch_weight_min_dist(query, key, dist_fn, batch_size).view(1, -1)
#     else:
#         weight = None
#     for i in range((query.shape[0] + qbatch - 1) // qbatch):
#         idxs.append(nn_lookup2d(query[i*qbatch:(i+1)*qbatch], key, weight=weight, dist_fn=dist_fn))
#     idx = torch.cat(idxs, dim=0)
#     result = torch.index_select(value, 0, idx)
    
#     # reshape output
#     result = view_2d_as(result, size, ndim)
#     output = fold2d(result, reduce=reduce, use_padding=use_padding)
#     return output


In [None]:
# ratio=(1, 1) k=0 time=8596.04ms
# ratio=(1, 1) k=1 time=8577.48ms
# ratio=(1, 0.8) k=0 time=7105.97ms
# ratio=(1, 0.8) k=1 time=7127.98ms
# ratio=(1, 0.75) k=0 time=6812.93ms
# ratio=(1, 0.75) k=1 time=6795.47ms
# ratio=(1, 0.3333333333333333) k=0 time=3733.47ms
# ratio=(1, 0.3333333333333333) k=1 time=3683.60ms
# ratio=(1, 1.25) k=0 time=10514.17ms
# ratio=(1, 1.25) k=1 time=10468.91ms
# ratio=(1, 2) k=0 time=16047.43ms
# ratio=(1, 2) k=1 time=16064.11ms


In [None]:
# OLD FP16
# ratio=(1, 1) k=0 time=1168.27ms
# ratio=(1, 1) k=1 time=1215.02ms
# ratio=(1, 0.8) k=0 time=952.06ms
# ratio=(1, 0.8) k=1 time=1015.07ms
# ratio=(1, 0.75) k=0 time=925.23ms
# ratio=(1, 0.75) k=1 time=983.23ms
# ratio=(1, 0.3333333333333333) k=0 time=478.62ms
# ratio=(1, 0.3333333333333333) k=1 time=511.96ms
# ratio=(1, 1.25) k=0 time=1365.95ms
# ratio=(1, 1.25) k=1 time=1461.49ms
# ratio=(1, 2) k=0 time=2097.24ms
# ratio=(1, 2) k=1 time=2239.94ms


In [None]:
# def _get_faiss_gpu_config(device, dtype):
#     res = faiss.StandardGpuResources()
#     cfg = faiss.GpuIndexFlatConfig()
#     cfg.useFloat16 = (dtype == torch.float16)
#     cfg.device = device.index
#     return res, cfg

# def faiss_create_index_l2(key):
#     assert key.dim() == 2
#     d = key.shape[1]
#     if key.device.type == 'cpu':
#         index = faiss.IndexFlatL2(d)
#     elif key.device.type == 'cuda':
#         res, cfg = _get_faiss_gpu_config(key.device, key.dtype)
#         index = faiss.GpuIndexFlatL2(res, d, cfg)
#     else:
#         raise ValueError('unsupported device: {}'.format(key.device))        
#     index.add(key)
#     return index

# # @torch.no_grad()
# # def patch_nn_l2(query, keys, values, patch_size, reduce='mean', use_padding=False):
# #     # reshape input
# #     query = unfold2d(query, patch_size, use_padding=use_padding)
# #     query, size, ndim = view_as_2d(query)
# #     key, value = extract_kvs(keys, values, patch_size, use_padding=use_padding)
    
# #     # perform nn-search
# #     key_index = create_index_l2(key)
# #     idx = key_index.search(query.contiguous(), 1)[1].squeeze(-1)
# #     result = torch.index_select(value, 0, idx)
    
# #     # reshape output
# #     result = view_2d_as(result, size, ndim)
# #     output = fold2d(result, reduce=reduce, use_padding=use_padding)
# #     return output

# @torch.no_grad()
# def faiss_patch_nn_l2(query, key_index, value, patch_size, reduce='mean', use_padding=False):
#     # reshape input
#     query = unfold2d(query, patch_size, use_padding=use_padding)
#     query, size, ndim = view_as_2d(query)
#     # key, value = extract_kvs(keys, values, patch_size, use_padding=use_padding)
    
#     # perform nn-search
#     # key_index = create_index_l2(key)
#     idx = key_index.search(query.contiguous(), 1)[1].squeeze(-1)
#     result = torch.index_select(value, 0, idx)
    
#     # reshape output
#     result = view_2d_as(result, size, ndim)
#     output = fold2d(result, reduce=reduce, use_padding=use_padding)
#     return output

# class FaissPatchNNL2(nn.Module):
#     def __init__(self, patch_size, reduce='mean', use_padding=False):
#         super().__init__()
#         self._patch_size = _pair(patch_size)
#         self._reduce = reduce
#         self._use_padding = use_padding
    
#     def forward(self, query, key_index, value):
#         return faiss_patch_nn_l2(
#             query,
#             key_index,
#             value,
#             patch_size=self._patch_size,
#             reduce=self._reduce,
#             use_padding=self._use_padding)
    
#     def extract_kvs(self, keys, values):
#         key, value = extract_kvs(keys, values, self._patch_size, use_padding=self._use_padding)
#         key_index = faiss_create_index_l2(key)
#         return key_index, value
    
#     def __repr__(self):
#         return '{0}(patch_size={1}, reduce="{2}", use_padding={3})'.format(
#             self.__class__.__name__, self._patch_size, self._reduce, self._use_padding)

In [None]:
# @torch.no_grad()
# def torch_patch_nn_l2(query, key, value, patch_size, batch_size=16777216, reduce='mean', use_padding=False):
#     # reshape input
#     query = unfold2d(query, patch_size, use_padding=use_padding)
#     query, size, ndim = view_as_2d(query)
    
#     # perform nn-search
#     idxs = []
#     qbatch = (batch_size + key.shape[0] - 1) // key.shape[0]
#     kbatch = (batch_size + query.shape[0] - 1) // query.shape[0]
#     for i in range((query.shape[0] + qbatch - 1) // qbatch):
#         idxs.append(nn_lookup2d(query[i*qbatch:(i+1)*qbatch], key))
#     idx = torch.cat(idxs, dim=0)
#     result = torch.index_select(value, 0, idx)
    
#     # reshape output
#     result = view_2d_as(result, size, ndim)
#     output = fold2d(result, reduce=reduce, use_padding=use_padding)
#     return output


# class TorchPatchNNL2(nn.Module):
#     def __init__(self, patch_size, batch_size=16777216, reduce='mean', use_padding=False):
#         super().__init__()
#         self._patch_size = _pair(patch_size)
#         self._reduce = reduce
#         self._use_padding = use_padding
#         self._batch_size = batch_size
    
#     def forward(self, query, key, value):
#         return torch_patch_nn_l2(
#             query=query,
#             key=key,
#             value=value,
#             patch_size=self._patch_size,
#             batch_size=self._batch_size,
#             reduce=self._reduce,
#             use_padding=self._use_padding)
    
#     def extract_kvs(self, keys, values):
#         key, value = extract_kvs(keys, values, self._patch_size, use_padding=self._use_padding)
#         return key, value
    
#     def __repr__(self):
#         return '{0}(patch_size={1}, reduce="{2}", use_padding={3}, batch_size={4})'.format(
#             self.__class__.__name__, self._patch_size, self._reduce, self._use_padding, self._batch_size)

In [None]:
# NEW FP16
# ratio=(1, 1) k=0 time=1677.75ms
# ratio=(1, 1) k=1 time=1769.13ms
# ratio=(1, 0.8) k=0 time=1401.01ms
# ratio=(1, 0.8) k=1 time=1493.48ms
# ratio=(1, 0.75) k=0 time=1350.04ms
# ratio=(1, 0.75) k=1 time=1436.66ms
# ratio=(1, 0.3333333333333333) k=0 time=637.96ms
# ratio=(1, 0.3333333333333333) k=1 time=688.41ms
# ratio=(1, 1.25) k=0 time=1993.33ms
# ratio=(1, 1.25) k=1 time=2133.84ms
# ratio=(1, 2) k=0 time=3086.48ms
# ratio=(1, 2) k=1 time=3301.28ms


In [None]:
# pyr = get_pyramid('/home/nivg/data/mountains.jpg', None, 0, 15, 0.75, print_=True) 
# pyr=pyr[3:]
# for ratio in ([[1,1],[1,4/5],[1,3/4],[1,1/3],[1,5/4],[1,2]]):
#     rsizer = Resizer.Resizer(pyr[0].shape, ratio).cuda()
#     dst_pyr = get_pyramids_aux(rsizer(pyr[0]), num_layers=15, ratio=3/4, print_=False)
#     for k in range(2):
#         im = new_image_generation(pyr, dst_pyr)
# #         imwrite('/home/nivg/data/DropGAN/image_generation/' + IM_NAME + '_' + str(ratio[0]) + '_' + str(ratio[1]) +  '_im' + str(k) + '.png', im.squeeze(0).detach().cpu())

# OLD

In [None]:
# def patch2im(input, w, h, patch_size=7):
#     patch_size = _pair(patch_size)
#     out = input.transpose(1,2)
#     normalize = torch.ones_like(out)
#     normalize = normalize.transpose(1,2)
#     fold = F.fold(
#         input, 
#         output_size=(w, h), 
#         kernel_size=patch_size
#     )

#     norm = F.fold(
#         normalize,
#         output_size=(w, h),
#         kernel_size=patch_size
#     )

#     return fold / norm


# def _calc_dist_l2(X, Y):
#     Y = Y.transpose(0, 1)
#     X2 = X.pow(2).sum(1, keepdim=True)
#     Y2 = Y.pow(2).sum(0, keepdim=True)
#     XY = X @ Y
#     return X2 - (2 * XY) + Y2

# DIV = 100
# import math

# def build_image(input_img, index_imgs, ref_imgs, patch_size=7):
#     patch_size = _pair(patch_size)
#     unfold = nn.Unfold(kernel_size=patch_size)
#     in_patches = unfold(input_img)
#     for i, (ind_img, ref_img) in enumerate(zip(index_imgs, ref_imgs)):
#         if i == 0:
#             index_patches = unfold(ind_img)
#             ref_patches = unfold(ref_img)
#         else:
#             index_patches = torch.cat([index_patches, unfold(ind_img)],dim=-1)
#             ref_patches = torch.cat([ref_patches, unfold(ref_img)], dim=-1)
            
#     if in_patches.shape[-1]*ref_patches.shape[-1] > 14000**2:
#         for j in range(DIV):
#             start_patch = j*(math.ceil(in_patches.shape[-1]/DIV))
#             end_patch = min((j+1)*(math.ceil(in_patches.shape[-1]/DIV)), in_patches.shape[-1])
#             dist_mat = _calc_dist_l2(in_patches[:,:,start_patch:end_patch].permute(0,2,1).squeeze(0), index_patches.squeeze(0).permute(1,0))
# #                 print(start_patch, end_patch, dist_mat.shape)
#             if j == 0:
#                 ind = dist_mat.argmin(1)
#             elif start_patch < end_patch:
#                 ind = torch.cat([ind, dist_mat.argmin(1)])

#     else:
#         dist_mat = _calc_dist_l2(in_patches.permute(0,2,1).squeeze(0), index_patches.squeeze(0).permute(1,0))
#         ind = dist_mat.argmin(1)
#     out_patches = F.embedding(ind, ref_patches.squeeze(0).permute(1,0))
#     return patch2im(out_patches.unsqueeze(0).permute(0,2,1), input_img.shape[-2], input_img.shape[-1], patch_size)

In [None]:
# @torch.no_grad()
# def new_image_generation_old(src_pyramid, dst_pyramid, patch_size, ratio=4/3, noise_std=0.75, num_iters=10, top_level=9, verbose=False):
#     to = {'device': src_pyramid[0].device, 'dtype': src_pyramid[0].dtype}
#     patch_size = _pair(patch_size)
#     start = time.time()
#     new_im = dst_pyramid[top_level] + torch.randn_like(dst_pyramid[top_level]) * noise_std
#     for l in range(top_level, -1, -1):
#         start = time.time()
#         resizer = Resizer(src_pyramid[l + 1].shape, ratio, src_pyramid[l].shape).to(**to)
#         if l == top_level:
#             curr = [src_pyramid[l]]
#             prev = [src_pyramid[l]]
#         else:
#             curr = [src_pyramid[l]]
#             prev = [resizer(src_pyramid[l + 1])]
#         for k in range(num_iters if l != top_level else 1):
#             # new_im = pnn(new_im, keys=prev, values=curr)
#             start = time.time()
#             new_im = build_image(new_im, prev, curr, patch_size=patch_size)
#             # print(l, k, '%.2fms' % (1000 * (time.time() - start),))
#         if l > 0:
#             resizer = Resizer(dst_pyramid[l].shape, ratio, dst_pyramid[l - 1].shape).to(**to)
#             new_im = resizer(new_im)
    
#     if verbose:
#         print('Total time: %.2f[s]' % (time.time() - start,))
#         imshow(new_im)
    
#     return new_im

In [None]:
# patch_size = (7, 7)
# im = imread('/home/nivg/data/mountains.jpg', pt=True).to(device=device)
# # im = torch.cat([im, im], dim=-1)
# # im = torch.cat([im, im], dim=-2)

# pnn = PatchNNL2(patch_size=patch_size, reduce='weighted_mean', use_padding=False)  # , faiss_device=faiss_gpu_device)
# print(pnn)

# start = torch.cuda.Event(enable_timing=True)
# end = torch.cuda.Event(enable_timing=True)

# start.record()
# with torch.no_grad():
#     out = pnn(query=im, keys=[im], values=[im])
# end.record()
# torch.cuda.synchronize()
# print('time = %.2fms' % start.elapsed_time(end))
# # imshow(out)

In [None]:
# im = gt.to(device=device)
# pyr = get_pyramid(im, depth=15, ratio=3/4, verbose=False)
# pyr = pyr[3:]
# for ratio in [(1, 1), (1, 4/5), (1, 3/4), (1, 1/3), (1, 5/4), (1, 2)]:
#     rsizer = Resizer(pyr[0].shape, ratio).to(device=device)
#     dst_pyr = get_pyramid(rsizer(pyr[0]), depth=15, ratio=3/4, verbose=False)
#     for k in range(2):
#         start = time.time()
#         out = new_image_generation_old(pyr, dst_pyr)
#         print('ratio=%s' % (ratio,), 'k=%d' % (k,), 'time=%.2fms' % (1000 * (time.time() - start),))
#         imshow(out)

In [None]:
# class PatchNNL2(nn.Module):
#     def __init__(self, patch_size, reduce='mean', use_padding=False, faiss_device=None):
#         super().__init__()
#         self._patch_size = _pair(patch_size)
#         self._reduce = reduce
#         self._use_padding = use_padding
#         self._faiss_device = faiss_device
    
#     @torch.no_grad()
#     def forward(self, query, key, value):
#         query = unfold2d(query, self._patch_size, use_padding=self._use_padding)
#         query, size, ndim = view_as_2d(query)
#         key, value = self.extract_kvs(key, value)

#         # perform nn-search
#         key_index = self.create_index(key)
#         idx = key_index.search(query.contiguous(), 1)[1].squeeze(-1)
#         assert idx.dim() == 1
#         result = torch.index_select(value, 0, idx)

#         # reshape output
#         output = view_2d_as(result, size, ndim)
#         output = fold2d(output, reduce=self._reduce, use_padding=self._use_padding)
#         return output
    
#     @torch.no_grad()
#     def extract_kvs(self, key_images, value_images):
#         if isinstance(key_images, torch.Tensor):
#             key_images = [key_images]
#         if isinstance(value_images, torch.Tensor):
#             value_images = [value_images]
#         assert len(key_images) == len(value_images)
#         assert all(k.shape == v.shape for k, v in zip(key_images, value_images))
#         key = torch.cat([view_as_2d(unfold2d(k, self._patch_size, use_padding=self._use_padding))[0] for k in key_images], dim=0)
#         value = torch.cat([view_as_2d(unfold2d(v, self._patch_size, use_padding=self._use_padding))[0] for v in value_images], dim=0)
#         return key, value
    
#     @torch.no_grad()
#     def create_index(self, key):
#         assert key.dim() == 2
#         d = key.shape[1]
#         if self._faiss_device is not None:
#             index = faiss.GpuIndexFlatL2(self._faiss_device, d)
#         else:
#             index = faiss.IndexFlatL2(d)
#         index.add(key)
#         return index

In [None]:
# def create_key_index(key, faiss_gpu_device):
#     assert key.dim() == 2
#     d = key.shape[1]
#     if faiss_gpu_device is not None:
#         key_index = faiss.GpuIndexFlatL2(faiss_gpu_device, d)
#     else:
#         key_index = faiss.IndexFlatL2(d)
#     key_index.add(key)
#     return key_index

# def extract_kvs(key_images, value_images, patch_size, use_padding=True):
#     if isinstance(key_images, torch.Tensor):
#         key_images = [key_images]
#     if isinstance(value_images, torch.Tensor):
#         value_images = [value_images]
#     assert len(key_images) == len(value_images)
#     assert all(k.shape == v.shape for k, v in zip(key_images, value_images))
#     key = torch.cat([view_as_2d(unfold2d(k, patch_size, use_padding=use_padding))[0] for k in key_images], dim=0)
#     value = torch.cat([view_as_2d(unfold2d(v, patch_size, use_padding=use_padding))[0] for v in value_images], dim=0)
#     return key, value

# @torch.no_grad()
# def pnn_fn(query, key, value, patch_size, reduce='mean', use_padding=True):
#     # reshape input
#     query = unfold2d(query, patch_size, use_padding=use_padding)
#     query, size, ndim = view_as_2d(query)
#     key, value = extract_kvs(key, value, patch_size, use_padding=use_padding)
    
#     # perform nn-search
#     key_index = create_key_index(key, faiss_gpu_device)
#     idx = key_index.search(query.contiguous(), 1)[1].squeeze(-1)
#     assert idx.dim() == 1
#     result = torch.index_select(value, 0, idx)
    
#     # reshape output
#     result = view_2d_as(result, size, ndim)
#     output = fold2d(result, reduce=reduce, use_padding=use_padding)
#     return output

# @torch.no_grad()
# def preindexed_pnn_fn(query, key_index, value, patch_size, reduce='mean', use_padding=True, device=device, faiss_gpu_device=faiss_gpu_device):
#     # reshape input
#     query = unfold2d(query, patch_size, use_padding=use_padding)
#     query, size, ndim = view_as_2d(query)
    
#     # perform nn-search
#     idx = key_index.search(query.contiguous(), 1)[1].squeeze(-1)
#     assert idx.dim() == 1
#     result = torch.index_select(value, 0, idx)

#     # reshape output
#     result = view_2d_as(result, size, ndim)
#     output = fold2d(result, reduce=reduce, use_padding=use_padding)
#     return output