In [1]:
# import dependencies 
import math
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter

In [2]:
# Graph Convolutional Network 
class GCN(torch.nn.Module):    
    def __init__(self, num_layers, feats_per_node, layer_1_feats, layer_2_feats, activation):
        super().__init__()
        self.activation = activation
        self.num_layers = num_layers

        self.w_list = nn.ParameterList()
        for i in range(self.num_layers):
            if i==0:
                w_i = Parameter(torch.Tensor(feats_per_node, layer_1_feats))
                self.reset_param(w_i)
            else:
                w_i = Parameter(torch.Tensor(layer_1_feats, layer_2_feats))
                self.reset_param(w_i)
            self.w_list.append(w_i)
            
    def forward(self, A_list, Nodes_list, nodes_mask_list):
        node_feats = Nodes_list[-1]
        #A_list: T, each element sparse tensor
        #take only last adj matrix in time
        Ahat = A_list[-1]
        #Ahat: NxN ~ 30k
        #sparse multiplication

        # Ahat NxN
        # self.node_embs = Nxk
        #
        # note(bwheatman, tfk): change order of matrix multiply
        last_l = self.activation(Ahat.matmul(node_feats.matmul(self.w_list[0])))
        for i in range(1, self.num_layers):
            last_l = self.activation(Ahat.matmul(last_l.matmul(self.w_list[i])))
        
        return last_l

    def reset_param(self, t):
        stdv = 2. / math.sqrt(t.size(0))
        t.data.uniform_(-stdv,stdv)

In [3]:
# Cross Entropy Class
class Cross_Entropy(torch.nn.Module):
    def __init__(self):
        super().__init__()
        #weights = torch.tensor(args.class_weights).to(args.device)

In [4]:
# Trainer Class 
class Trainer():
    def init(self, gcn):
        self.gcn = gcn         

In [5]:
# start training to extract node embeddings 
device = "cpu"

# new instance of graph convolutional neural network 
gcn = GCN(num_layers=2, 
          feats_per_node=164, 
          layer_1_feats=20, 
          layer_2_feats=10, 
          activation=torch.nn.RReLU()).to(device)

# new instance of cross entropy 
cross_entropy = Cross_Entropy().to(device)