In [1]:
from src.get_data import load_benchmark, load_synthetic
from src.normalization import get_adj_feats
from src.args import get_args
from src.models import get_model
from src.utils import accuracy, LDA_loss
from src.plots import plot_feature
import torch.optim as optim
import torch
import torch.nn.functional as F
import time
import matplotlib.pyplot as plt
import numpy as np
import pickle as pkl

In [2]:
# load dataset
# all tensor, dense
dataset_name = 'citeseer'
# dataset_name = input('input dataset name: cora/citeseer/pubmed/...')

adj, feats, labels, idx_train, idx_val, idx_test = load_benchmark(dataset_name)

Loading citeseer dataset...


  r_inv = np.power(rowsum, -1).flatten()


finish load data


In [3]:
# get args
# model_name = input('choose model: GCN/SGC/GFNN/GFN/AGNN/GIN/...')
model_name = 'PreCompute_AFGNN'
args = get_args(model_opt = model_name, dataset = dataset_name)
norm_method = 'row'

In [4]:
# args.lr = 0.1
# args.weight_decay = 5e-4
# args.degree = 3
# args.epochs = 200

In [5]:
# get input for model
adj, feats = get_adj_feats(adj = adj, feats = feats, model_opt = model_name, degree = args.degree, norm_method=norm_method)

for SGC, return identity matrix and propagated feats


In [6]:
nb_class = (torch.max(labels) + 1).numpy()
Y_onehot =  torch.zeros(labels.shape[0], nb_class).scatter_(1, labels.unsqueeze(-1), 1)

nb_each_class_train = torch.sum(Y_onehot[idx_train], dim = 0)
nb_each_class_inv_train = torch.tensor(np.power(nb_each_class_train.numpy(), -1).flatten())
nb_each_class_inv_mat_train = torch.diag(nb_each_class_inv_train)

nb_each_class_val = torch.sum(Y_onehot[idx_val], dim = 0)
nb_each_class_inv_val = torch.tensor(np.power(nb_each_class_val.numpy(), -1).flatten())
nb_each_class_inv_mat_val = torch.diag(nb_each_class_inv_val)

nb_each_class_test = torch.sum(Y_onehot[idx_test], dim = 0)
nb_each_class_inv_test = torch.tensor(np.power(nb_each_class_test.numpy(), -1).flatten())
nb_each_class_inv_mat_test = torch.diag(nb_each_class_inv_test)

In [7]:
# train, test


def train(epoch, model, optimizer, adj, feats, labels, idx_train, idx_val, \
          idx_test, Y_onehot, nb_each_class_inv_mat_train):
    t = time.time()
    model.train()
    optimizer.zero_grad()
    output, fp1, fp2 = model(feats, adj)
    CE_loss_train = F.nll_loss(output[idx_train], labels[idx_train])
    if model_name == 'AGNN':
        LDA_loss_train = LDA_loss(fp1[idx_train], Y_onehot[idx_train], nb_each_class_inv_mat_train, norm_or_not = False)
        loss_train = CE_loss_train - LDA_loss_train

    else:
        loss_train = CE_loss_train
    acc_train = accuracy(output[idx_train], labels[idx_train])
    loss_train.backward()
    optimizer.step()

    model.eval()
    output, fp1, fp2 = model(feats, adj)
    
    CE_loss_val = F.nll_loss(output[idx_val], labels[idx_val])
    loss_val = CE_loss_val
    acc_val = accuracy(output[idx_val], labels[idx_val])
    
    CE_loss_test = F.nll_loss(output[idx_test], labels[idx_test])
    loss_test = CE_loss_test
    acc_test = accuracy(output[idx_test], labels[idx_test])
    
    
    print('Epoch: {:04d}'.format(epoch+1),
          'loss_train: {:.4f}'.format(loss_train.item()),
          'acc_train: {:.4f}'.format(acc_train.item()),
          'loss_val: {:.4f}'.format(loss_val.item()),
          'acc_val: {:.4f}'.format(acc_val.item()),
#           'loss_test: {:.4f}'.format(loss_test.item()),
#           'acc_test: {:.4f}'.format(acc_test.item()),
          'time: {:.4f}s'.format(time.time() - t))

    return epoch+1, loss_train.item(), acc_train.item(), loss_val.item(), \
            acc_val.item(), loss_test.item(), acc_test.item(), time.time() - t, \
            




In [8]:
def get_acc(adj, feats, labels, idx_train, idx_val, idx_test):

    # get model
    model = get_model(model_opt = model_name, nfeat = feats.size(1), \
                      nclass = labels.max().item()+1, nhid = args.hidden, \
                      dropout = args.dropout, cuda = args.cuda, \
                      dataset = dataset_name, degree = args.degree)
    # optimizer
    optimizer = optim.Adam(model.parameters(),
                               lr=args.lr, weight_decay=args.weight_decay)

    if args.cuda:
        if model_name!='AGNN' and model_name!='GIN':
            model.cuda()
            feats = feats.cuda()
            adj = adj.cuda()
            labels = labels.cuda()
            idx_train = idx_train.cuda()
            idx_val = idx_val.cuda()
            idx_test = idx_test.cuda()


    # Print model's state_dict    
    print("Model's state_dict:")
    for param_tensor in model.state_dict():
        print(param_tensor,"\t",model.state_dict()[param_tensor].size()) 
    print("optimizer's state_dict:")

    # Print optimizer's state_dict
    for var_name in optimizer.state_dict():
        print(var_name,"\t",optimizer.state_dict()[var_name])

    # # Print parameters
    # for name,param in model.named_parameters():
    #     print(name, param)


    training_log = []

    # Train model
    t_total = time.time()
    temp_val_loss = 999999
    temp_test_loss = 0
    temp_test_acc = 0
    PATH = "save/model_param/{}{}.pt".format(model_name, dataset_name)

    for epoch in range(args.epochs):

        epo, trainloss, trainacc, valloss, valacc, testloss, testacc, epotime = train(epoch, model, \
                                                                                      optimizer, adj, feats, \
                                                                                      labels, idx_train, idx_val,\
                                                                                      idx_test, Y_onehot, \
                                                                                      nb_each_class_inv_mat_train)
        training_log.append([epo, trainloss, trainacc, valloss, valacc, testloss, testacc, epotime])

        if valloss <= temp_val_loss:
            temp_val_loss = valloss
            temp_test_loss = testloss
            temp_test_acc = testacc
            torch.save(model.state_dict(), PATH)


    print("Optimization Finished!")
    print("Total time elapsed: {:.4f}s".format(time.time() - t_total))
    print("Best result:",
              "val_loss=",temp_val_loss,
                "test_loss=",temp_test_loss,
                 "test_acc=",temp_test_acc)
    bestmodel = torch.load(PATH)
    if model_name == 'AGNN':
        print("the weight is: ", torch.softmax(bestmodel['gc1.linear_weight'].data,dim=0))
        
    res_acc = temp_test_acc



    # # save training log
    # # expname = input('input experiment name: ')
    # expname = dataset_name + '_' + model_name 
    # log_pk = open('./save/trainlog_'+expname+'.pkl','wb')
    # pkl.dump(np.array(training_log),log_pk)
    # log_pk.close()
    # print("finish save log")

    # # store result

    # X_epoch = np.array(training_log)[:,0]
    # Y1_trainloss = np.array(training_log)[:,1]
    # Y2_valloss = np.array(training_log)[:,3]
    # Y3_testloss = np.array(training_log)[:,5]

    # plt.plot(X_epoch, Y1_trainloss, color = 'k', label = 'train')
    # plt.plot(X_epoch, Y2_valloss, color = 'b', label = 'val')
    # plt.plot(X_epoch, Y3_testloss, color = 'r', linestyle = '-.', label = 'test')
    # plt.xlabel("epochs")
    # plt.ylabel("loss")
    # plt.legend(loc = 'upper right')
    # plt.show()
    
    return temp_test_acc

In [None]:
acc = []
for i in range(10):
    acc.append(get_acc(adj, feats, labels, idx_train, idx_val, idx_test))

Model's state_dict:
W1.weight 	 torch.Size([6, 3703])
W1.bias 	 torch.Size([6])
optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0.0005, 'amsgrad': False, 'params': [4526314984, 4526314696]}]
Epoch: 0001 loss_train: 1.7917 acc_train: 0.1667 loss_val: 1.7870 acc_val: 0.3980 time: 0.0420s
Epoch: 0002 loss_train: 1.7192 acc_train: 0.6583 loss_val: 1.7487 acc_val: 0.6360 time: 0.0309s
Epoch: 0003 loss_train: 1.6657 acc_train: 0.9500 loss_val: 1.7221 acc_val: 0.5500 time: 0.0309s
Epoch: 0004 loss_train: 1.6337 acc_train: 0.8667 loss_val: 1.7089 acc_val: 0.5500 time: 0.0436s
Epoch: 0005 loss_train: 1.6119 acc_train: 0.8583 loss_val: 1.7035 acc_val: 0.6840 time: 0.0310s
Epoch: 0006 loss_train: 1.5946 acc_train: 0.9083 loss_val: 1.7025 acc_val: 0.7080 time: 0.0387s
Epoch: 0007 loss_train: 1.5819 acc_train: 0.9333 loss_val: 1.7023 acc_val: 0.6620 time: 0.0348s
Epoch: 0008 loss_train: 1.5732 acc_train: 0.9333 loss_val: 1.6995 acc

Epoch: 0085 loss_train: 1.5513 acc_train: 0.9083 loss_val: 1.6614 acc_val: 0.7200 time: 0.0747s
Epoch: 0086 loss_train: 1.5513 acc_train: 0.9083 loss_val: 1.6613 acc_val: 0.7200 time: 0.0621s
Epoch: 0087 loss_train: 1.5513 acc_train: 0.9083 loss_val: 1.6612 acc_val: 0.7200 time: 0.0719s
Epoch: 0088 loss_train: 1.5512 acc_train: 0.9083 loss_val: 1.6612 acc_val: 0.7200 time: 0.0742s
Epoch: 0089 loss_train: 1.5511 acc_train: 0.9083 loss_val: 1.6611 acc_val: 0.7200 time: 0.0723s
Epoch: 0090 loss_train: 1.5510 acc_train: 0.9083 loss_val: 1.6611 acc_val: 0.7200 time: 0.0328s
Epoch: 0091 loss_train: 1.5510 acc_train: 0.9083 loss_val: 1.6610 acc_val: 0.7220 time: 0.0303s
Epoch: 0092 loss_train: 1.5509 acc_train: 0.9083 loss_val: 1.6609 acc_val: 0.7220 time: 0.0310s
Epoch: 0093 loss_train: 1.5509 acc_train: 0.9083 loss_val: 1.6608 acc_val: 0.7280 time: 0.0310s
Epoch: 0094 loss_train: 1.5509 acc_train: 0.9083 loss_val: 1.6608 acc_val: 0.7280 time: 0.0402s
Epoch: 0095 loss_train: 1.5508 acc_train

Epoch: 0173 loss_train: 1.5500 acc_train: 0.9000 loss_val: 1.6599 acc_val: 0.7240 time: 0.0780s
Epoch: 0174 loss_train: 1.5500 acc_train: 0.9000 loss_val: 1.6599 acc_val: 0.7240 time: 0.0763s
Epoch: 0175 loss_train: 1.5500 acc_train: 0.9000 loss_val: 1.6599 acc_val: 0.7240 time: 0.0752s
Epoch: 0176 loss_train: 1.5500 acc_train: 0.9000 loss_val: 1.6599 acc_val: 0.7240 time: 0.0586s
Epoch: 0177 loss_train: 1.5500 acc_train: 0.9000 loss_val: 1.6599 acc_val: 0.7240 time: 0.0632s
Epoch: 0178 loss_train: 1.5500 acc_train: 0.9000 loss_val: 1.6599 acc_val: 0.7240 time: 0.0586s
Epoch: 0179 loss_train: 1.5500 acc_train: 0.9000 loss_val: 1.6599 acc_val: 0.7240 time: 0.0503s
Epoch: 0180 loss_train: 1.5500 acc_train: 0.9000 loss_val: 1.6599 acc_val: 0.7240 time: 0.0581s
Epoch: 0181 loss_train: 1.5500 acc_train: 0.9000 loss_val: 1.6599 acc_val: 0.7240 time: 0.0584s
Epoch: 0182 loss_train: 1.5500 acc_train: 0.9000 loss_val: 1.6599 acc_val: 0.7240 time: 0.0504s
Epoch: 0183 loss_train: 1.5500 acc_train

Epoch: 0058 loss_train: 1.5540 acc_train: 0.9083 loss_val: 1.6637 acc_val: 0.7240 time: 0.0394s
Epoch: 0059 loss_train: 1.5539 acc_train: 0.9083 loss_val: 1.6637 acc_val: 0.7240 time: 0.0371s
Epoch: 0060 loss_train: 1.5537 acc_train: 0.9167 loss_val: 1.6637 acc_val: 0.7240 time: 0.0328s
Epoch: 0061 loss_train: 1.5535 acc_train: 0.9167 loss_val: 1.6636 acc_val: 0.7220 time: 0.0434s
Epoch: 0062 loss_train: 1.5533 acc_train: 0.9167 loss_val: 1.6634 acc_val: 0.7200 time: 0.0323s
Epoch: 0063 loss_train: 1.5531 acc_train: 0.9083 loss_val: 1.6632 acc_val: 0.7220 time: 0.0625s
Epoch: 0064 loss_train: 1.5529 acc_train: 0.9083 loss_val: 1.6629 acc_val: 0.7240 time: 0.0663s
Epoch: 0065 loss_train: 1.5528 acc_train: 0.9083 loss_val: 1.6628 acc_val: 0.7260 time: 0.0648s
Epoch: 0066 loss_train: 1.5527 acc_train: 0.9083 loss_val: 1.6627 acc_val: 0.7260 time: 0.0574s
Epoch: 0067 loss_train: 1.5525 acc_train: 0.9083 loss_val: 1.6627 acc_val: 0.7260 time: 0.0677s
Epoch: 0068 loss_train: 1.5524 acc_train

Epoch: 0143 loss_train: 1.5500 acc_train: 0.9000 loss_val: 1.6599 acc_val: 0.7240 time: 0.0290s
Epoch: 0144 loss_train: 1.5500 acc_train: 0.9000 loss_val: 1.6599 acc_val: 0.7240 time: 0.0375s
Epoch: 0145 loss_train: 1.5500 acc_train: 0.9000 loss_val: 1.6599 acc_val: 0.7240 time: 0.0305s
Epoch: 0146 loss_train: 1.5500 acc_train: 0.9000 loss_val: 1.6599 acc_val: 0.7240 time: 0.0317s
Epoch: 0147 loss_train: 1.5500 acc_train: 0.9000 loss_val: 1.6599 acc_val: 0.7240 time: 0.0300s
Epoch: 0148 loss_train: 1.5500 acc_train: 0.9000 loss_val: 1.6599 acc_val: 0.7240 time: 0.0300s
Epoch: 0149 loss_train: 1.5500 acc_train: 0.9000 loss_val: 1.6599 acc_val: 0.7240 time: 0.0317s
Epoch: 0150 loss_train: 1.5500 acc_train: 0.9000 loss_val: 1.6599 acc_val: 0.7240 time: 0.0434s
Epoch: 0151 loss_train: 1.5500 acc_train: 0.9000 loss_val: 1.6599 acc_val: 0.7240 time: 0.0310s
Epoch: 0152 loss_train: 1.5500 acc_train: 0.9000 loss_val: 1.6599 acc_val: 0.7240 time: 0.0313s
Epoch: 0153 loss_train: 1.5500 acc_train

Epoch: 0028 loss_train: 1.5597 acc_train: 0.9083 loss_val: 1.6696 acc_val: 0.7120 time: 0.0347s
Epoch: 0029 loss_train: 1.5591 acc_train: 0.9083 loss_val: 1.6690 acc_val: 0.7220 time: 0.0357s
Epoch: 0030 loss_train: 1.5589 acc_train: 0.9167 loss_val: 1.6688 acc_val: 0.7300 time: 0.0309s
Epoch: 0031 loss_train: 1.5591 acc_train: 0.9167 loss_val: 1.6689 acc_val: 0.7200 time: 0.0303s
Epoch: 0032 loss_train: 1.5593 acc_train: 0.9083 loss_val: 1.6689 acc_val: 0.7160 time: 0.0307s
Epoch: 0033 loss_train: 1.5593 acc_train: 0.9083 loss_val: 1.6688 acc_val: 0.7100 time: 0.0318s
Epoch: 0034 loss_train: 1.5592 acc_train: 0.9083 loss_val: 1.6685 acc_val: 0.7080 time: 0.0298s
Epoch: 0035 loss_train: 1.5588 acc_train: 0.9083 loss_val: 1.6681 acc_val: 0.7060 time: 0.0372s
Epoch: 0036 loss_train: 1.5583 acc_train: 0.9083 loss_val: 1.6677 acc_val: 0.7160 time: 0.0306s
Epoch: 0037 loss_train: 1.5577 acc_train: 0.9083 loss_val: 1.6675 acc_val: 0.7260 time: 0.0318s
Epoch: 0038 loss_train: 1.5573 acc_train

Epoch: 0114 loss_train: 1.5501 acc_train: 0.9000 loss_val: 1.6601 acc_val: 0.7240 time: 0.0396s
Epoch: 0115 loss_train: 1.5501 acc_train: 0.9000 loss_val: 1.6601 acc_val: 0.7240 time: 0.0361s
Epoch: 0116 loss_train: 1.5501 acc_train: 0.9000 loss_val: 1.6601 acc_val: 0.7240 time: 0.0364s
Epoch: 0117 loss_train: 1.5501 acc_train: 0.9000 loss_val: 1.6601 acc_val: 0.7240 time: 0.0303s
Epoch: 0118 loss_train: 1.5501 acc_train: 0.9000 loss_val: 1.6601 acc_val: 0.7240 time: 0.0292s
Epoch: 0119 loss_train: 1.5501 acc_train: 0.9000 loss_val: 1.6600 acc_val: 0.7240 time: 0.0294s
Epoch: 0120 loss_train: 1.5501 acc_train: 0.9000 loss_val: 1.6600 acc_val: 0.7240 time: 0.0298s
Epoch: 0121 loss_train: 1.5501 acc_train: 0.9000 loss_val: 1.6600 acc_val: 0.7240 time: 0.0392s
Epoch: 0122 loss_train: 1.5501 acc_train: 0.9000 loss_val: 1.6600 acc_val: 0.7240 time: 0.0311s
Epoch: 0123 loss_train: 1.5500 acc_train: 0.9000 loss_val: 1.6600 acc_val: 0.7240 time: 0.0313s
Epoch: 0124 loss_train: 1.5500 acc_train

Epoch: 0200 loss_train: 1.5500 acc_train: 0.9000 loss_val: 1.6600 acc_val: 0.7240 time: 0.0393s
Optimization Finished!
Total time elapsed: 7.3486s
Best result: val_loss= 1.6599222421646118 test_loss= 1.6585291624069214 test_acc= 0.713
Model's state_dict:
W1.weight 	 torch.Size([6, 3703])
W1.bias 	 torch.Size([6])
optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0.0005, 'amsgrad': False, 'params': [5105001528, 5105100120]}]
Epoch: 0001 loss_train: 1.7917 acc_train: 0.1667 loss_val: 1.7855 acc_val: 0.3240 time: 0.0394s
Epoch: 0002 loss_train: 1.7201 acc_train: 0.5000 loss_val: 1.7482 acc_val: 0.6200 time: 0.0352s
Epoch: 0003 loss_train: 1.6658 acc_train: 0.9583 loss_val: 1.7237 acc_val: 0.5060 time: 0.0321s
Epoch: 0004 loss_train: 1.6340 acc_train: 0.8417 loss_val: 1.7108 acc_val: 0.5060 time: 0.0368s
Epoch: 0005 loss_train: 1.6123 acc_train: 0.8167 loss_val: 1.7042 acc_val: 0.6020 time: 0.0375s
Epoch: 0006 loss_train: 1.

In [None]:
acc = np.array(acc)
mean = acc.mean()
var = acc.var()
print(acc)
print(mean,var)

In [None]:
std = np.sqrt(var)
print(mean*100,std*100)