In [1]:
from torch_geometric.nn import GATConv
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from torch.optim import Adam
import scipy.sparse as sp
import numpy as np
import json
from sklearn.model_selection import KFold
import copy
import matplotlib.pyplot as plt
import torch.nn as nn
from torch_geometric.nn import SAGEConv


In [2]:
# generic GraphSAGE model
class GraphSAGE(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_units):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(num_features, hidden_units)
        self.conv2 = SAGEConv(hidden_units, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)


In [3]:
def load_data(): # the same as used with GCN
    adj = sp.load_npz('adj.npz')
    features = np.load('features.npy')
    labels = np.load('labels.npy')
    with open('splits.json', 'r') as file:
        splits = json.load(file)
    idx_train, idx_test = splits['idx_train'], splits['idx_test']
    # convert adjacency matrix to edge index
    adj = adj.tocoo()
    edge_index = np.vstack((adj.row, adj.col))

    # normalize features
    features = features / features.sum(1, keepdims=True)

    # convert to tensors
    edge_index = torch.tensor(edge_index, dtype=torch.long)
    features = torch.tensor(features, dtype=torch.float)
    labels = torch.tensor(labels, dtype=torch.long)

    # create torch geometric data object
    data = Data(x=features, edge_index=edge_index, y=labels)

    return data, torch.tensor(idx_train, dtype=torch.long), torch.tensor(idx_test, dtype=torch.long)

data, idx_train, idx_test = load_data()

We choose to optimize on loss, rather than accuracy, due to the tiny size of the labelled training data we have available. Hopefully we will be better about avoiding over-fitting this way.

In [4]:
# full set of labels with a default value (-1)
full_labels = torch.full((2480,), -1, dtype=torch.long)
full_labels[idx_train] = data.y

kf = KFold(n_splits=10, shuffle=True, random_state=42)
fold = 0

avg_training_losses = []
avg_validation_accuracies = []

best_validation_accuracy = 0
best_validation_loss = float('inf')  # Initialize best validation loss
best_model_state = None

for train_index, val_index in kf.split(idx_train.numpy()): # folding
    fold += 1
    model = GraphSAGE(num_features=1390, num_classes=7, hidden_units=64)
    optimizer = Adam(model.parameters(), lr=0.011)
    training_losses = []
    validation_accuracies = []
    validation_losses = []  # Track validation loss for each epoch

    
    for epoch in range(200):
        model.train()
        optimizer.zero_grad()
        out = model(data)
        # compute loss on the training part
        train_mask = full_labels[idx_train[train_index]] != -1
        loss = F.nll_loss(out[idx_train[train_index]][train_mask], full_labels[idx_train[train_index]][train_mask])
        loss.backward()
        optimizer.step()
        
        training_losses.append(loss.item())

        model.eval()
        with torch.no_grad():
            # validation accuracy
            out = model(data)
            _, pred = out.max(1)
            val_mask = full_labels[idx_train[val_index]] != -1  # Ensure valid labels for val subset
            val_loss = F.nll_loss(out[idx_train[val_index]][val_mask], full_labels[idx_train[val_index]][val_mask]).item()
            validation_losses.append(val_loss)  # Store validation loss for this epoch
            
            # accuracy
            correct = pred[idx_train[val_index]][val_mask].eq(full_labels[idx_train[val_index]][val_mask]).sum().item()
            accuracy = correct / val_mask.sum().item()
            validation_accuracies.append(accuracy)
        
            # Update best model based on validation loss
            if val_loss < best_validation_loss:
                best_validation_loss = val_loss
                best_validation_accuracy = accuracy
                best_model_state = copy.deepcopy(model.state_dict())

        if epoch % 10 == 0:
            print(f'Fold {fold}, Epoch {epoch}: Training Loss {np.mean(training_losses[-10:])}, Validation Loss {val_loss}, Validation Accuracy: {accuracy}')
    
    avg_training_losses.append(np.mean(training_losses))
    avg_validation_accuracies.append(np.mean(validation_accuracies))

# load the best model state
model.load_state_dict(best_model_state)

Fold 1, Epoch 0: Training Loss 1.9426805973052979, Validation Loss 1.9149760007858276, Validation Accuracy: 0.14
Fold 1, Epoch 10: Training Loss 1.7043107390403747, Validation Loss 1.5302152633666992, Validation Accuracy: 0.34
Fold 1, Epoch 20: Training Loss 1.0492207944393157, Validation Loss 0.8729313015937805, Validation Accuracy: 0.78
Fold 1, Epoch 30: Training Loss 0.43951942324638366, Validation Loss 0.4909323453903198, Validation Accuracy: 0.88
Fold 1, Epoch 40: Training Loss 0.1713249646127224, Validation Loss 0.42892152070999146, Validation Accuracy: 0.86
Fold 1, Epoch 50: Training Loss 0.07042262591421604, Validation Loss 0.4346557557582855, Validation Accuracy: 0.84
Fold 1, Epoch 60: Training Loss 0.033582423254847525, Validation Loss 0.4620679020881653, Validation Accuracy: 0.82
Fold 1, Epoch 70: Training Loss 0.020745658315718174, Validation Loss 0.5022260546684265, Validation Accuracy: 0.82
Fold 1, Epoch 80: Training Loss 0.015056735463440418, Validation Loss 0.5055391788

<All keys matched successfully>

In [5]:
best_validation_accuracy

0.9387755102040817

In [6]:
best_validation_loss

0.3019876182079315

In [7]:
best_model_state

OrderedDict([('conv1.lin_l.weight',
              tensor([[ 0.0064, -0.0263,  0.0123,  ..., -0.0211, -0.0128, -0.0816],
                      [ 0.0844,  0.0597, -0.0415,  ..., -0.2679, -0.0391,  0.0435],
                      [ 0.2282,  0.2783, -0.1716,  ..., -0.0869,  0.3411, -0.1982],
                      ...,
                      [ 0.0752,  0.4038, -0.0530,  ..., -0.3953,  0.0818,  0.3353],
                      [-0.2868,  0.0044,  0.1315,  ..., -0.1185, -0.3480,  0.1689],
                      [ 0.0216, -0.0062, -0.0155,  ..., -0.0098,  0.0113, -0.0263]])),
             ('conv1.lin_l.bias',
              tensor([-0.0812,  0.1357,  0.1614,  0.1796, -0.0665,  0.1211,  0.1273,  0.1290,
                       0.1111, -0.0246, -0.0837,  0.1510,  0.1391,  0.1243,  0.1707, -0.0186,
                       0.1846,  0.1440,  0.1168, -0.0222,  0.1410,  0.1362,  0.1544,  0.1720,
                       0.1219,  0.1454,  0.1446,  0.2195,  0.1093,  0.1213, -0.0221,  0.1397,
                    

In [8]:
model.eval()
with torch.no_grad():
    out = model(data)
    _, pred = out.max(1)
    test_labels_pred = pred[idx_test].numpy()
submission_file_path = 'submission_graphSage.txt'
with open(submission_file_path, 'w') as file:
    for label in test_labels_pred:
        file.write(f'{label}\n')