# Alternative localy connected / spatially variant filters

## Technical needs

In [1]:
import torch
import torch.nn.functional as F
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

## The regular method

In [4]:
def method_1(feature_map, filters):
    # sizes
    N, C, H, W = feature_map.shape
    K = round(filters.shape[1] ** 0.5) # patch size 1 dim, assume squared shape
    
    # unfold feautre-map and reshape to match filters tensor
    unfolded = F.unfold(feature_map, K, padding=K//2).view(N, C, K ** 2, H, W)
    
    # sum over the filter dim
    return torch.einsum('nckhw,nkhw->nchw', unfolded, filters)

## Alternative no-unfolding method

In [5]:
def method_2(feature_map, filters):
    # sizes
    N, C, H, W = feature_map.shape
    K = round(filters.shape[1] ** 0.5) # patch size 1 dim, assume squared shape
    
    # shift channels of filter tensors
    filters = filters.view(N, K, K, H, W)
    filters = shift_all_directions(filters, K, reverse=False).view(N, K**2, H, W)
    
    # outer product on channels and patch size dims
    multiplied = torch.einsum('nchw,nkhw->nckhw', feature_map, filters).view(N, C, K, K, H, W)
    
    # shift result back and sum
    return shift_all_directions(multiplied, K, reverse=True).view(N, C, K**2, H, W).sum(2)


def shift_all_directions(x, K, reverse):
    # shift each channel with a different shift vector, row major
    for ind_h in range(K):
        for ind_w in range(K):
            shift_h, shift_w = (ind_h - K // 2) * (1 - 2*int(reverse)), (ind_w - K // 2) * (1 - 2*int(reverse))
            x[..., ind_h, ind_w, :, :] = F.pad(x[..., ind_h, ind_w, :, :], (shift_w, -shift_w, shift_h, -shift_h))
    return x

## Set experiment data

In [38]:
N = 2048  # batch size
C = 1 # feature map channels
K = 3  # filter size is K by K.
H = 32
W = 32
feature_map = torch.randn((N, C, H, W), device=DEVICE)
filters = torch.rand((N, K**2, H, W), device=DEVICE)

## Compare methods

In [39]:
start.record()
result_1 = method_1(feature_map, filters)
end.record()
torch.cuda.synchronize()
print('method_1 time:', start.elapsed_time(end))

start.record()
result_2 = method_2(feature_map, filters)
end.record()
torch.cuda.synchronize()
print('method_2 time:', start.elapsed_time(end))

print("~ same result?:", ((result_1 - result_2).abs() < 10e-5).all().item())

method_1 time: 26.50726318359375
method_2 time: 2.617151975631714
~ same result?: True
