In [None]:
!gpustat -cu

In [None]:
%env CUDA_VISIBLE_DEVICES 0

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

In [None]:
from torch.nn.modules.utils import _pair
from typing import Tuple

#@torch.jit.script
def unfold2d(x: torch.Tensor, kernel_size: Tuple[int, int], stride: Tuple[int, int] = (1, 1)):
    # type: (torch.Tensor, Tuple[int, int], Tuple[int, int]) -> torch.Tensor
    kernel_size = _pair(kernel_size)
    stride = _pair(stride)
    padding = ((kernel_size[1] - 1) // 2, kernel_size[1] // 2, (kernel_size[0] - 1) // 2, kernel_size[0] // 2)

    x = F.pad(x, pad=padding)
    return x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1])

In [None]:
unfold2d.graph

In [None]:
# @torch.jit.script
# def l2_norm_2d(x, y):
#     return (x - y).pow(2).sum(1)

# @torch.jit.script
# def l2_norm_2d_(qs, ks, rs, i, j, t):
#     # type: (torch.Tensor, torch.Tensor, torch.Tensor, int, int, int) -> int
#     rs[t] = l2_norm_2d(qs, ks[..., i, j])
#     return 0

# @torch.jit.script
# def local_argmin(qs, ks):
#     N, C, H, W, h, w = ks.size()
#     futures : List[torch.jit.Future[int]] = []
#     results = qs.new_zeros(size=(h * w, N, H, W))
#     for i in range(h):
#         for j in range(w):
#             futures.append(torch.jit.fork(l2_norm_2d_, qs, ks, results, i, j, i * w + j))
    
# #     results = []
#     for future in futures:
#         torch.jit.wait(future)
# #         results.append(torch.jit.wait(future))

# #     results = torch.stack(results, dim=0)

#     argmin = results.argmin(0)
# #     return argmin
#     return torch.stack([argmin // w, argmin % w], dim=-1)

In [None]:
# from typing import List

# # @torch.jit.script
# def pll2_norm_2d(x, y):
# #     dxy = x - y
# #     return torch.einsum('nchw,nchw->nhw', dxy, dxy)
#     return (x - y).pow(2).sum(1)

# def parallel_local_argmin_(qs, ks, rs, t1, t2):
#     # type: (torch.Tensor, torch.Tensor, torch.Tensor, int, int) -> int
#     # qs: NCHW
#     # ks: NCHWhw
#     N, C, H, W, h, w = ks.size()
# #     results = qs.new_zeros(size=(t2 - t1, N, H, W))
#     for t in range(t1, t2):
#         i, j = t // w, t % w
#         rs[t] = ll2_norm_2d(qs, ks[..., i, j])
#     return 0
# #     argmin = results.argmin(0)
# #     return torch.stack([argmin // w, argmin % w], dim=-1)

# @torch.jit.script
# def parallel_local_argmin(qs, ks, split=1):
#     # type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
#     # qs: NCHW
#     # ks: NCHWhw
#     N, C, H, W, h, w = ks.size()
#     T = h * w
#     results = qs.new_zeros(size=(T, N, H, W))
#     futures : List[torch.jit.Future[int]] = []
#     dt = (T + split - 1) // split
#     for t1 in range(0, T, dt):
#         t2 = min(T - 1, t1 + dt)
#         futures.append(torch.jit.fork(parallel_local_argmin_, qs, ks, results, t1, t2))
    
# #     results = []
#     for future in futures:
#         torch.jit.wait(future)
# #         results.append(torch.jit.wait(future))
# #     results = torch.cat(results, dim=0)
#     argmin = results.argmin(0)
#     return torch.stack([argmin // w, argmin % w], dim=-1)

In [None]:
# lsz = 31
# with torch.no_grad():
#     qs = torch.randn(size=(1, 75, 256, 256), device=DEVICE)
#     ks = torch.randn_like(qs)
#     ks = unfold2d(ks, (lsz, lsz))

# # local_argmin_jit = torch.jit.optimized_execution(local_argmin, (qs, ks))
    
# start.record()
# with torch.no_grad():
#     r = parallel_local_argmin(qs, ks)
# end.record()
# torch.cuda.synchronize()
# print(start.elapsed_time(end))

# # r = local_argmin(qs, ks)
# !nvidia-smi -i $CUDA_VISIBLE_DEVICES

In [None]:
# @torch.jit.script
def ll2_norm_2d(x, y):
#     dxy = x - y
#     return torch.einsum('nchw,nchw->nhw', dxy, dxy)
    return (x - y).pow(2).sum(1)

# @torch.jit.script
def local_argmin(qs, ks):
    # qs: NCHW
    # ks: NCHWhw
    N, C, H, W, h, w = ks.size()
#     results = qs.new_zeros(size=(h * w, N, H, W))
    results: List[torch.Tensor] = []
    for t in range(h * w):
        i = t // w
        j = t % w
#         results[t] = ll2_norm_2d(qs, ks[..., i, j]) # .pow(2).sum(1)
        results.append(ll2_norm_2d(qs, ks[..., i, j])) # .pow(2).sum(1)
#     for i in range(h):
#         for j in range(w):
#             t = i * w + j
#             results.append((qs - ks[..., i, j]).pow(2).sum(1))  # ll2_norm_2d(qs, ks[..., i, j]))
#             results[i, j] = ll2_norm_2d(qs, ks[..., i, j])
#             results[t] = (qs - ks[..., i, j]).pow(2).sum(1)
    results = torch.stack(results, dim=0)
#     results = results.view(h * w, N, H, W)
    argmin = results.argmin(0)
#     return argmin
    return torch.stack([argmin // w, argmin % w], dim=-1)

In [None]:
# @torch.jit.script
def ll2_norm_2d(x, y):
#     dxy = x - y
#     return torch.einsum('nchw,nchw->nhw', dxy, dxy)
    return (x - y).pow(2).sum(1)

# @torch.jit.script
def local_argmin(qs, ks):
    # qs: NCHW
    # ks: NCHWhw
    N, C, H, W, h, w = ks.size()
    results = qs.new_zeros(size=(h * w, N, H, W))
#     results: List[torch.Tensor] = []
    for t in range(h * w):
        i = t // w
        j = t % w
        results[t] = ll2_norm_2d(qs, ks[..., i, j]) # .pow(2).sum(1)
#         results.append(ll2_norm_2d(qs, ks[..., i, j])) # .pow(2).sum(1)
#     for i in range(h):
#         for j in range(w):
#             t = i * w + j
#             results.append((qs - ks[..., i, j]).pow(2).sum(1))  # ll2_norm_2d(qs, ks[..., i, j]))
#             results[i, j] = ll2_norm_2d(qs, ks[..., i, j])
#             results[t] = (qs - ks[..., i, j]).pow(2).sum(1)
#     results = torch.stack(results, dim=0)
#     results = results.view(h * w, N, H, W)
    argmin = results.argmin(0)
#     return argmin
    return torch.stack([argmin // w, argmin % w], dim=-1)

In [None]:
local_argmin.graph

In [None]:
lsz = 31
with torch.no_grad():
    qs = torch.randn(size=(10, 147, 256, 256), device=DEVICE)
    ks = torch.randn_like(qs)
    ks = unfold2d(ks, (lsz, lsz))

# local_argmin_jit = torch.jit.optimized_execution(local_argmin, (qs, ks))
    
start.record()
with torch.no_grad():
    r = local_argmin(qs, ks)
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))

# r = local_argmin(qs, ks)
!nvidia-smi -i $CUDA_VISIBLE_DEVICES

In [None]:
# @torch.jit.script
def gl2_norm_2d(x, y):
    xx = torch.einsum('nchw,nchw->nhw', x, x).unsqueeze(-1).unsqueeze(-1)
    yy = torch.einsum('ncij,ncij->nij', y, y).unsqueeze(1).unsqueeze(1)
    xy = torch.einsum('nchw,ncij->nhwij', x, y)
    return xx + yy - 2 * xy

# @torch.jit.script
def global_argmin(qs, ks, max_size=2**28):
    # qs: NCHW
    # ks: NCIJ
    N, C, H, W = ks.size()
    _, _, I, J = qs.size()
    dist_mat = gl2_norm_2d(qs, ks)  # torch.einsum('nchw,ncij->nhwij', qs, ks)
    dist_mat = dist_mat.view(N, H * W, I, J)
    argmin = dist_mat.argmin(1)
    return torch.stack([argmin // W, argmin % W], dim=-1)

In [None]:
# @torch.jit.script
def gl2_norm_2d(x, y):
    # print(x.shape, y.shape)
    xx = torch.einsum('nic,nic->ni', x, x).unsqueeze(2)
    yy = torch.einsum('njc,njc->nj', y, y).unsqueeze(1)
    xy = torch.einsum('nic,njc->nij', x, y)
    return xx + yy - 2 * xy

@torch.jit.script
def global_argmin(qs, ks):
    # qs: NCHW
    # ks: NCIJ
    max_size = 2**28
    N, C, H, W = ks.size()
    _, _, I, J = qs.size()
    qs = qs.permute(0, 2, 3, 1).contiguous().view(N, I * J, C)
    ks = ks.permute(0, 2, 3, 1).contiguous().view(N, H * W, C)
    nq, nk = I * J, H * W
    max_size = int((max_size + N - 1) // N)
    bq = int((max_size + nk - 1) // (nk))
    idxs = []
    # print(qs)
    for j in range(int((nq + bq - 1) // bq)):
        dist_mat = gl2_norm_2d(ks, qs[:, j*bq:(j+1)*bq])
        dist_mat = dist_mat.view(N, H * W, -1)
        idxs.append(dist_mat.argmin(1))
    idxs = torch.cat(idxs, dim=-1).view(N, I, J)
    return torch.stack([idxs // W, idxs % W], dim=-1)

In [None]:
with torch.no_grad():
    qs = torch.randn(size=(10, 147, 256, 256), device=DEVICE)
    ks = torch.randn(size=(10, 147, 256, 256), device=DEVICE)

start.record()
with torch.no_grad():
    r = global_argmin(qs, ks)  # , max_size=2**29)
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))

# r = local_argmin(qs, ks)
!nvidia-smi -i $CUDA_VISIBLE_DEVICES