In [11]:
import torch
import torch.nn as nn

from torch import unsqueeze, where, matmul, sum, repeat_interleave, sqrt, topk, flip, index_select

import numpy as np

In [12]:
def topk_with_distance(src, dis, k):
    """
    find for every data in src its k-nearest neighbor from dis with index and distance 
    params:
        src: (N, C) 
        dis: (M, C)
        k: numbers to select from dis
    return:
        indices: (N, k)
        distance: (N, k) 
    """
    N, _ =src.shape
    M, _ = dis.shape
    src_sqare = repeat_interleave(sum(src ** 2, -1).reshape(N,1), M, 1, output_size=M) # (N, M)
    dis_sqare = repeat_interleave(sum(dis ** 2, -1).reshape(1,M), N, 0, output_size=N) # (N, M)
    src_ids = matmul(src, dis.permute(1,0)) # (N, M)
    distance = src_sqare + dis_sqare - 2 * src_ids
    distance, indices = topk(distance, k, 1) # (N, k)
    distance = sqrt(distance)

    return flip(indices,dims=[1]), flip(distance, dims=[1])

In [27]:
def nonconformity_measure(train_index, train_distance, train_label):
    """
    params:
        train_index: (N, k)
        train_distance: (N, k)
        label: (M, L)
    return:
        nonconformity: (N, L) 
    """
    N, k = train_distance.shape
    M, L = train_label.shape
    # train_distance = repeat_interleave(train_distance.reshape(N, 1, k), L, 1, output_size=L) # (N, L, k)
    train_distance = train_distance.reshape(N, 1, k)
    labels = train_label[train_index] # (N, k, L)
    labels = ~labels
    nonconformity = matmul(train_distance, labels.float()) # (N, 1, L)
    nonconformity = nonconformity.reshape(N, L)
    nonconformity = nonconformity ** (-1)
    return nonconformity

In [28]:
class kNN(nn.Module):
    """
    params:
        input_channel:
        output_channel:
    inputs:
        feature: (M, C)
        train_feature: (N, C)
        train_label: (N, L), one-hot code
    return:
        nonconformity: (M, L)
    """
    def __init__(self, input_channel, output_channel, k):
        super(kNN, self).__init__()
        self.k = k
    
    def forward(self, feature, train_feature, train_label):
        index, distance = topk_with_distance(feature, train_feature, self.k) # (M, k), (M, k)
        nonconformity = nonconformity_measure(index, distance, train_label) # (M, L)
        return nonconformity

In [29]:
feature = torch.tensor(np.random.randint(0,16,size=(200,128))).float()
train_feature = torch.tensor(np.random.randint(0,16,size=(1000,128))).float()
train_label = torch.tensor(np.random.randint(0,2,size=(1000,10))).bool()

In [30]:
knn = kNN(87,128,100)

In [31]:
nonconformity = knn(feature, train_feature, train_label)

In [32]:
print(nonconformity.shape)

torch.Size([200, 10])


In [5]:
N, M = 20000, 5000
nonconformity = torch.Tensor(np.random.random(size=(N,))*3.4E+38)
standard_nonconformity = torch.Tensor(np.random.random(size=(M,))*3.4E+38)

In [6]:
nonconformity = unsqueeze(nonconformity, -1)
nonconformity = nonconformity.repeat(1, M)
print(nonconformity.shape)

torch.Size([20000, 5000])


In [52]:
result = nonconformity - standard_nonconformity
print(result.shape)

torch.Size([20000, 5000])


In [48]:
ones_array = torch.ones(result.shape)
zeros_array = torch.zeros(result.shape)
_result = where(result < 0,ones_array,zeros_array)

In [37]:
result = _result.sum(-1, keepdim=False)
print(result.shape)

torch.Size([20000])


In [38]:
result = result / M
print(result[:20])

tensor([0.5904, 0.8680, 0.3888, 0.2612, 0.4950, 0.0360, 0.8210, 0.2072, 0.8244,
        0.3100, 0.5700, 0.9140, 0.2904, 0.3008, 0.0592, 0.5360, 0.1212, 0.2880,
        0.4150, 0.8034])


In [39]:
result.requires_grad(True)

TypeError: 'bool' object is not callable