In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import SAGEConv, global_mean_pool
from torch_geometric.data import Dataset, Data
from torch_geometric.loader import DataLoader
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from tqdm import tqdm
import json
import seaborn as sns
import matplotlib.pyplot as plt

# --- CONFIGURATION ---
class Config:
    METADATA_DIR = "metadata"
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    BATCH_SIZE = 64
    LEARNING_RATE = 0.001
    EPOCHS = 20
    HIDDEN_CHANNELS = 128
    NUM_WORKERS = 4
    MODEL_DIR = "models"

print(f"Using device: {Config.DEVICE}")

In [None]:
# --- EFFICIENT DATASET LOADER ---
class AlertGraphDataset(Dataset):
    def __init__(self, metadata_path):
        super().__init__()
        self.metadata = pd.read_csv(metadata_path)

    def len(self):
        return len(self.metadata)

    def get(self, idx):
        graph_info = self.metadata.iloc[idx]
        filepath = graph_info['filepath']
        label = graph_info['label_id']
        
        with np.load(filepath) as data:
            node_feats = torch.from_numpy(data['node_feats']).float()
            # Handle graphs with no edges
            if 'sources' in data and data['sources'].size > 0:
                edge_index = torch.from_numpy(
                    np.vstack([data['sources'], data['destinations']])
                ).long()
            else:
                edge_index = torch.empty((2, 0), dtype=torch.long)
        
        return Data(x=node_feats, edge_index=edge_index, y=torch.tensor(label, dtype=torch.long))


# Load label mapping to get class names
label_map_path = os.path.join(Config.METADATA_DIR, 'label_map.json')
with open(label_map_path, 'r') as f:
    label_map = json.load(f)
id_to_label = {v: k for k, v in label_map.items()}
num_classes = len(label_map)
class_names = [id_to_label[i] for i in sorted(id_to_label.keys())]

# Load all three datasets
print("Loading datasets...")
train_dataset = AlertGraphDataset(os.path.join(Config.METADATA_DIR, 'train.csv'))
val_dataset = AlertGraphDataset(os.path.join(Config.METADATA_DIR, 'val.csv'))
test_dataset = AlertGraphDataset(os.path.join(Config.METADATA_DIR, 'test.csv'))

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=Config.NUM_WORKERS)
val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=Config.NUM_WORKERS)
test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=Config.NUM_WORKERS)   


In [None]:
# --- GNN MODEL FOR GRAPH CLASSIFICATION ---
class GNNClassifier(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.classifier = Linear(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # Node embedding phase
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        
        # Graph pooling phase: aggregate node features into a single graph vector
        x = global_mean_pool(x, batch)
        
        # Final classification
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.classifier(x)
        
        return x

# --- TRAINING AND EVALUATION FUNCTIONS ---
def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for batch in tqdm(loader, desc="Training", leave=False):
        batch = batch.to(Config.DEVICE)
        optimizer.zero_grad()
        out = model(batch)
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.num_graphs
    return total_loss / len(loader.dataset)

def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating", leave=False):
            batch = batch.to(Config.DEVICE)
            out = model(batch)
            loss = criterion(out, batch.y)
            total_loss += loss.item() * batch.num_graphs
            preds = out.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch.y.cpu().numpy())
            
    avg_loss = total_loss / len(loader.dataset)
    accuracy = accuracy_score(all_labels, all_preds)
    return avg_loss, accuracy, all_preds, all_labels


# Calculate class weights for the loss function
print("Calculating class weights...")
train_labels = pd.read_csv(os.path.join(Config.METADATA_DIR, 'train.csv'))['label_id'].values
class_weights = compute_class_weight('balanced', classes=np.unique(train_labels), y=train_labels)
weights = torch.tensor(class_weights, dtype=torch.float).to(Config.DEVICE)
criterion = torch.nn.CrossEntropyLoss(weight=weights)
print("Weights calculated.")

# Initialize model and optimizer
model = GNNClassifier(
    in_channels=train_dataset.num_features,
    hidden_channels=Config.HIDDEN_CHANNELS,
    out_channels=num_classes
).to(Config.DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=Config.LEARNING_RATE)

# --- Training Loop with Validation ---
print("\n--- Starting Training ---")
best_val_loss = float('inf')

In [None]:
for epoch in range(1, Config.EPOCHS + 1):
    train_loss = train(model, train_loader, optimizer, criterion)
    val_loss, val_acc, _, _ = evaluate(model, val_loader, criterion)
    
    print(f"Epoch: {epoch:02d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

    # Save the model if validation loss improves
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_dir = os.path.join(Config.MODEL_DIR, "graph_sage_classifier.pth")
        torch.save(model.state_dict(), save_dir)
        print(f"  -> Best model saved to {save_dir}")

print("\n--- Training Finished ---")

# --- Final Evaluation on Test Set ---
print("\n--- Evaluating on Test Set ---")
# Load the best performing model
model.load_state_dict(torch.load(save_dir))

test_loss, test_acc, test_preds, test_labels = evaluate(model, test_loader, criterion)
print(f"\nFinal Test Accuracy: {test_acc:.4f}")

# Display detailed classification report
report = classification_report(test_labels, test_preds, target_names=class_names, zero_division=0)
print("\n--- Classification Report ---")
print(report)

# Display confusion matrix
print("\n--- Confusion Matrix ---")
cm = confusion_matrix(test_labels, test_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix on Test Set')
plt.show()