In [1]:
from ImportLocalData import loadData

In [2]:
from BalanceClassDistribution import AdjustClassSamples, NumberOfSamplesClass

In [None]:
# Import one of the custom STKG files
# Call BalanceClass... to handle outliers if you need
# Please change path names based on your local files 
data = loadData('.../node_features.txt', '.../edges.txt', '.../edge_features.txt', '.../node_labels.txt')
data = AdjustClassSamples(data) #this is optional, yet in paper we used

Import Required Packages

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv, GAE
from torch.nn import LSTM
from torch_geometric.data import Data
from torch_geometric.utils import subgraph
from torch_geometric.loader import DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np

Definition for FusionGAT Model

In [None]:
# This module with GAE based reconsruction for deper anlysis
class FusionGAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, lstm_hidden_size, out_channels):
        super(FusionGAT, self).__init__()
        # GAT-LSTM for node classification
        self.conv1 = GATConv(in_channels, hidden_channels, heads=2, concat=True)
        self.conv2 = GATConv(hidden_channels * 2, hidden_channels, heads=1, concat=False)
        self.lstm = LSTM(hidden_channels, lstm_hidden_size, batch_first=True)
        self.fc = torch.nn.Linear(lstm_hidden_size, out_channels)
        self.dropout = torch.nn.Dropout(p=0.5)

        # Graph Autoencoder for reconstruction loss
        self.encoder_conv1 = GATConv(in_channels, hidden_channels, heads=2, concat=True)
        self.encoder_conv2 = GATConv(hidden_channels * 2, hidden_channels, heads=1, concat=False)
        self.fc_decoder = torch.nn.Linear(hidden_channels, in_channels)  # for node reconstruction

    def forward(self, data):
        x, edge_index = data.x.to(device), data.edge_index.to(device)
        x_cls = self.conv1(x, edge_index)
        x_cls = F.elu(x_cls)
        x_cls = self.dropout(x_cls)
        x_cls = self.conv2(x_cls, edge_index)
        x_cls = F.elu(x_cls)
        x_cls = x_cls.view(data.num_nodes, -1, x_cls.size(1))
        lstm_out, _ = self.lstm(x_cls)
        out_cls = self.fc(lstm_out[:, -1, :])

        # Graph Autoencoder part
        x_enc = self.encoder_conv1(x, edge_index)
        x_enc = F.elu(x_enc)
        x_enc = self.encoder_conv2(x_enc, edge_index)
        out_reconstructed = self.fc_decoder(x_enc)

        return out_cls, out_reconstructed

Train and test of the FusionGAT model for top-1 / top-5 accuracy

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Compute top-k acuracy (for our case, they are Top-1, Top-5)
def top_k_accuracy(y_true, y_pred, k=5):
    top_k_preds = torch.topk(y_pred, k, dim=1)[1]
    correct = 0
    for i in range(len(y_true)):
        if y_true[i] in top_k_preds[i]:
            correct += 1
    return correct / len(y_true)

# Test the model
def test_model(model, test_loader):
    model.eval()
    test_correct = 0
    test_top5_correct = 0
    all_true = []
    
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            out_test_class, _ = model(batch)
            _, pred_test = out_test_class.max(dim=1)
            test_correct += float(pred_test.eq(batch.y).sum().item())
            test_top5_correct += float(top_k_accuracy(batch.y, out_test_class, k=5) * len(batch.y))
            all_true.extend(batch.y.cpu().numpy())

    test_accuracy = test_correct / len(all_true)
    test_top5_accuracy = test_top5_correct / len(all_true)
    print(f'Test Accuracy (Top-1): {test_accuracy:.4f}')
    print(f'Test Accuracy (Top-5): {test_top5_accuracy:.4f}')
    return test_accuracy, test_top5_accuracy

# Normalize input features
scaler = StandardScaler()
x_scaled = scaler.fit_transform(data.x.cpu().numpy())
data.x = torch.tensor(x_scaled, dtype=torch.float).to(device)

# Split data into 80% train, 10% validation, 10% test
train_idx, test_idx = train_test_split(torch.arange(data.num_nodes), test_size=0.2, stratify=data.y.cpu().numpy())
train_idx, val_idx = train_test_split(train_idx, test_size=0.125, stratify=data.y[train_idx].cpu().numpy())  # 0.125 * 0.8 = 0.1
train_idx = torch.tensor(train_idx, dtype=torch.long)
val_idx = torch.tensor(val_idx, dtype=torch.long)
test_idx = torch.tensor(test_idx, dtype=torch.long)

# Create subgraphs for data loading
train_subgraph = subgraph(train_idx, data.edge_index, relabel_nodes=True, num_nodes=data.num_nodes)
val_subgraph = subgraph(val_idx, data.edge_index, relabel_nodes=True, num_nodes=data.num_nodes)
test_subgraph = subgraph(test_idx, data.edge_index, relabel_nodes=True, num_nodes=data.num_nodes)
train_data = Data(x=data.x[train_idx], edge_index=train_subgraph[0], y=data.y[train_idx])
val_data = Data(x=data.x[val_idx], edge_index=val_subgraph[0], y=data.y[val_idx])
test_data = Data(x=data.x[test_idx], edge_index=test_subgraph[0], y=data.y[test_idx])

# Create dataloaders
batch_size = 32
train_loader = DataLoader([train_data], batch_size=batch_size, shuffle=True)
val_loader = DataLoader([val_data], batch_size=batch_size)
test_loader = DataLoader([test_data], batch_size=batch_size)

num_node_features = data.num_node_features
num_classes = len(data.y.unique())
hidden_channels = 256
lstm_hidden_size = 128
model = FusionGAT(num_node_features, hidden_channels, lstm_hidden_size, num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.02, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)
classification_loss_fn = torch.nn.CrossEntropyLoss()
reconstruction_loss_fn = torch.nn.MSELoss()

for epoch in range(4000):
    model.train()
    epoch_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out_class, out_reconstructed = model(batch)
# compute classification and reconstruction loss
        loss_class = classification_loss_fn(out_class, batch.y)
        loss_recon = reconstruction_loss_fn(out_reconstructed, batch.x)
        loss = loss_class + 0.2 * loss_recon
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    scheduler.step()
    if (epoch + 1) % 100 == 0:
        model.eval()
        with torch.no_grad():
            val_losses = []
            val_correct = 0
            val_top5_correct = 0
            for batch in val_loader:
                batch = batch.to(device)
                out_val_class, out_val_recon = model(batch)
                val_loss_class = classification_loss_fn(out_val_class, batch.y)
                val_loss_recon = reconstruction_loss_fn(out_val_recon, batch.x)
                val_losses.append((val_loss_class + 0.2 * val_loss_recon).item())
                _, pred_val = out_val_class.max(dim=1)
                val_correct += float(pred_val.eq(batch.y).sum().item())
                val_top5_correct += float(top_k_accuracy(batch.y, out_val_class, k=5) * len(batch.y))
            val_loss = np.mean(val_losses)
            val_accuracy = val_correct / len(val_data.y)
            val_top5_accuracy = val_top5_correct / len(val_data.y)
        print(f'Epoch: {epoch+1}, Loss: {loss.item():.4f}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}, Validation Top-5 Accuracy: {val_top5_accuracy:.4f}, Learning Rate: {scheduler.get_last_lr()[0]:.6f}')

test_accuracy, test_top5_accuracy = test_model(model, test_loader)
torch.save(model.state_dict(), 'FusionGAT_model.pth')
