## 1Ô∏è‚É£ Install Dependencies 

In [None]:
%%capture
!pip install torch torch-geometric faiss-cpu scikit-learn pandas numpy tqdm matplotlib seaborn

In [None]:
# Ki·ªÉm tra GPU
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2Ô∏è‚É£ Upload Data Files
Upload c√°c file sau l√™n Kaggle (t·ª´ dataset-processed/):
- `X.npy` - Features (~3M samples, ƒë√£ ƒë∆∞·ª£c scaled)
- `y.npy` - Labels  
- `idx_train.npy` - Training indices
- `idx_val.npy` - Validation indices
- `idx_test.npy` - Test indices

**L∆∞u √Ω**: Dataset n√†y ƒë∆∞·ª£c build t·ª´ to√†n b·ªô 10 file CSV CICIDS2018 (~3 tri·ªáu flows)

## 3Ô∏è‚É£ Utils Functions (utils.py)

In [None]:
"""Utilities for flow_gnn package."""

import logging
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, 
    confusion_matrix, roc_curve, auc, classification_report
)
from typing import Dict, Optional
import json

def get_device(device_str: str = "auto") -> torch.device:
    """Get PyTorch device."""
    if device_str == "auto":
        if torch.cuda.is_available():
            device = torch.device("cuda")
        else:
            device = torch.device("cpu")
    else:
        device = torch.device(device_str)
    
    return device


def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray, 
                    y_probs: Optional[np.ndarray] = None) -> Dict[str, float]:
    """Compute comprehensive metrics.
    
    Args:
        y_true: True labels
        y_pred: Predicted labels
        y_probs: Predicted probabilities (optional, for AUC)
    
    Returns:
        Dictionary of metrics
    """
    metrics = {
        "accuracy": accuracy_score(y_true, y_pred),
        "precision": precision_score(y_true, y_pred, pos_label=1, zero_division=0),
        "recall": recall_score(y_true, y_pred, pos_label=1, zero_division=0),
        "f1": f1_score(y_true, y_pred, pos_label=1, zero_division=0)
    }
    
    # FAR (False Alarm Rate) and Detection Rate
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
    metrics["far"] = fp / (fp + tn) if (fp + tn) > 0 else 0.0
    metrics["detection_rate"] = metrics["recall"]
    
    # AUC if probabilities provided
    if y_probs is not None:
        fpr, tpr, _ = roc_curve(y_true, y_probs)
        metrics["auc"] = auc(fpr, tpr)
    
    return metrics


def save_metrics_plots(y_true: np.ndarray, y_pred: np.ndarray, 
                       y_probs: np.ndarray, metrics: Dict[str, float],
                       output_dir: str, history: Optional[Dict] = None):
    """Save comprehensive performance visualization plots.
    
    Args:
        y_true: True labels
        y_pred: Predicted labels
        y_probs: Predicted probabilities
        metrics: Computed metrics dictionary
        output_dir: Directory to save plots
        history: Training history (optional)
    """
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    # Set style
    plt.style.use('seaborn-v0_8-darkgrid')
    sns.set_palette("husl")
    
    # 1. Confusion Matrix
    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    cm = confusion_matrix(y_true, y_pred)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=['Benign', 'Attack'],
                yticklabels=['Benign', 'Attack'], ax=ax)
    ax.set_title('Confusion Matrix', fontsize=14, fontweight='bold')
    ax.set_ylabel('True Label', fontsize=12)
    ax.set_xlabel('Predicted Label', fontsize=12)
    plt.tight_layout()
    plt.savefig(output_path / 'confusion_matrix.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # 2. ROC Curve
    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    fpr, tpr, _ = roc_curve(y_true, y_probs)
    roc_auc = auc(fpr, tpr)
    
    ax.plot(fpr, tpr, color='darkorange', lw=2, 
            label=f'ROC curve (AUC = {roc_auc:.4f})')
    ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_xlabel('False Positive Rate', fontsize=12)
    ax.set_ylabel('True Positive Rate', fontsize=12)
    ax.set_title('ROC Curve', fontsize=14, fontweight='bold')
    ax.legend(loc="lower right", fontsize=10)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(output_path / 'roc_curve.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # 3. Metrics Bar Chart
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    metric_names = ['Accuracy', 'Precision', 'Recall', 'F1', 'AUC']
    metric_values = [
        metrics.get('accuracy', 0),
        metrics.get('precision', 0),
        metrics.get('recall', 0),
        metrics.get('f1', 0),
        metrics.get('auc', 0)
    ]
    
    bars = ax.bar(metric_names, metric_values, color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'])
    ax.set_ylim([0, 1.1])
    ax.set_ylabel('Score', fontsize=12)
    ax.set_title('Performance Metrics', fontsize=14, fontweight='bold')
    ax.grid(axis='y', alpha=0.3)
    
    # Add value labels on bars
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.4f}', ha='center', va='bottom', fontsize=10)
    
    plt.tight_layout()
    plt.savefig(output_path / 'metrics_bar.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # 4. Training History (if provided)
    if history is not None:
        fig, axes = plt.subplots(1, 2, figsize=(14, 5))
        
        # Loss plot
        if 'train_loss' in history and 'val_loss' in history:
            axes[0].plot(history['train_loss'], label='Train Loss', linewidth=2)
            axes[0].plot(history['val_loss'], label='Val Loss', linewidth=2)
            axes[0].set_xlabel('Epoch', fontsize=12)
            axes[0].set_ylabel('Loss', fontsize=12)
            axes[0].set_title('Training & Validation Loss', fontsize=14, fontweight='bold')
            axes[0].legend(fontsize=10)
            axes[0].grid(True, alpha=0.3)
        
        # F1 plot
        if 'val_f1' in history:
            axes[1].plot(history['val_f1'], label='Val F1', color='green', linewidth=2)
            axes[1].set_xlabel('Epoch', fontsize=12)
            axes[1].set_ylabel('F1 Score', fontsize=12)
            axes[1].set_title('Validation F1 Score', fontsize=14, fontweight='bold')
            axes[1].legend(fontsize=10)
            axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(output_path / 'training_history.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    print(f"üìä Plots saved to {output_path}/")


def save_metrics_report(metrics: Dict[str, float], output_dir: str, 
                        y_true: np.ndarray = None, y_pred: np.ndarray = None,
                        latency: Optional[float] = None):
    """Save metrics to JSON and CSV files.
    
    Args:
        metrics: Metrics dictionary
        output_dir: Directory to save reports
        y_true: True labels (for classification report)
        y_pred: Predicted labels (for classification report)
        latency: Inference latency in seconds (optional)
    """
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    # Add latency if provided
    if latency is not None:
        metrics['latency_seconds'] = latency
        metrics['latency_ms'] = latency * 1000
    
    # Save as JSON
    with open(output_path / 'metrics.json', 'w') as f:
        json.dump(metrics, f, indent=2)
    
    # Save as CSV
    df = pd.DataFrame([metrics])
    df.to_csv(output_path / 'metrics.csv', index=False)
    
    # Save classification report if labels provided
    if y_true is not None and y_pred is not None:
        report = classification_report(y_true, y_pred, 
                                       target_names=['Benign', 'Attack'],
                                       digits=4)
        with open(output_path / 'classification_report.txt', 'w') as f:
            f.write("Classification Report\n")
            f.write("=" * 60 + "\n")
            f.write(report)
            f.write("\n\nDetailed Metrics\n")
            f.write("=" * 60 + "\n")
            for key, value in metrics.items():
                f.write(f"{key:20s}: {value:.6f}\n")
    
    print(f"üìÑ Metrics saved to {output_path}/")


class EarlyStopping:
    """Early stopping callback."""
    
    def __init__(self, patience: int = 10, min_delta: float = 1e-4):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
    
    def __call__(self, score: float) -> bool:
        """Check if should stop."""
        if self.best_score is None:
            self.best_score = score
            return False
        
        if score > self.best_score + self.min_delta:
            self.best_score = score
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        
        return False

print("‚úÖ Utils functions loaded")

## 4Ô∏è‚É£ Graph Builder (graph.py)

In [None]:
"""Build KNN graph from features using FAISS for efficient ANN."""

import faiss
from tqdm import tqdm

def build_knn_graph(X_scaled: np.ndarray, k: int = 10) -> torch.Tensor:
    """Build KNN graph with progress tracking."""
    
    print(f"üî® Building KNN graph (k={k})...")
    
    # Prepare data for FAISS
    X = np.ascontiguousarray(X_scaled, dtype=np.float32)
    n_samples, n_features = X.shape
    
    print(f"   Data shape: {n_samples:,} samples √ó {n_features} features")
    
    # Normalize vectors
    with tqdm(total=1, desc="Normalizing vectors", ncols=100) as pbar:
        faiss.normalize_L2(X)
        pbar.update(1)
    
    # Build FAISS index
    with tqdm(total=1, desc="Building FAISS index", ncols=100) as pbar:
        index = faiss.IndexFlatIP(n_features)
        index.add(X)
        pbar.update(1)
    
    # Search for k+1 neighbors (including self)
    print(f"   Searching for {k} nearest neighbors...")
    with tqdm(total=n_samples, desc="KNN search", unit="samples", ncols=100) as pbar:
        batch_size = 10000
        all_indices = []
        
        for i in range(0, n_samples, batch_size):
            end_idx = min(i + batch_size, n_samples)
            _, indices = index.search(X[i:end_idx], k + 1)
            all_indices.append(indices)
            pbar.update(end_idx - i)
        
        indices = np.vstack(all_indices)
    
    # Remove self-loops
    indices = indices[:, 1:]
    
    # Build edge list
    with tqdm(total=1, desc="Building edges", ncols=100) as pbar:
        row = np.repeat(np.arange(n_samples), k)
        col = indices.flatten()
        
        # Symmetrize
        edges = np.vstack([
            np.concatenate([row, col]),
            np.concatenate([col, row])
        ])
        pbar.update(1)
    
    # Remove duplicates
    with tqdm(total=1, desc="Removing duplicates", ncols=100) as pbar:
        edges = np.unique(edges, axis=1)
        pbar.update(1)
    
    edge_index = torch.tensor(edges, dtype=torch.long)
    
    num_edges = edge_index.shape[1]
    avg_degree = num_edges / n_samples
    
    print(f"‚úÖ KNN graph built: {num_edges:,} edges, avg degree: {avg_degree:.2f}")
    
    return edge_index

print("‚úÖ Graph builder loaded")

## 5Ô∏è‚É£ Model Definition (model.py)

In [None]:
"""GraphSAGE model for flow classification."""

import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

class FlowGraphSAGE(torch.nn.Module):
    """GraphSAGE model for flow classification."""    
    def __init__(self, in_dim: int, hidden_dim: int = 128, num_classes: int = 2, num_layers: int = 2, dropout: float = 0.3):
        super().__init__()

        self.num_layers = num_layers
        self.dropout = dropout
        
        # Build GraphSAGE layers
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_dim, hidden_dim))
        
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_dim, hidden_dim))
        
        if num_layers > 1:
            self.convs.append(SAGEConv(hidden_dim, hidden_dim))
        
        # Batch normalization
        self.bns = torch.nn.ModuleList()
        for _ in range(num_layers):
            self.bns.append(torch.nn.BatchNorm1d(hidden_dim))
        
        # Classifier - output 1 logit for binary classification
        self.classifier = torch.nn.Linear(hidden_dim, 1)
        
        print(f"‚úÖ FlowGraphSAGE: {in_dim}‚Üí{hidden_dim}x{num_layers}‚Üí1 (binary)")
    
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        """Forward pass. Returns logits (no sigmoid)."""
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        
        x = self.classifier(x)  # Shape: [N, 1]
        return x.squeeze(-1)  # Shape: [N]
    
    def get_embeddings(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        """Get node embeddings before classification."""
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        return x

print("‚úÖ Model definition loaded")

## 6Ô∏è‚É£ Training Functions (train.py)

In [None]:
"""Training functions for Flow-based GNN."""

import torch.nn as nn
from torch_geometric.data import Data
from pathlib import Path
import time
from tqdm import tqdm
from sklearn.metrics import classification_report, f1_score, precision_score, recall_score


class RandomNodeSampler:
    """Simple random node sampler for mini-batch training on full graph.
    
    Instead of sampling neighbors (which requires pyg-lib/torch-sparse),
    we sample random nodes and compute loss only on those nodes while
    using the full graph for message passing.
    """
    def __init__(self, mask: torch.Tensor, batch_size: int, shuffle: bool = True):
        self.node_indices = mask.nonzero(as_tuple=True)[0]
        self.batch_size = batch_size
        self.shuffle = shuffle
        
    def __iter__(self):
        indices = self.node_indices.clone()
        if self.shuffle:
            perm = torch.randperm(len(indices))
            indices = indices[perm]
        
        for i in range(0, len(indices), self.batch_size):
            yield indices[i:i + self.batch_size]
    
    def __len__(self):
        return (len(self.node_indices) + self.batch_size - 1) // self.batch_size


def train_flow_gnn(
    x_tensor: torch.Tensor,
    y_tensor: torch.Tensor,
    edge_index: torch.Tensor,
    train_mask: torch.Tensor,
    val_mask: torch.Tensor,
    test_mask: torch.Tensor,
    config: dict,
    device: torch.device
) -> Dict:
    """Train Flow-based GNN model with comprehensive logging and progress tracking.
    
    Uses full-graph message passing with mini-batch node sampling for loss computation.
    This approach doesn't require pyg-lib or torch-sparse.
    """
    
    print("\n" + "="*80)
    print("üöÄ TRAINING FLOW-BASED GNN")
    print("="*80)
    
    # Create PyG Data and move to device
    data = Data(
        x=x_tensor,
        edge_index=edge_index,
        y=y_tensor,
        train_mask=train_mask,
        val_mask=val_mask,
        test_mask=test_mask
    ).to(device)
    
    # Calculate pos_weight from TRAINING set only
    y_train = y_tensor[train_mask]
    pos = (y_train == 1).sum().item()
    neg = (y_train == 0).sum().item()
    pos_weight = neg / pos if pos > 0 else 1.0
    
    print(f"\nüìä Dataset Statistics:")
    print(f"   Training samples: {train_mask.sum().item():,}")
    print(f"   Validation samples: {val_mask.sum().item():,}")
    print(f"   Test samples: {test_mask.sum().item():,}")
    print(f"   Class distribution (train): Benign={neg:,} ({neg/(neg+pos)*100:.1f}%), Attack={pos:,} ({pos/(neg+pos)*100:.1f}%)")
    print(f"   Positive weight (for loss): {pos_weight:.4f}")
    
    # Model
    print(f"\nüèóÔ∏è  Building Model:")
    model = FlowGraphSAGE(
        in_dim=x_tensor.shape[1],
        hidden_dim=config['model']['hidden_dim'],
        num_classes=config['model']['num_classes'],
        num_layers=config['model']['num_layers'],
        dropout=config['model']['dropout']
    ).to(device)
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"   Total parameters: {total_params:,}")
    print(f"   Trainable parameters: {trainable_params:,}")
    print(f"   Device: {device}")
    
    # Optimizer & Loss (BCEWithLogitsLoss with pos_weight)
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=config['training']['learning_rate'],
        weight_decay=config['training'].get('weight_decay', 0)
    )
    criterion = nn.BCEWithLogitsLoss(
        pos_weight=torch.tensor([pos_weight], device=device)
    )
    
    print(f"\n‚öôÔ∏è  Training Configuration:")
    print(f"   Epochs: {config['training']['epochs']}")
    print(f"   Batch size: {config['training']['batch_size']}")
    print(f"   Learning rate: {config['training']['learning_rate']}")
    print(f"   Weight decay: {config['training'].get('weight_decay', 0)}")
    print(f"   Early stopping patience: {config['training'].get('patience', 10)}")
    print(f"   Mode: Full-graph message passing with mini-batch loss")
    
    # Random node sampler for training
    train_sampler = RandomNodeSampler(
        train_mask,
        batch_size=config['training']['batch_size'],
        shuffle=True
    )
    
    # Training loop with history tracking
    early_stopping = EarlyStopping(
        patience=config['training'].get('patience', 10),
        min_delta=config['training'].get('min_delta', 0.001)
    )
    
    best_f1 = 0.0
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_f1': [],
        'val_accuracy': []
    }
    
    print(f"\nüî• Starting Training...")
    print("-" * 80)
    
    # Progress bar for epochs
    epoch_pbar = tqdm(range(1, config['training']['epochs'] + 1), 
                      desc="Training", unit="epoch", ncols=120)
    
    for epoch in epoch_pbar:
        # Train - full graph forward pass, mini-batch loss
        model.train()
        total_loss = 0
        num_batches = 0
        
        for batch_nodes in train_sampler:
            batch_nodes = batch_nodes.to(device)
            optimizer.zero_grad()
            
            # Full graph forward pass
            logits = model(data.x, data.edge_index)
            
            # Compute loss only on batch nodes
            loss = criterion(logits[batch_nodes], data.y[batch_nodes].float())
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            num_batches += 1
        
        train_loss = total_loss / num_batches
        
        # Validate - full graph inference
        val_loss, val_metrics = evaluate(model, data, val_mask, criterion, device)
        
        # Update history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_f1'].append(val_metrics['f1'])
        history['val_accuracy'].append(val_metrics['accuracy'])
        
        # Update progress bar
        epoch_pbar.set_postfix({
            'train_loss': f"{train_loss:.4f}",
            'val_loss': f"{val_loss:.4f}",
            'val_f1': f"{val_metrics['f1']:.4f}",
            'val_acc': f"{val_metrics['accuracy']:.4f}"
        })
        
        # Log important epochs
        if epoch % 10 == 0 or epoch == 1:
            print(
                f"   Epoch {epoch:3d}/{config['training']['epochs']} | "
                f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
                f"Val F1: {val_metrics['f1']:.4f} | Val Acc: {val_metrics['accuracy']:.4f}"
            )

        # Save best model
        if val_metrics['f1'] > best_f1:
            best_f1 = val_metrics['f1']
            save_path = Path(config.get('output_dir', 'output/flow_gnn')) / 'best_model.pt'
            save_path.parent.mkdir(parents=True, exist_ok=True)
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'metrics': val_metrics,
                'config': config
            }, save_path)
        
        # Early stopping
        if early_stopping(val_metrics['f1']):
            print(f"\n‚ö†Ô∏è  Early stopping triggered at epoch {epoch}")
            break
    
    epoch_pbar.close()
    
    print(f"\n‚úÖ Training completed!")
    print(f"   Best validation F1: {best_f1:.4f}")
    
    # Tune threshold on validation set
    print("\n" + "="*80)
    print("üéØ TUNING DECISION THRESHOLD")
    print("="*80)
    best_threshold = tune_threshold(model, data, val_mask, device)
    print(f"‚úÖ Optimal threshold: {best_threshold:.4f}")
    
    # Test with tuned threshold
    print("\n" + "="*80)
    print("üß™ FINAL EVALUATION ON TEST SET")
    print("="*80)
    
    start_time = time.time()
    test_loss, test_metrics, y_true, y_pred, y_probs = evaluate_with_predictions(
        model, data, test_mask, criterion, best_threshold, device
    )
    inference_time = time.time() - start_time
    latency_per_sample = inference_time / len(y_true)
    
    print(f"\nüìà Test Results:")
    print(f"   Accuracy:  {test_metrics['accuracy']:.4f}")
    print(f"   Precision: {test_metrics['precision']:.4f}")
    print(f"   Recall:    {test_metrics['recall']:.4f}")
    print(f"   F1 Score:  {test_metrics['f1']:.4f}")
    if 'auc' in test_metrics:
        print(f"   AUC:       {test_metrics['auc']:.4f}")
    print(f"   FAR:       {test_metrics['far']:.4f}")
    print(f"\n‚è±Ô∏è  Inference Performance:")
    print(f"   Total time: {inference_time:.2f}s")
    print(f"   Latency per sample: {latency_per_sample*1000:.4f}ms")
    print(f"   Throughput: {len(y_true)/inference_time:.2f} samples/sec")
    
    # Save visualizations and reports
    output_dir = Path(config.get('output_dir', 'output/flow_gnn'))
    print(f"\nüíæ Saving results to {output_dir}/")
    
    save_metrics_plots(y_true, y_pred, y_probs, test_metrics, 
                      str(output_dir), history=history)
    save_metrics_report(test_metrics, str(output_dir), 
                       y_true, y_pred, latency=latency_per_sample)
    
    # Classification report
    print("\nüìä Detailed Classification Report:")
    print(classification_report(y_true, y_pred, target_names=['Benign', 'Attack'], digits=4))
    
    print("\n" + "="*80)
    print("‚ú® ALL DONE!")
    print("="*80 + "\n")
    
    return test_metrics


def tune_threshold(model, data, mask, device):
    """Find optimal threshold on validation set to maximize F1 score.
    
    Uses full-graph inference (no neighbor sampling).
    """
    model.eval()
    
    with torch.no_grad():
        logits = model(data.x, data.edge_index)
        val_logits = logits[mask]
        val_targets = data.y[mask]
    
    # Convert to probabilities
    val_probs = torch.sigmoid(val_logits).cpu().numpy()
    y_val_np = val_targets.cpu().numpy()
    
    # Search for best threshold with progress bar
    best_t, best_f1 = 0.5, 0.0
    best_precision, best_recall = 0.0, 0.0
    
    thresholds = np.linspace(0.01, 0.99, 99)
    for t in tqdm(thresholds, desc="Searching threshold", ncols=100, leave=False):
        y_pred = (val_probs >= t).astype(int)
        f1 = f1_score(y_val_np, y_pred, zero_division=0)
        
        if f1 > best_f1:
            best_f1 = f1
            best_t = t
            best_precision = precision_score(y_val_np, y_pred, zero_division=0)
            best_recall = recall_score(y_val_np, y_pred, zero_division=0)
    
    print(f"   Threshold: {best_t:.4f}")
    print(f"   Precision: {best_precision:.4f}")
    print(f"   Recall:    {best_recall:.4f}")
    print(f"   F1 Score:  {best_f1:.4f}")
    
    return best_t


def evaluate(model, data, mask, criterion, device, threshold=0.5):
    """Evaluate model on given mask using full-graph inference."""
    model.eval()
    
    with torch.no_grad():
        logits = model(data.x, data.edge_index)
        
        # Get predictions for masked nodes
        mask_logits = logits[mask]
        mask_y = data.y[mask]
        
        loss = criterion(mask_logits, mask_y.float()).item()
    
    # Convert logits to probabilities and apply threshold
    probs = torch.sigmoid(mask_logits).cpu().numpy()
    pred = (probs >= threshold).astype(int)
    true = mask_y.cpu().numpy()
    
    metrics = compute_metrics(true, pred, y_probs=probs)
    
    return loss, metrics


def evaluate_with_predictions(model, data, mask, criterion, threshold, device):
    """Evaluate and return predictions using full-graph inference."""
    model.eval()
    
    with torch.no_grad():
        logits = model(data.x, data.edge_index)
        
        # Get predictions for masked nodes
        mask_logits = logits[mask]
        mask_y = data.y[mask]
        
        loss = criterion(mask_logits, mask_y.float()).item()
    
    # Convert logits to probabilities and apply threshold
    probs = torch.sigmoid(mask_logits).cpu().numpy()
    pred = (probs >= threshold).astype(int)
    true = mask_y.cpu().numpy()
    
    metrics = compute_metrics(true, pred, y_probs=probs)
    
    return loss, metrics, true, pred, probs

print("‚úÖ Training functions loaded")

## 7Ô∏è‚É£ Load Data
Thay ƒë·ªïi path n·∫øu c·∫ßn

In [None]:
# Load data (adjust paths as needed)
print("üìÇ Loading data from dataset-processed...")

X = np.load('/kaggle/input/cicids-cleaned/dataset-processed/X.npy')
y = np.load('/kaggle/input/cicids-cleaned/dataset-processed/y.npy')
idx_train = np.load('/kaggle/input/cicids-cleaned/dataset-processed/idx_train.npy')
idx_val = np.load('/kaggle/input/cicids-cleaned/dataset-processed/idx_val.npy')
idx_test = np.load('/kaggle/input/cicids-cleaned/dataset-processed/idx_test.npy')

print(f"‚úÖ Data loaded:")
print(f"   X shape: {X.shape}")
print(f"   y shape: {y.shape}")
print(f"   Train: {len(idx_train):,} samples")
print(f"   Val:   {len(idx_val):,} samples")
print(f"   Test:  {len(idx_test):,} samples")
print(f"   Attack ratio: {(y==1).sum()/len(y):.2%}")

# üî• OPTIONAL: Limit dataset size for faster training on Kaggle
MAX_SAMPLES = None  # Set to None to use full dataset, or 2_000_000 for 2M samples

if MAX_SAMPLES is not None and len(X) > MAX_SAMPLES:
    print(f"\n‚ö†Ô∏è  Dataset has {len(X):,} samples, limiting to {MAX_SAMPLES:,}...")
    
    # Strategy: Sample from each split proportionally to maintain distribution
    train_limit = int(MAX_SAMPLES * len(idx_train) / len(X))
    val_limit = int(MAX_SAMPLES * len(idx_val) / len(X))
    test_limit = MAX_SAMPLES - train_limit - val_limit
    
    # Sample indices
    np.random.seed(42)
    idx_train_sampled = np.random.choice(idx_train, size=min(train_limit, len(idx_train)), replace=False)
    idx_val_sampled = np.random.choice(idx_val, size=min(val_limit, len(idx_val)), replace=False)
    idx_test_sampled = np.random.choice(idx_test, size=min(test_limit, len(idx_test)), replace=False)
    
    # Combine all selected indices
    all_selected_idx = np.concatenate([idx_train_sampled, idx_val_sampled, idx_test_sampled])
    all_selected_idx.sort()
    
    # Filter data
    X = X[all_selected_idx]
    y = y[all_selected_idx]
    
    # Remap indices to new positions
    idx_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(all_selected_idx)}
    idx_train = np.array([idx_mapping[idx] for idx in idx_train_sampled])
    idx_val = np.array([idx_mapping[idx] for idx in idx_val_sampled])
    idx_test = np.array([idx_mapping[idx] for idx in idx_test_sampled])
    
    print(f"‚úÖ Data limited:")
    print(f"   X shape: {X.shape}")
    print(f"   y shape: {y.shape}")
    print(f"   Train: {len(idx_train):,} samples")
    print(f"   Val:   {len(idx_val):,} samples")
    print(f"   Test:  {len(idx_test):,} samples")
    print(f"   Attack ratio: {(y==1).sum()/len(y):.2%}")
else:
    print(f"\n‚úÖ Using full dataset: {len(X):,} samples")

## 8Ô∏è‚É£ Build KNN Graph

In [None]:
print("üî® Building KNN graph on FULL dataset...")
edge_index = build_knn_graph(X, k=10)

## 9Ô∏è‚É£ Prepare Data & Configuration

In [None]:
# Convert to PyTorch tensors
x_tensor = torch.from_numpy(X).float()
y_tensor = torch.from_numpy(y).long()

# Ensure edge_index is contiguous
edge_index = edge_index.contiguous()

# Create masks
train_mask = torch.zeros(len(y), dtype=torch.bool)
train_mask[idx_train] = True

val_mask = torch.zeros(len(y), dtype=torch.bool)
val_mask[idx_val] = True

test_mask = torch.zeros(len(y), dtype=torch.bool)
test_mask[idx_test] = True

print(f"‚úÖ Tensors created")
print(f"   x_tensor: {x_tensor.shape}")
print(f"   y_tensor: {y_tensor.shape}")
print(f"   edge_index: {edge_index.shape}")

# Configuration - matching flow_gnn/config.yaml
config = {
    'model': {
        'hidden_dim': 128,
        'num_classes': 2,
        'num_layers': 2,
        'dropout': 0.3
    },
    'training': {
        'epochs': 50,
        'batch_size': 512,
        'learning_rate': 0.001,
        'weight_decay': 0.0001,
        'patience': 10,
        'min_delta': 0.001
    },
    'output_dir': 'output'
}

print("\n‚öôÔ∏è  Configuration:")
print(f"   Model: {config['model']}")
print(f"   Training: {config['training']}")

# Get device
device = get_device("auto")
print(f"\nüñ•Ô∏è  Using device: {device}")

## üîü Train Model

In [None]:
# Train the model
metrics = train_flow_gnn(
    x_tensor, 
    y_tensor, 
    edge_index,
    train_mask, 
    val_mask, 
    test_mask,
    config, 
    device
)

print("\n‚úÖ Training completed!")

## üéâ Done!
Model ƒë√£ ƒë∆∞·ª£c train xong. C√°c metrics ƒë√£ ƒë∆∞·ª£c in ra ·ªü tr√™n.

### Metrics bao g·ªìm:
- **Accuracy**: ƒê·ªô ch√≠nh x√°c t·ªïng th·ªÉ
- **Precision**: ƒê·ªô ch√≠nh x√°c khi d·ª± ƒëo√°n attack
- **Recall / Detection Rate**: T·ª∑ l·ªá ph√°t hi·ªán attack
- **F1 Score**: ƒêi·ªÉm c√¢n b·∫±ng gi·ªØa precision v√† recall
- **FAR (False Alarm Rate)**: T·ª∑ l·ªá c·∫£nh b√°o nh·∫ßm