In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F



In [2]:
# Hyperparameters
alpha=2
beta=3
s=3


In [3]:
def knn(x,k):
    
    # x: [B,Dim,N]
    
    
    inner=-2*torch.matmul(x.transpose(2,1).contiguous(),x)
    xx=torch.sum(x**2,dim=1,keepdim=True)
    pairwise_distance=-xx-inner-xx.transpose(2,1).contiguous()
    
    
    distance=pairwise_distance.topk(k=k,dim=-1)[0]
    
    return distance 



In [45]:
class FeatureExtraction(nn.Module):
    def __init__(self,d_in,d_out):
        super(FeatureExtraction,self).__init__()
        
        self.d_in=d_in
        self.d_out=d_out
        
        # make linear layer s times
        
        self.fc1=nn.Linear(1,d_in)
        self.fc2=nn.Linear(1,d_in)
        self.fc3=nn.Linear(1,d_in)
        
        # last layer for learning sigma 
        
        self.fc_sigma=nn.Linear(s,1)
        
        # for calculating functin h
        self.fc_h=nn.Linear(d_in,1)
        
    
    def forward(self,input):
        
        B,N,d_in=input.shape
        
        all_one=torch.ones(B,N,1)
        
        delta1=F.relu(self.fc1(all_one))
        delta2=F.relu(self.fc2(all_one))
        delta3=F.relu(self.fc3(all_one))
        
        # addition of displacement vector delta
        
        feature1=input+delta1
        feature2=input+delta2
        feature3=input+delta3
        
        # minimum distance calculation using knn
        distance1=knn(feature1.transpose(1,2),2)[:,:,1:2]
        distance2=knn(feature2.transpose(1,2),2)[:,:,1:2]
        distance3=knn(feature3.transpose(1,2),2)[:,:,1:2]
        
        g1=torch.tanh(((distance1+beta)**-1)*alpha)
        g2=torch.tanh(((distance2+beta)**-1)*alpha)
        g3=torch.tanh(((distance3+beta)**-1)*alpha)
        
        # concatenate g for all deltas
        
        g=torch.cat((g1,g2,g3),dim=2)
        
        g=F.relu(self.fc_sigma(g))
        
        
        h=F.relu(self.fc_h(input))
        
        output_feature=g+h
        
        output_feature=output_feature.repeat(1,1,self.d_out)
        
        return output_feature, g
    


In [46]:
class NeighborPooling(nn.Module):
    def __init__(self, fraction):
        super(NeighborPooling,self).__init__()
        
        self.fraction=fraction
        
        
    def forward(self,output_feature,g):
        
        B,_,_=g.shape
        _,_,d=output_feature.shape
        
        
        
        A=torch.tanh(torch.abs(g))
        
        idx=(torch.arange(0,B).view(-1,1,1))*self.fraction
        
        idx_base=A.topk(k=self.fraction,dim=1)[1]
        
        idx=idx+idx_base
        
        idx=idx.view(B*self.fraction)
        
        output_feature=output_feature.view(-1,d)
        output_feature=output_feature[idx,:].view(B,-1,d)
        
        return output_feature
    
        


In [74]:
class Encoder(nn.Module):
    def __init__(self,d_in, d_out, fraction):
        super(Encoder,self).__init__()
        
        self.d_in=d_in
        self.d_out=d_out
        self.fraction=fraction
        
        self.feature1=FeatureExtraction(self.d_in, self.d_out)
        self.neighborpooling1=NeighborPooling(self.fraction)
        
        self.feature2=FeatureExtraction(self.d_out,self.d_out)
        self.neighborpooling2=NeighborPooling(self.fraction)
        
        self.feature3=FeatureExtraction(self.d_out, self.d_out)
        self.neighborpooling3=NeighborPooling(self.fraction)
        
        self.feature4=FeatureExtraction(self.d_out,self.d_out)
        
    def forward(self,x):
        
        feature,g=self.feature1(x)
        pooling=self.neighborpooling1(feature,g)
        
        feature,g=self.feature2(pooling)
        pooling=self.neighborpooling2(feature,g)
        
        feature,g=self.feature3(pooling)
        pooling=self.neighborpooling3(feature,g)
        
        feature,g=self.feature4(pooling)
        
        feature=torch.max(feature,dim=1, keepdim=True)[0]
        
        return feature


In [85]:
class Upsample(nn.Module):
    def __init__(self):
        super(Upsample,self).__init__()
        
        self.upsampling_factor=3  # 3 times upsampling of feature vectors
        
    def forward(self, input):
        B,N,D=input.shape
        
        x=input.repeat(1,self.upsampling_factor,1)
        
        return x
        

In [86]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder,self).__init__()
        
        self.upsml=Upsample()
        
    def forward(self, input):
        
        out=self.upsml(input)
        out=self.upsml(out)
        out=self.upsml(out)
        out=self.upsml(out)
        out=self.upsml(out)
        out=self.upsml(out)
        P_out=self.upsml(out)
        
        return P_out
    
    

In [110]:
class Model(nn.Module):
    def __init__(self, d_in, d_out, fraction):
        super(Model,self).__init__()
        
        
        self.d_in=d_in
        self.d_out=d_out
        self.fraction=fraction
        
        self.encoder=Encoder(self.d_in,self.d_out, self.fraction)
        
        self.decoder=Decoder()
        
    def forward(self,input):
        
        out=self.encoder(input)
        out=self.decoder(out)
        
        return out
       

In [115]:
# code testing 

t=torch.randn(3,10,3)
m=Model(3,3,5)
a=m(t)
print(a.shape)





torch.Size([3, 2187, 3])
