In [1]:
import time
import random
import argparse
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import uuid
from utils import *
from model import *
from process import *

In [2]:
cudaid = "cuda"
device = torch.device(cudaid)
checkpt_file = 'pretrained/'+uuid.uuid4().hex+'.pt'

In [None]:
def train_step(model,optimizer,features,labels,adj,idx_train):
    model.train()
    optimizer.zero_grad()
    output = model(features,adj)
    acc_train = accuracy(output[idx_train], labels[idx_train].to(device))
    loss_train = F.nll_loss(output[idx_train], labels[idx_train].to(device))
    loss_train.backward()
    optimizer.step()
    return loss_train.item(),acc_train.item()


def validate_step(model,features,labels,adj,idx_val):
    model.eval()
    with torch.no_grad():
        output = model(features,adj)
        loss_val = F.nll_loss(output[idx_val], labels[idx_val].to(device))
        acc_val = accuracy(output[idx_val], labels[idx_val].to(device))
        return loss_val.item(),acc_val.item()

def test_step(model,features,labels,adj,idx_test):
    model.load_state_dict(torch.load(checkpt_file))
    model.eval()
    with torch.no_grad():
        output = model(features, adj)
        loss_test = F.nll_loss(output[idx_test], labels[idx_test].to(device))
        acc_test = accuracy(output[idx_test], labels[idx_test].to(device))
        return loss_test.item(),acc_test.item()
    

def train(datastr,splitstr):
    splitstr = 'splits/'+data+'_split_0.6_0.2_'+str(0)+'.npz'
    adj, features, labels, idx_train, idx_val, idx_test, num_features, num_labels = full_load_data(data,splitstr)
    
    features = features.to(device)
    adj = adj.to(device)
    n=features.shape[0]
    feature_size=features.shape[1]
    num_classes = len(torch.unique(labels))
    hidden_dim=16
    dropout=0.5
    model=my_GCN(n,feature_size,hidden_dim,num_classes,dropout)
    #model=GCN(feature_size,hidden_dim,num_classes,dropout)
    model=model.to(device)

    optimizer = optim.Adam(model.parameters(),lr=0.01,weight_decay=5e-4)

    epochs=200
    patience=10
    test=False

    bad_counter = 0
    best = 999999999
    for epoch in range(epochs):
        loss_tra,acc_tra = train_step(model,optimizer,features,labels,adj,idx_train)
        loss_val,acc_val = validate_step(model,features,labels,adj,idx_val)
        if(epoch+1)%1 == 0: 
            print('Epoch:{:04d}'.format(epoch+1),
                'train',
                'loss:{:.3f}'.format(loss_tra),
                'acc:{:.2f}'.format(acc_tra*100),
                '| val',
                'loss:{:.3f}'.format(loss_val),
                'acc:{:.2f}'.format(acc_val*100))
        if loss_val < best:
            best = loss_val
            torch.save(model.state_dict(), checkpt_file)
            bad_counter = 0
        else:
            bad_counter += 1

        if bad_counter == patience:
            break
    acc = test_step(model,features,labels,adj,idx_test)[1]

    return acc*100

In [5]:
t_total = time.time()
acc_list = []
#citeseer, cora
data='citeseer'
for i in range(10):
    datastr = data
    splitstr = 'splits/'+data+'_split_0.6_0.2_'+str(i)+'.npz'
    acc_list.append(train(datastr,splitstr))
    print(i,": {:.2f}".format(acc_list[-1]))
print("Train cost: {:.4f}s".format(time.time() - t_total))
print("Test acc.:{:.2f}".format(np.mean(acc_list)))

Epoch:0001 train loss:1.929 acc:12.97 | val loss:1.683 acc:40.94
Epoch:0002 train loss:1.596 acc:46.93 | val loss:1.340 acc:62.35
Epoch:0003 train loss:1.024 acc:72.74 | val loss:1.073 acc:66.95
Epoch:0004 train loss:0.626 acc:82.83 | val loss:0.885 acc:72.49
Epoch:0005 train loss:0.361 acc:90.91 | val loss:0.841 acc:75.02
Epoch:0006 train loss:0.252 acc:93.98 | val loss:0.921 acc:68.17
Epoch:0007 train loss:0.219 acc:94.74 | val loss:0.953 acc:66.10
Epoch:0008 train loss:0.181 acc:95.93 | val loss:0.931 acc:68.36
Epoch:0009 train loss:0.132 acc:96.87 | val loss:0.907 acc:71.64
Epoch:0010 train loss:0.094 acc:97.93 | val loss:0.902 acc:72.58
Epoch:0011 train loss:0.081 acc:97.99 | val loss:0.909 acc:73.05
Epoch:0012 train loss:0.073 acc:98.18 | val loss:0.923 acc:73.15
Epoch:0013 train loss:0.061 acc:98.43 | val loss:0.940 acc:73.05
Epoch:0014 train loss:0.061 acc:98.31 | val loss:0.959 acc:71.74
Epoch:0015 train loss:0.064 acc:98.75 | val loss:0.977 acc:70.89
0 : 72.22
Epoch:0001 trai

In [None]:

print(adj.to_dense().shape)
        print(support.shape)
        print(output.shape)


torch.Size([2708, 2708]) adj
torch.Size([2708, 16])  support= input x weight
torch.Size([2708, 2724])