In [30]:
# kernels
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data
from torch_geometric.utils import degree

import torch
from torch.nn import CosineSimilarity, Module, ModuleList, Linear, Sigmoid
from torch.nn.parameter import Parameter

from itertools import permutations
import math


# kernel = Data(p_s = p_s)
# same_kernel = Data(p_s = p)
# sig = Sigmoid()



class Kernel(Module):
    def __init__(self, D, s, d, k, init_kernel:'type(Data)'=None):
        super(Kernel, self).__init__()
        if init_kernel is None:
            
            if s is None or d is None:
                raise Exception('s or d is not specified')
            else:
                init_kernel = Data(x_c = torch.randn(1, d),  x_s = torch.randn(s,d), edge_attr_s = torch.randn(s, k), p_s = torch.randn(s,D))

                
        x_c_tensor = init_kernel.x_c
        self.x_c = Parameter(x_c_tensor)

        x_s_tensor = init_kernel.x_s
        self.x_s = Parameter(x_s_tensor)
        
        edge_attr_s_tensor = init_kernel.edge_attr_s
        self.edge_attr_s = Parameter(edge_attr_s_tensor)
        
        p_s_tensor = init_kernel.p_s
        self.p_s = Parameter(p_s_tensor)
        

    def permute(self, x):
        rows = x.shape[0]
        l = [x[torch.tensor(permute)] for permute in list(permutations(range(rows)))]
        return l        
        
    def intra_angle(self,p:'shape(s,D)'):
        '''
        angles between each row vectors
        ''' 
        cos = CosineSimilarity(dim = 1)
        new_p = torch.roll(p,1,0)
        sc:'shape(s)' = cos(new_p, p)
        
        return sc  
    
    def get_angle_score(self, p, p_s):
        cos = CosineSimilarity(dim = 0)
#         print('p_dist:')
        inter_p_dist = self.intra_angle(p)
#         print('\n')
#         print('s_dist:')
        inter_p_s_dist = self.intra_angle(p_s)
#         print(f'inter_p_s_dist:{inter_p_s_dist}')
        sc = cos(inter_p_dist, inter_p_s_dist)
#         sc = torch.dot(inter_p_dist, inter_p_s_dist.T)
#         print(sc)
        return sc

    def get_length_score(self, p, p_s):
        len_p = torch.norm(p, dim = -1)
        len_p_s = torch.norm(p_s, dim = -1)
    
        # inverse of L2 norm
        diff = torch.square(len_p - len_p_s)
        sc = torch.sum(diff)
        sc = torch.atan(1/sc)
    
#         print(sc)
        return sc
    
    def get_support_attribute_score(self, x_nei, x_s:'shape(s, d)'):

        # inverse of L2 norm
        diff = torch.square(x_nei - x_s)
        sc = torch.sum(diff)
        sc:'shape([])' = torch.atan(1/sc)
        return sc
    
    def get_center_attribute_score(self, x_ori, x_c):
            pass
    
    def get_edge_attribute_score(self, edge_attr_nei, edge_attr_s):
            pass
    
    def forward(self, **kwargv):
        
        if len(kwargv) == 1:
            x = kwargv['data'].x
            edge_index = kwargv['data'].edge_index
            edge_attr = kwargv['data'].edge_attr
            p = kwargv['data'].p            
        else:
            x = kwargv['x']
            edge_index = kwargv['edge_index']
            edge_attr = kwargv['edge_attr']
            p = kwargv['p']
            
        x_c = self.x_c
        x_s = self.x_s
        edge_attr_s = self.edge_attr_s
        p_s = self.p_s            

              
        p_s_list = self.permute(p_s)
        x_s_list = self.permute(x_s)
     
        max_angle_sc = torch.tensor(-float('Inf'))
        best_angle_sc_index = 0
        for i, each_p_s in enumerate(p_s_list):
            angle_sc = self.get_angle_score(p, each_p_s)

            if angle_sc > max_angle_sc:
                best_angle_sc_index = i
                max_angle_sc = angle_sc
        
        # get the p and p_s combination that gives the best angle score
        best_p_s = p_s_list[best_angle_sc_index]
        
        # calculate the length score for the best combination
        length_sc = self.get_length_score(p, best_p_s)

        # calculate the support attribute score for the best combination
        best_x_s = x_s_list[best_angle_sc_index]
        supp_attr_sc = self.get_support_attribute_score(x, best_x_s)
        
        length_sc = length_sc.unsqueeze(dim = 0)
        max_angle_sc = max_angle_sc.unsqueeze(dim = 0)
        supp_attr_sc = supp_attr_sc.unsqueeze(dim=0)
        
        
        max_atan = torch.tensor([math.pi/2]) # the maxium value a arctain function can get
        one = torch.tensor([1])
        
        sc = torch.atan(1/(torch.square(length_sc-max_atan) + torch.square(max_angle_sc-one)+ torch.square(supp_attr_sc-max_atan)))
        
        print('\n')
        print(f'len sc:{length_sc.item()}')
        print(f'angle sc:{max_angle_sc.item()}')
        print(f'attribute_sc:{supp_attr_sc.item()}')
        print(f'total sc: {sc.item()}')
        return sc
        


    

model = Kernel(D=2, s=3, d=5, k=1)
# model = KernelSet(dim = 2, f_dim = 5)
# model = GraphConv(2, 3, 5)

# for item in model.state_dict():
#     print(item)

    
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
edge_attr =  torch.randn(4,2)
p = torch.randn(3,2)

data = Data(x=x, edge_index=edge_index, edge_attr = edge_attr, p = p)
model(data = data)




len sc:0.6297647356987
angle sc:0.6640481352806091
attribute_sc:0.027843495830893517
total sc: 0.28772425651550293


tensor([0.2877], grad_fn=<AtanBackward>)

In [None]:
from 

In [None]:
class KernelSet(Module):
    def __init__(self, dim, f_dim):
        super(KernelSet, self).__init__()        
         
        self.kernel_set = ModuleList(
            [Kernel(dim=dim, S=1, f_dim=f_dim), 
             Kernel(dim=dim, S=2, f_dim=f_dim), 
             Kernel(dim=dim, S=3, f_dim=f_dim), 
             Kernel(dim=dim, S=4, f_dim=f_dim)])
    def get_degree_index(self, x, edge_index):
        deg = degree(edge_index[0], x.shape[0])
        return deg
    
    def get_neighbor_index(self, edge_index, center_index):
#         print('edge_index')
#         print(edge_index)
#         print('\n')
#         print('center_index')
#         print(center_index)
        a = edge_index[0]
        b = a.unsqueeze(1) == center_index
        c = b.nonzero()
        d = c[:,0] 
        return edge_index[1,d]
    
    def get_focal_nodes_of_degree(self, deg, x, edge_index, p):
        '''
        inputs:
        
        
        outputs
        ori_x: a feature matrix that only contains rows (i.e. the center node) having certain degree
        ori_p: a position matrix that only contains rows (i.e. the center node) having certain degree
        '''
        deg_index  = self.get_degree_index(x, edge_index)
        selected_index = (deg_index == deg).nonzero(as_tuple=True)
        ori_x = torch.index_select(input = x, dim = 0, index = selected_index[0])
        ori_p = torch.index_select(input = p, dim = 0, index = selected_index[0])
        return ori_x, ori_p
    
    def get_neighbor_nodes_of_degree(self, deg, x, edge_index, p):
        '''
        outputs:
        nei_x: a feature matrix that only contains rows (i.e. the neighboring node) that its center node has certain degree
        nei_p: a position matrix that only contains rows (i.e. the neighboring node) that its center node has certain degree
        '''
        deg_index = self.get_degree_index(x, edge_index)
        center_index = (deg_index == deg).nonzero(as_tuple=True)[0]
        nei_index = self.get_neighbor_index(edge_index, center_index)
        
#         print(f'nei_index:{nei_index}')
        nei_x = torch.index_select(x, 0, nei_index)
        nei_p = torch.index_select(p, 0, nei_index)
        return nei_x, nei_p 
        
    def forward(self, data):
        '''
        inputs:
        data: graph data containing feature matrix (i.e. x), adjacency matrix (i.e. edge_index),, edge_attr matrix (i.e. edge_attr)
        '''
        x = data.x
        edge_index = data.edge_index
        edge_attr = data.edge_attr
        p = data.p
        
        
        deg = 2
        ori_x, ori_p = self.get_focal_nodes_of_degree(deg = deg, x = x, edge_index = edge_index, p=p)
#         print(ori_x)
#         print(ori_p)
        nei_x, nei_p = self.get_neighbor_nodes_of_degree(deg = deg, x = x, edge_index = edge_index, p=p)
        
           
        return ori_x, ori_p
        
#         output = torch.cat([x(data) for x in ModuleList], dim=0)
 
class GraphConv(MessagePassing):
    def __init__(self, dim, L, f_dim):
        '''
        inputs:
        dim: 2d or 3d
        L: number of KernelSet
        f_dim: feature dimension
        '''
        super(GraphConv, self).__init__()
        self.kernelset_list = ModuleList()
        for l in range(L):
            self.kernelset_list.append(KernelSet(dim, f_dim))
    
    
    def forward(self, data):
        x = data.x
        edge_index = data.edge_index
        edge_attr = data.edge_attr
        
        
    def message(self):
        pass

In [None]:
# kernels
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data
from torch_geometric.utils import degree

import torch
from torch.nn import CosineSimilarity, Module, ModuleList, Linear, Sigmoid
from torch.nn.parameter import Parameter

from itertools import permutations
import math


# kernel = Data(p_s = p_s)
# same_kernel = Data(p_s = p)
# sig = Sigmoid()



class Kernel(Module):
    def __init__(self, D, s, d, k, init_kernel=None):
        super(Kernel, self).__init__()
        if init_kernel is None:
            
            if s is None or d is None:
                raise Exception('s or d is not specified')
            else:
#                 print('randomizing kernel')
                init_kernel = Data(x_c = torch.randn(1, d),  x_s = torch.randn(s,d), e_s = torch.randn(s, k), p_s = torch.randn(s,D))

        
        p_s_tensor = init_kernel.p_s#torch.tensor([[1,0],[-math.sqrt(0.5),-math.sqrt(0.5)],[0,1]], dtype = torch.float)
        self.p_s = Parameter(p_s_tensor)
        
        x_s_tensor = init_kernel.x_s
        x_s = Parameter(x_s_tensor)
        self.x_s = x_s
        
        x_c_tensor = init_kernel.x_c
        x_c = Parameter(x_c_tensor)
        self.x_c = x_c

        
    def permute(self, x):
        rows = x.shape[0]
        l = [x[torch.tensor(permute)] for permute in list(permutations(range(rows)))]
        return l        
        
    def intra_angle(self,p:'shape(s,D)'):
        '''
        angles between each row vectors
        ''' 
        cos = CosineSimilarity(dim = 1)
        new_p = torch.roll(p,1,0)
        sc:'shape(s)' = cos(new_p, p)
        
        return sc  
    
    def get_angle_score(self, p, p_s):
        cos = CosineSimilarity(dim = 0)
#         print('p_dist:')
        inter_p_dist = self.intra_angle(p)
#         print('\n')
#         print('s_dist:')
        inter_p_s_dist = self.intra_angle(p_s)
#         print(f'inter_p_s_dist:{inter_p_s_dist}')
        sc = cos(inter_p_dist, inter_p_s_dist)
#         sc = torch.dot(inter_p_dist, inter_p_s_dist.T)
#         print(sc)
        return sc

    def get_length_score(self, p, p_s):
        len_p = torch.norm(p, dim = -1)
        len_p_s = torch.norm(p_s, dim = -1)
    
        # inverse of L2 norm
        diff = torch.square(len_p - len_p_s)
        sc = torch.sum(diff)
        sc = torch.atan(1/sc)
    
#         print(sc)
        return sc
    
    def get_support_attribute_score(self, x, x_s:'shape(s, d)'):

        # inverse of L2 norm
        diff = torch.square(x - x_s)
        sc = torch.sum(diff)
        sc:'shape([])' = torch.atan(1/sc)
        return sc
    def get_center_attribute_score(self, x_ori, x_c):
    
    def forward(self, data):

        # data
        p = data.p
        x = data.x
        
        # kernel
        p_s = self.p_s
        x_s = self.x_s
        x_c = self.x_c
              
        p_s_list = self.permute(p_s)
        x_s_list = self.permute(x_s)
     
        max_angle_sc = torch.tensor(-float('Inf'))
        best_angle_sc_index = 0
        for i, each_p_s in enumerate(p_s_list):
            angle_sc = self.get_angle_score(p, each_p_s)

            if angle_sc > max_angle_sc:
                best_angle_sc_index = i
                max_angle_sc = angle_sc
        
        # get the p and p_s combination that gives the best angle score
        best_p_s = p_s_list[best_angle_sc_index]
        
        # calculate the length score for the best combination
        length_sc = self.get_length_score(p, best_p_s)

        # calculate the support attribute score for the best combination
        best_x_s = x_s_list[best_angle_sc_index]
        supp_attr_sc = self.get_support_attribute_score(x, best_x_s)
        
        length_sc = length_sc.unsqueeze(dim = 0)
        max_angle_sc = max_angle_sc.unsqueeze(dim = 0)
        supp_attr_sc = supp_attr_sc.unsqueeze(dim=0)
        
        
        max_atan = torch.tensor([math.pi/2]) # the maxium value a arctain function can get
        one = torch.tensor([1])
        
        sc = torch.atan(1/(torch.square(length_sc-max_atan) + torch.square(max_angle_sc-one)+ torch.square(supp_attr_sc-max_atan)))
        
        print('\n')
        print(f'len sc:{length_sc.item()}')
        print(f'angle sc:{max_angle_sc.item()}')
        print(f'attribute_sc:{supp_attr_sc.item()}')
        print(f'total sc: {sc.item()}')
        return sc
        


    

model = Kernel(D=2, s=3, d=5, k=1)
# model = KernelSet(dim = 2, f_dim = 5)
# model = GraphConv(2, 3, 5)

# for item in model.state_dict():
#     print(item)

    
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
edge_attr =  torch.randn(4,2)
p = torch.randn(3,2)

data = Data(x=x, edge_index=edge_index, edge_attr = edge_attr, p = p)
model(data)
