In [1]:
import os
import time
import math
import copy
import random
import numpy as np
import scipy
from scipy.special import comb

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics

from dglnode import DglNodePropPredDataset
from model_arxiv_node_classification import Specformer

In [2]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def init_params(module):
    if isinstance(module, nn.Linear):
        module.weight.data.normal_(mean=0.0, std=0.01)
        if module.bias is not None:
            module.bias.data.zero_()
            
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
def mean_confidence_interval(data, confidence=0.95):
    import scipy.stats as st

    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), st.sem(a)
    h = se * st.t.ppf((1 + confidence) / 2., n-1)
    return m, h

def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)

In [3]:
torch.cuda.set_device(1)

In [4]:
data_name = 'arxiv'
e_tensor, u_tensor, x_tensor, y_tensor = torch.load('data/{}.pt'.format(data_name))

x = x_tensor.cuda()
y = y_tensor.cuda().squeeze()

if len(y.size()) > 1:
    y = torch.argmax(y, dim=1)

In [5]:
dataset = DglNodePropPredDataset('ogbn-arxiv')

In [6]:
split = dataset.get_idx_split()

In [7]:
train, valid, test = split['train'], split['valid'], split['test']

In [8]:
train, valid, test = map(torch.LongTensor, (train, valid, test))
train, valid, test = train.cuda(), valid.cuda(), test.cuda()

In [9]:
epoch = 2000

nclass = 40
nfeat = x.size(1)

nlayer = 1
num_heads = 1
hidden_dim = 512
tran_dropout = 0.1
prop_dropout = 0.1
feat_dropout = 0.1

lr = 1e-3
weight_decay = 0

metric = 'loss'
evaluation = torchmetrics.Accuracy()
loss = nn.CrossEntropyLoss()

In [10]:
seed_everything(0)

net = Specformer(nclass, nfeat, nlayer, hidden_dim, num_heads, tran_dropout, feat_dropout, prop_dropout).cuda()
net.apply(init_params)
optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)

print(count_parameters(net))

res = []
times = []
min_loss = 100.0
max_acc = 0
counter = 0

topk = 5000
rand_idx = torch.LongTensor([i for i in range(x.size(0))][:topk])
e = e_tensor[rand_idx].cuda()
u = u_tensor[:, rand_idx].cuda()

for idx in range(epoch):
        
    net.train()
    optimizer.zero_grad()
    t1 = time.time()
    logits = net(e, u, x)
    train_loss = loss(logits[train], y[train])
    train_loss.backward()
    optimizer.step()
    t2 = time.time()
    
    net.eval()
    logits = net(e, u, x)
    val_pred = logits[valid].detach().cpu()
    val_label = y[valid].detach().cpu()
    val_loss = loss(val_pred, val_label).item()
    val_acc = evaluation(val_pred, val_label).item()

    test_pred = logits[test].detach().cpu()
    test_label = y[test].detach().cpu()
    test_acc = evaluation(test_pred, test_label).item()
    
    res.append([train_loss.item(), val_loss, val_acc, test_acc])
    
    if metric == 'loss':
        if val_loss < min_loss:
            min_loss = val_loss
            counter = 0
        else:
            counter += 1
    else:
        if val_acc > max_acc:
            max_acc = val_acc
            best_net = copy.deepcopy(net)
            counter = 0
        else:
            counter += 1
        
    if counter == 200:
        break
    
    times.append(t2-t1)
#     print(np.mean(times))
    
    print(idx, train_loss.item(), val_loss, val_acc, test_acc, np.mean(times))

1931305
0 3.687615394592285 3.6328325271606445 0.07661330699920654 0.058802954852581024 0.8284280300140381
1 3.6180951595306396 3.4604480266571045 0.0762777253985405 0.05861778184771538 0.4202967882156372
2 3.426513910293579 3.2544105052948 0.0762777253985405 0.05861778184771538 0.2829123338063558
3 3.224663496017456 3.117955446243286 0.0762777253985405 0.05861778184771538 0.21422284841537476
4 3.1090281009674072 3.034008502960205 0.0762777253985405 0.05861778184771538 0.17301998138427735
5 3.0556752681732178 2.965461015701294 0.0762777253985405 0.05861778184771538 0.14550721645355225
6 3.0298068523406982 2.905622720718384 0.0762777253985405 0.05861778184771538 0.12586147444588797
7 3.0034260749816895 2.862138032913208 0.24393436312675476 0.2239779382944107 0.11103111505508423
8 2.9777729511260986 2.8352723121643066 0.2966206967830658 0.267411470413208 0.09949972894456652
9 2.9471261501312256 2.818138599395752 0.2981308102607727 0.26833733916282654 0.09026808738708496
10 2.913299798965

84 1.0431909561157227 0.9926493763923645 0.6907614469528198 0.6771392822265625 0.017496883167940028
85 1.04083251953125 0.9914116859436035 0.6913990378379822 0.6790321469306946 0.017388280047926793
86 1.0362989902496338 0.9879888892173767 0.6929762959480286 0.6842581629753113 0.0172822502837784
87 1.0348559617996216 0.9847339391708374 0.6915668249130249 0.6804106831550598 0.017177318984811955
88 1.031522274017334 0.9820166230201721 0.692003071308136 0.6794230937957764 0.017076610179429644
89 1.029572606086731 0.9773633480072021 0.6935803294181824 0.68349689245224 0.016979018847147625
90 1.0261071920394897 0.9767435789108276 0.6952582597732544 0.6877353191375732 0.016882733984307927
91 1.0234053134918213 0.9752658009529114 0.6947548389434814 0.6840935945510864 0.016788969869199005
92 1.0195260047912598 0.9758597016334534 0.691264808177948 0.6779622435569763 0.016712873212752805
93 1.0204353332519531 0.9717022776603699 0.6942179203033447 0.6823858618736267 0.01661776735427532
94 1.016671

166 0.934105634689331 0.9094078540802002 0.7126413583755493 0.695471465587616 0.0128683829735853
167 0.9320888519287109 0.9024038314819336 0.71354740858078 0.7020760178565979 0.0128347376982371
168 0.9291988611221313 0.9060697555541992 0.7121044397354126 0.6980227828025818 0.012805536653868545
169 0.9292860627174377 0.9074143171310425 0.7113661766052246 0.6959446668624878 0.012772306273965275
170 0.9276077747344971 0.8999431729316711 0.7157622575759888 0.70396888256073 0.012739410177308914
171 0.9265381097793579 0.9060523509979248 0.7129433751106262 0.6963561773300171 0.012708302154097446
172 0.9262921214103699 0.9067396521568298 0.7132118344306946 0.696623682975769 0.012676132896732044
173 0.9272480607032776 0.8999077081680298 0.7149904370307922 0.7060263752937317 0.01266752166309576
174 0.9270595908164978 0.9106963276863098 0.7106949687004089 0.6902866363525391 0.012655482973371233
175 0.9265367388725281 0.900688111782074 0.7151246666908264 0.7052651047706604 0.012641446156935259
176

248 0.8869686126708984 0.8873518109321594 0.7167019248008728 0.7006563544273376 0.011358211318173083
249 0.8903257250785828 0.8853102326393127 0.7173730731010437 0.7048124670982361 0.011342562675476075
250 0.886555552482605 0.8802759647369385 0.7197221517562866 0.7053062319755554 0.011325846630263613
251 0.8830379843711853 0.8800142407417297 0.7199234962463379 0.7053474187850952 0.011314850950997973
252 0.8861725926399231 0.8883258700370789 0.7171717286109924 0.7012529969215393 0.011305109785479519
253 0.8842606544494629 0.8817322254180908 0.7190510034561157 0.7070757150650024 0.011289311206246924
254 0.8831192851066589 0.8796254992485046 0.7194536924362183 0.7038865685462952 0.01127614133498248
255 0.8832964301109314 0.8767882585525513 0.7207959890365601 0.7109437584877014 0.011262748390436172
256 0.8809265494346619 0.8875479102134705 0.7178093194961548 0.6980639100074768 0.01124740945689873
257 0.8823390007019043 0.876451849937439 0.7208631038665771 0.7093183398246765 0.0112322920052

330 0.8521960377693176 0.8658970594406128 0.7231115102767944 0.7113347053527832 0.010406608667978707
331 0.8487510681152344 0.8642048239707947 0.7244203090667725 0.7142974734306335 0.010398461876145328
332 0.8498420119285583 0.8745453953742981 0.7214000225067139 0.7045655846595764 0.010390326783463761
333 0.8496993184089661 0.8682371973991394 0.7229437232017517 0.7100179195404053 0.01038188777283994
334 0.8487161993980408 0.8672589063644409 0.7240175604820251 0.7118902206420898 0.010373775282902504
335 0.8504046201705933 0.8780972957611084 0.7173730731010437 0.6992161273956299 0.010365657153583709
336 0.8481246829032898 0.8675900101661682 0.7228094935417175 0.7131658792495728 0.01035736505638598
337 0.8488243818283081 0.8691470623016357 0.7228094935417175 0.7084130644798279 0.010349150240068605
338 0.846599280834198 0.872714638710022 0.7211315631866455 0.7035573720932007 0.010341151977359018
339 0.8466947078704834 0.8650580644607544 0.7228766083717346 0.715634822845459 0.01033311030443

412 0.817939043045044 0.8588079810142517 0.7260646224021912 0.7165812849998474 0.00987895175850709
413 0.8194596171379089 0.8624311089515686 0.7246552109718323 0.7097504138946533 0.009873434541306058
414 0.8180535435676575 0.8587132692337036 0.7255277037620544 0.7121371030807495 0.009868107646344656
415 0.8184431791305542 0.8597468733787537 0.7252256870269775 0.7152027487754822 0.009862750195539914
416 0.8174915313720703 0.8621864914894104 0.7256283760070801 0.7081249952316284 0.009857380132881
417 0.8169862627983093 0.8579983115196228 0.7266015410423279 0.7131658792495728 0.009852250226947109
418 0.8168166875839233 0.8578183054924011 0.7255948185920715 0.7148324251174927 0.009846799413457976
419 0.816391110420227 0.857026219367981 0.7256954908370972 0.7138653993606567 0.009841838337126232
420 0.8139024972915649 0.8568735122680664 0.726668655872345 0.7138242721557617 0.00983666241027397
421 0.8151047229766846 0.8561455011367798 0.7274405360221863 0.7146472334861755 0.009831378245240704

494 0.7862799763679504 0.8563331365585327 0.7262659668922424 0.7112729549407959 0.009521530613754735
495 0.7868262529373169 0.8526230454444885 0.7280781269073486 0.7148735523223877 0.009517657660668897
496 0.7867218852043152 0.8556500673294067 0.7267022132873535 0.7108408808708191 0.009513714903557085
497 0.7854506373405457 0.8499176502227783 0.7285143733024597 0.7167870402336121 0.009509933521469912
498 0.7852601408958435 0.8506782054901123 0.7282123565673828 0.7160257697105408 0.009506186884725262
499 0.7837912440299988 0.8523072004318237 0.7280445694923401 0.7154908180236816 0.00950242042541504
500 0.7842161059379578 0.8530017137527466 0.7273063063621521 0.7147089838981628 0.009498581914844628
501 0.7841751575469971 0.8511320948600769 0.7282123565673828 0.7153879404067993 0.0094977232564493
502 0.7817796468734741 0.8463667035102844 0.7293868660926819 0.720572829246521 0.00949392735839601
503 0.7840469479560852 0.8681985139846802 0.7214000225067139 0.7017056345939636 0.00949010158342

576 0.7586836218833923 0.8487688302993774 0.7304943203926086 0.717486560344696 0.0092513548852666
577 0.7569347023963928 0.848868727684021 0.7296218276023865 0.716293215751648 0.009248027339526114
578 0.7566487193107605 0.8515808582305908 0.7287492752075195 0.713474452495575 0.00924530259891701
579 0.7572858929634094 0.8504925966262817 0.7295210957527161 0.7155319452285767 0.009241828014110698
580 0.7551758289337158 0.8512259125709534 0.7297896146774292 0.7144209146499634 0.009238483163042446
581 0.7565784454345703 0.8496832251548767 0.7299909591674805 0.7142357230186462 0.009235211254395161
582 0.7555040121078491 0.8538906574249268 0.7289841771125793 0.7142974734306335 0.009231844242809158
583 0.7594396471977234 0.8492639660835266 0.7295547127723694 0.7151616215705872 0.009229026020389714
584 0.755986213684082 0.8488451242446899 0.7299909591674805 0.7148118615150452 0.009225657862475794
585 0.7558826804161072 0.8546430468559265 0.7286821603775024 0.7126103043556213 0.00922248225163274

658 0.7380666136741638 0.8524570465087891 0.729453980922699 0.7156965732574463 0.009030227053328241
659 0.7392140626907349 0.8514018058776855 0.7288499474525452 0.7159023284912109 0.009027685179854885
660 0.7393475770950317 0.8570151925086975 0.7287828326225281 0.7121988534927368 0.009025731353644343
661 0.7376716732978821 0.851326048374176 0.7299574017524719 0.7163344025611877 0.00902350683586835
662 0.736079216003418 0.8577458262443542 0.7267022132873535 0.7112729549407959 0.00902112900401672
663 0.7381367683410645 0.8491687774658203 0.7294204235076904 0.7168076038360596 0.009019089032368487
664 0.7350378632545471 0.8482481241226196 0.7310983538627625 0.718762218952179 0.009016566886041397
665 0.7343701124191284 0.855326235294342 0.7277089953422546 0.7145649194717407 0.009014145032063619
666 0.733208179473877 0.8538668155670166 0.7285814881324768 0.7140917181968689 0.009012094323245482
667 0.7339208126068115 0.8470636606216431 0.7302258610725403 0.7175483107566833 0.00900965750574351

740 0.7177945375442505 0.8484679460525513 0.7295882701873779 0.718412458896637 0.00900932555256585
741 0.7172735929489136 0.8562071323394775 0.7286150455474854 0.7128366827964783 0.009011078073650678
742 0.7183765769004822 0.8537854552268982 0.7284472584724426 0.7122399806976318 0.009012757207631743
743 0.7160608172416687 0.8462409973144531 0.7309305667877197 0.7184741497039795 0.009014995508296516
744 0.717980146408081 0.8544746041297913 0.7280445694923401 0.712425172328949 0.009016833209351404
745 0.7185677289962769 0.8544659614562988 0.7271385192871094 0.712466299533844 0.009018615807029581
746 0.7159551382064819 0.8498101830482483 0.7293533086776733 0.717095673084259 0.009020868554172745
747 0.7168228030204773 0.8503486514091492 0.7294875383377075 0.7155525088310242 0.009022257863519026
748 0.715584397315979 0.8553211092948914 0.7285479307174683 0.7138859629631042 0.009023176812043336
749 0.7143864035606384 0.8514357209205627 0.7295547127723694 0.7166841626167297 0.0090241607030232