In [8]:
import math
from math import log

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [6]:
def compute_distances(xe, ye, I, train=True):
    r"""
    Computes pairwise distances for all pairs of query items and
    potential neighbors.
    :param xe: BxNxE tensor of database item embeddings
    :param ye: BxMxE tensor of query item embeddings
    :param I: BxMxO index tensor that selects O potential neighbors for each item in ye
    :param train: whether to use tensor comprehensions for inference (forward only)
    :return: a BxMxO tensor of distances
    """

    # xe -> b n e
    # ye -> b m e
    # I  -> b m o
    b,n, e = xe.shape
    m = ye.shape[1]
    o = I.shape[2]

    if not train:
        # xe_ind -> b m o e
        If = I.view(b, m*o,1).expand(b,m*o,e)

        # D -> b m o
        ye = ye.unsqueeze(3)

        D = -2*ops.indexed_matmul_1_efficient(xe, ye.squeeze(3), I).unsqueeze(3)

        xe_sqs = (xe**2).sum(dim=-1, keepdim=True)
        xe_sqs_ind = xe_sqs.gather(dim=1, index=If[:,:,0:1]).view(b,m,o,1)
        D += xe_sqs_ind
        D += (ye**2).sum(dim=-2, keepdim=True)

        D = D.squeeze(3)
    else:
        # D_full -> b m n
        D_full = ops.euclidean_distance(ye, xe.permute(0,2,1))

        # D -> b m o
        D = D_full.gather(dim=2, index=I)

    return -D

In [10]:
def aggregate_output(W,x,I, train=True):
    r"""
    Calculates weighted averages for k nearest neighbor volumes.
    :param W: BxMxOxK matrix of weights
    :param x: BxNxF tensor of database items
    :param I: BxMxO index tensor that selects O potential neighbors for each item in ye
    :param train: whether to use tensor comprehensions for inference (forward only)
    :return: a BxMxFxK tensor of the k nearest neighbor volumes for each query item
    """

    # W -> b m o k
    # x -> b n f
    # I -> b m o
    b,n,f = x.shape
    m,o = I.shape[1:3]
    k = W.shape[3]
    # print(b,m,o,k,f,n)

    z = ops.indexed_matmul_2_efficient(x, W,I)

    return z

In [7]:
class N3AggregationBase(nn.Module):
    r"""
    Domain agnostic base class for computing neural nearest neighbors
    """
    def __init__(self, k, temp_opt={}):
        r"""
        :param k: Number of neighbor volumes to compute
        :param temp_opt: options for handling temperatures, see `NeuralNearestNeighbors`
        """
        super(N3AggregationBase, self).__init__()
        self.k = k
        self.nnn = NeuralNearestNeighbors(k, temp_opt=temp_opt)

    def forward(self, x, xe, ye, I, log_temp=None):
        r"""
        :param x: database items, shape BxNxF
        :param xe: embedding of database items, shape BxNxE
        :param ye: embedding of query items, shape BxMxE
        :param y: query items, if None then y=x is assumed, shape BxMxF
        :param I: Indexing tensor defining O potential neighbors for each query item
            shape BxMxO
        :param log_temp: optional log temperature
        :return:
        """

        # x  -> b n f
        # xe -> b n e
        # ye -> b m e
        # I  -> b m o
        b, n, f = x.shape
        m, e = ye.shape[1:]
        o = I.shape[2]
        k = self.k

        assert((b,n,e) == xe.shape)
        assert((b,m,e) == ye.shape)
        assert((b,m,o) == I.shape)

        # compute distance
        D = compute_distances(xe, ye, I, train=self.training)
        assert((b,m,o) == D.shape)

        # compute aggregation weights
        W = self.nnn(D, log_temp=log_temp)

        assert((b,m,o,k) == W.shape)
        # aggregate output
        z = aggregate_output(W, x, I, train=self.training)
        assert((b,m,f,k) == z.shape)

        return z

In [5]:
class N3Aggregation2D(nn.Module):
    r"""
    Computes neural nearest neighbors for image data based on extracting patches
    in strides.
    """
    def __init__(self, indexing, k, patchsize, stride, temp_opt={}, padding=None):
        r"""
        :param indexing: function for creating index tensor
        :param k: number of neighbor volumes
        :param patchsize: size of patches that are matched
        :param stride: stride with which patches are extracted
        :param temp_opt: options for handling temperatures, see `NeuralNearestNeighbors`
        """
        super(N3Aggregation2D, self).__init__()
        self.patchsize = patchsize
        self.stride = stride
        self.indexing = indexing
        self.k = k
        self.temp_opt = temp_opt
        self.padding = padding
        if k <= 0:
            self.aggregation = None
        else:
            self.aggregation = N3AggregationBase(k, temp_opt=temp_opt)

    def forward(self, x, xe, ye, y=None, log_temp=None):
        r"""
        :param x: database image
        :param xe: embedding of database image
        :param ye: embedding of query image
        :param y: query image, if None then y=x is assumed
        :param log_temp: optional log temperature image
        :return:
        """
        if self.aggregation is None:
            return y if y is not None else x

        # Convert everything to patches
        x_patch, padding = ops.im2patch(x, self.patchsize, self.stride, None, returnpadding=True)
        xe_patch = ops.im2patch(xe, self.patchsize, self.stride, self.padding)
        if y is None:
            y = x
            ye_patch = xe_patch
        else:
            ye_patch = ops.im2patch(ye, self.patchsize, self.stride, self.padding)

        I = self.indexing(xe_patch, ye_patch)
        if not self.training:
            index_neighbours_cache.clear()

        b,c,p1,p2,n1,n2 = x_patch.shape
        _,ce,e1,e2,m1,m2 = ye_patch.shape
        _,_,o = I.shape
        k = self.k
        _,_,H,W = y.shape
        n = n1*n2; m=m1*m2; f=c*p1*p2; e=ce*e1*e2

        x_patch = x_patch.permute(0,4,5,1,2,3).contiguous().view(b,n,f)
        xe_patch = xe_patch.permute(0,4,5,1,2,3).contiguous().view(b,n,e)
        ye_patch = ye_patch.permute(0,4,5,1,2,3).contiguous().view(b,m,e)

        if log_temp is not None:
            log_temp_patch = ops.im2patch(log_temp, self.patchsize, self.stride, self.padding)
            log_temp_patch = log_temp_patch.permute(0,4,5,2,3,1).contiguous().view(b,m,self.patchsize**2, log_temp.shape[1])
            if self.temp_opt["avgpool"]:
                log_temp_patch = log_temp_patch.mean(dim=2)
            else:
                log_temp_patch = log_temp_patch[:,:,log_temp_patch.shape[2]//2,:].contiguous()
        else:
            log_temp_patch = None

        # Get nearest neighbor volumes
        # z  -> b m1*m2 c*p1*p2 k
        z_patch = self.aggregation(x_patch, xe_patch, ye_patch, I, log_temp=log_temp_patch)
        z_patch = z_patch.permute(0,1,3,2).contiguous().view(b,m1,m2,k*c,p1,p2).permute(0,3,4,5,1,2).contiguous()

        # Convert patches back to whole images
        z = ops.patch2im(z_patch, self.patchsize, self.stride, padding)

        z = z.contiguous().view(b,k,c,H,W)
        z = z-y.view(b,1,c,H,W)
        z = z.view(b,k*c,H,W)

        # Concat with input
        z = torch.cat([y, z], dim=1)

        return z