In [2]:
from torch_geometric.nn import GATConv
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from torch.optim import Adam
import scipy.sparse as sp
import numpy as np
import json
from sklearn.model_selection import KFold
import copy
import matplotlib.pyplot as plt

from torch_geometric.nn import SAGEConv


In [None]:
# generic GraphSAGE model
class GraphSAGE(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(num_features, 64)
        self.conv2 = SAGEConv(64, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)


In [3]:
def load_data(): # the same as used with GCN
    adj = sp.load_npz('./data_2024/adj.npz')
    features = np.load('./data_2024/features.npy')
    labels = np.load('./data_2024/labels.npy')
    with open('./data_2024/splits.json', 'r') as file:
        splits = json.load(file)
    idx_train, idx_test = splits['idx_train'], splits['idx_test']
    # convert adjacency matrix to edge index
    adj = adj.tocoo()
    edge_index = np.vstack((adj.row, adj.col))

    # normalize features
    features = features / features.sum(1, keepdims=True)

    # convert to tensors
    edge_index = torch.tensor(edge_index, dtype=torch.long)
    features = torch.tensor(features, dtype=torch.float)
    labels = torch.tensor(labels, dtype=torch.long)

    # create torch geometric data object
    data = Data(x=features, edge_index=edge_index, y=labels)

    return data, torch.tensor(idx_train, dtype=torch.long), torch.tensor(idx_test, dtype=torch.long)

data, idx_train, idx_test = load_data()

In [5]:
# full set of labels with a default value (-1)
full_labels = torch.full((2480,), -1, dtype=torch.long)
full_labels[idx_train] = data.y

kf = KFold(n_splits=5, shuffle=True, random_state=42)
fold = 0

avg_training_losses = []
avg_validation_accuracies = []

best_validation_accuracy = 0
best_model_state = None

for train_index, val_index in kf.split(idx_train.numpy()): # folding
    fold += 1
    model = GraphSAGE(num_features=1390, num_classes=7)
    optimizer = Adam(model.parameters(), lr=0.01)
    
    training_losses = []
    validation_accuracies = []
    
    for epoch in range(200):
        model.train()
        optimizer.zero_grad()
        out = model(data)
        # compute loss on the training part
        train_mask = full_labels[idx_train[train_index]] != -1
        loss = F.nll_loss(out[idx_train[train_index]][train_mask], full_labels[idx_train[train_index]][train_mask])
        loss.backward()
        optimizer.step()
        
        training_losses.append(loss.item())

        model.eval()
        with torch.no_grad():
            # validation accuracy
            out = model(data)
            _, pred = out.max(1)
            val_mask = full_labels[idx_train[val_index]] != -1  # Ensure valid labels for val subset
            correct = pred[idx_train[val_index]][val_mask].eq(full_labels[idx_train[val_index]][val_mask]).sum().item()
            accuracy = correct / val_mask.sum().item()
            validation_accuracies.append(accuracy)

            if accuracy > best_validation_accuracy:
                best_validation_accuracy = accuracy
                best_model_state = copy.deepcopy(model.state_dict())

        if epoch % 10 == 0:
            print(f'Fold {fold}, Epoch {epoch}: Loss {loss.item()}, Validation Accuracy: {accuracy}')
    
    avg_training_losses.append(np.mean(training_losses))
    avg_validation_accuracies.append(np.mean(validation_accuracies))

# load the best model state
model.load_state_dict(best_model_state)

Fold 1, Epoch 0: Loss 1.9440292119979858, Validation Accuracy: 0.3
Fold 1, Epoch 10: Loss 1.456589698791504, Validation Accuracy: 0.3
Fold 1, Epoch 20: Loss 0.7617794275283813, Validation Accuracy: 0.72
Fold 1, Epoch 30: Loss 0.30648693442344666, Validation Accuracy: 0.82
Fold 1, Epoch 40: Loss 0.1387600600719452, Validation Accuracy: 0.83
Fold 1, Epoch 50: Loss 0.05323858559131622, Validation Accuracy: 0.83
Fold 1, Epoch 60: Loss 0.028148505836725235, Validation Accuracy: 0.81
Fold 1, Epoch 70: Loss 0.019652577117085457, Validation Accuracy: 0.81
Fold 1, Epoch 80: Loss 0.015705496072769165, Validation Accuracy: 0.81
Fold 1, Epoch 90: Loss 0.011096468195319176, Validation Accuracy: 0.81
Fold 1, Epoch 100: Loss 0.012005605734884739, Validation Accuracy: 0.82
Fold 1, Epoch 110: Loss 0.009671076200902462, Validation Accuracy: 0.81
Fold 1, Epoch 120: Loss 0.006942235864698887, Validation Accuracy: 0.81
Fold 1, Epoch 130: Loss 0.005364997312426567, Validation Accuracy: 0.82
Fold 1, Epoch 14

<All keys matched successfully>