In [3]:
# src/gnn_model.py
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    roc_auc_score, precision_recall_curve, auc,
    accuracy_score, f1_score
)

# ============================= CONFIG =============================
DATA_PROCESSED = "/content"  # Local path (not Colab)
RESULTS_DIR = "results"
MODELS_DIR = "models"
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(MODELS_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

WINDOW_SIZE = 10
STRIDE = 5
BATCH_SIZE = 16
EPOCHS = 80
MAX_ATTACK_RATIO = 0.30  # For test set balancing

# ============================= GRAPH BUILDER =============================
def build_graphs_with_ratio(df: pd.DataFrame):
    """
    Builds sliding window graphs.
    Returns list of (graph, attack_ratio_in_window)
    """
    graphs = []
    label_col = "label_binary"  # Use the new binary label

    for start in range(0, len(df) - WINDOW_SIZE + 1, STRIDE):
        win = df.iloc[start:start + WINDOW_SIZE].copy()
        if len(win) < 8:
            continue

        # Create node identifiers using 'service' (or fallback to index if missing)
        if 'service' in win.columns:
            win['src_node'] = win['service'].astype(str) + '_src'
            win['dst_node'] = win['service'].astype(str) + '_dst'
        else:
            # Fallback: use row index as unique identifier
            win['src_node'] = win.index.astype(str) + '_src'
            win['dst_node'] = win.index.astype(str) + '_dst'

        nodes = pd.unique(win[['src_node', 'dst_node']].values.ravel('K'))
        if len(nodes) < 2:
            continue

        node2idx = {node: idx for idx, node in enumerate(nodes)}

        edge_index = []
        edge_attr = []
        feature_cols = ['duration', 'src_bytes', 'dst_bytes', 'count',
                        'same_srv_rate', 'diff_srv_rate', 'serror_rate', 'rerror_rate']

        for _, row in win.iterrows():
            src = node2idx[row['src_node']]
            dst = node2idx[row['dst_node']]
            edge_index.append([src, dst])
            attr = []
            for col in feature_cols:
                val = row[col] if col in row else 0.0
                attr.append(float(val))
            edge_attr.append(attr)

        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float)

        # Node features: aggregated stats per node
        node_feat = []
        for node in nodes:
            mask = (win['src_node'] == node) | (win['dst_node'] == node)
            node_data = win[mask]
            if len(node_data) == 0:
                feats = [0.0] * len(feature_cols)
            else:
                feats = [
                    node_data['duration'].mean(),
                    node_data['src_bytes'].mean(),
                    node_data['dst_bytes'].mean(),
                    node_data['count'].mean(),
                    node_data['same_srv_rate'].mean(),
                    node_data['diff_srv_rate'].mean(),
                    node_data['serror_rate'].mean(),
                    node_data['rerror_rate'].mean(),
                ]
            feats += [len(node_data), 1.0 if 'http' in str(node) else 0.0]
            node_feat.append(feats)

        x = torch.tensor(node_feat, dtype=torch.float)

        # Label: 1 if majority of connections in window are attacks
        attack_ratio = (win[label_col] == "attack").mean()
        y = 1 if attack_ratio > 0.5 else 0

        graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=torch.tensor([y], dtype=torch.float))
        graphs.append((graph, attack_ratio))

    return graphs


# ============================= MODEL =============================
class GNNAnomalyDetector(nn.Module):
    def __init__(self, input_dim, hidden=128, layers=3):
        super().__init__()
        self.convs = nn.ModuleList(
            [GCNConv(input_dim, hidden)] +
            [GCNConv(hidden, hidden) for _ in range(layers - 1)]
        )
        self.lin1 = nn.Linear(hidden, hidden // 2)
        self.lin2 = nn.Linear(hidden // 2, 1)
        self.relu = nn.ReLU()
        self.drop = nn.Dropout(0.4)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for conv in self.convs:
            x = conv(x, edge_index)
            x = self.relu(x)
            x = self.drop(x)
        x = global_mean_pool(x, batch)
        x = self.relu(self.lin1(x))
        x = self.drop(x)
        return self.lin2(x).squeeze(-1)


# ============================= TRAINING & EVALUATION =============================
def train_gnn():
    print("Loading processed data...")
    train_df = pd.read_csv(f"{DATA_PROCESSED}/train_processed.csv")
    test_df  = pd.read_csv(f"{DATA_PROCESSED}/test_processed.csv")

    if "label_binary" not in train_df.columns:
        raise ValueError("Column 'label_binary' not found! Required for GNN.")

    print(f"Train samples: {len(train_df):,} | Test samples: {len(test_df):,}")

    # ============================= BUILD GRAPHS =============================
    print("Building graphs from normal training data...")
    normal_train = train_df[train_df["label_binary"] == "normal"].copy()
    train_graphs_with_ratio = build_graphs_with_ratio(normal_train)
    train_graphs = [g for g, _ in train_graphs_with_ratio]
    print(f"→ {len(train_graphs)} normal training graphs created")

    print("Building test graphs...")
    test_graphs_with_ratio = build_graphs_with_ratio(test_df)
    test_graphs = [g for g, _ in test_graphs_with_ratio]
    print(f"→ {len(test_graphs)} test graphs created")

    # Balanced test set: some clean windows + attack windows
    clean_windows = [g for g, r in test_graphs_with_ratio if r <= MAX_ATTACK_RATIO]
    attack_windows = [g for g in test_graphs if g.y.item() == 1]

    # Limit for faster evaluation (remove in production)
    clean_windows = clean_windows[:100]
    attack_windows = attack_windows[:600]

    final_test_graphs = clean_windows + attack_windows
    print(f"Final test set: {len(clean_windows)} normal + {len(attack_windows)} attack windows")

    # DataLoaders
    train_loader = DataLoader(train_graphs, batch_size=BATCH_SIZE, shuffle=True)
    test_loader  = DataLoader(final_test_graphs, batch_size=BATCH_SIZE, shuffle=False)

    if len(train_graphs) == 0:
        raise ValueError("No training graphs generated! Check data and features.")

    input_dim = train_graphs[0].x.shape[1]
    print(f"Node feature dimension: {input_dim}")

    model = GNNAnomalyDetector(input_dim=input_dim, hidden=128, layers=3).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-5)
    criterion = nn.BCEWithLogitsLoss()

    # ============================= TRAINING =============================
    print("Starting GNN training...")
    for epoch in range(1, EPOCHS + 1):
        model.train()
        epoch_loss = 0.0
        for data in train_loader:
            data = data.to(DEVICE)
            data.y = torch.zeros(data.num_graphs, dtype=torch.float, device=DEVICE)  # All normal
            logits = model(data)
            loss = criterion(logits, data.y)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            epoch_loss += loss.item()

        if epoch % 20 == 0 or epoch == 1:
            avg_loss = epoch_loss / len(train_loader)
            print(f"  Epoch {epoch:3d} | Loss: {avg_loss:.6f}")

    torch.save(model.state_dict(), f"{MODELS_DIR}/gnn_detector.pth")
    print(f"Model saved → {MODELS_DIR}/gnn_detector.pth")

    # ============================= EVALUATION =============================
    print("Evaluating on test graphs...")
    model.eval()
    scores = []
    true_labels = []

    with torch.no_grad():
        # Score on normal training graphs (for threshold)
        train_scores = []
        for data in train_loader:
            data = data.to(DEVICE)
            logits = model(data)
            train_scores.extend(torch.sigmoid(logits).cpu().numpy())
        threshold = np.percentile(train_scores, 95)
        print(f"Threshold (95th percentile on normal): {threshold:.6f}")

        # Score on test set
        for data in test_loader:
            data = data.to(DEVICE)
            logits = model(data)
            scores.extend(torch.sigmoid(logits).cpu().numpy())
            true_labels.extend(data.y.cpu().numpy())

    scores = np.array(scores)
    true_labels = np.array(true_labels).astype(int)

    y_pred = (scores > threshold).astype(int)

    # Metrics
    auc_roc = roc_auc_score(true_labels, scores)
    prec, rec, _ = precision_recall_curve(true_labels, scores)
    auc_pr = auc(rec, prec)
    accuracy = accuracy_score(true_labels, y_pred)
    f1 = f1_score(true_labels, y_pred)
    k = max(1, int(0.1 * len(scores)))
    precision_at_10 = np.mean(true_labels[np.argsort(scores)[-k:]])

    print("\n" + "="*60)
    print("GNN Anomaly Detector Results")
    print("="*60)
    print(f"  Graphs evaluated  : {len(scores)}")
    print(f"  ROC-AUC           : {auc_roc:.6f}")
    print(f"  PR-AUC            : {auc_pr:.6f}")
    print(f"  Accuracy          : {accuracy:.6f}")
    print(f"  F1 Score          : {f1:.6f}")
    print(f"  Precision@10%     : {precision_at_10:.6f}")
    print(f"  Threshold         : {threshold:.6f}")
    print("="*60)

    # Save results
    pd.DataFrame({
        "anomaly_score": scores,
        "true_label": true_labels,
        "predicted_label": y_pred
    }).to_csv(f"{RESULTS_DIR}/gnn_scores.csv", index=False)

    # Plot score distribution
    plt.figure(figsize=(10, 6))
    sns.histplot(scores[true_labels == 0], label="Normal Windows", alpha=0.7, bins=40, color="blue")
    sns.histplot(scores[true_labels == 1], label="Attack Windows", alpha=0.7, bins=40, color="red")
    plt.axvline(threshold, color='black', linestyle='--', linewidth=2, label=f"Threshold = {threshold:.4f}")
    plt.title("GNN Anomaly Score Distribution (Test Set)")
    plt.xlabel("Anomaly Score")
    plt.ylabel("Count")
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{RESULTS_DIR}/gnn_score_distribution.png", dpi=200, bbox_inches='tight')
    plt.close()

    print(f"\nAll results saved in '{RESULTS_DIR}'")
    print("GNN training and evaluation completed successfully!")


# ============================= RUN =============================
if __name__ == "__main__":
    train_gnn()

  import torch_geometric.typing
  import torch_geometric.typing
  import torch_geometric.typing
  import torch_geometric.typing
  import torch_geometric.typing


Using device: cuda
Loading processed data...
Train samples: 125,973 | Test samples: 21,934
Building graphs from normal training data...
→ 13467 normal training graphs created
Building test graphs...
→ 4385 test graphs created
Final test set: 100 normal + 600 attack windows
Node feature dimension: 10
Starting GNN training...
  Epoch   1 | Loss: 0.013136
  Epoch  20 | Loss: 0.000002
  Epoch  40 | Loss: 0.000002
  Epoch  60 | Loss: 0.000002
  Epoch  80 | Loss: 0.000002
Model saved → models/gnn_detector.pth
Evaluating on test graphs...
Threshold (95th percentile on normal): 0.000001

GNN Anomaly Detector Results
  Graphs evaluated  : 700
  ROC-AUC           : 0.839233
  PR-AUC            : 0.968961
  Accuracy          : 0.765714
  F1 Score          : 0.849541
  Precision@10%     : 1.000000
  Threshold         : 0.000001

All results saved in 'results'
GNN training and evaluation completed successfully!
