In [1]:
import numpy as np
import pickle as pkl
import networkx as nx
import scipy.sparse as sp
import torch
from scipy.sparse import csgraph
import sys
import time
import argparse
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from utils import *
from graphConvolution import *
from torch import nn

### Load data

In [2]:
adj, features, labels, idx_train, idx_val, idx_test = load_data(dataset='cora')


[STEP 1]: Upload cora dataset.
| # of nodes : 2708
| # of edges : 5278.0
| # of features : 1433
| # of clases   : 7
| # of train set : 140
| # of val set   : 500
| # of test set  : 1000


In [3]:
from torch.autograd import Variable
features, adj, labels = Variable(features), Variable(adj), Variable(labels)
# torch.cuda.manual_seed(72)
features = features.cuda()
adj = adj.cuda()
labels = labels.cuda()
idx_train = idx_train.cuda()
idx_val = idx_val.cuda()
idx_test = idx_test.cuda()

### FM-6layers

In [4]:
class GraphConvolutionFM(Module):
    def __init__(self, in_features, out_features, embedding, bias=True):
        super(GraphConvolutionFM, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.embedding = embedding
        self.weight = Parameter(torch.FloatTensor(in_features, out_features),requires_grad=True)
        self.V = Parameter(torch.randn(out_features, in_features, embedding),requires_grad=True)
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        self.V.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj,nhid1,nhid2,nhid3,nhid4):
        out_lin = torch.mm(input, self.weight) + self.bias
        # all
        out_1 = torch.matmul(input,self.V).pow(2).sum(2, keepdim=True).view(self.out_features,input.shape[0]).t()
        out_2 = torch.matmul(input.pow(2), self.V.pow(2)).sum(2, keepdim=True).view(self.out_features,input.shape[0]).t() 
        out_inter1 = 0.5*(out_1 - out_2)
        # x1-xnhid1
        out_3 = torch.matmul(input[:,:nhid1],self.V[:,:nhid1,:]).pow(2).sum(2, keepdim=True).view(self.out_features,input.shape[0]).t()
        out_4 = torch.matmul(input[:,:nhid1].pow(2), self.V[:,:nhid1,:].pow(2)).sum(2, keepdim=True).view(self.out_features,input.shape[0]).t() 
        out_inter2 = 0.5*(out_3 - out_4)
        # xnhid1-xnhid2
        out_5 = torch.matmul(input[:,nhid1:nhid2],self.V[:,nhid1:nhid2,:]).pow(2).sum(2, keepdim=True).view(self.out_features,input.shape[0]).t()
        out_6 = torch.matmul(input[:,nhid1:nhid2].pow(2), self.V[:,nhid1:nhid2,:].pow(2)).sum(2, keepdim=True).view(self.out_features,input.shape[0]).t() 
        out_inter3 = 0.5*(out_5 - out_6)
        # xnhid2-xnhid3
        out_7 = torch.matmul(input[:,nhid2:nhid3],self.V[:,nhid2:nhid3,:]).pow(2).sum(2, keepdim=True).view(self.out_features,input.shape[0]).t()
        out_8 = torch.matmul(input[:,nhid2:nhid3].pow(2), self.V[:,nhid2:nhid3,:].pow(2)).sum(2, keepdim=True).view(self.out_features,input.shape[0]).t() 
        out_inter4 = 0.5*(out_7 - out_8)
        # xnhid3-xnhid4
        out_9 = torch.matmul(input[:,nhid3:nhid4],self.V[:,nhid3:nhid4,:]).pow(2).sum(2, keepdim=True).view(self.out_features,input.shape[0]).t()
        out_10 = torch.matmul(input[:,nhid3:nhid4].pow(2), self.V[:,nhid3:nhid4,:].pow(2)).sum(2, keepdim=True).view(self.out_features,input.shape[0]).t() 
        out_inter5 = 0.5*(out_9- out_10)
        
        # xnhid4-xnhid5
        out_11 = torch.matmul(input[:,nhid4:],self.V[:,nhid4:,:]).pow(2).sum(2, keepdim=True).view(self.out_features,input.shape[0]).t()
        out_12 = torch.matmul(input[:,nhid4:].pow(2), self.V[:,nhid4:,:].pow(2)).sum(2, keepdim=True).view(self.out_features,input.shape[0]).t() 
        out_inter6 = 0.5*(out_11- out_12)
        
        out_inter = out_inter1 - out_inter2 - out_inter3- out_inter4- out_inter5-out_inter6
        
        output = out_inter + out_lin
        
        output = torch.spmm(adj, output) 
        return output
    

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

In [5]:
class GCN(nn.Module):
    def __init__(self, nfeat, nhid1,nhid2, nhid3,nhid4,nhid5,nclass, dropout):
        super(GCN, self).__init__()
        self.nhid1 = nhid1
        self.nhid2 = nhid2
        self.nhid3 = nhid3
        self.nhid4 = nhid4
        self.nhid5 = nhid5
        
        self.gc_1 = GraphConvolution(nfeat, nhid1,bias=True)
        self.gc_2 = GraphConvolution(nfeat, nhid2,bias=True)
        self.gc_3 = GraphConvolution(nfeat, nhid3,bias=True)
        self.gc_4 = GraphConvolution(nfeat, nhid4,bias=True)
        
        self.gc1 = GraphConvolution(nfeat, nhid1,bias=True)
        self.gc1_2 = GraphConvolution(nhid1, nhid2,bias=True)
        self.gc1_3 = GraphConvolution(nhid1, nhid3,bias=True)
        self.gc1_4 = GraphConvolution(nhid1, nhid4,bias=True)
        
        
        self.gc2 = GraphConvolution(nhid1, nhid2,bias=True)
        self.gc2_3 = GraphConvolution(nhid2, nhid3,bias=True)
        self.gc2_4 = GraphConvolution(nhid2, nhid4,bias=True)

        
        self.gc3 = GraphConvolution(nhid2, nhid3,bias=True)
        self.gc3_4 = GraphConvolution(nhid3, nhid4,bias=True)
        
        self.gc4 = GraphConvolution(nhid3, nhid4,bias=True)
        
        self.gc5 = GraphConvolution(nhid4, nhid5,bias=True)
        
        self.gcFM6 = GraphConvolutionFM(nhid5+nhid4+nhid3+nhid2+nhid1, nclass,embedding=5,bias=True)
        
        self.dropout = dropout

    def forward(self, x, adj):
        x_d = F.dropout(x, self.dropout, training=self.training)
        x_df1 = F.dropout(F.relu(self.gc_1(x_d,adj)), self.dropout, training=self.training)
        x_df2 = F.dropout(F.relu(self.gc_2(x_d,adj)), self.dropout, training=self.training)
        x_df3 = F.dropout(F.relu(self.gc_3(x_d,adj)), self.dropout, training=self.training)
        x_df4 = F.dropout(F.relu(self.gc_4(x_d,adj)), self.dropout, training=self.training)

        x1_d = F.dropout(F.relu(self.gc1(x_d, adj)), training=self.training)
        x1_df2 = F.dropout(F.relu(self.gc1_2(x1_d,adj)), self.dropout, training=self.training)
        x1_df3 = F.dropout(F.relu(self.gc1_3(x1_d,adj)), self.dropout, training=self.training)
        x1_df4 = F.dropout(F.relu(self.gc1_4(x1_d,adj)), self.dropout, training=self.training)
        
        combined1 =  torch.cat([x_df1,x1_d],dim=1).view(x.shape[0],2,self.nhid1)
        combined1 = torch.max(combined1,dim=1).values
        combined1 = F.dropout(combined1, self.dropout, training=self.training)
        
        x2_d = F.dropout(F.relu(self.gc2(combined1, adj)), self.dropout, training=self.training)
        x2_df3 = F.dropout(F.relu(self.gc2_3(x2_d,adj)), self.dropout, training=self.training)
        x2_df4 = F.dropout(F.relu(self.gc2_4(x2_d,adj)), self.dropout, training=self.training)

        combined2 =  torch.cat([x_df2,x1_df2,x2_d],dim=1).view(x.shape[0],3,self.nhid2)
        combined2 = torch.max(combined2,dim=1).values
        combined2 = F.dropout(combined2, self.dropout, training=self.training)

        x3_d = F.dropout(F.relu(self.gc3(combined2, adj)), self.dropout, training=self.training)
        x3_df4 = F.dropout(F.relu(self.gc3_4(x3_d,adj)), self.dropout, training=self.training)
        
        combined3 =  torch.cat([x_df3,x1_df3,x2_df3,x3_d],dim=1).view(x.shape[0],4,self.nhid3)
        combined3 = torch.max(combined3,dim=1).values
        combined3 = F.dropout(combined3, self.dropout, training=self.training)
        
        x4_d = F.dropout(F.relu(self.gc4(combined3, adj)), self.dropout, training=self.training)
        
        combined4 =  torch.cat([x_df4,x1_df4,x2_df4,x3_df4,x4_d],dim=1).view(x.shape[0],5,self.nhid4)
        combined4 = torch.max(combined4,dim=1).values
        combined4 = F.dropout(combined4, self.dropout, training=self.training)
        
        
        x5_d = F.dropout(F.relu(self.gc5(combined4, adj)), self.dropout, training=self.training)
        
        combined5 = torch.cat([x1_d, x2_d, x3_d, x4_d, x5_d], dim=1)
        combined5 = F.dropout(combined5, self.dropout, training=self.training)
        
        x6 = self.gcFM6(combined5,adj,self.nhid1,self.nhid2+self.nhid1,self.nhid3+self.nhid2+self.nhid1,self.nhid4+self.nhid3+self.nhid2+self.nhid1) 
        return F.log_softmax(x6, dim=1)

In [6]:
def train(epoch, model,record):
    t = time.time()
    model.train()
    optimizer.zero_grad()
    output = model(features, adj)
    loss_train = F.cross_entropy(output[idx_train], labels[idx_train]) 
    acc_train = accuracy(output[idx_train], labels[idx_train])
    loss_train.backward()
    optimizer.step()
    model.eval()
    output = model(features, adj)

    loss_val = F.cross_entropy(output[idx_val], labels[idx_val])
    acc_val = accuracy(output[idx_val], labels[idx_val])
    
    loss_test = F.cross_entropy(output[idx_test], labels[idx_test])
    acc_test = accuracy(output[idx_test], labels[idx_test])
    print('Epoch: {:04d}'.format(epoch+1),
          'loss_train: {:.4f}'.format(loss_train.item()),
          'acc_train: {:.4f}'.format(acc_train.item()),
          'acc_val: {:.4f}'.format(acc_val.item()),
          'acc_test: {:.4f}'.format(acc_test.item()),
          'time: {:.4f}s'.format(time.time() - t))
    record[acc_val.item()] = acc_test.item()

In [7]:
model = GCN(nfeat=features.shape[1],
                nhid1=32,
                nhid2=32,
                nhid3=32,
                nhid4=32,
                nhid5=32,
                nclass=labels.max().item() + 1,
                dropout=0.8)
model.cuda()
optimizer = optim.Adam(model.parameters(),
                       lr=0.02, weight_decay=5e-4)
t_total = time.time()
record = {}
for epoch in range(400):  
    train(epoch,model,record)
print("Optimization Finished!")
print("Total time elapsed: {:.4f}s".format(time.time() - t_total))
bit_list = sorted(record.keys())
bit_list.reverse()
for key in bit_list[:10]:
    value = record[key]
    print(key,value)

Epoch: 0001 loss_train: 10.1382 acc_train: 0.1286 acc_val: 0.1320 acc_test: 0.1190 time: 0.4096s
Epoch: 0002 loss_train: 3.3943 acc_train: 0.1571 acc_val: 0.1620 acc_test: 0.1490 time: 0.0247s
Epoch: 0003 loss_train: 2.7748 acc_train: 0.1143 acc_val: 0.1540 acc_test: 0.1470 time: 0.0302s
Epoch: 0004 loss_train: 2.2759 acc_train: 0.1357 acc_val: 0.1260 acc_test: 0.1310 time: 0.0236s
Epoch: 0005 loss_train: 2.2139 acc_train: 0.1071 acc_val: 0.1420 acc_test: 0.1400 time: 0.0257s
Epoch: 0006 loss_train: 2.0917 acc_train: 0.1571 acc_val: 0.1580 acc_test: 0.1500 time: 0.0303s
Epoch: 0007 loss_train: 2.0865 acc_train: 0.1429 acc_val: 0.1740 acc_test: 0.1500 time: 0.0237s
Epoch: 0008 loss_train: 2.0954 acc_train: 0.1071 acc_val: 0.3160 acc_test: 0.3330 time: 0.0259s
Epoch: 0009 loss_train: 1.9814 acc_train: 0.1571 acc_val: 0.3100 acc_test: 0.3240 time: 0.0326s
Epoch: 0010 loss_train: 1.9121 acc_train: 0.1857 acc_val: 0.3080 acc_test: 0.3200 time: 0.0305s
Epoch: 0011 loss_train: 1.9890 acc_trai

Epoch: 0089 loss_train: 1.2710 acc_train: 0.5571 acc_val: 0.7740 acc_test: 0.7940 time: 0.0307s
Epoch: 0090 loss_train: 1.2557 acc_train: 0.6143 acc_val: 0.7760 acc_test: 0.8000 time: 0.0309s
Epoch: 0091 loss_train: 1.2630 acc_train: 0.5500 acc_val: 0.7740 acc_test: 0.7970 time: 0.0267s
Epoch: 0092 loss_train: 1.2676 acc_train: 0.5429 acc_val: 0.7640 acc_test: 0.7910 time: 0.0318s
Epoch: 0093 loss_train: 1.2937 acc_train: 0.5786 acc_val: 0.7540 acc_test: 0.7840 time: 0.0257s
Epoch: 0094 loss_train: 1.2805 acc_train: 0.5857 acc_val: 0.7420 acc_test: 0.7720 time: 0.0329s
Epoch: 0095 loss_train: 1.2516 acc_train: 0.5929 acc_val: 0.7180 acc_test: 0.7520 time: 0.0327s
Epoch: 0096 loss_train: 1.2579 acc_train: 0.6429 acc_val: 0.7080 acc_test: 0.7420 time: 0.0258s
Epoch: 0097 loss_train: 1.2329 acc_train: 0.5857 acc_val: 0.7000 acc_test: 0.7400 time: 0.0317s
Epoch: 0098 loss_train: 1.1861 acc_train: 0.6143 acc_val: 0.7020 acc_test: 0.7400 time: 0.0308s
Epoch: 0099 loss_train: 1.2017 acc_train

Epoch: 0180 loss_train: 1.0066 acc_train: 0.6643 acc_val: 0.7700 acc_test: 0.8020 time: 0.0304s
Epoch: 0181 loss_train: 0.9928 acc_train: 0.6786 acc_val: 0.7660 acc_test: 0.8020 time: 0.0295s
Epoch: 0182 loss_train: 0.9782 acc_train: 0.6786 acc_val: 0.7720 acc_test: 0.8060 time: 0.0303s
Epoch: 0183 loss_train: 0.9688 acc_train: 0.6929 acc_val: 0.7780 acc_test: 0.8030 time: 0.0303s
Epoch: 0184 loss_train: 0.9624 acc_train: 0.7071 acc_val: 0.7840 acc_test: 0.8040 time: 0.0288s
Epoch: 0185 loss_train: 0.9866 acc_train: 0.6857 acc_val: 0.7900 acc_test: 0.8120 time: 0.0323s
Epoch: 0186 loss_train: 0.9964 acc_train: 0.6429 acc_val: 0.7920 acc_test: 0.8160 time: 0.0306s
Epoch: 0187 loss_train: 1.0849 acc_train: 0.6214 acc_val: 0.7980 acc_test: 0.8240 time: 0.0298s
Epoch: 0188 loss_train: 1.0860 acc_train: 0.6429 acc_val: 0.7960 acc_test: 0.8230 time: 0.0300s
Epoch: 0189 loss_train: 1.0138 acc_train: 0.7000 acc_val: 0.7940 acc_test: 0.8250 time: 0.0306s
Epoch: 0190 loss_train: 0.9713 acc_train

Epoch: 0271 loss_train: 0.9499 acc_train: 0.6857 acc_val: 0.7800 acc_test: 0.7740 time: 0.0305s
Epoch: 0272 loss_train: 0.9585 acc_train: 0.6143 acc_val: 0.7760 acc_test: 0.7800 time: 0.0299s
Epoch: 0273 loss_train: 0.8502 acc_train: 0.7429 acc_val: 0.7740 acc_test: 0.7830 time: 0.0300s
Epoch: 0274 loss_train: 0.9972 acc_train: 0.6786 acc_val: 0.7760 acc_test: 0.7910 time: 0.0303s
Epoch: 0275 loss_train: 0.8353 acc_train: 0.7357 acc_val: 0.7880 acc_test: 0.8060 time: 0.0290s
Epoch: 0276 loss_train: 0.7987 acc_train: 0.7571 acc_val: 0.7940 acc_test: 0.8090 time: 0.0325s
Epoch: 0277 loss_train: 0.8196 acc_train: 0.7714 acc_val: 0.7940 acc_test: 0.8110 time: 0.0306s
Epoch: 0278 loss_train: 0.8312 acc_train: 0.7643 acc_val: 0.7900 acc_test: 0.8210 time: 0.0294s
Epoch: 0279 loss_train: 0.8826 acc_train: 0.7286 acc_val: 0.7960 acc_test: 0.8270 time: 0.0305s
Epoch: 0280 loss_train: 0.8351 acc_train: 0.7214 acc_val: 0.7940 acc_test: 0.8250 time: 0.0258s
Epoch: 0281 loss_train: 0.8190 acc_train

Epoch: 0362 loss_train: 0.8813 acc_train: 0.7000 acc_val: 0.8060 acc_test: 0.8350 time: 0.0268s
Epoch: 0363 loss_train: 0.8079 acc_train: 0.7357 acc_val: 0.8040 acc_test: 0.8330 time: 0.0318s
Epoch: 0364 loss_train: 0.8521 acc_train: 0.7357 acc_val: 0.8080 acc_test: 0.8280 time: 0.0286s
Epoch: 0365 loss_train: 0.7034 acc_train: 0.7429 acc_val: 0.8020 acc_test: 0.8240 time: 0.0298s
Epoch: 0366 loss_train: 0.8372 acc_train: 0.7286 acc_val: 0.8000 acc_test: 0.8160 time: 0.0258s
Epoch: 0367 loss_train: 0.7171 acc_train: 0.7857 acc_val: 0.8020 acc_test: 0.8130 time: 0.0317s
Epoch: 0368 loss_train: 0.8240 acc_train: 0.7143 acc_val: 0.8040 acc_test: 0.8150 time: 0.0267s
Epoch: 0369 loss_train: 0.7457 acc_train: 0.7571 acc_val: 0.8060 acc_test: 0.8240 time: 0.0289s
Epoch: 0370 loss_train: 0.9346 acc_train: 0.7500 acc_val: 0.8020 acc_test: 0.8210 time: 0.0257s
Epoch: 0371 loss_train: 0.8690 acc_train: 0.7143 acc_val: 0.8000 acc_test: 0.8200 time: 0.0303s
Epoch: 0372 loss_train: 0.8416 acc_train