In [6]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv, GraphSAGE, TransformerConv
from torch_geometric.loader import NeighborLoader
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, precision_recall_curve, f1_score
from sklearn.utils.class_weight import compute_class_weight
from imblearn.over_sampling import SMOTE, ADASYN
import networkx as nx
import warnings
import pickle
import os
import json
from collections import defaultdict
warnings.filterwarnings('ignore')

# Mount Google Drive (if using Colab)
try:
    CHECKPOINT_DIR = 'Ransomware_Checkpoints_v2'
except:
    CHECKPOINT_DIR = './checkpoints'

os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# ============================================================================
# CRITICAL FIX 1: FOCAL LOSS FOR EXTREME IMBALANCE
# ============================================================================
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, weight=None, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.weight = weight
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, weight=self.weight, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# ============================================================================
# CRITICAL FIX 2: LEAKAGE-FREE EVALUATION
# ============================================================================
def evaluate_without_leakage(model, data, test_indices, device, threshold=0.5):
    """
    Evaluate model on test set WITHOUT data leakage
    Uses only test nodes without neighbor sampling
    """
    model.eval()
    with torch.no_grad():
        # Create a mapping from original indices to test subgraph indices
        test_nodes = torch.tensor(test_indices, device=device)
        
        # Get the subgraph containing test nodes and their direct neighbors
        # This is necessary because GNN needs neighbor information
        all_edges = data.edge_index
        
        # Find all neighbors of test nodes
        test_neighbors = set()
        for node in test_indices:
            # Find edges where this node appears
            neighbors = all_edges[1][all_edges[0] == node].cpu().numpy()
            test_neighbors.update(neighbors)
            neighbors = all_edges[0][all_edges[1] == node].cpu().numpy()
            test_neighbors.update(neighbors)
        
        # Combine test nodes with their neighbors
        subgraph_nodes = list(set(test_indices) | test_neighbors)
        subgraph_nodes = torch.tensor(subgraph_nodes, device=device)
        
        # Create node mapping
        node_mapping = {old_idx.item(): new_idx for new_idx, old_idx in enumerate(subgraph_nodes)}
        
        # Extract subgraph
        subgraph_x = data.x[subgraph_nodes]
        
        # Filter edges to only include those within the subgraph
        mask = torch.isin(all_edges[0], subgraph_nodes) & torch.isin(all_edges[1], subgraph_nodes)
        subgraph_edges = all_edges[:, mask]
        
        # Remap edge indices
        subgraph_edges_remapped = torch.zeros_like(subgraph_edges)
        for i in range(subgraph_edges.shape[1]):
            subgraph_edges_remapped[0, i] = node_mapping[subgraph_edges[0, i].item()]
            subgraph_edges_remapped[1, i] = node_mapping[subgraph_edges[1, i].item()]
        
        # Forward pass on subgraph
        out = model(subgraph_x, subgraph_edges_remapped)
        
        # Extract predictions only for original test nodes
        test_node_indices_in_subgraph = [node_mapping[idx] for idx in test_indices if idx in node_mapping]
        test_out = out[test_node_indices_in_subgraph]
        
        # Get predictions and probabilities
        pred_proba = F.softmax(test_out, dim=1)[:, 1]
        pred_binary = (pred_proba >= threshold).long()
        pred_argmax = test_out.argmax(dim=1)
        
        # Get true labels for test nodes
        true_labels = data.y[test_indices]
        
    return {
        'predictions': pred_binary.cpu().numpy(),
        'predictions_argmax': pred_argmax.cpu().numpy(),
        'probabilities': pred_proba.cpu().numpy(),
        'true_labels': true_labels.cpu().numpy()
    }

# ============================================================================
# ADVANCED FEATURE ENGINEERING: GRAPH CENTRALITY MEASURES
# ============================================================================
def compute_graph_centrality_features(edge_index, num_nodes):
    """Compute centrality measures for graph nodes"""
    print("Computing graph centrality features...")
    
    # Convert to NetworkX graph
    edge_list = edge_index.t().cpu().numpy()
    G = nx.Graph()
    G.add_nodes_from(range(num_nodes))
    G.add_edges_from(edge_list)
    
    # Compute centrality measures
    print("  - Computing PageRank...")
    pagerank = nx.pagerank(G, max_iter=50)
    
    print("  - Computing Betweenness Centrality...")
    betweenness = nx.betweenness_centrality(G, k=min(1000, num_nodes))
    
    # print("  - Computing Closeness Centrality...")
    # closeness = nx.closeness_centrality(G)
    
    print("  - Computing Degree Centrality...")
    degree_centrality = nx.degree_centrality(G)
    
    # Convert to arrays
    pagerank_array = np.array([pagerank.get(i, 0) for i in range(num_nodes)])
    betweenness_array = np.array([betweenness.get(i, 0) for i in range(num_nodes)])
    # closeness_array = np.array([closeness.get(i, 0) for i in range(num_nodes)])
    degree_array = np.array([degree_centrality.get(i, 0) for i in range(num_nodes)])
    
    return {
        'pagerank': pagerank_array,
        'betweenness': betweenness_array,
        # 'closeness': closeness_array,
        'degree_centrality': degree_array
    }

# ============================================================================
# ADVANCED ARCHITECTURE: MULTI-LAYER GNN WITH ATTENTION FUSION
# ============================================================================
class AdvancedGNN(nn.Module):
    def __init__(self, num_features, hidden_dim=128, num_classes=2, dropout=0.3):
        super(AdvancedGNN, self).__init__()
        
        # --- FIX: Define the input dimension for the second and third layers ---
        # This is the dimension after concatenating the first GAT and GCN outputs.
        layer2_input_dim = (hidden_dim * 4) + hidden_dim  # 512 + 128 = 640
        
        # --- FIX: Define the input dimension for the final layer ---
        # This is the dimension after concatenating the second GAT and GCN outputs.
        layer3_input_dim = (hidden_dim * 4) + hidden_dim  # 512 + 128 = 640

        # Multiple GNN Types
        self.gat_layers = nn.ModuleList([
            GATConv(num_features, hidden_dim, heads=4, dropout=dropout),
            # --- FIX: Correct the input dimension ---
            GATConv(layer2_input_dim, hidden_dim, heads=4, dropout=dropout),
            # --- FIX: Correct the input dimension ---
            GATConv(layer3_input_dim, hidden_dim // 2, heads=2, dropout=dropout)
        ])
        
        self.gcn_layers = nn.ModuleList([
            GCNConv(num_features, hidden_dim),
            # --- FIX: Correct the input dimension ---
            GCNConv(layer2_input_dim, hidden_dim),
            # --- FIX: Correct the input dimension ---
            GCNConv(layer3_input_dim, hidden_dim // 2)
        ])
        
        # Batch normalization layers
        self.bn_layers = nn.ModuleList([
            nn.BatchNorm1d(hidden_dim * 4 + hidden_dim),  # GAT(512) + GCN(128) = 640
            nn.BatchNorm1d(hidden_dim * 4 + hidden_dim),  # GAT(512) + GCN(128) = 640
            nn.BatchNorm1d((hidden_dim // 2 * 2) + (hidden_dim // 2)) # GAT(128) + GCN(64) = 192
        ])
        
        # Residual connections
        self.residual_projections = nn.ModuleList([
            nn.Linear(num_features, layer2_input_dim),
            # --- FIX: Correct the input dimension for the projection ---
            nn.Linear(layer2_input_dim, layer3_input_dim),
            # --- FIX: Correct the input dimension for the projection ---
            nn.Linear(layer3_input_dim, (hidden_dim // 2 * 2) + (hidden_dim // 2))
        ])
        
        # Feature-only path (XGBoost mimic)
        self.feature_path = nn.Sequential(
            nn.Linear(num_features, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.Dropout(dropout)
        )
        
        # Combined classifier
        # --- FIX: Adjust combined_dim based on the final layer's output ---
        final_graph_dim = (hidden_dim // 2 * 2) + (hidden_dim // 2) # 192
        feature_only_dim = hidden_dim // 2 # 64
        combined_dim = final_graph_dim + feature_only_dim # 192 + 64 = 256
        self.classifier = nn.Sequential(
            nn.Linear(combined_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.Dropout(dropout // 2),
            nn.Linear(hidden_dim // 2, 64),
            nn.ReLU(),
            nn.Dropout(dropout // 2),
            nn.Linear(64, num_classes)
        )
        
        self.dropout = dropout
        
    def forward(self, x, edge_index, batch=None):
        # Store original features for the feature-only path
        original_x = x
        current_x = x
        
        for i, (gat_layer, gcn_layer) in enumerate(zip(self.gat_layers, self.gcn_layers)):
            # --- FIX: Store the input of the current block for the residual connection ---
            input_for_residual = current_x
            
            # GAT path
            h_gat = F.elu(gat_layer(current_x, edge_index))
            
            # GCN path
            h_gcn = F.elu(gcn_layer(current_x, edge_index))
            
            # Concatenate different GNN outputs
            h_combined = torch.cat([h_gat, h_gcn], dim=1)
            
            # Residual connection
            residual = self.residual_projections[i](input_for_residual)
            
            # Add residual connection
            h_combined = h_combined + residual

            # Batch normalization
            h_combined = self.bn_layers[i](h_combined)
            
            # Dropout
            current_x = F.dropout(h_combined, p=self.dropout, training=self.training)
        
        # Graph features are the final output of the GNN blocks
        graph_features = current_x
        
        # Feature-only path
        feature_only = self.feature_path(original_x)
        
        # Combine all paths
        combined = torch.cat([graph_features, feature_only], dim=1)
        
        # Final classification
        out = self.classifier(combined)
        
        return out
    
# ============================================================================
# CHECKPOINT SYSTEM (KEEPING YOUR EXISTING SYSTEM)
# ============================================================================
def save_checkpoint(checkpoint_name, data, message="Checkpoint saved"):
    """Save checkpoint data"""
    checkpoint_path = os.path.join(CHECKPOINT_DIR, f"{checkpoint_name}.pkl")
    with open(checkpoint_path, 'wb') as f:
        pickle.dump(data, f)
    print(f"‚úÖ {message}: {checkpoint_path}")

def load_checkpoint(checkpoint_name):
    """Load checkpoint data"""
    checkpoint_path = os.path.join(CHECKPOINT_DIR, f"{checkpoint_name}.pkl")
    if os.path.exists(checkpoint_path):
        with open(checkpoint_path, 'rb') as f:
            return pickle.load(f)
    return None

def checkpoint_exists(checkpoint_name):
    """Check if checkpoint exists"""
    checkpoint_path = os.path.join(CHECKPOINT_DIR, f"{checkpoint_name}.pkl")
    return os.path.exists(checkpoint_path)

# ============================================================================
# MAIN INTEGRATION FUNCTION
# ============================================================================
def main_advanced_gnn_training():
    """Complete advanced GNN training with your existing structure"""
    
    print("üöÄ ADVANCED GNN RANSOMWARE DETECTION WITH LEAKAGE FIXES")
    print("="*70)
    
    # STEP 1: Data Loading (keeping your existing logic)
    if checkpoint_exists('data_loaded'):
        print("[RESUME] Loading data from checkpoint...")
        checkpoint_data = load_checkpoint('data_loaded')
        wallets_df = checkpoint_data['wallets_df']
        edges_df = checkpoint_data['edges_df']
        print(f"[SUCCESS] Resumed with {len(wallets_df)} wallet records and {len(edges_df)} edges")
    else:
        print("[INFO] Loading wallet features and classes...")
        wallets_df = pd.read_csv('Elliptic++ Dataset/wallets_features_classes_combined.csv')
        print(f"[SUCCESS] Loaded {len(wallets_df)} wallet records")

        print("[INFO] Loading address-to-address edges...")
        edges_df = pd.read_csv('Elliptic++ Dataset/AddrAddr_edgelist.csv')
        print(f"[SUCCESS] Loaded {len(edges_df)} edges")

        save_checkpoint('data_loaded', {
            'wallets_df': wallets_df,
            'edges_df': edges_df
        }, "Data loading checkpoint saved")

    # STEP 2: Feature Engineering (using your enhanced features)
    def create_enhanced_pattern_features(df):
        """Create the EXACT same enhanced features as XGBoost model"""
        numeric_columns = df.select_dtypes(include=[np.number]).columns
        df[numeric_columns] = df[numeric_columns].fillna(0)

        # Your existing feature engineering
        df['partner_transaction_ratio'] = (
            df.get('transacted_w_address_total', 0) /
            (df.get('total_txs', 1) + 1e-8)
        )
        df['activity_density'] = (
            df.get('total_txs', 0) /
            (df.get('lifetime_in_blocks', 1) + 1e-8)
        )
        df['transaction_size_variance'] = (
            df.get('btc_transacted_max', 0) - df.get('btc_transacted_min', 0)
        ) / (df.get('btc_transacted_mean', 1) + 1e-8)
        df['flow_imbalance'] = (
            (df.get('btc_sent_total', 0) - df.get('btc_received_total', 0)) /
            (df.get('btc_transacted_total', 1) + 1e-8)
        )
        df['temporal_spread'] = (
            df.get('last_block_appeared_in', 0) - df.get('first_block_appeared_in', 0)
        ) / (df.get('num_timesteps_appeared_in', 1) + 1e-8)
        df['fee_percentile'] = (
            df.get('fees_total', 0) /
            (df.get('btc_transacted_total', 1) + 1e-8)
        )
        df['interaction_intensity'] = (
            df.get('num_addr_transacted_multiple', 0) /
            (df.get('transacted_w_address_total', 1) + 1e-8)
        )
        df['value_per_transaction'] = (
            df.get('btc_transacted_total', 0) /
            (df.get('total_txs', 1) + 1e-8)
        )
        df['burst_activity'] = (
            df.get('total_txs', 0) * df.get('activity_density', 0)
        )
        df['mixing_intensity'] = (
            df.get('partner_transaction_ratio', 0) * df.get('interaction_intensity', 0)
        )
        return df

    if checkpoint_exists('features_engineered'):
        print("[RESUME] Loading feature engineering from checkpoint...")
        checkpoint_data = load_checkpoint('features_engineered')
        wallets_df = checkpoint_data['wallets_df']
        print(f"[SUCCESS] Resumed with {len(wallets_df.columns)} columns")
    else:
        print("[INFO] Creating enhanced features...")
        wallets_df = create_enhanced_pattern_features(wallets_df)
        print(f"[SUCCESS] Enhanced features created. Now have {len(wallets_df.columns)} columns")
        save_checkpoint('features_engineered', {'wallets_df': wallets_df}, "Feature engineering checkpoint saved")

    # STEP 3: Data Cleaning (keeping your existing logic)
    if checkpoint_exists('data_cleaned'):
        print("[RESUME] Loading cleaned data from checkpoint...")
        checkpoint_data = load_checkpoint('data_cleaned')
        wallets_clean = checkpoint_data['wallets_clean']
        class_weight_dict = checkpoint_data['class_weight_dict']
        print(f"[SUCCESS] Resumed with {len(wallets_clean)} cleaned addresses")
    else:
        print("[INFO] Cleaning data...")
        wallets_clean = wallets_df[wallets_df['class'].isin([1, 2])].copy()
        wallets_clean['class'] = wallets_clean['class'].map({1: 1, 2: 0})
        
        class_weights = compute_class_weight('balanced', classes=np.unique(wallets_clean['class']), y=wallets_clean['class'])
        class_weight_dict = {0: class_weights[0], 1: class_weights[1]}
        
        save_checkpoint('data_cleaned', {
            'wallets_clean': wallets_clean,
            'class_weight_dict': class_weight_dict
        }, "Data cleaning checkpoint saved")

    # STEP 4: Graph Construction with Centrality Features
    if checkpoint_exists('graph_built_advanced'):
        print("[RESUME] Loading advanced graph from checkpoint...")
        checkpoint_data = load_checkpoint('graph_built_advanced')
        data = checkpoint_data['data']
        addr_to_idx = checkpoint_data['addr_to_idx']
        feature_cols = checkpoint_data['feature_cols']
        scaler = checkpoint_data['scaler']
        print(f"[SUCCESS] Resumed with graph: {data.x.shape[0]} nodes, {data.edge_index.shape[1]} edges")
    else:
        print("[INFO] Building advanced graph with centrality features...")
        
        # Your existing graph construction
        unique_addresses = wallets_clean['address'].unique()
        addr_to_idx = {addr: idx for idx, addr in enumerate(unique_addresses)}
        
        exclude_cols = ['address', 'Time step', 'class']
        feature_cols = [col for col in wallets_clean.columns if col not in exclude_cols]
        
        features_list = []
        labels_list = []
        
        for addr in unique_addresses:
            addr_data = wallets_clean[wallets_clean['address'] == addr].iloc[0]
            features_list.append(addr_data[feature_cols].values)
            labels_list.append(addr_data['class'])
        
        X = np.array(features_list)
        y = np.array(labels_list)
        
        # Create edge list
        edge_list = []
        for _, row in edges_df.iterrows():
            input_addr = row['input_address']
            output_addr = row['output_address']
            
            if input_addr in addr_to_idx and output_addr in addr_to_idx:
                input_idx = addr_to_idx[input_addr]
                output_idx = addr_to_idx[output_addr]
                edge_list.extend([[input_idx, output_idx], [output_idx, input_idx]])
        
        edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
        
        # CRITICAL ADDITION: Compute centrality features
        print("[ADVANCED] Computing centrality features...")
        centrality_features = compute_graph_centrality_features(edge_index, len(unique_addresses))
        
        # Combine original features with centrality features
        centrality_matrix = np.column_stack([
            centrality_features['pagerank'],
            centrality_features['betweenness'],
            # centrality_features['closeness'],
            centrality_features['degree_centrality']
        ])
        
        X_enhanced = np.concatenate([X, centrality_matrix], axis=1)
        print(f"[SUCCESS] Enhanced features: {X.shape[1]} -> {X_enhanced.shape[1]}")
        
        # Scale features
        scaler = RobustScaler()
        X_scaled = scaler.fit_transform(X_enhanced)
        
        # Create PyTorch tensors
        x = torch.tensor(X_scaled, dtype=torch.float)
        y = torch.tensor(y, dtype=torch.long)
        
        data = Data(x=x, edge_index=edge_index, y=y)
        
        save_checkpoint('graph_built_advanced', {
            'data': data,
            'addr_to_idx': addr_to_idx,
            'feature_cols': feature_cols,
            'scaler': scaler
        }, "Advanced graph construction checkpoint saved")

    # STEP 5: CRITICAL - Proper Train/Test Split (NO DATA LEAKAGE)
    print("[CRITICAL FIX] Creating proper train/test split...")
    
    # Node-level split to prevent data leakage
    train_idx, test_idx = train_test_split(
        range(len(data.y)), 
        test_size=0.2, 
        random_state=42,
        stratify=data.y.numpy()
    )
    
    train_mask = torch.zeros(data.y.size(0), dtype=torch.bool)
    test_mask = torch.zeros(data.y.size(0), dtype=torch.bool)
    train_mask[train_idx] = True
    test_mask[test_idx] = True
    
    print(f"Train nodes: {train_mask.sum().item()}")
    print(f"Test nodes: {test_mask.sum().item()}")
    
    # STEP 6: Apply SMOTE to Training Data
    print("[ADVANCED] Applying SMOTE to training data...")
    X_train = data.x[train_mask].cpu().numpy()
    y_train = data.y[train_mask].cpu().numpy()
    
    # Apply SMOTE only if we have minority class samples
    if sum(y_train) > 0 and sum(y_train) < len(y_train):
        smote = SMOTE(random_state=42, k_neighbors=min(5, sum(y_train) - 1))
        X_train_smote, y_train_smote = smote.fit_resample(X_train, y_train)
        print(f"Original: {np.bincount(y_train)}")
        print(f"SMOTE: {np.bincount(y_train_smote)}")
        
        # Update training data in graph
        # Note: This is complex with graph data, so we'll use class weights instead
        print("[INFO] Using class weights due to graph structure complexity")
    
    # STEP 7: Initialize Advanced Model
    print("[ADVANCED] Initializing advanced GNN model...")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"[INFO] Using device: {device}")
    
    model = AdvancedGNN(
        num_features=data.x.shape[1],
        hidden_dim=128,
        dropout=0.3
    ).to(device)
    
    data = data.to(device)
    
    # STEP 8: Advanced Loss and Optimizer
    print("[ADVANCED] Setting up Focal Loss and optimizer...")
    class_weights_tensor = torch.tensor([class_weight_dict[0], class_weight_dict[1]], dtype=torch.float).to(device)
    
    # Use Focal Loss instead of CrossEntropy
    criterion = FocalLoss(alpha=1, gamma=2, weight=class_weights_tensor)
    
    # Advanced optimizer with different learning rates
    optimizer = torch.optim.AdamW([
        {'params': model.gat_layers.parameters(), 'lr': 0.001},
        {'params': model.gcn_layers.parameters(), 'lr': 0.001},
        {'params': model.classifier.parameters(), 'lr': 0.002}
    ], weight_decay=1e-4)
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2)
    
    # STEP 9: Create NeighborLoader
    print("[INFO] Creating NeighborLoader...")
    train_loader = NeighborLoader(
        data,
        num_neighbors=[15, 10, 5],
        batch_size=512,
        input_nodes=train_mask,
        shuffle=True,
        num_workers=0
    )
    
    # STEP 10: Training Loop
    print("[TRAINING] Starting advanced training...")
    best_f1 = 0
    patience = 0
    max_patience = 50
    
    for epoch in range(500):
        model.train()
        total_loss = 0
        num_batches = 0
        
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            
            out = model(batch.x, batch.edge_index)
            loss = criterion(out, batch.y)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
        
        scheduler.step()
        
        # Evaluation every 10 epochs
        if epoch % 10 == 0:
            # CRITICAL: Use leakage-free evaluation
            results = evaluate_without_leakage(model, data, test_idx, device, threshold=0.5)
            
            # Find optimal threshold
            if len(np.unique(results['true_labels'])) > 1:
                precision, recall, thresholds = precision_recall_curve(
                    results['true_labels'], results['probabilities']
                )
                f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
                best_threshold_idx = np.argmax(f1_scores)
                optimal_threshold = thresholds[best_threshold_idx] if len(thresholds) > 0 else 0.5
                
                # Re-evaluate with optimal threshold
                results = evaluate_without_leakage(model, data, test_idx, device, threshold=optimal_threshold)
                
                test_f1 = f1_score(results['true_labels'], results['predictions'])
                test_acc = np.mean(results['predictions'] == results['true_labels'])
                
                print(f'Epoch {epoch:03d}, Loss: {total_loss/num_batches:.4f}, '
                      f'Test Acc: {test_acc:.4f}, F1: {test_f1:.4f}, '
                      f'Threshold: {optimal_threshold:.3f}')
                
                # Early stopping
                if test_f1 > best_f1:
                    best_f1 = test_f1
                    patience = 0
                    torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, 'best_advanced_gnn.pth'))
                else:
                    patience += 1
                    
                if patience >= max_patience:
                    print(f"Early stopping at epoch {epoch}")
                    break
    
    # STEP 11: Final Evaluation
    print("\n" + "="*70)
    print("üìä FINAL EVALUATION (NO DATA LEAKAGE)")
    print("="*70)
    
    # Load best model
    model.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, 'best_advanced_gnn.pth')))
    
    # Final evaluation
    results = evaluate_without_leakage(model, data, test_idx, device, threshold=optimal_threshold)
    
    test_acc = np.mean(results['predictions'] == results['true_labels'])
    test_f1 = f1_score(results['true_labels'], results['predictions'])
    auc_score = roc_auc_score(results['true_labels'], results['probabilities'])
    
    print(f"Final Test Accuracy: {test_acc:.4f}")
    print(f"Final F1 Score: {test_f1:.4f}")
    print(f"Final AUC Score: {auc_score:.4f}")
    print(f"Optimal Threshold: {optimal_threshold:.3f}")
    
    print("\nClassification Report:")
    print(classification_report(results['true_labels'], results['predictions'], 
                              target_names=['Licit', 'Illicit'], zero_division=0))
    
    print("\nConfusion Matrix:")
    cm = confusion_matrix(results['true_labels'], results['predictions'])
    print(f"   Predicted:  Licit  Illicit")
    print(f"   Licit:      {cm[0,0]:5d}     {cm[0,1]:4d}")
    print(f"   Illicit:    {cm[1,0]:5d}     {cm[1,1]:4d}")
    
    # Save final model
    model_data = {
        'model_state_dict': model.state_dict(),
        'scaler': scaler,
        'addr_to_idx': addr_to_idx,
        'feature_cols': feature_cols,
        'optimal_threshold': optimal_threshold,
        'class_weights': class_weight_dict,
        'model_config': {
            'num_features': data.x.shape[1],
            'hidden_dim': 128,
            'num_classes': 2,
            'dropout': 0.3
        },
        'final_metrics': {
            'test_accuracy': test_acc,
            'auc_score': auc_score,
            'f1_score': test_f1,
            'optimal_threshold': optimal_threshold
        }
    }
    
    save_checkpoint('final_advanced_model', model_data, "Final advanced model saved")
    
    print("\nüéâ ADVANCED GNN TRAINING COMPLETED!")
    print(f"‚úÖ Best F1 Score: {best_f1:.4f}")
    print(f"‚úÖ Final AUC Score: {auc_score:.4f}")
    print(f"‚úÖ Model saved to: {CHECKPOINT_DIR}/final_advanced_model.pkl")
    
    return {
        'model': model,
        'data': data,
        'results': results,
        'metrics': {
            'accuracy': test_acc,
            'f1_score': test_f1,
            'auc_score': auc_score,
            'threshold': optimal_threshold
        }
    }

# ============================================================================
# INFERENCE FUNCTION FOR NEW DATA
# ============================================================================
def predict_new_addresses(model_checkpoint_path, new_addresses_data, scaler, addr_to_idx):
    """
    Predict ransomware probability for new addresses
    """
    print("üîç PREDICTING NEW ADDRESSES...")
    
    # Load model
    checkpoint_data = load_checkpoint('final_advanced_model')
    model_config = checkpoint_data['model_config']
    
    # Initialize model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = AdvancedGNN(**model_config).to(device)
    model.load_state_dict(checkpoint_data['model_state_dict'])
    model.eval()
    
    # Process new data (assuming same format as training data)
    predictions = []
    
    with torch.no_grad():
        for addr_data in new_addresses_data:
            # Scale features
            features_scaled = scaler.transform(addr_data.reshape(1, -1))
            features_tensor = torch.tensor(features_scaled, dtype=torch.float).to(device)
            
            # Note: For new addresses, we might not have graph structure
            # This is a simplified prediction - in practice, you'd need to handle graph structure
            # For now, using only the feature path of the model
            
            # Create dummy edge index for single node
            edge_index = torch.tensor([[0], [0]], dtype=torch.long).to(device)
            
            # Predict
            output = model(features_tensor, edge_index)
            prob = torch.softmax(output, dim=1)[0, 1].cpu().item()
            
            predictions.append({
                'probability': prob,
                'prediction': 1 if prob >= checkpoint_data['optimal_threshold'] else 0
            })
    
    return predictions

# ============================================================================
# COMPARISON WITH BASELINE MODELS
# ============================================================================
def compare_with_baselines(data, train_idx, test_idx):
    """
    Compare Advanced GNN with baseline models
    """
    print("\nüìä COMPARING WITH BASELINE MODELS")
    print("="*50)
    
    X_train = data.x[train_idx].cpu().numpy()
    y_train = data.y[train_idx].cpu().numpy()
    X_test = data.x[test_idx].cpu().numpy()
    y_test = data.y[test_idx].cpu().numpy()
    
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.linear_model import LogisticRegression
    from xgboost import XGBClassifier
    
    models = {
        'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42, class_weight='balanced'),
        'Logistic Regression': LogisticRegression(random_state=42, class_weight='balanced', max_iter=1000),
        'XGBoost': XGBClassifier(random_state=42, scale_pos_weight=len(y_train[y_train==0])/len(y_train[y_train==1]))
    }
    
    results = {}
    
    for name, model in models.items():
        print(f"\nTraining {name}...")
        model.fit(X_train, y_train)
        
        y_pred = model.predict(X_test)
        y_prob = model.predict_proba(X_test)[:, 1]
        
        acc = np.mean(y_pred == y_test)
        f1 = f1_score(y_test, y_pred)
        auc = roc_auc_score(y_test, y_prob)
        
        results[name] = {
            'accuracy': acc,
            'f1_score': f1,
            'auc_score': auc
        }
        
        print(f"{name} - Acc: {acc:.4f}, F1: {f1:.4f}, AUC: {auc:.4f}")
    
    return results

# ============================================================================
# FEATURE IMPORTANCE ANALYSIS
# ============================================================================
def analyze_feature_importance(model, data, feature_cols, device):
    """
    Analyze which features are most important for the model
    """
    print("\nüîç FEATURE IMPORTANCE ANALYSIS")
    print("="*40)
    
    model.eval()
    
    # Get baseline prediction
    with torch.no_grad():
        baseline_output = model(data.x, data.edge_index)
        baseline_probs = torch.softmax(baseline_output, dim=1)[:, 1]
    
    feature_importance = []
    
    # Permutation importance
    for i in range(data.x.shape[1]):
        # Permute feature i
        data_permuted = data.x.clone()
        data_permuted[:, i] = data_permuted[torch.randperm(data_permuted.shape[0]), i]
        
        with torch.no_grad():
            permuted_output = model(data_permuted, data.edge_index)
            permuted_probs = torch.softmax(permuted_output, dim=1)[:, 1]
        
        # Calculate importance as difference in predictions
        importance = torch.mean(torch.abs(baseline_probs - permuted_probs)).item()
        feature_importance.append(importance)
    
    # Sort by importance
    #'closeness',
    feature_names = feature_cols + ['pagerank', 'betweenness', 'degree_centrality']
    importance_pairs = list(zip(feature_names, feature_importance))
    importance_pairs.sort(key=lambda x: x[1], reverse=True)
    
    print("Top 10 Most Important Features:")
    for i, (feature, importance) in enumerate(importance_pairs[:10]):
        print(f"{i+1:2d}. {feature:30s}: {importance:.6f}")
    
    return importance_pairs

# ============================================================================
# MAIN EXECUTION
# ============================================================================
if __name__ == "__main__":
    try:
        # Run the main training function
        results = main_advanced_gnn_training()
        
        print("\n" + "="*70)
        print("üöÄ ADDITIONAL ANALYSIS")
        print("="*70)
        
        # Feature importance analysis
        if 'model' in results and 'data' in results:
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            checkpoint_data = load_checkpoint('features_engineered')
            feature_cols = load_checkpoint('graph_built_advanced')['feature_cols']
            
            importance_results = analyze_feature_importance(
                results['model'], 
                results['data'], 
                feature_cols, 
                device
            )
        
        print("\nüéâ ALL ANALYSIS COMPLETED!")
        print("Check the checkpoints directory for saved models and data.")
        
    except Exception as e:
        print(f"‚ùå Error occurred: {str(e)}")
        import traceback
        traceback.print_exc()
        print("\nüí° Check if all required data files are available and paths are correct.")

üöÄ ADVANCED GNN RANSOMWARE DETECTION WITH LEAKAGE FIXES
[RESUME] Loading data from checkpoint...
[SUCCESS] Resumed with 1268260 wallet records and 2868964 edges
[RESUME] Loading feature engineering from checkpoint...
[SUCCESS] Resumed with 68 columns
[RESUME] Loading cleaned data from checkpoint...
[SUCCESS] Resumed with 367472 cleaned addresses
[RESUME] Loading advanced graph from checkpoint...
[SUCCESS] Resumed with graph: 265354 nodes, 2236444 edges
[CRITICAL FIX] Creating proper train/test split...
Train nodes: 212283
Test nodes: 53071
[ADVANCED] Applying SMOTE to training data...
Original: [200870  11413]
SMOTE: [200870 200870]
[INFO] Using class weights due to graph structure complexity
[ADVANCED] Initializing advanced GNN model...
[INFO] Using device: cuda
[ADVANCED] Setting up Focal Loss and optimizer...
[INFO] Creating NeighborLoader...
[TRAINING] Starting advanced training...
‚ùå Error occurred: CUDA out of memory. Tried to allocate 3.80 GiB. GPU 0 has a total capacity of 6

Traceback (most recent call last):
  File "C:\Users\CATURWARGA COMPUTER\AppData\Local\Temp\ipykernel_6624\914404923.py", line 844, in <module>
    results = main_advanced_gnn_training()
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\CATURWARGA COMPUTER\AppData\Local\Temp\ipykernel_6624\914404923.py", line 595, in main_advanced_gnn_training
    results = evaluate_without_leakage(model, data, test_idx, device, threshold=0.5)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\CATURWARGA COMPUTER\AppData\Local\Temp\ipykernel_6624\914404923.py", line 100, in evaluate_without_leakage
    out = model(subgraph_x, subgraph_edges_remapped)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\CATURWARGA COMPUTER\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv, GraphSAGE, TransformerConv
from torch_geometric.loader import NeighborLoader
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, precision_recall_curve, f1_score
from sklearn.utils.class_weight import compute_class_weight
from imblearn.over_sampling import SMOTE, ADASYN
import networkx as nx
import warnings
import pickle
import os
import json
from collections import defaultdict
warnings.filterwarnings('ignore')

# Mount Google Drive (if using Colab)
try:
    CHECKPOINT_DIR = 'Ransomware_Checkpoints_v2'
except:
    CHECKPOINT_DIR = './checkpoints'

os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# ============================================================================
# CRITICAL FIX 1: FOCAL LOSS FOR EXTREME IMBALANCE
# ============================================================================
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, weight=None, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.weight = weight
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, weight=self.weight, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# ============================================================================
# MEMORY OPTIMIZATION: BATCHED EVALUATION TO PREVENT OOM
# ============================================================================
def evaluate_in_batches(model, data, node_indices, device, threshold=0.5, batch_size=1024):
    """
    Evaluate the model on a given set of nodes using mini-batches (NeighborLoader)
    to prevent CUDA Out of Memory errors.
    """
    model.eval()
    
    # Create a NeighborLoader for the nodes to be evaluated
    eval_loader = NeighborLoader(
        data,
        num_neighbors=[15, 10, 5],  # Same neighborhood sampling as training
        batch_size=batch_size,
        input_nodes=node_indices,
        shuffle=False,
        num_workers=0
    )
    
    all_preds = []
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for batch in eval_loader:
            batch = batch.to(device)
            out = model(batch.x, batch.edge_index)
            
            # Note: The output 'out' is for all nodes in the batch (sampled neighbors included).
            # We only care about the predictions for the primary 'input_nodes' of the batch.
            # batch.batch_size gives the number of target nodes in the current batch.
            target_out = out[:batch.batch_size]
            
            pred_proba = F.softmax(target_out, dim=1)[:, 1]
            pred_binary = (pred_proba >= threshold).long()
            
            all_preds.append(pred_binary.cpu())
            all_probs.append(pred_proba.cpu())
            all_labels.append(batch.y[:batch.batch_size].cpu())

    return {
        'predictions': torch.cat(all_preds).numpy(),
        'probabilities': torch.cat(all_probs).numpy(),
        'true_labels': torch.cat(all_labels).numpy()
    }


# ============================================================================
# ADVANCED FEATURE ENGINEERING: GRAPH CENTRALITY MEASURES
# ============================================================================
def compute_graph_centrality_features(edge_index, num_nodes):
    """Compute centrality measures for graph nodes"""
    print("Computing graph centrality features...")
    
    # Convert to NetworkX graph
    edge_list = edge_index.t().cpu().numpy()
    G = nx.Graph()
    G.add_nodes_from(range(num_nodes))
    G.add_edges_from(edge_list)
    
    # Compute centrality measures
    print("  - Computing PageRank...")
    pagerank = nx.pagerank(G, max_iter=50)
    
    print("  - Computing Betweenness Centrality...")
    betweenness = nx.betweenness_centrality(G, k=min(1000, num_nodes))
    
    # print("  - Computing Closeness Centrality...")
    # closeness = nx.closeness_centrality(G)
    
    print("  - Computing Degree Centrality...")
    degree_centrality = nx.degree_centrality(G)
    
    # Convert to arrays
    pagerank_array = np.array([pagerank.get(i, 0) for i in range(num_nodes)])
    betweenness_array = np.array([betweenness.get(i, 0) for i in range(num_nodes)])
    # closeness_array = np.array([closeness.get(i, 0) for i in range(num_nodes)])
    degree_array = np.array([degree_centrality.get(i, 0) for i in range(num_nodes)])
    
    return {
        'pagerank': pagerank_array,
        'betweenness': betweenness_array,
        # 'closeness': closeness_array,
        'degree_centrality': degree_array
    }

# ============================================================================
# ADVANCED ARCHITECTURE: MULTI-LAYER GNN WITH ATTENTION FUSION
# ============================================================================
class AdvancedGNN(nn.Module):
    def __init__(self, num_features, hidden_dim=128, num_classes=2, dropout=0.3):
        super(AdvancedGNN, self).__init__()
        
        layer2_input_dim = (hidden_dim * 4) + hidden_dim
        layer3_input_dim = (hidden_dim * 4) + hidden_dim

        # Multiple GNN Types
        self.gat_layers = nn.ModuleList([
            GATConv(num_features, hidden_dim, heads=4, dropout=dropout),
            GATConv(layer2_input_dim, hidden_dim, heads=4, dropout=dropout),
            GATConv(layer3_input_dim, hidden_dim // 2, heads=2, dropout=dropout)
        ])
        
        self.gcn_layers = nn.ModuleList([
            GCNConv(num_features, hidden_dim),
            GCNConv(layer2_input_dim, hidden_dim),
            GCNConv(layer3_input_dim, hidden_dim // 2)
        ])
        
        # Batch normalization layers
        self.bn_layers = nn.ModuleList([
            nn.BatchNorm1d(hidden_dim * 4 + hidden_dim),
            nn.BatchNorm1d(hidden_dim * 4 + hidden_dim),
            nn.BatchNorm1d((hidden_dim // 2 * 2) + (hidden_dim // 2))
        ])
        
        # Residual connections
        self.residual_projections = nn.ModuleList([
            nn.Linear(num_features, layer2_input_dim),
            nn.Linear(layer2_input_dim, layer3_input_dim),
            nn.Linear(layer3_input_dim, (hidden_dim // 2 * 2) + (hidden_dim // 2))
        ])
        
        # Feature-only path (XGBoost mimic)
        self.feature_path = nn.Sequential(
            nn.Linear(num_features, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.Dropout(dropout)
        )
        
        # Combined classifier
        final_graph_dim = (hidden_dim // 2 * 2) + (hidden_dim // 2)
        feature_only_dim = hidden_dim // 2
        combined_dim = final_graph_dim + feature_only_dim
        self.classifier = nn.Sequential(
            nn.Linear(combined_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.Dropout(dropout // 2),
            nn.Linear(hidden_dim // 2, 64),
            nn.ReLU(),
            nn.Dropout(dropout // 2),
            nn.Linear(64, num_classes)
        )
        
        self.dropout = dropout
        
    def forward(self, x, edge_index, batch=None):
        original_x = x
        current_x = x
        
        for i, (gat_layer, gcn_layer) in enumerate(zip(self.gat_layers, self.gcn_layers)):
            input_for_residual = current_x
            
            h_gat = F.elu(gat_layer(current_x, edge_index))
            h_gcn = F.elu(gcn_layer(current_x, edge_index))
            
            h_combined = torch.cat([h_gat, h_gcn], dim=1)
            
            residual = self.residual_projections[i](input_for_residual)
            h_combined = h_combined + residual
            h_combined = self.bn_layers[i](h_combined)
            current_x = F.dropout(h_combined, p=self.dropout, training=self.training)
        
        graph_features = current_x
        feature_only = self.feature_path(original_x)
        combined = torch.cat([graph_features, feature_only], dim=1)
        out = self.classifier(combined)
        
        return out
    
# ============================================================================
# CHECKPOINT SYSTEM (KEEPING YOUR EXISTING SYSTEM)
# ============================================================================
def save_checkpoint(checkpoint_name, data, message="Checkpoint saved"):
    """Save checkpoint data"""
    checkpoint_path = os.path.join(CHECKPOINT_DIR, f"{checkpoint_name}.pkl")
    with open(checkpoint_path, 'wb') as f:
        pickle.dump(data, f)
    print(f"‚úÖ {message}: {checkpoint_path}")

def load_checkpoint(checkpoint_name):
    """Load checkpoint data"""
    checkpoint_path = os.path.join(CHECKPOINT_DIR, f"{checkpoint_name}.pkl")
    if os.path.exists(checkpoint_path):
        with open(checkpoint_path, 'rb') as f:
            return pickle.load(f)
    return None

def checkpoint_exists(checkpoint_name):
    """Check if checkpoint exists"""
    checkpoint_path = os.path.join(CHECKPOINT_DIR, f"{checkpoint_name}.pkl")
    return os.path.exists(checkpoint_path)

# ============================================================================
# MAIN INTEGRATION FUNCTION
# ============================================================================
def main_advanced_gnn_training():
    """Complete advanced GNN training with your existing structure"""
    
    print("üöÄ ADVANCED GNN RANSOMWARE DETECTION WITH LEAKAGE FIXES")
    print("="*70)
    
    # STEP 1: Data Loading (keeping your existing logic)
    if checkpoint_exists('data_loaded'):
        print("[RESUME] Loading data from checkpoint...")
        checkpoint_data = load_checkpoint('data_loaded')
        wallets_df = checkpoint_data['wallets_df']
        edges_df = checkpoint_data['edges_df']
        print(f"[SUCCESS] Resumed with {len(wallets_df)} wallet records and {len(edges_df)} edges")
    else:
        print("[INFO] Loading wallet features and classes...")
        wallets_df = pd.read_csv('Elliptic++ Dataset/wallets_features_classes_combined.csv')
        print(f"[SUCCESS] Loaded {len(wallets_df)} wallet records")

        print("[INFO] Loading address-to-address edges...")
        edges_df = pd.read_csv('Elliptic++ Dataset/AddrAddr_edgelist.csv')
        print(f"[SUCCESS] Loaded {len(edges_df)} edges")

        save_checkpoint('data_loaded', {
            'wallets_df': wallets_df,
            'edges_df': edges_df
        }, "Data loading checkpoint saved")

    # STEP 2: Feature Engineering (using your enhanced features)
    def create_enhanced_pattern_features(df):
        """Create the EXACT same enhanced features as XGBoost model"""
        numeric_columns = df.select_dtypes(include=[np.number]).columns
        df[numeric_columns] = df[numeric_columns].fillna(0)
        # Your existing feature engineering...
        df['partner_transaction_ratio'] = (df.get('transacted_w_address_total', 0) / (df.get('total_txs', 1) + 1e-8))
        df['activity_density'] = (df.get('total_txs', 0) / (df.get('lifetime_in_blocks', 1) + 1e-8))
        df['transaction_size_variance'] = ((df.get('btc_transacted_max', 0) - df.get('btc_transacted_min', 0)) / (df.get('btc_transacted_mean', 1) + 1e-8))
        df['flow_imbalance'] = ((df.get('btc_sent_total', 0) - df.get('btc_received_total', 0)) / (df.get('btc_transacted_total', 1) + 1e-8))
        df['temporal_spread'] = ((df.get('last_block_appeared_in', 0) - df.get('first_block_appeared_in', 0)) / (df.get('num_timesteps_appeared_in', 1) + 1e-8))
        df['fee_percentile'] = (df.get('fees_total', 0) / (df.get('btc_transacted_total', 1) + 1e-8))
        df['interaction_intensity'] = (df.get('num_addr_transacted_multiple', 0) / (df.get('transacted_w_address_total', 1) + 1e-8))
        df['value_per_transaction'] = (df.get('btc_transacted_total', 0) / (df.get('total_txs', 1) + 1e-8))
        df['burst_activity'] = (df.get('total_txs', 0) * df.get('activity_density', 0))
        df['mixing_intensity'] = (df.get('partner_transaction_ratio', 0) * df.get('interaction_intensity', 0))
        return df

    if checkpoint_exists('features_engineered'):
        print("[RESUME] Loading feature engineering from checkpoint...")
        checkpoint_data = load_checkpoint('features_engineered')
        wallets_df = checkpoint_data['wallets_df']
        print(f"[SUCCESS] Resumed with {len(wallets_df.columns)} columns")
    else:
        print("[INFO] Creating enhanced features...")
        wallets_df = create_enhanced_pattern_features(wallets_df)
        print(f"[SUCCESS] Enhanced features created. Now have {len(wallets_df.columns)} columns")
        save_checkpoint('features_engineered', {'wallets_df': wallets_df}, "Feature engineering checkpoint saved")

    # STEP 3: Data Cleaning (keeping your existing logic)
    if checkpoint_exists('data_cleaned'):
        print("[RESUME] Loading cleaned data from checkpoint...")
        checkpoint_data = load_checkpoint('data_cleaned')
        wallets_clean = checkpoint_data['wallets_clean']
        class_weight_dict = checkpoint_data['class_weight_dict']
        print(f"[SUCCESS] Resumed with {len(wallets_clean)} cleaned addresses")
    else:
        print("[INFO] Cleaning data...")
        wallets_clean = wallets_df[wallets_df['class'].isin([1, 2])].copy()
        wallets_clean['class'] = wallets_clean['class'].map({1: 1, 2: 0})
        
        class_weights = compute_class_weight('balanced', classes=np.unique(wallets_clean['class']), y=wallets_clean['class'])
        class_weight_dict = {0: class_weights[0], 1: class_weights[1]}
        
        save_checkpoint('data_cleaned', {
            'wallets_clean': wallets_clean,
            'class_weight_dict': class_weight_dict
        }, "Data cleaning checkpoint saved")

    # STEP 4: Graph Construction with Centrality Features
    if checkpoint_exists('graph_built_advanced'):
        print("[RESUME] Loading advanced graph from checkpoint...")
        checkpoint_data = load_checkpoint('graph_built_advanced')
        data = checkpoint_data['data']
        addr_to_idx = checkpoint_data['addr_to_idx']
        feature_cols = checkpoint_data['feature_cols']
        scaler = checkpoint_data['scaler']
        print(f"[SUCCESS] Resumed with graph: {data.x.shape[0]} nodes, {data.edge_index.shape[1]} edges")
    else:
        print("[INFO] Building advanced graph with centrality features...")
        
        unique_addresses = wallets_clean['address'].unique()
        addr_to_idx = {addr: idx for idx, addr in enumerate(unique_addresses)}
        
        exclude_cols = ['address', 'Time step', 'class']
        feature_cols = [col for col in wallets_clean.columns if col not in exclude_cols]
        
        features_list = []
        labels_list = []
        
        for addr in unique_addresses:
            addr_data = wallets_clean[wallets_clean['address'] == addr].iloc[0]
            features_list.append(addr_data[feature_cols].values)
            labels_list.append(addr_data['class'])
        
        X = np.array(features_list)
        y = np.array(labels_list)
        
        edge_list = []
        for _, row in edges_df.iterrows():
            input_addr, output_addr = row['input_address'], row['output_address']
            if input_addr in addr_to_idx and output_addr in addr_to_idx:
                input_idx, output_idx = addr_to_idx[input_addr], addr_to_idx[output_addr]
                edge_list.extend([[input_idx, output_idx], [output_idx, input_idx]])
        
        edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
        
        print("[ADVANCED] Computing centrality features...")
        centrality_features = compute_graph_centrality_features(edge_index, len(unique_addresses))
        
        centrality_matrix = np.column_stack([
            centrality_features['pagerank'],
            centrality_features['betweenness'],
            centrality_features['degree_centrality']
        ])
        
        X_enhanced = np.concatenate([X, centrality_matrix], axis=1)
        print(f"[SUCCESS] Enhanced features: {X.shape[1]} -> {X_enhanced.shape[1]}")
        
        scaler = RobustScaler()
        X_scaled = scaler.fit_transform(X_enhanced)
        
        x = torch.tensor(X_scaled, dtype=torch.float)
        y = torch.tensor(y, dtype=torch.long)
        
        data = Data(x=x, edge_index=edge_index, y=y)
        
        save_checkpoint('graph_built_advanced', {
            'data': data, 'addr_to_idx': addr_to_idx,
            'feature_cols': feature_cols, 'scaler': scaler
        }, "Advanced graph construction checkpoint saved")

    # STEP 5: CRITICAL - Proper Train/Test Split (NO DATA LEAKAGE)
    print("[CRITICAL FIX] Creating proper train/test split...")
    
    indices = range(len(data.y))
    train_idx, test_idx, y_train, y_test = train_test_split(
        indices, data.y.numpy(), test_size=0.2, random_state=42, stratify=data.y.numpy()
    )
    
    train_mask = torch.zeros(data.y.size(0), dtype=torch.bool)
    test_mask = torch.zeros(data.y.size(0), dtype=torch.bool)
    train_mask[train_idx] = True
    test_mask[test_idx] = True
    
    print(f"Train nodes: {len(train_idx)}")
    print(f"Test nodes: {len(test_idx)}")
    
    # STEP 6: SMOTE Info
    print("[INFO] SMOTE will not be directly applied to the graph.")
    print("[INFO] Using class weights and Focal Loss to handle imbalance.")

    # STEP 7: Initialize Advanced Model
    print("[ADVANCED] Initializing advanced GNN model...")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"[INFO] Using device: {device}")
    
    model = AdvancedGNN(
        num_features=data.x.shape[1], hidden_dim=128, dropout=0.3
    ).to(device)
    
    data = data.to(device)
    
    # STEP 8: Advanced Loss and Optimizer
    print("[ADVANCED] Setting up Focal Loss and optimizer...")
    class_weights_tensor = torch.tensor([class_weight_dict[0], class_weight_dict[1]], dtype=torch.float).to(device)
    
    criterion = FocalLoss(alpha=1, gamma=2, weight=class_weights_tensor)
    
    optimizer = torch.optim.AdamW([
        {'params': model.gat_layers.parameters(), 'lr': 0.001},
        {'params': model.gcn_layers.parameters(), 'lr': 0.001},
        {'params': model.classifier.parameters(), 'lr': 0.002}
    ], weight_decay=1e-4)
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2)
    
    # STEP 9: Create NeighborLoader
    print("[INFO] Creating NeighborLoader...")
    train_loader = NeighborLoader(
        data, num_neighbors=[15, 10, 5], batch_size=512,
        input_nodes=train_mask, shuffle=True, num_workers=0
    )
    
    # STEP 10: Training Loop
    print("[TRAINING] Starting advanced training...")
    best_f1 = 0
    patience = 0
    max_patience = 50
    optimal_threshold = 0.5
    
    for epoch in range(20):
        model.train()
        total_loss = 0
        num_batches = 0
        
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index)
            loss = criterion(out[:batch.batch_size], batch.y[:batch.batch_size])
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            total_loss += loss.item()
            num_batches += 1
        
        scheduler.step()
        
        if epoch % 10 == 0:
            torch.cuda.empty_cache() # Free memory before evaluation
            # OPTIMIZATION: Use batched evaluation
            results = evaluate_in_batches(model, data, test_mask, device)
            
            if len(np.unique(results['true_labels'])) > 1:
                precision, recall, thresholds = precision_recall_curve(
                    results['true_labels'], results['probabilities']
                )
                f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
                f1_scores = np.nan_to_num(f1_scores) # Handle division by zero
                best_threshold_idx = np.argmax(f1_scores)
                current_optimal_threshold = thresholds[best_threshold_idx] if len(thresholds) > 0 else 0.5
                
                # Re-evaluate with this optimal threshold for logging
                temp_results = evaluate_in_batches(model, data, test_mask, device, threshold=current_optimal_threshold)
                test_f1 = f1_score(temp_results['true_labels'], temp_results['predictions'])
                test_acc = np.mean(temp_results['predictions'] == temp_results['true_labels'])
                
                print(f'Epoch {epoch:03d}, Loss: {total_loss/num_batches:.4f}, '
                      f'Test Acc: {test_acc:.4f}, F1: {test_f1:.4f}, '
                      f'Threshold: {current_optimal_threshold:.3f}')
                
                if test_f1 > best_f1:
                    best_f1 = test_f1
                    optimal_threshold = current_optimal_threshold # Save the best threshold
                    patience = 0
                    torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, 'best_advanced_gnn.pth'))
                else:
                    patience += 1
                
                if patience >= max_patience:
                    print(f"Early stopping at epoch {epoch}")
                    break
            torch.cuda.empty_cache() # Free memory after evaluation

    # STEP 11: Final Evaluation
    print("\n" + "="*70)
    print("üìä FINAL EVALUATION (NO DATA LEAKAGE)")
    print("="*70)
    
    model.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, 'best_advanced_gnn.pth')))
    
    # Final evaluation using the best threshold found during training
    results = evaluate_in_batches(model, data, test_mask, device, threshold=optimal_threshold)
    
    test_acc = np.mean(results['predictions'] == results['true_labels'])
    test_f1 = f1_score(results['true_labels'], results['predictions'])
    auc_score = roc_auc_score(results['true_labels'], results['probabilities'])
    
    print(f"Final Test Accuracy: {test_acc:.4f}")
    print(f"Final F1 Score: {test_f1:.4f}")
    print(f"Final AUC Score: {auc_score:.4f}")
    print(f"Optimal Threshold Used: {optimal_threshold:.3f}")
    
    print("\nClassification Report:")
    print(classification_report(results['true_labels'], results['predictions'], 
                                target_names=['Licit', 'Illicit'], zero_division=0))
    
    print("\nConfusion Matrix:")
    cm = confusion_matrix(results['true_labels'], results['predictions'])
    print(f"   Predicted:  Licit  Illicit")
    print(f"   Actual Licit:   {cm[0,0]:5d}   {cm[0,1]:5d}")
    print(f"   Actual Illicit: {cm[1,0]:5d}   {cm[1,1]:5d}")

    
    # Save final model
    model_data = {
        'model_state_dict': model.state_dict(), 'scaler': scaler,
        'addr_to_idx': addr_to_idx, 'feature_cols': feature_cols,
        'optimal_threshold': optimal_threshold, 'class_weights': class_weight_dict,
        'model_config': {
            'num_features': data.num_features, 'hidden_dim': 128,
            'num_classes': 2, 'dropout': 0.3
        },
        'final_metrics': {
            'test_accuracy': test_acc, 'auc_score': auc_score,
            'f1_score': test_f1, 'optimal_threshold': optimal_threshold
        }
    }
    
    save_checkpoint('final_advanced_model', model_data, "Final advanced model saved")
    
    print("\nüéâ ADVANCED GNN TRAINING COMPLETED!")
    print(f"‚úÖ Best F1 Score: {best_f1:.4f}")
    print(f"‚úÖ Final AUC Score: {auc_score:.4f}")
    print(f"‚úÖ Model saved to: {CHECKPOINT_DIR}/final_advanced_model.pkl")
    
    return {
        'model': model, 'data': data, 'results': results,
        'metrics': {
            'accuracy': test_acc, 'f1_score': test_f1,
            'auc_score': auc_score, 'threshold': optimal_threshold
        }
    }

# ============================================================================
# FEATURE IMPORTANCE ANALYSIS (CPU OPTIMIZED)
# ============================================================================
def analyze_feature_importance(model, data, feature_cols, device):
    """
    Analyze feature importance on the CPU to prevent VRAM issues.
    """
    print("\nüîç FEATURE IMPORTANCE ANALYSIS (Running on CPU)")
    print("="*50)
    
    # Move model and data to CPU for this analysis
    model.to('cpu')
    data.to('cpu')
    model.eval()
    
    # Get baseline prediction
    with torch.no_grad():
        baseline_output = model(data.x, data.edge_index)
        baseline_probs = torch.softmax(baseline_output, dim=1)[:, 1]
    
    feature_importance = []
    
    # Permutation importance
    for i in range(data.x.shape[1]):
        original_col = data.x[:, i].clone()
        
        # Permute feature i
        permuted_indices = torch.randperm(data.x.shape[0])
        data.x[:, i] = data.x[permuted_indices, i]
        
        with torch.no_grad():
            permuted_output = model(data.x, data.edge_index)
            permuted_probs = torch.softmax(permuted_output, dim=1)[:, 1]
        
        # Calculate importance and restore original column
        importance = torch.mean(torch.abs(baseline_probs - permuted_probs)).item()
        feature_importance.append(importance)
        data.x[:, i] = original_col # Restore for next iteration
        
        if (i+1) % 10 == 0:
            print(f"  ... analyzed {i+1}/{data.x.shape[1]} features")

    # Move model back to original device
    model.to(device)

    # Sort by importance
    feature_names = feature_cols + ['pagerank', 'betweenness', 'degree_centrality']
    importance_pairs = list(zip(feature_names, feature_importance))
    importance_pairs.sort(key=lambda x: x[1], reverse=True)
    
    print("\nTop 10 Most Important Features:")
    for i, (feature, importance) in enumerate(importance_pairs[:10]):
        print(f"{i+1:2d}. {feature:30s}: {importance:.6f}")
    
    return importance_pairs


# ============================================================================
# MAIN EXECUTION
# ============================================================================
if __name__ == "__main__":
    try:
        # Run the main training function
        results = main_advanced_gnn_training()
        
        print("\n" + "="*70)
        print("üöÄ ADDITIONAL ANALYSIS")
        print("="*70)
        
        # Feature importance analysis
        if 'model' in results and 'data' in results:
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            # We need the original, un-enhanced feature columns list
            graph_checkpoint = load_checkpoint('graph_built_advanced')
            if graph_checkpoint:
                 feature_cols = graph_checkpoint['feature_cols']
                 importance_results = analyze_feature_importance(
                    results['model'], 
                    results['data'], 
                    feature_cols, 
                    device
                )
        
        print("\nüéâ ALL ANALYSIS COMPLETED!")
        print("Check the checkpoints directory for saved models and data.")
        
    except Exception as e:
        print(f"‚ùå Error occurred: {str(e)}")
        import traceback
        traceback.print_exc()
        print("\nüí° Check if all required data files are available and paths are correct.")

  import torch_geometric.typing
  from .autonotebook import tqdm as notebook_tqdm


üöÄ ADVANCED GNN RANSOMWARE DETECTION WITH LEAKAGE FIXES
[RESUME] Loading data from checkpoint...
[SUCCESS] Resumed with 1268260 wallet records and 2868964 edges
[RESUME] Loading feature engineering from checkpoint...
[SUCCESS] Resumed with 68 columns
[RESUME] Loading cleaned data from checkpoint...
[SUCCESS] Resumed with 367472 cleaned addresses
[RESUME] Loading advanced graph from checkpoint...
[SUCCESS] Resumed with graph: 265354 nodes, 2236444 edges
[CRITICAL FIX] Creating proper train/test split...
Train nodes: 212283
Test nodes: 53071
[INFO] SMOTE will not be directly applied to the graph.
[INFO] Using class weights and Focal Loss to handle imbalance.
[ADVANCED] Initializing advanced GNN model...
[INFO] Using device: cuda
[ADVANCED] Setting up Focal Loss and optimizer...
[INFO] Creating NeighborLoader...
[TRAINING] Starting advanced training...
Epoch 000, Loss: 0.1545, Test Acc: 0.9546, F1: 0.5326, Threshold: 0.733
