In [1]:
import os
import time
import math
import copy
import random
import numpy as np
import scipy
from scipy.special import comb
from scipy import io
from sklearn.preprocessing import label_binarize

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics

from model_penn_node_classification import Specformer

In [2]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def init_params(module):
    if isinstance(module, nn.Linear):
        module.weight.data.normal_(mean=0.0, std=0.01)
        if module.bias is not None:
            module.bias.data.zero_()
            
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
def mean_confidence_interval(data, confidence=0.95):
    import scipy.stats as st

    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), st.sem(a)
    h = se * st.t.ppf((1 + confidence) / 2., n-1)
    return m, h

def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)

def load_fb100(filename):
    # e.g. filename = Rutgers89 or Cornell5 or Wisconsin87 or Amherst41
    # columns are: student/faculty, gender, major,
    #              second major/minor, dorm/house, year/ high school
    # 0 denotes missing entry
    mat = io.loadmat('data/Penn94.mat')
    A = mat['A']
    metadata = mat['local_info']
    return A, metadata


def load_fb100_dataset():
    A, metadata = load_fb100('Penn94.mat')

    edge_index = torch.tensor(A.nonzero(), dtype=torch.long)
    metadata = metadata.astype(int)
    label = metadata[:, 1] - 1  # gender label, -1 means unlabeled

    # make features into one-hot encodings
    feature_vals = np.hstack((np.expand_dims(metadata[:, 0], 1), metadata[:, 2:]))
    features = np.empty((A.shape[0], 0))
    for col in range(feature_vals.shape[1]):
        feat_col = feature_vals[:, col]
        feat_onehot = label_binarize(feat_col, classes=np.unique(feat_col))
        features = np.hstack((features, feat_onehot))

    node_feat = torch.tensor(features, dtype=torch.float)
    num_nodes = metadata.shape[0]
    
    return num_nodes, edge_index, node_feat, label

In [3]:
torch.cuda.set_device(2)

In [4]:
data_name = 'penn'
e_tensor, u_tensor, x_tensor, y_tensor = torch.load('data/{}.pt'.format(data_name))

x = x_tensor.cuda()
y = y_tensor.cuda()

if len(y.size()) > 1:
    y = torch.argmax(y, dim=1)

In [5]:
split = np.load('data/fb100-Penn94-splits.npy', allow_pickle=True)

In [6]:
split_idx = 0
train, valid, test = split[split_idx]['train'], split[split_idx]['valid'], split[split_idx]['test']

In [7]:
train, valid, test = map(torch.LongTensor, (train, valid, test))
train, valid, test = train.cuda(), valid.cuda(), test.cuda()

In [8]:
epoch = 2000
nclass = 2
nfeat = x.size(1)

nlayer = 1
num_heads = 1
hidden_dim = 64
tran_dropout = 0.0
prop_dropout = 0.4
feat_dropout = 0.4

lr = 1e-3
weight_decay = 1e-3

metric = 'loss'
evaluation = torchmetrics.Accuracy()
loss = nn.CrossEntropyLoss()

In [9]:
seed_everything(1)

net = Specformer(nclass, nfeat, nlayer, hidden_dim, num_heads, tran_dropout, feat_dropout, prop_dropout).cuda()
net.apply(init_params)
optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)

print(count_parameters(net))

res = []
times = []
min_loss = 100.0
max_acc = 0
counter = 0

topk = 3000
rand_idx = torch.LongTensor([i for i in range(x.size(0))][:topk] + [i for i in range(x.size(0))][-topk:])
e = e_tensor[rand_idx].cuda()
u = u_tensor[:, rand_idx].cuda()

for idx in range(epoch):
        
    net.train()
    optimizer.zero_grad()
    logits = net(e, u, x)
    train_loss = loss(logits[train], y[train])
    train_loss.backward()
    optimizer.step()
    
    net.eval()
    logits = net(e, u, x)
    val_pred = logits[valid].detach().cpu()
    val_label = y[valid].detach().cpu()
    val_loss = loss(val_pred, val_label).item()
    val_acc = evaluation(val_pred, val_label).item()

    test_pred = logits[test].detach().cpu()
    test_label = y[test].detach().cpu()
    test_acc = evaluation(test_pred, test_label).item()
    
    res.append([train_loss.item(), val_loss, val_acc, test_acc])
    
    if metric == 'loss':
        if val_loss < min_loss:
            min_loss = val_loss
            counter = 0
        else:
            counter += 1
    else:
        if val_acc > max_acc:
            max_acc = val_acc
            best_net = copy.deepcopy(net)
            counter = 0
        else:
            counter += 1
        
    if counter == 200:
        break
    
    print(idx, train_loss.item(), val_loss, val_acc, test_acc)

338179
0 0.6931772232055664 0.6930996775627136 0.5194270014762878 0.5279752612113953
1 0.6925821304321289 0.6930509209632874 0.5194270014762878 0.5279752612113953
2 0.6917356252670288 0.693000853061676 0.5194270014762878 0.5279752612113953
3 0.6910204887390137 0.6929486989974976 0.5194270014762878 0.5279752612113953
4 0.6899073123931885 0.6928951740264893 0.5194270014762878 0.5279752612113953
5 0.6887400150299072 0.6928400993347168 0.5194270014762878 0.5279752612113953
6 0.6873757839202881 0.6927828192710876 0.5194270014762878 0.5279752612113953
7 0.6855469346046448 0.6927207708358765 0.5194270014762878 0.5279752612113953
8 0.6831280589103699 0.6926555037498474 0.5194270014762878 0.5279752612113953
9 0.6808530688285828 0.6925871968269348 0.5194270014762878 0.5279752612113953
10 0.6778061389923096 0.6925115585327148 0.5194270014762878 0.5279752612113953
11 0.6740731596946716 0.6924243569374084 0.5194270014762878 0.5279752612113953
12 0.6698228716850281 0.6923224329948425 0.5194270014762

108 0.3111274838447571 0.362347811460495 0.8435535430908203 0.8437918424606323
109 0.309066504240036 0.36129575967788696 0.8426259756088257 0.8444101214408875
110 0.31008923053741455 0.3598185181617737 0.8416984677314758 0.8451313972473145
111 0.30543315410614014 0.3581065237522125 0.8415954113006592 0.8442040085792542
112 0.30851516127586365 0.35832130908966064 0.8412861824035645 0.8441009521484375
113 0.31051987409591675 0.3586030602455139 0.8409770131111145 0.8432766795158386
114 0.3044298589229584 0.36006486415863037 0.8397402763366699 0.8421432375907898
115 0.3046947121620178 0.3604001998901367 0.8389158248901367 0.8433796763420105
116 0.30578720569610596 0.3609171509742737 0.8389158248901367 0.8441009521484375
117 0.30475059151649475 0.36005863547325134 0.8393280506134033 0.8447192311286926
118 0.3045269846916199 0.35855165123939514 0.8393280506134033 0.8459556698799133
119 0.3017549514770508 0.3581233322620392 0.840255618095398 0.8451313972473145
120 0.3073551058769226 0.3574825

215 0.26175132393836975 0.3582858145236969 0.8453055620193481 0.8422462940216064
216 0.27157947421073914 0.3561048209667206 0.8479851484298706 0.843173623085022
217 0.26151272654533386 0.3573654890060425 0.8442749381065369 0.8425554037094116
218 0.26285675168037415 0.35973358154296875 0.8438627123832703 0.8412158489227295
219 0.2644636631011963 0.356655478477478 0.8442749381065369 0.8425554037094116
220 0.2613380551338196 0.35526183247566223 0.843656599521637 0.8447192311286926
221 0.2645006477832794 0.3557078242301941 0.8422137498855591 0.8449252843856812
222 0.26556727290153503 0.3588554859161377 0.8405647873878479 0.842452347278595
223 0.25795698165893555 0.3616449534893036 0.8412861824035645 0.8416280150413513
224 0.26012924313545227 0.3615897297859192 0.8414922952651978 0.8423492908477783
225 0.26191234588623047 0.35859718918800354 0.8446872234344482 0.842452347278595
226 0.26123520731925964 0.3576856255531311 0.8442749381065369 0.8426584005355835
227 0.2539701461791992 0.35943275

317 0.23634324967861176 0.3623063266277313 0.8434504866600037 0.8392581343650818
318 0.2390439212322235 0.3607495427131653 0.8421106934547424 0.8394641876220703
319 0.24138687551021576 0.35857516527175903 0.8445841670036316 0.8423492908477783
320 0.23725926876068115 0.36001038551330566 0.8441718816757202 0.8423492908477783
321 0.24181824922561646 0.36257222294807434 0.8432443737983704 0.8439979553222656
322 0.24183747172355652 0.3644603490829468 0.8418015241622925 0.8445131182670593
323 0.23975178599357605 0.3664708435535431 0.8430382609367371 0.8436888456344604
324 0.24334372580051422 0.3663458526134491 0.8439657688140869 0.8442040085792542
325 0.2403559535741806 0.36731672286987305 0.8442749381065369 0.8415249586105347
326 0.23505248129367828 0.3642983138561249 0.8447902798652649 0.8419371247291565
327 0.23912830650806427 0.3616682291030884 0.8445841670036316 0.8429675698280334
328 0.23700417578220367 0.36276036500930786 0.8435535430908203 0.8423492908477783
329 0.2377922683954239 0.

In [11]:
ind1 = sorted(res, key=lambda x: x[1], reverse=False)
# ind2 = sorted(res, key=lambda x: x[2], reverse=True)
# print(ind1[0])
print(ind2[0])

[0.27430588006973267, 0.3515945076942444, 0.8495310544967651, 0.8457496166229248]
