In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch import unsqueeze, where, matmul, sum, repeat_interleave, sqrt, topk, flip, cat
from torch.utils.tensorboard import SummaryWriter

import numpy as np

from dataprocess.cic_ids_2017 import CIC_IDS_2107_DataLoader
from net.linear import linear_3

In [2]:
def topk_with_distance(src, dis, k):
    """
    find for every data in src its k-nearest neighbor from dis with index and distance 
    params:
        src: (N, C) 
        dis: (M, C)
        k: numbers to select from dis
    return:
        indices: (N, k)
        distance: (N, k) 
    """
    N, _ =src.shape
    M, _ = dis.shape
    src_sqare = repeat_interleave(sum(src ** 2, -1).reshape(N,1), M, 1, output_size=M) # (N, M)
    dis_sqare = repeat_interleave(sum(dis ** 2, -1).reshape(1,M), N, 0, output_size=N) # (N, M)
    src_ids = matmul(src, dis.permute(1,0)) # (N, M)
    distance = src_sqare + dis_sqare - 2 * src_ids
    distance, indices = topk(distance, k, 1) # (N, k)
    distance = sqrt(distance)

    return flip(indices,dims=[1]), flip(distance, dims=[1])

In [3]:
def nonconformity_measure(train_index, train_distance, train_label):
    """
    params:
        train_index: (N, k)
        train_distance: (N, k)
        label: (M, L)
    return:
        nonconformity: (N, L) 
    """
    N, k = train_distance.shape
    M, L = train_label.shape
    # train_distance = repeat_interleave(train_distance.reshape(N, 1, k), L, 1, output_size=L) # (N, L, k)
    train_distance = train_distance.reshape(N, 1, k)
    labels = train_label[train_index] # (N, k, L)
    labels = ~labels
    nonconformity = matmul(train_distance, labels.float()) # (N, 1, L)
    nonconformity = nonconformity.reshape(N, L)
    nonconformity = nonconformity ** (-1)
    return nonconformity

In [4]:
class kNN(nn.Module):
    """
    params:
        k:
    inputs:
        feature: (M, C)
        train_feature: (N, C)
        train_label: (N, L), one-hot code
    return:
        nonconformity: (M, L)
    """
    def __init__(self, k):
        super(kNN, self).__init__()
        self.k = k
    
    def forward(self, feature, train_feature, train_label):
        index, distance = topk_with_distance(feature, train_feature, self.k) # (M, k), (M, k)
        nonconformity = nonconformity_measure(index, distance, train_label) # (M, L)
        return nonconformity

In [10]:
class DkNN(nn.Module):
    """
    params:

    inputs:
        nonconformity: (N, F, L)
        cali_nonconformity: (M,)
    return:
        logits: (N, L)
    """
    def __init__(self):
        super(kNN, self).__init__()
    
    def forward(self, nonconformity, cali_nonconformity):
        M, = cali_nonconformity.shape
        N, F, L = nonconformity.shape
        nonconformity = repeat_interleave(sum(nonconformity, 1).reshape(N, L, 1), M, 2, output_size=M) # (N, L, M)
        _conformity = nonconformity - cali_nonconformity <= 0 # (N, L, M)
        conformity = sum(_conformity, 2) / M # (N, L)
        return conformity

In [11]:
class DkNN_linear(nn.Module):
    def __init__(self, input_channel, output_channel, k):
        super(DkNN_linear, self).__init__()
        self.linear = linear_3(input_channel, output_channel)
        self.knn = kNN(k)
        self.DkNN = DkNN()
    
    def forward(self, input_data, train_feature_list, train_label, cali_nonconformity):
        N, C = input_data.shape
        M, L = train_label.shape
        x1, x2, x3 = self.linear(input_data) #(N, C)
        nonconformity1 = self.knn(x1, train_feature_list[0], train_label).reshape(N, 1, L) # (N, 1, L)
        nonconformity2 = self.knn(x2, train_feature_list[1], train_label).reshape(N, 1, L) # (N, 1, L)
        nonconformity3 = self.knn(x3, train_feature_list[2], train_label).reshape(N, 1, L) # (N, 1, L)
        
        nonconformity = cat((nonconformity1, nonconformity2, nonconformity3), dim=1)
        return self.DkNN(nonconformity, cali_nonconformity)