In [18]:
import torch
from ipynb.fs.full.SpatialConv import SpatialConv
from torch_geometric.nn import ChebConv, GCNConv
import torch.nn as nn
import torch.nn.functional as F

In [19]:
class AGS_layer(torch.nn.Module):
    def __init__(self, input_channels, output_channels, dropout=0.2):
        super().__init__()
        self.T = 3
        self.p = dropout
        self.Aconv1 = GCNConv(input_channels, output_channels)
        self.Sconv1 = SpatialConv(input_channels, output_channels)
#         self.Sconv1 =  ChebConv(input_channels, output_channels, K=2)
#         self.Aconv1 = ChebConv(input_channels, output_channels, K=2)
        self.I1 = nn.Linear(input_channels, output_channels)
        
        self.layer_norm_a1 =  nn.LayerNorm(output_channels)
        self.layer_norm_s1 =  nn.LayerNorm(output_channels)
        self.layer_norm_i1 =  nn.LayerNorm(output_channels)
        
        self.alpha_a1 = nn.Linear(output_channels, 1)
        self.alpha_s1 = nn.Linear(output_channels, 1)
        self.alpha_i1 = nn.Linear(output_channels, 1)
        self.w1 = nn.Linear(3, 3)
        
        #self.reset_parameters()        
        
    def reset_parameters(self):
        
        stdv = 1. / math.sqrt(self.I1.weight.size(1))
        std_att = 1. / math.sqrt(self.w1.weight.size(1))
        std_att_vec = 1. / math.sqrt( self.alpha_a1.weight.size(1))
        
#         self.I1.weight.data.uniform_(-stdv, stdv)
#         self.Aconv1.lin.weight.data.uniform_(-stdv, stdv)
        #self.Sconv1.lin.weight.data.uniform_(-stdv, stdv)
        
#         for i in range(len(self.Sconv1.lins)):
#             self.Sconv1.lins[i].weight.data.uniform_(-stdv, stdv)
        
        self.alpha_a1.weight.data.uniform_(-std_att, std_att)
        self.alpha_s1.weight.data.uniform_(-std_att, std_att)
        self.alpha_i1.weight.data.uniform_(-std_att, std_att)
        
        self.w1.weight.data.uniform_(-std_att_vec, std_att_vec)
        
        self.layer_norm_a1.reset_parameters()
        self.layer_norm_s1.reset_parameters()
        self.layer_norm_i1.reset_parameters()
        

    def forward(self, x0, edge_index, edge_weight=None):
        a1 = F.relu(self.Aconv1(x0, edge_index, edge_weight))
        #a1 = self.layer_norm_a1(a1)
        #a1 = F.dropout(a1, p=self.p, training=self.training)
        
        s1 = F.relu(self.Sconv1(x0, edge_index, edge_weight))
        #s1 = self.layer_norm_s1(s1)
        #s1 = F.dropout(s1, p=self.p, training=self.training)

        i1 = F.relu(self.I1(x0))
        #i1 = self.layer_norm_i1(i1)
        #i1 = F.dropout(i1, p=self.p, training=self.training)
        
        ala1, als1, ali1 = self.alpha_a1(a1), self.alpha_s1(s1), self.alpha_i1(i1)
       
        alpha1 = self.w1(torch.sigmoid(torch.cat([ala1, als1, ali1],dim=-1)))/self.T
        alpha1 = F.softmax(alpha1, dim=1)
                    
        #x1 = torch.mm(torch.diag(alpha1[:,0]),a1) + torch.mm(torch.diag(alpha1[:,1]),s1) + torch.mm(torch.diag(alpha1[:,2]),i1)                        
        x1 = 3*(alpha1[:,0][:,None]*a1 + alpha1[:,1][:,None]*s1 + alpha1[:,2][:,None]*i1)
        
        return x1
        
class AGS_GCN(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_channels=16, dropout=0.2):
        super().__init__()        
        self.num_classes = num_classes
        self.p = dropout
        
        self.ags_layer1 = AGS_layer(num_features, hidden_channels, dropout)
        self.ags_layer2 = AGS_layer(hidden_channels, num_classes, dropout)
        
        self.PredW = nn.Linear(2*hidden_channels, num_classes)
        
    
    def forward(self, x0, edge_index, edge_weight=None):
        
        #x0 = F.dropout(x0, p=self.p, training=self.training)
        x1 = F.relu(self.ags_layer1(x0, edge_index, edge_weight))
        x1 = F.dropout(x1, p=self.p, training=self.training)        
        x2 = self.ags_layer2(x1, edge_index, edge_weight)
              
#         x2 = F.dropout(x2, p=self.p, training=self.training)                
#         x2 = self.PredW(torch.cat([x1, x2], dim=-1))
        
        return x2.log_softmax(dim=-1)

In [21]:
if __name__ == '__main__':
    test = AGS_GCN(2, 2)
    print(test)
    n=7
    x = torch.Tensor([[1,0],[1,0],[1,0],[0,1],[0,1],[0,1],[0,1]])
    y = torch.LongTensor([0,0,0, 1, 1, 1, 1])
    edge_index = torch.LongTensor([[1,2],[1,4],[1,5],[2,1],[3,6],[3,7],[4,5],[4,1],[4,6],[4,7],[5,1],[5,4],[5,6],[6,3],[6,4],[6,5],[6,7],[7,3],[7,4],[7,6]]).T
    edge_index = edge_index-1
    print(test(x,edge_index))

AGS_GCN(
  (ags_layer1): AGS_layer(
    (Aconv1): GCNConv(2, 16)
    (Sconv1): SpatialConv(2, 16)
    (I1): Linear(in_features=2, out_features=16, bias=True)
    (layer_norm_a1): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
    (layer_norm_s1): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
    (layer_norm_i1): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
    (alpha_a1): Linear(in_features=16, out_features=1, bias=True)
    (alpha_s1): Linear(in_features=16, out_features=1, bias=True)
    (alpha_i1): Linear(in_features=16, out_features=1, bias=True)
    (w1): Linear(in_features=3, out_features=3, bias=True)
  )
  (ags_layer2): AGS_layer(
    (Aconv1): GCNConv(16, 2)
    (Sconv1): SpatialConv(16, 2)
    (I1): Linear(in_features=16, out_features=2, bias=True)
    (layer_norm_a1): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
    (layer_norm_s1): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
    (layer_norm_i1): LayerNorm((2,), eps=1e-05, elementwise_affi