In [1]:
import argparse
import numpy as np
import networkx as nx
import time
import torch
import torch.nn.functional as F
import dgl
from dgl.data import register_data_args
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset

from gat import GAT
from utils import EarlyStopping


In [2]:
def accuracy(logits, labels):
    _, indices = torch.max(logits, dim=1)
    correct = torch.sum(indices == labels)
    return correct.item() * 1.0 / len(labels)


In [3]:
def evaluate(model, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(features)
        logits = logits[mask]
        labels = labels[mask]
        return accuracy(logits, labels)


In [4]:
#param:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 200
num_heads = 8
num_out_heads = 1
num_layers = 3
num_hidden = [16,8,4]
residual = False
in_drop = .06
attn_drop = .06
lr = 0.005
weight_decay = 5e-4
negative_slope = 0.2
early_stop = False
fastmode = False


data = CoraGraphDataset()


g = data[0]


features = g.ndata['feat']
labels = g.ndata['label']
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
num_feats = features.shape[1]
n_classes = data.num_labels
n_edges = g.number_of_edges()
print("""----Data statistics------'
  #Edges %d
  #Classes %d
  #Train samples %d
  #Val samples %d
  #Test samples %d""" %
      (n_edges, n_classes,
       train_mask.int().sum().item(),
       val_mask.int().sum().item(),
       test_mask.int().sum().item()))

# add self loop
g = dgl.remove_self_loop(g)
g = dgl.add_self_loop(g)
n_edges = g.number_of_edges()
# create model
heads = ([num_heads] * (num_layers-1)) + [num_out_heads]
# print(heads)
model = GAT(g,
            num_layers,
            num_feats,
            num_hidden,
            n_classes,
            heads,
            F.elu,
            in_drop,
            attn_drop,
            negative_slope,
            residual)
print(model)
if early_stop:
    stopper = EarlyStopping(patience=100)

model.to(device)
loss_fcn = torch.nn.CrossEntropyLoss()

# use optimizer
optimizer = torch.optim.Adam(
    model.parameters(), lr=lr, weight_decay=weight_decay)

# initialize graph
dur = []
for epoch in range(epochs):
    model.train()
    
    # forward
    logits = model(features)
    loss = loss_fcn(logits[train_mask], labels[train_mask])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


    train_acc = accuracy(logits[train_mask], labels[train_mask])

    if fastmode:
        val_acc = accuracy(logits[val_mask], labels[val_mask])
    else:
        val_acc = evaluate(model, features, labels, val_mask)
        if early_stop:
            if stopper.step(val_acc, model):
                break

    print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |"
          " ValAcc {:.4f} | ETputs(KTEPS) {:.2f}".
          format(epoch, np.mean(dur), loss.item(), train_acc,
                 val_acc, n_edges / np.mean(dur) / 1000))

print()
if early_stop:
    model.load_state_dict(torch.load('es_checkpoint.pt'))
acc = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc))




  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
----Data statistics------'
  #Edges 10556
  #Classes 7
  #Train samples 140
  #Val samples 500
  #Test samples 1000
GAT(
  (gat_layers): ModuleList(
    (0): GATConv(
      (fc): Linear(in_features=1433, out_features=128, bias=False)
      (feat_drop): Dropout(p=0.06, inplace=False)
      (attn_drop): Dropout(p=0.06, inplace=False)
      (leaky_relu): LeakyReLU(negative_slope=0.2)
    )
    (1): GATConv(
      (fc): Linear(in_features=128, out_features=64, bias=False)
      (feat_drop): Dropout(p=0.06, inplace=False)
      (attn_drop): Dropout(p=0.06, inplace=False)
      (leaky_relu): LeakyReLU(negative_slope=0.2)
    )
    (2): GATConv(
      (fc): Linear(in_features=64, out_features=7, bias=False)
      (feat_drop): Dropout(p=0.06, inplace=False)
      (attn_drop): Dropout(p=0.06, inplace=False)
      (l

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 00003 | Time(s) nan | Loss 1.7784 | TrainAcc 0.7857 | ValAcc 0.7360 | ETputs(KTEPS) nan
Epoch 00004 | Time(s) nan | Loss 1.7137 | TrainAcc 0.8571 | ValAcc 0.7500 | ETputs(KTEPS) nan
Epoch 00005 | Time(s) nan | Loss 1.6392 | TrainAcc 0.9286 | ValAcc 0.7240 | ETputs(KTEPS) nan
Epoch 00006 | Time(s) nan | Loss 1.5577 | TrainAcc 0.9143 | ValAcc 0.7160 | ETputs(KTEPS) nan
Epoch 00007 | Time(s) nan | Loss 1.4735 | TrainAcc 0.9000 | ValAcc 0.7720 | ETputs(KTEPS) nan
Epoch 00008 | Time(s) nan | Loss 1.3840 | TrainAcc 0.9071 | ValAcc 0.7840 | ETputs(KTEPS) nan
Epoch 00009 | Time(s) nan | Loss 1.2850 | TrainAcc 0.9214 | ValAcc 0.7940 | ETputs(KTEPS) nan
Epoch 00010 | Time(s) nan | Loss 1.1955 | TrainAcc 0.9429 | ValAcc 0.7960 | ETputs(KTEPS) nan
Epoch 00011 | Time(s) nan | Loss 1.0874 | TrainAcc 0.9357 | ValAcc 0.8100 | ETputs(KTEPS) nan
Epoch 00012 | Time(s) nan | Loss 0.9604 | TrainAcc 0.9500 | ValAcc 0.8060 | ETputs(KTEPS) nan
Epoch 00013 | Time(s) nan | Loss 0.8908 | TrainAcc 0.9571 | 

Epoch 00091 | Time(s) nan | Loss 0.0464 | TrainAcc 1.0000 | ValAcc 0.7780 | ETputs(KTEPS) nan
Epoch 00092 | Time(s) nan | Loss 0.0377 | TrainAcc 1.0000 | ValAcc 0.7840 | ETputs(KTEPS) nan
Epoch 00093 | Time(s) nan | Loss 0.0415 | TrainAcc 1.0000 | ValAcc 0.7940 | ETputs(KTEPS) nan
Epoch 00094 | Time(s) nan | Loss 0.0509 | TrainAcc 1.0000 | ValAcc 0.7940 | ETputs(KTEPS) nan
Epoch 00095 | Time(s) nan | Loss 0.0422 | TrainAcc 1.0000 | ValAcc 0.7940 | ETputs(KTEPS) nan
Epoch 00096 | Time(s) nan | Loss 0.0393 | TrainAcc 1.0000 | ValAcc 0.7900 | ETputs(KTEPS) nan
Epoch 00097 | Time(s) nan | Loss 0.0574 | TrainAcc 0.9929 | ValAcc 0.7860 | ETputs(KTEPS) nan
Epoch 00098 | Time(s) nan | Loss 0.0472 | TrainAcc 1.0000 | ValAcc 0.7780 | ETputs(KTEPS) nan
Epoch 00099 | Time(s) nan | Loss 0.0409 | TrainAcc 1.0000 | ValAcc 0.7760 | ETputs(KTEPS) nan
Epoch 00100 | Time(s) nan | Loss 0.0430 | TrainAcc 1.0000 | ValAcc 0.7880 | ETputs(KTEPS) nan
Epoch 00101 | Time(s) nan | Loss 0.0443 | TrainAcc 1.0000 | 

Epoch 00179 | Time(s) nan | Loss 0.0458 | TrainAcc 0.9929 | ValAcc 0.7780 | ETputs(KTEPS) nan
Epoch 00180 | Time(s) nan | Loss 0.0384 | TrainAcc 1.0000 | ValAcc 0.7760 | ETputs(KTEPS) nan
Epoch 00181 | Time(s) nan | Loss 0.0336 | TrainAcc 1.0000 | ValAcc 0.7800 | ETputs(KTEPS) nan
Epoch 00182 | Time(s) nan | Loss 0.0323 | TrainAcc 1.0000 | ValAcc 0.7800 | ETputs(KTEPS) nan
Epoch 00183 | Time(s) nan | Loss 0.0353 | TrainAcc 1.0000 | ValAcc 0.7900 | ETputs(KTEPS) nan
Epoch 00184 | Time(s) nan | Loss 0.0504 | TrainAcc 0.9929 | ValAcc 0.7840 | ETputs(KTEPS) nan
Epoch 00185 | Time(s) nan | Loss 0.0305 | TrainAcc 1.0000 | ValAcc 0.7840 | ETputs(KTEPS) nan
Epoch 00186 | Time(s) nan | Loss 0.0315 | TrainAcc 1.0000 | ValAcc 0.7880 | ETputs(KTEPS) nan
Epoch 00187 | Time(s) nan | Loss 0.0301 | TrainAcc 1.0000 | ValAcc 0.7920 | ETputs(KTEPS) nan
Epoch 00188 | Time(s) nan | Loss 0.0438 | TrainAcc 0.9929 | ValAcc 0.7840 | ETputs(KTEPS) nan
Epoch 00189 | Time(s) nan | Loss 0.0333 | TrainAcc 1.0000 | 