In [2]:
!gpustat -cu

[1m[37mn103.mcl11.weizmann.ac.il[m  Mon Jan  4 19:38:16 2021  [1m[30m418.87.00[m
[36m[0][m [34mGeForce RTX 2080 Ti[m |[31m 25'C[m, [32m  0 %[m | [36m[1m[33m    0[m / [33m10989[m MB |
[36m[1][m [34mGeForce RTX 2080 Ti[m |[31m 24'C[m, [32m  0 %[m | [36m[1m[33m    0[m / [33m10989[m MB |
[36m[2][m [34mGeForce RTX 2080 Ti[m |[31m 26'C[m, [32m  0 %[m | [36m[1m[33m    0[m / [33m10989[m MB |
[36m[3][m [34mGeForce RTX 2080 Ti[m |[31m 25'C[m, [32m  0 %[m | [36m[1m[33m    0[m / [33m10989[m MB |


In [3]:
%env CUDA_VISIBLE_DEVICES 0

env: CUDA_VISIBLE_DEVICES=0


In [4]:
import torch
import torch.nn.functional as F
from torch.nn.modules.utils import _pair

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

In [43]:
def unfold2d(input, kernel_size):
    if input.dim() != 4:
        raise ValueError('expected 4D tensor as input.')
    n, c, h, w = input.size()
    kh, kw = kernel_size = _pair(kernel_size)
    ph, pw = padding = (kh - 1, kw - 1)
    oh, ow = h + 2 * ph - kh + 1, w + 2 * pw - kw + 1
    output = F.unfold(input, kernel_size, padding=padding)
    output = output.view(n, c, kh, kw, oh, ow)
    return output


def fold2d(input, reduce='sum', std=1.7):
    if input.dim() != 6:
        raise ValueError('expected 6D tensor as input.')
    n, c, kh, kw, h, w = input.shape
    if reduce == 'sum':
        output = _fold2d_sum(input)
    elif reduce == 'mean':
        weights = _get_weights_fold2d_mean(input, kh, kw)
        output = _fold2d_sum(input)
        norm = weights.sum()
        output = output / norm
    elif reduce == 'weighted_mean':
        weights = _get_weights_fold2d_weighted_mean(input, kh, kw, std)
        output = _fold2d_sum(input * weights)
        norm = weights.sum()
        output = output / norm
    elif reduce == 'median':
        output = _fold2d_median(input)
    else:
        raise ValueError(f'unknown reduction: {reduce}')
    return output


def _fold2d_sum(input):
    if input.dim() != 6:
        raise ValueError('expected 6D tensor as input.')
    n, c, kh, kw, h, w = input.shape
    ph, pw = padding = (kh - 1, kw - 1)
    oh, ow = output_size = (h + kh - 1 - 2 * ph, w + kw - 1 - 2 * pw)
    kernel_size = (kh, kw)
    input = input.reshape(n, c * kh * kw, h * w)
    output = F.fold(input, output_size, kernel_size, padding=padding)
    return output


def _fold2d_median(input):
    if input.dim() != 6:
        raise ValueError('expected 6D tensor as input.')
    n, c, kh, kw, h, w = input.shape
    ph, pw = (kh - 1, kw - 1)
    oh, ow = (h + kh - 1 - 2 * ph, w + kw - 1 - 2 * pw)
    output = input.new_zeros(size=(kh * kw, n, c, oh, ow))
    for i in range(kh):
        for j in range(kw):
            output[i * kw + j] = input[:, :, i, j, kh - 1 - i:h - i, kw - 1 - j:w - j]  # noqa
    output = output.median(dim=0)[0]
    return output


def _get_weights_fold2d_mean(input, kh, kw):
    weights = input.new_ones(size=(kh, kw))
    return weights.view(1, 1, kh, kw, 1, 1)


def _get_weights_fold2d_weighted_mean(input, kh, kw, std):
    to = {'device': input.device, 'dtype': input.dtype}
    gh = torch.linspace(-1, 1, kh, **to)
    gw = torch.linspace(-1, 1, kw, **to)
    nh = torch.exp(-0.5 * (gh / std).pow(2))
    nw = torch.exp(-0.5 * (gw / std).pow(2))
    weights = torch.einsum('i,j->ij', nh, nw)
    return weights.view(1, 1, kh, kw, 1, 1)


In [93]:
kernel_size = (2, 2)
x = 10 * torch.rand(size=(10, 3, 100, 100), device=DEVICE)
y = unfold2d(x, kernel_size)

In [94]:
with torch.no_grad():
    start.record()
    z1 = fold2d(y, reduce='sum')
    end.record()
    torch.cuda.synchronize()
    print(start.elapsed_time(end))

0.6042879819869995


In [96]:
with torch.no_grad():
    start.record()
    z2 = fold2d(y, reduce='mean')
    end.record()
    torch.cuda.synchronize()
    print(start.elapsed_time(end))
    assert torch.allclose(x, z2)

1.1708159446716309


In [97]:
with torch.no_grad():
    start.record()
    z3 = fold2d(y, reduce='weighted_mean')
    end.record()
    torch.cuda.synchronize()
    print(start.elapsed_time(end))

2.0504000186920166


In [98]:
with torch.no_grad():
    start.record()
    z4 = fold2d(y, reduce='median')
    end.record()
    torch.cuda.synchronize()
    print(start.elapsed_time(end))

5.095359802246094


In [102]:

torch.allclose(x, z4)

True