In [8]:
%matplotlib inline

import torch.optim as optim

from torch_geometric.data import DataLoader
import torch

In [9]:
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

dataset = Planetoid(root='data/Planetoid', name='Cora', transform=NormalizeFeatures())

In [10]:
from GAT import GAT
model = GAT(dataset)
print(model)

GAT(
  (conv1): GATConv(1433, 8, heads=8)
  (conv2): GATConv(64, 7, heads=1)
)


In [11]:
def model_test(loader, model, is_validation=False, is_training=False):
    ''' Testing Code of the Model '''
    model.eval()

    correct = 0
    for data in loader:
        with torch.no_grad():
            emb, pred = model(data.x, data.edge_index)
            pred = pred.argmax(dim=1)
            label = data.y

        if is_training:
            mask = data.val_mask if is_validation else data.train_mask
        else: # testing
            mask = data.val_mask if is_validation else data.test_mask
        # node classification: only evaluate on nodes in test set
        pred = pred[mask]
        label = data.y[mask]

        correct += pred.eq(label).sum().item()
    total = 0
    for data in loader.dataset:
        if is_training:
            total += torch.sum(data.train_mask).item()
        else:
            total += torch.sum(data.test_mask).item()
    return correct / total

def model_train(dataset, writer, model, epoch_num, lr, weight_decay):
    ''' Training code of the model '''
    test_loader = loader = DataLoader(dataset, shuffle=False)

    # Optimizer
    # opt = optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9)
    opt = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    # visualize the model architecture in tensorboard
    # writer.add_graph(model, ( data.x, data.edge_index ))

    # Training:
    for epoch in range(epoch_num + 1):
        total_loss = 0
        model.train()
        for batch in loader:
            #print(batch.train_mask, '----')
            opt.zero_grad()
            embedding, pred = model(batch.x, batch.edge_index)
            label = batch.y
            pred = pred[batch.train_mask]
            label = label[batch.train_mask]
            loss = model.loss(pred, label)
            loss.backward()
            opt.step()
            total_loss += loss.item() * batch.num_graphs
        total_loss /= len(loader.dataset)
        writer.add_scalar("loss", total_loss, epoch)

        if epoch % 10 == 0:
            test_acc = model_test(test_loader, model, is_training=False)
            print("Epoch {}. Loss: {:.4f}. Test accuracy: {:.4f}".format(
                epoch, total_loss, test_acc))
            writer.add_scalar("test accuracy", test_acc, epoch)

        if epoch % 20 == 0:
            name = 'epoch' + str(epoch)
            writer.add_embedding(embedding, global_step=epoch, tag=name, metadata=batch.y)

    return model

from datetime import datetime
from tensorboardX import SummaryWriter

writer = SummaryWriter("./log/" + datetime.now().strftime("%Y%m%d-%H%M%S"))

model = model_train(dataset, writer, model, epoch_num=200, lr=0.01, weight_decay=4e-4)

Epoch 0. Loss: 1.9449. Test accuracy: 0.1700
Epoch 10. Loss: 1.7743. Test accuracy: 0.7470
Epoch 20. Loss: 1.5636. Test accuracy: 0.7770
Epoch 30. Loss: 1.3030. Test accuracy: 0.8080
Epoch 40. Loss: 1.0818. Test accuracy: 0.8050
Epoch 50. Loss: 0.8800. Test accuracy: 0.8120
Epoch 60. Loss: 0.9632. Test accuracy: 0.8390
Epoch 70. Loss: 0.7423. Test accuracy: 0.8130
Epoch 80. Loss: 0.7738. Test accuracy: 0.8370
Epoch 90. Loss: 0.7583. Test accuracy: 0.8260
Epoch 100. Loss: 0.6778. Test accuracy: 0.8310
Epoch 110. Loss: 0.6190. Test accuracy: 0.8090
Epoch 120. Loss: 0.6063. Test accuracy: 0.8360
Epoch 130. Loss: 0.6812. Test accuracy: 0.8020
Epoch 140. Loss: 0.6551. Test accuracy: 0.8380
Epoch 150. Loss: 0.5872. Test accuracy: 0.8190
Epoch 160. Loss: 0.5565. Test accuracy: 0.8280
Epoch 170. Loss: 0.6353. Test accuracy: 0.8210
Epoch 180. Loss: 0.6758. Test accuracy: 0.8180
Epoch 190. Loss: 0.6158. Test accuracy: 0.8310
Epoch 200. Loss: 0.5733. Test accuracy: 0.8210
