In [9]:
# src/gnn_model.py
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc

# ============================= CONFIG =============================
DATA_PROCESSED = "/content"  # Colab path
RESULTS_DIR = "results"
MODELS_DIR = "models"
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(MODELS_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# === FINAL SETTINGS (No leakage, realistic) ===
WINDOW_SIZE = 10          # Small to find real normal windows
STRIDE = 5
BATCH_SIZE = 16
EPOCHS = 80
MAX_ATTACK_RATIO = 0.30   # Allow 30% attack in "normal" window

# ============================= GRAPH BUILDER =============================
def build_graphs_with_ratio(df: pd.DataFrame):
    graphs = []
    for start in range(0, len(df) - WINDOW_SIZE + 1, STRIDE):
        win = df.iloc[start:start + WINDOW_SIZE].copy()
        if len(win) < 8:
            continue

        win['src_node'] = win['service'].astype(str) + '_src'
        win['dst_node'] = win['service'].astype(str) + '_dst'

        nodes = pd.unique(win[['src_node', 'dst_node']].values.ravel())
        if len(nodes) < 2:
            continue

        node2idx = {n: i for i, n in enumerate(nodes)}

        edge_index = []
        edge_attr = []
        for _, r in win.iterrows():
            src = node2idx[r['src_node']]
            dst = node2idx[r['dst_node']]
            edge_index.append([src, dst])
            edge_attr.append([
                r['duration'], r['src_bytes'], r['dst_bytes'],
                r['count'], r['same_srv_rate'], r['diff_srv_rate'],
                r['serror_rate'], r['rerror_rate']
            ])

        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float)

        node_feat = np.array([
            [
                win[(win['src_node'] == n) | (win['dst_node'] == n)]['duration'].mean(),
                win[(win['src_node'] == n) | (win['dst_node'] == n)]['src_bytes'].mean(),
                win[(win['src_node'] == n) | (win['dst_node'] == n)]['dst_bytes'].mean(),
                win[(win['src_node'] == n) | (win['dst_node'] == n)]['count'].mean(),
                win[(win['src_node'] == n) | (win['dst_node'] == n)]['same_srv_rate'].mean(),
                win[(win['src_node'] == n) | (win['dst_node'] == n)]['diff_srv_rate'].mean(),
                win[(win['src_node'] == n) | (win['dst_node'] == n)]['serror_rate'].mean(),
                win[(win['src_node'] == n) | (win['dst_node'] == n)]['rerror_rate'].mean(),
                len(win[(win['src_node'] == n) | (win['dst_node'] == n)]),
                1.0 if 'http' in n else 0.0
            ] for n in nodes
        ], dtype=np.float32)
        x = torch.from_numpy(node_feat)

        attack_ratio = (win['label'] == 'attack').mean()
        y = 1 if attack_ratio > 0.5 else 0
        graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=torch.tensor([y]))
        graphs.append((graph, attack_ratio))
    return graphs

# ============================= MODEL =============================
class GNNAnomalyDetector(nn.Module):
    def __init__(self, input_dim, hidden=128, layers=3):
        super().__init__()
        self.convs = nn.ModuleList([GCNConv(input_dim, hidden)] +
                                   [GCNConv(hidden, hidden) for _ in range(layers-1)])
        self.lin1 = nn.Linear(hidden, hidden // 2)
        self.lin2 = nn.Linear(hidden // 2, 1)
        self.relu = nn.ReLU()
        self.drop = nn.Dropout(0.4)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for conv in self.convs:
            x = conv(x, edge_index)
            x = self.relu(x)
            x = self.drop(x)
        x = global_mean_pool(x, batch)
        x = self.relu(self.lin1(x))
        x = self.drop(x)
        return self.lin2(x).squeeze(-1)

# ============================= TRAINING =============================
def train_gnn():
    train_df = pd.read_csv(f"{DATA_PROCESSED}/train_processed.csv")
    test_df  = pd.read_csv(f"{DATA_PROCESSED}/test_processed.csv")

    print(f"Train: {len(train_df):,} | Test: {len(test_df):,}")

    # === TRAINING GRAPHS (PURE NORMAL) ===
    normal_train = train_df[train_df['label'] == 'normal'].copy()
    train_graphs = [g for g, _ in build_graphs_with_ratio(normal_train)]

    # === TEST GRAPHS ===
    test_graphs_with_ratio = build_graphs_with_ratio(test_df)
    test_graphs = [g for g, _ in test_graphs_with_ratio]

    # Extract REAL mostly-normal windows from TEST ONLY
    normal_windows = [g for g, r in test_graphs_with_ratio if r <= MAX_ATTACK_RATIO]
    attack_graphs = [g for g in test_graphs if g.y.item() == 1]

    print(f"Found {len(normal_windows)} REAL mostly-normal windows (≤{MAX_ATTACK_RATIO*100}% attack)")

    # === NO TRAINING DATA IN TEST ===
    normal_windows = normal_windows[:50]   # Use only real ones
    attack_graphs = attack_graphs[:500]
    final_test_graphs = normal_windows + attack_graphs

    print(f"REAL Test → Normal: {len(normal_windows)}, Attack: {len(attack_graphs)}")

    if len(normal_windows) == 0:
        raise ValueError("No real normal windows. Try WINDOW_SIZE=8")

    train_loader = DataLoader(train_graphs, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(final_test_graphs, batch_size=BATCH_SIZE, shuffle=False)

    input_dim = train_graphs[0].x.shape[1]
    model = GNNAnomalyDetector(input_dim).to(DEVICE)
    opt = optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-5)
    crit = nn.BCEWithLogitsLoss()

    print("Training...")
    for epoch in range(1, EPOCHS + 1):
        model.train()
        epoch_loss = 0.0
        for data in train_loader:
            data = data.to(DEVICE)
            data.y = torch.zeros(data.num_graphs, device=DEVICE)
            logits = model(data)
            loss = crit(logits, data.y)
            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            epoch_loss += loss.item()
        if epoch % 20 == 0:
            avg_loss = epoch_loss / len(train_loader)
            print(f"  Epoch {epoch:2d} | Loss: {avg_loss:.8f}")

    torch.save(model.state_dict(), f"{MODELS_DIR}/gnn_detector.pth")

    # === EVALUATION ===
    model.eval()
    scores, y_true = [], []
    with torch.no_grad():
        for data in test_loader:
            data = data.to(DEVICE)
            logits = model(data)
            scores.extend(torch.sigmoid(logits).cpu().numpy())
            y_true.extend(data.y.cpu().numpy())

    scores = np.array(scores)
    y_true = np.array(y_true)

    normal_scores = []
    with torch.no_grad():
        for data in train_loader:
            data = data.to(DEVICE)
            normal_scores.extend(torch.sigmoid(model(data)).cpu().numpy())
    threshold = np.percentile(normal_scores, 95)

    auc_roc = roc_auc_score(y_true, scores)
    prec, rec, _ = precision_recall_curve(y_true, scores)
    auc_pr = auc(rec, prec)
    k = int(0.1 * len(scores))
    prec10 = np.mean(y_true[np.argsort(scores)[-k:]])

    print(f"\nGNN Results (No Leakage):")
    print(f"  ROC-AUC     : {auc_roc:.6f}")
    print(f"  PR-AUC      : {auc_pr:.6f}")
    print(f"  Precision@10%: {prec10:.6f}")
    print(f"  Threshold   : {threshold:.6f}")

    pd.DataFrame({"score": scores, "label": y_true}).to_csv(f"{RESULTS_DIR}/gnn_scores.csv", index=False)

    plt.figure(figsize=(10,6))
    sns.histplot(scores[y_true==0], label="Normal", alpha=0.7, bins=40, color="#1f77b4")
    sns.histplot(scores[y_true==1], label="Attack", alpha=0.7, bins=40, color="#d62728")
    plt.axvline(threshold, color='k', ls='--', label=f'Threshold = {threshold:.4f}')
    plt.title("GNN Anomaly Score Distribution")
    plt.xlabel("Score"); plt.legend()
    plt.tight_layout()
    plt.savefig(f"{RESULTS_DIR}/gnn_score_distribution.png", dpi=200)
    plt.close()

    print(f"\nResults saved in: {RESULTS_DIR}/")
    print("GNN training completed")

# ============================= RUN =============================
train_gnn()

Using device: cuda
Train: 125,973 | Test: 21,934
Found 375 REAL mostly-normal windows (≤30.0% attack)
REAL Test → Normal: 50, Attack: 500
Training...
  Epoch 20 | Loss: 0.00000186
  Epoch 40 | Loss: 0.00000188
  Epoch 60 | Loss: 0.00000192
  Epoch 80 | Loss: 0.00000184

GNN Results (No Leakage):
  ROC-AUC     : 0.832120
  PR-AUC      : 0.979525
  Precision@10%: 1.000000
  Threshold   : 0.000001

Results saved in: results/
GNN training completed — NO LEAKAGE, REAL TEST
