In [23]:
import time
import random
import argparse
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from utils import *
from model import *
from process import *

In [24]:
adj, features, labels,idx_train,idx_val,idx_test = load_citation('cora')
data='cora'
#splitstr = 'splits/'+data+'_split_0.6_0.2_'+str(0)+'.npz'
#adj, features, labels, idx_train, idx_val, idx_test, num_features, num_labels = full_load_data('cora',splitstr)
print(adj.shape)
print(features.shape)
print(labels.shape)
print(idx_train.shape)
print(idx_val.shape)
print(idx_test.shape)

torch.Size([2708, 2708])
torch.Size([2708, 1433])
torch.Size([2708])
torch.Size([140])
torch.Size([500])
torch.Size([1000])


In [25]:
cudaid = "cuda"
device = torch.device(cudaid)
features = features.to(device)
adj = adj.to(device).coalesce()

In [32]:
feature_size=features.shape[1]
num_classes = len(torch.unique(labels))
hidden_dim=16
dropout=0.5
model=GCN(feature_size,hidden_dim,num_classes,dropout)
model=model.to(device)
optimizer = optim.Adam(model.parameters(),lr=0.01,weight_decay=5e-4)

epochs=100
patience=10
test=False

In [33]:
def train():
    model.train()
    optimizer.zero_grad()
    output = model(features,adj)
    acc_train = accuracy(output[idx_train], labels[idx_train].to(device))
    loss_train = F.nll_loss(output[idx_train], labels[idx_train].to(device))
    loss_train.backward()
    optimizer.step()
    return loss_train.item(),acc_train.item()


def validate():
    model.eval()
    with torch.no_grad():
        output = model(features,adj)
        loss_val = F.nll_loss(output[idx_val], labels[idx_val].to(device))
        acc_val = accuracy(output[idx_val], labels[idx_val].to(device))
        return loss_val.item(),acc_val.item()

def test():
    #model.load_state_dict(torch.load(checkpt_file))
    model.eval()
    with torch.no_grad():
        output = model(features, adj)
        loss_test = F.nll_loss(output[idx_test], labels[idx_test].to(device))
        acc_test = accuracy(output[idx_test], labels[idx_test].to(device))
        return loss_test.item(),acc_test.item()
    
t_total = time.time()
bad_counter = 0
best = 999999999
best_epoch = 0
acc = 0
for epoch in range(epochs):
    loss_tra,acc_tra = train()
    loss_val,acc_val = validate()
    if(epoch+1)%1 == 0: 
        print('Epoch:{:04d}'.format(epoch+1),
            'train',
            'loss:{:.3f}'.format(loss_tra),
            'acc:{:.2f}'.format(acc_tra*100),
            '| val',
            'loss:{:.3f}'.format(loss_val),
            'acc:{:.2f}'.format(acc_val*100))
    if loss_val < best:
        best = loss_val
        best_epoch = epoch
        acc = acc_val
        #torch.save(model.state_dict(), checkpt_file)
        bad_counter = 0
    else:
        bad_counter += 1

    if bad_counter == patience:
        break

if test:
    acc = test()[1]

print("Train cost: {:.4f}s".format(time.time() - t_total))
print('Load {}th epoch'.format(best_epoch))
print("Val","acc.:{:.1f}".format(acc*100))
    

Epoch:0001 train loss:1.961 acc:15.00 | val loss:1.949 acc:7.00
Epoch:0002 train loss:1.952 acc:15.00 | val loss:1.947 acc:7.40
Epoch:0003 train loss:1.955 acc:11.43 | val loss:1.945 acc:7.60
Epoch:0004 train loss:1.949 acc:14.29 | val loss:1.944 acc:7.60
Epoch:0005 train loss:1.943 acc:14.29 | val loss:1.943 acc:7.60
Epoch:0006 train loss:1.942 acc:17.86 | val loss:1.941 acc:7.80
Epoch:0007 train loss:1.931 acc:19.29 | val loss:1.939 acc:8.60
Epoch:0008 train loss:1.934 acc:20.00 | val loss:1.937 acc:10.40
Epoch:0009 train loss:1.927 acc:20.71 | val loss:1.934 acc:12.00
Epoch:0010 train loss:1.913 acc:27.86 | val loss:1.931 acc:15.40
Epoch:0011 train loss:1.905 acc:32.86 | val loss:1.929 acc:18.80
Epoch:0012 train loss:1.902 acc:36.43 | val loss:1.926 acc:22.00
Epoch:0013 train loss:1.902 acc:34.29 | val loss:1.923 acc:24.20
Epoch:0014 train loss:1.894 acc:38.57 | val loss:1.921 acc:23.80
Epoch:0015 train loss:1.887 acc:42.14 | val loss:1.918 acc:24.80
Epoch:0016 train loss:1.882 acc: