In [1]:
"""
Complete K-Fold GNN Pipeline for Elliptic Dataset with Embedding Extraction
===========================================================================
This function trains all GNN models using K-fold cross-validation and saves
node embeddings for use with baseline ML models.
"""

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv, GATv2Conv, SAGEConv
from torch_geometric.utils import to_undirected, add_self_loops
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
import warnings
warnings.filterwarnings('ignore')

def run_elliptic_kfold_with_embeddings(
    n_splits=5,
    fold_id=1,          # 1..n_splits
    val_ratio=0.10,     # stratified holdout from train
    epochs=1000,
    patience=50,
    random_state=42,
    device=None,
    save_embeddings=True  # Whether to save embeddings globally
):
    """
    Complete K-fold GNN pipeline with embedding extraction.
    
    Parameters:
    -----------
    n_splits : int
        Number of K-fold splits
    fold_id : int
        Which fold to use as test (1 to n_splits)
    val_ratio : float
        Proportion of training data to use for validation
    epochs : int
        Maximum training epochs
    patience : int
        Early stopping patience
    random_state : int
        Random seed for reproducibility
    device : torch.device
        Device for computation
    save_embeddings : bool
        Whether to save node embeddings as global variables
        
    Returns:
    --------
    df_results : DataFrame
        Results for all models
    embeddings_dict : dict
        Dictionary of embeddings {model_name: embeddings_array}
    
    Global Variables Created:
    ------------------------
    gcn_embeddings, skip_gcn_embeddings, gat_embeddings, 
    gatv2_embeddings, sage_embeddings : numpy arrays of shape (N, 64)
    """
    
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # ============================================
    # STEP 1: Load Data (Custom Prep, No Normalization Yet)
    # ============================================
    
    def load_elliptic_custom():
        """Load Elliptic data without normalization."""
        # Load raw files
        features_path = 'elliptic_txs_features.csv'
        classes_path = 'elliptic_txs_classes.csv'
        edges_path = 'elliptic_txs_edgelist.csv'
        
        # Load features
        df_features = pd.read_csv(features_path, header=None)
        
        # Define column names
        col_names = ['tx_id', 'time_step']
        col_names += [f'local_{i}' for i in range(1, 94)]
        col_names += [f'aggregated_{i}' for i in range(1, 73)]
        df_features.columns = col_names
        
        # Fix types
        df_features['tx_id'] = df_features['tx_id'].astype(int)
        df_features['time_step'] = df_features['time_step'].astype(int)
        
        # Load classes and merge
        df_classes = pd.read_csv(classes_path)
        df_features = df_features.merge(
            df_classes.rename(columns={'txId': 'tx_id'}),
            on='tx_id',
            how='left'
        )
        
        # Load edges
        df_edges = pd.read_csv(edges_path)
        df_edges['txId1'] = df_edges['txId1'].astype(int)
        df_edges['txId2'] = df_edges['txId2'].astype(int)
        
        return df_features, df_edges
    
    print("Loading Elliptic dataset...")
    df_features, df_edges = load_elliptic_custom()
    print(f"Loaded {len(df_features):,} transactions and {len(df_edges):,} edges")
    
    # ============================================
    # STEP 2: Create PyG Data with K-Fold Splits
    # ============================================
    
    def create_pyg_data_kfold(df_features, df_edges, n_splits, fold_id, val_ratio, random_state):
        """Convert to PyG format with K-fold splits."""
        
        # Create node index mapping
        tx_id_to_idx = {tx_id: idx for idx, tx_id in enumerate(df_features['tx_id'].values)}
        
        # Extract feature columns
        feature_cols = [col for col in df_features.columns 
                       if col.startswith(('local_', 'aggregated_'))]
        x = torch.tensor(df_features[feature_cols].values, dtype=torch.float)
        
        # Create labels: 0=licit, 1=illicit, 2=unknown
        y = torch.full((len(df_features),), 2, dtype=torch.long)
        
        # Map string class to integers
        class_values = df_features['class'].astype(str).values
        y[class_values == '2'] = 0  # licit
        y[class_values == '1'] = 1  # illicit
        
        # Create edge index (undirected + self-loops)
        edge_list = []
        for _, row in df_edges.iterrows():
            if row['txId1'] in tx_id_to_idx and row['txId2'] in tx_id_to_idx:
                idx1 = tx_id_to_idx[row['txId1']]
                idx2 = tx_id_to_idx[row['txId2']]
                edge_list.append([idx1, idx2])
        
        edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
        edge_index = to_undirected(edge_index, num_nodes=len(df_features))
        edge_index, _ = add_self_loops(edge_index, num_nodes=len(df_features))
        
        # Time steps
        time_steps = torch.tensor(df_features['time_step'].values, dtype=torch.long)
        
        # ============ K-FOLD SPLITTING ============
        # Get labeled indices only
        labeled_indices = torch.where(y != 2)[0].numpy()
        y_labeled = y[labeled_indices].numpy()
        
        # Create stratified K-fold splits
        skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
        splits = list(skf.split(np.zeros_like(y_labeled), y_labeled))
        
        # Get train and test indices for the specified fold
        train_fold_idx, test_fold_idx = splits[fold_id - 1]
        train_labeled = labeled_indices[train_fold_idx]
        test_labeled = labeled_indices[test_fold_idx]
        
        # Create validation split from training data
        if val_ratio > 0:
            sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=random_state)
            y_train_labeled = y[train_labeled].numpy()
            train_keep_idx, val_idx = next(sss.split(np.zeros_like(y_train_labeled), y_train_labeled))
            
            final_train = train_labeled[train_keep_idx]
            val_indices = train_labeled[val_idx]
        else:
            final_train = train_labeled
            val_indices = np.array([], dtype=int)
        
        # Create boolean masks for all nodes
        N = len(df_features)
        train_mask = torch.zeros(N, dtype=torch.bool)
        val_mask = torch.zeros(N, dtype=torch.bool)
        test_mask = torch.zeros(N, dtype=torch.bool)
        
        train_mask[torch.tensor(final_train)] = True
        if val_indices.size > 0:
            val_mask[torch.tensor(val_indices)] = True
        test_mask[torch.tensor(test_labeled)] = True
        
        # Verify masks are disjoint
        assert (train_mask & val_mask).sum() == 0, "Train and val masks overlap!"
        assert (train_mask & test_mask).sum() == 0, "Train and test masks overlap!"
        assert (val_mask & test_mask).sum() == 0, "Val and test masks overlap!"
        
        # ============ Z-SCORE NORMALIZATION (TRAIN ONLY) ============
        with torch.no_grad():
            train_features = x[train_mask]
            mu = train_features.mean(0, keepdim=True)
            std = train_features.std(0, keepdim=True).clamp_min(1e-6)
            x = (x - mu) / std
        
        # Create PyG Data object
        data = Data(
            x=x,
            edge_index=edge_index,
            y=y,
            train_mask=train_mask,
            val_mask=val_mask,
            test_mask=test_mask,
            time_steps=time_steps,
            num_nodes=N
        )
        
        # Print statistics
        n_train = int(train_mask.sum())
        n_val = int(val_mask.sum())
        n_test = int(test_mask.sum())
        
        train_illicit = int((y[train_mask] == 1).sum())
        train_licit = int((y[train_mask] == 0).sum())
        illicit_pct = (train_illicit / (train_illicit + train_licit)) * 100
        
        print(f"\nK-Fold Configuration (Fold {fold_id}/{n_splits}):")
        print(f"  Train: {n_train:,} | Val: {n_val:,} | Test: {n_test:,}")
        print(f"  Train class balance: {train_illicit:,} illicit / {train_illicit + train_licit:,} total ({illicit_pct:.1f}%)")
        print(f"  Graph: {N:,} nodes, {edge_index.size(1):,} edges")
        
        return data
    
    # Create data with K-fold splits
    data = create_pyg_data_kfold(df_features, df_edges, n_splits, fold_id, val_ratio, random_state)
    data = data.to(device)
    
    # ============================================
    # STEP 3: Model Definitions (PyG-Matched)
    # ============================================
    
    class GCNNet_Matched(nn.Module):
        def __init__(self, in_dim, hidden=128, dropout=0.5):
            super().__init__()
            self.conv1 = GCNConv(in_dim, hidden, cached=True, normalize=True)
            self.bn1 = nn.BatchNorm1d(hidden)
            self.conv2 = GCNConv(hidden, 64, cached=True, normalize=True)
            self.bn2 = nn.BatchNorm1d(64)
            self.dropout = nn.Dropout(dropout)
            self.act = nn.ReLU()
            self.head = nn.Linear(64, 2)
            
        def forward(self, x, edge_index):
            h = self.conv1(x, edge_index)
            h = self.act(self.bn1(h))
            h = self.dropout(h)
            h = self.conv2(h, edge_index)
            h = self.act(self.bn2(h))
            h = self.dropout(h)
            self.embeddings = h
            return self.head(h)
    
    class SkipGCNNet_Matched(nn.Module):
        def __init__(self, in_dim, hidden=128, dropout=0.5):
            super().__init__()
            self.in_proj = nn.Linear(in_dim, 64, bias=False)
            self.conv1 = GCNConv(in_dim, hidden)
            self.bn1 = nn.BatchNorm1d(hidden)
            self.conv2 = GCNConv(hidden, 64)
            self.bn2 = nn.BatchNorm1d(64)
            self.dropout = nn.Dropout(dropout)
            self.act = nn.ReLU()
            self.head = nn.Linear(64, 2)
            
        def forward(self, x, edge_index):
            skip = self.in_proj(x)
            h = self.conv1(x, edge_index)
            h = self.act(self.bn1(h))
            h = self.dropout(h)
            h = self.conv2(h, edge_index)
            h = self.bn2(h)
            h = self.act(h + skip)
            h = self.dropout(h)
            self.embeddings = h
            return self.head(h)
    
    class GATNet_Matched(nn.Module):
        def __init__(self, in_dim, hidden=128, heads=4, dropout=0.5):
            super().__init__()
            self.conv1 = GATConv(in_dim, hidden, heads=heads, concat=True,
                                dropout=dropout, add_self_loops=False)
            self.ln1 = nn.LayerNorm(hidden * heads)
            self.act = nn.ELU()
            self.dropout = nn.Dropout(dropout)
            self.conv2 = GATConv(hidden * heads, 2, heads=1, concat=False,
                                dropout=dropout, add_self_loops=False)
            self.emb_proj = nn.Linear(hidden * heads, 64, bias=False)
            self.emb_ln = nn.LayerNorm(64)
            
        def forward(self, x, edge_index):
            h = self.conv1(x, edge_index)
            h = self.act(self.ln1(h))
            h = self.dropout(h)
            logits = self.conv2(h, edge_index)
            self.embeddings = self.emb_ln(self.emb_proj(h))
            return logits
    
    class GATv2Net_Matched(nn.Module):
        def __init__(self, in_dim, hidden=128, heads=4, dropout=0.5):
            super().__init__()
            self.conv1 = GATv2Conv(in_dim, hidden, heads=heads, concat=True,
                                  dropout=dropout, add_self_loops=False)
            self.ln1 = nn.LayerNorm(hidden * heads)
            self.act = nn.ELU()
            self.dropout = nn.Dropout(dropout)
            self.conv2 = GATv2Conv(hidden * heads, 2, heads=1, concat=False,
                                  dropout=dropout, add_self_loops=False)
            self.emb_proj = nn.Linear(hidden * heads, 64, bias=False)
            self.emb_ln = nn.LayerNorm(64)
            
        def forward(self, x, edge_index):
            h = self.conv1(x, edge_index)
            h = self.act(self.ln1(h))
            h = self.dropout(h)
            logits = self.conv2(h, edge_index)
            self.embeddings = self.emb_ln(self.emb_proj(h))
            return logits
    
    class SAGENet_Matched(nn.Module):
        def __init__(self, in_dim, hidden=256, dropout=0.5):
            super().__init__()
            self.conv1 = SAGEConv(in_dim, hidden)
            self.bn1 = nn.BatchNorm1d(hidden)
            self.conv2 = SAGEConv(hidden, 64)
            self.bn2 = nn.BatchNorm1d(64)
            self.dropout = nn.Dropout(dropout)
            self.act = nn.ReLU()
            self.head = nn.Linear(64, 2)
            
        def forward(self, x, edge_index):
            h = self.conv1(x, edge_index)
            h = self.act(self.bn1(h))
            h = self.dropout(h)
            h = self.conv2(h, edge_index)
            h = self.act(self.bn2(h))
            h = self.dropout(h)
            self.embeddings = h
            return self.head(h)
    
    # ============================================
    # STEP 4: Training Function with Embedding Extraction
    # ============================================
    
    def train_gnn_with_embeddings(data, arch='gcn', epochs=1000, patience=50, 
                                  lr=None, weight_decay=5e-4, grad_clip=None):
        """Train GNN and extract embeddings."""
        
        # Architecture configurations
        arch_configs = {
            'gcn': {
                'model': GCNNet_Matched(data.x.size(1)),
                'lr': 0.01,
                'grad_clip': 2.0
            },
            'skip_gcn': {
                'model': SkipGCNNet_Matched(data.x.size(1)),
                'lr': 0.01,
                'grad_clip': 2.0
            },
            'gat': {
                'model': GATNet_Matched(data.x.size(1)),
                'lr': 0.003,
                'grad_clip': None
            },
            'gatv2': {
                'model': GATv2Net_Matched(data.x.size(1)),
                'lr': 0.003,
                'grad_clip': None
            },
            'sage': {
                'model': SAGENet_Matched(data.x.size(1)),
                'lr': 0.01,
                'grad_clip': 2.0
            }
        }
        
        config = arch_configs[arch.lower()]
        model = config['model'].to(device)
        lr = config['lr'] if lr is None else lr
        grad_clip = config['grad_clip'] if grad_clip is None else grad_clip
        
        # Class weights for imbalance
        y_train = data.y[data.train_mask]
        pos = (y_train == 1).sum().float()
        neg = (y_train == 0).sum().float()
        weight = torch.tensor([1.0, neg/pos.clamp_min(1)], device=device)
        criterion = nn.CrossEntropyLoss(weight=weight)
        
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        
        # Threshold optimization function
        def find_best_threshold(logits, y_true):
            probs = F.softmax(logits, dim=1)[:, 1]
            best_f1, best_thr = 0, 0.5
            
            for thr in torch.linspace(0.05, 0.95, 37):
                preds = (probs >= thr).long()
                y_true_cpu = y_true.cpu().numpy()
                preds_cpu = preds.cpu().numpy()
                f1 = f1_score(y_true_cpu, preds_cpu, zero_division=0)
                if f1 > best_f1:
                    best_f1 = f1
                    best_thr = thr.item()
            
            return best_thr, best_f1
        
        # Training loop
        best_val_f1 = 0
        best_state = None
        best_threshold = 0.5
        wait = 0
        
        for epoch in range(epochs):
            # Train
            model.train()
            optimizer.zero_grad()
            out = model(data.x, data.edge_index)
            loss = criterion(out[data.train_mask], data.y[data.train_mask])
            loss.backward()
            
            if grad_clip is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            
            optimizer.step()
            
            # Validate every 5 epochs
            if epoch % 5 == 0:
                model.eval()
                with torch.no_grad():
                    out = model(data.x, data.edge_index)
                    
                    if data.val_mask.sum() > 0:
                        val_out = out[data.val_mask]
                        val_y = data.y[data.val_mask]
                        thr, f1 = find_best_threshold(val_out, val_y)
                    else:
                        # If no validation set, use training for threshold
                        train_out = out[data.train_mask]
                        train_y = data.y[data.train_mask]
                        thr, f1 = find_best_threshold(train_out, train_y)
                    
                    if f1 > best_val_f1:
                        best_val_f1 = f1
                        best_threshold = thr
                        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
                        wait = 0
                        if epoch % 20 == 0:  # Print less frequently
                            print(f"[{arch.upper()}] Epoch {epoch:4d} | Loss: {loss:.4f} | Val F1: {f1:.4f} | Thr: {thr:.2f}")
                    else:
                        wait += 5
                    
                    if wait >= patience:
                        print(f"[{arch.upper()}] Early stopping at epoch {epoch}")
                        break
        
        # Load best model
        if best_state:
            model.load_state_dict(best_state)
            model.to(device)
        
        # Extract embeddings for ALL nodes
        model.eval()
        with torch.no_grad():
            # Forward pass to get embeddings
            _ = model(data.x, data.edge_index)
            embeddings = model.embeddings.detach().cpu().numpy()
            
            # Test evaluation
            out = model(data.x, data.edge_index)
            test_out = out[data.test_mask]
            test_y = data.y[data.test_mask]
            
            probs = F.softmax(test_out, dim=1)[:, 1]
            preds = (probs >= best_threshold).long()
            
            y_true = test_y.cpu().numpy()
            y_pred = preds.cpu().numpy()
            
            test_metrics = {
                'F1': f1_score(y_true, y_pred),
                'P': precision_score(y_true, y_pred),
                'R': recall_score(y_true, y_pred),
                'Acc': accuracy_score(y_true, y_pred),
                'microF1': f1_score(y_true, y_pred, average='micro')
            }
        
        return {
            'test_metrics': test_metrics,
            'val_best_F1': best_val_f1,
            'best_threshold': best_threshold,
            'embeddings': embeddings,
            'embeddings_shape': embeddings.shape
        }
    
    # ============================================
    # STEP 5: Train All Models and Save Embeddings
    # ============================================
    
    arch_list = ['gcn', 'skip_gcn', 'gat', 'gatv2', 'sage']
    results = {}
    embeddings_dict = {}
    
    print("\n" + "="*60)
    print("Training GNN Models with K-Fold Split")
    print("="*60)
    
    for arch in arch_list:
        print(f"\nTraining {arch.upper()}...")
        print("-" * 40)
        
        result = train_gnn_with_embeddings(
            data, 
            arch=arch, 
            epochs=epochs, 
            patience=patience
        )
        
        results[arch] = result
        embeddings_dict[arch] = result['embeddings']
        
        # Save embeddings as global variables if requested
        if save_embeddings:
            var_name = f"{arch.lower().replace('-', '_')}_embeddings"
            globals()[var_name] = result['embeddings']
            print(f"✓ Saved embeddings to global variable: {var_name} {result['embeddings'].shape}")
    
    # ============================================
    # STEP 6: Create Results DataFrame
    # ============================================
    
    rows = []
    for arch, r in results.items():
        m = r['test_metrics']
        rows.append({
            'Model': arch.upper(),
            'F1': m['F1'],
            'Precision': m['P'],
            'Recall': m['R'],
            'Accuracy': m['Acc'],
            'Micro-F1': m['microF1'],
            'Val best F1': r['val_best_F1'],
            'Best thr': r['best_threshold'],
            'Embeddings': r['embeddings_shape']
        })
    
    df_results = pd.DataFrame(rows).sort_values('F1', ascending=False).reset_index(drop=True)
    
    # Format for display
    for col in ['F1', 'Precision', 'Recall', 'Accuracy', 'Micro-F1', 'Val best F1']:
        if col in df_results.columns:
            df_results[col] = df_results[col].round(4)
    df_results['Best thr'] = df_results['Best thr'].round(2)
    
    # ============================================
    # STEP 7: Print Final Results
    # ============================================
    
    print("\n" + "="*80)
    print("FINAL RESULTS - K-Fold GNN Pipeline with Embeddings")
    print("="*80)
    print(df_results.to_string())
    
    print("\n" + "="*80)
    print("Embeddings Saved (Accessible as Global Variables)")
    print("="*80)
    
    for arch in arch_list:
        var_name = f"{arch.lower().replace('-', '_')}_embeddings"
        shape = embeddings_dict[arch].shape
        print(f"  {var_name}: shape {shape}")
    
    print("\n" + "="*80)
    print("Configuration Summary")
    print("="*80)
    print(f"  K-Fold splits: {n_splits}")
    print(f"  Current fold: {fold_id}/{n_splits}")
    print(f"  Validation ratio: {val_ratio}")
    print(f"  Train-only z-score normalization: ✓")
    print(f"  Class-weighted loss: ✓")
    print(f"  Architecture-specific hyperparameters: ✓")
    print(f"  Graph preprocessing (undirected + self-loops): ✓")
    
    print("\nUsage Example:")
    print("  # Access embeddings for baseline models")
    print("  X_enhanced = np.concatenate([X_original, sage_embeddings], axis=1)")
    print("  # Train baseline with enhanced features")
    print("  rf_enhanced = RandomForest().fit(X_enhanced[train], y[train])")
    
    return df_results, embeddings_dict


# ============================================
# USAGE EXAMPLE
# ============================================

if __name__ == "__main__":
    # Run the complete pipeline
    df_results, embeddings = run_elliptic_kfold_with_embeddings(
        n_splits=5,
        fold_id=1,
        val_ratio=0.10,
        epochs=1000,
        patience=50,
        random_state=42,
        save_embeddings=True
    )
    
    # After running, you can access embeddings as:
    # gcn_embeddings, skip_gcn_embeddings, gat_embeddings, 
    # gatv2_embeddings, sage_embeddings
    
    # Example: Enhance baseline model features
    # from sklearn.ensemble import RandomForestClassifier
    # X_with_sage = np.concatenate([original_features, sage_embeddings], axis=1)
    # rf_enhanced = RandomForestClassifier(n_estimators=100)
    # rf_enhanced.fit(X_with_sage[train_mask], y[train_mask])

Loading Elliptic dataset...
Loaded 203,769 transactions and 234,355 edges

K-Fold Configuration (Fold 1/5):
  Train: 33,525 | Val: 3,726 | Test: 9,313
  Train class balance: 3,272 illicit / 33,525 total (9.8%)
  Graph: 203,769 nodes, 672,479 edges

Training GNN Models with K-Fold Split

Training GCN...
----------------------------------------
[GCN] Epoch    0 | Loss: 0.7259 | Val F1: 0.3827 | Thr: 0.47
[GCN] Epoch   20 | Loss: 0.2721 | Val F1: 0.7411 | Thr: 0.57
[GCN] Epoch   40 | Loss: 0.1940 | Val F1: 0.7791 | Thr: 0.70
[GCN] Epoch   60 | Loss: 0.1623 | Val F1: 0.8118 | Thr: 0.82
[GCN] Epoch  100 | Loss: 0.1227 | Val F1: 0.8493 | Thr: 0.82
[GCN] Epoch  120 | Loss: 0.1205 | Val F1: 0.8579 | Thr: 0.90
[GCN] Early stopping at epoch 200
✓ Saved embeddings to global variable: gcn_embeddings (203769, 64)

Training SKIP_GCN...
----------------------------------------
[SKIP_GCN] Epoch    0 | Loss: 0.7069 | Val F1: 0.5528 | Thr: 0.62
[SKIP_GCN] Epoch   20 | Loss: 0.2017 | Val F1: 0.8237 | Thr

In [2]:
"""
Baseline Models Enhancement with GNN Embeddings (K-Fold) - Optimized
=====================================================================
Updated with hyperparameters matching your high-performing baseline configuration.
"""

import pandas as pd
import numpy as np
import torch
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
from xgboost import XGBClassifier
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
from sklearn.preprocessing import StandardScaler
import warnings
warnings.filterwarnings('ignore')

def train_baselines_with_embeddings_optimized(
    n_splits=5,
    fold_id=1,
    val_ratio=0.10,
    random_state=42,
    save_detailed_results=True
):
    """
    Train baseline models with and without GNN embeddings using optimized hyperparameters.
    
    This function uses the exact hyperparameters from your high-performing baseline models:
    - RF: max_depth=20, n_estimators=100
    - MLP: (100, 50) hidden layers with StandardScaler
    - XGBoost: Dynamic scale_pos_weight calculation
    
    Parameters:
    -----------
    n_splits : int
        Number of K-fold splits (should match GNN training)
    fold_id : int
        Which fold to use (should match GNN training)
    val_ratio : float
        Validation split ratio (should match GNN training)
    random_state : int
        Random seed for reproducibility
    save_detailed_results : bool
        Whether to save detailed results to CSV
        
    Returns:
    --------
    df_results : DataFrame
        Comprehensive results for all models and configurations
    detailed_results : dict
        Detailed results for each model/configuration
    """
    
    print("="*80)
    print("BASELINE MODELS ENHANCEMENT WITH GNN EMBEDDINGS (OPTIMIZED)")
    print("="*80)
    
    # ============================================
    # STEP 1: Load Data and Create Same K-Fold Split
    # ============================================
    
    print("\n1. Loading Elliptic dataset and creating K-fold splits...")
    
    # Load raw data
    features_path = 'elliptic_txs_features.csv'
    classes_path = 'elliptic_txs_classes.csv'
    
    # Load features
    df_features = pd.read_csv(features_path, header=None)
    col_names = ['tx_id', 'time_step']
    col_names += [f'local_{i}' for i in range(1, 94)]
    col_names += [f'aggregated_{i}' for i in range(1, 73)]
    df_features.columns = col_names
    
    df_features['tx_id'] = df_features['tx_id'].astype(int)
    df_features['time_step'] = df_features['time_step'].astype(int)
    
    # Load classes
    df_classes = pd.read_csv(classes_path)
    df_features = df_features.merge(
        df_classes.rename(columns={'txId': 'tx_id'}),
        on='tx_id',
        how='left'
    )
    
    # Extract features and labels
    feature_cols = [col for col in df_features.columns 
                   if col.startswith(('local_', 'aggregated_'))]
    X = df_features[feature_cols].values
    
    # Create labels: 0=licit, 1=illicit, -1=unknown
    y = np.full(len(df_features), -1, dtype=int)
    class_values = df_features['class'].astype(str).values
    y[class_values == '2'] = 0  # licit
    y[class_values == '1'] = 1  # illicit
    
    # Create same K-fold split as GNN training
    labeled_indices = np.where(y != -1)[0]
    y_labeled = y[labeled_indices]
    
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
    splits = list(skf.split(np.zeros_like(y_labeled), y_labeled))
    
    train_fold_idx, test_fold_idx = splits[fold_id - 1]
    train_labeled = labeled_indices[train_fold_idx]
    test_labeled = labeled_indices[test_fold_idx]
    
    # Create validation split if needed
    if val_ratio > 0:
        sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=random_state)
        y_train_labeled = y[train_labeled]
        train_keep_idx, val_idx = next(sss.split(np.zeros_like(y_train_labeled), y_train_labeled))
        
        final_train = train_labeled[train_keep_idx]
        val_indices = train_labeled[val_idx]
    else:
        final_train = train_labeled
        val_indices = np.array([], dtype=int)
    
    # Create masks
    N = len(df_features)
    train_mask = np.zeros(N, dtype=bool)
    val_mask = np.zeros(N, dtype=bool)
    test_mask = np.zeros(N, dtype=bool)
    
    train_mask[final_train] = True
    if val_indices.size > 0:
        val_mask[val_indices] = True
    test_mask[test_labeled] = True
    
    # Calculate class balance for XGBoost
    y_train = y[train_mask]
    pos_count = (y_train == 1).sum()
    neg_count = (y_train == 0).sum()
    scale_pos_weight = float(neg_count) / float(max(pos_count, 1))
    
    print(f"  Train: {train_mask.sum():,} | Val: {val_mask.sum():,} | Test: {test_mask.sum():,}")
    print(f"  Features shape: {X.shape}")
    print(f"  Class balance - Illicit: {pos_count:,} / {pos_count+neg_count:,} ({pos_count/(pos_count+neg_count)*100:.1f}%)")
    print(f"  XGBoost scale_pos_weight: {scale_pos_weight:.2f}")
    
    # ============================================
    # STEP 2: Collect GNN Embeddings
    # ============================================
    
    print("\n2. Collecting GNN embeddings from global variables...")
    
    embeddings_dict = {}
    gnn_models = ['gcn', 'skip_gcn', 'gat', 'gatv2', 'sage']
    
    for model_name in gnn_models:
        var_name = f"{model_name}_embeddings"
        if var_name in globals():
            embeddings_dict[model_name] = globals()[var_name]
            print(f"  ✓ Found {var_name}: shape {embeddings_dict[model_name].shape}")
        else:
            print(f"  ✗ Warning: {var_name} not found in global variables")
    
    if not embeddings_dict:
        raise ValueError("No GNN embeddings found! Please run GNN training first.")
    
    # ============================================
    # STEP 3: Define Optimized Baseline Models
    # ============================================
    
    def get_baseline_models():
        """Get baseline model configurations with optimized hyperparameters."""
        return {
            'Logistic Regression': {
                'model': LogisticRegression(
                    max_iter=1000,
                    class_weight='balanced',
                    random_state=random_state
                ),
                'needs_scaling': False
            },
            'Random Forest': {
                'model': RandomForestClassifier(
                    n_estimators=100,
                    max_depth=20,  # Updated from 10 to 20
                    class_weight='balanced',
                    random_state=random_state,
                    n_jobs=-1
                ),
                'needs_scaling': False
            },
            'MLP': {
                'model': MLPClassifier(
                    hidden_layer_sizes=(100, 50),  # Updated from (128, 64)
                    activation='relu',
                    solver='adam',
                    alpha=0.001,
                    max_iter=500,
                    early_stopping=True,
                    validation_fraction=0.1,
                    random_state=random_state
                ),
                'needs_scaling': True  # MLP needs StandardScaler
            },
            'XGBoost': {
                'model': XGBClassifier(
                    n_estimators=100,
                    max_depth=6,
                    learning_rate=0.1,
                    scale_pos_weight=scale_pos_weight,  # Dynamic calculation
                    use_label_encoder=False,
                    eval_metric='logloss',
                    random_state=random_state
                ),
                'needs_scaling': False
            }
        }
    
    # ============================================
    # STEP 4: Training Function
    # ============================================
    
    def train_and_evaluate(X_train, y_train, X_test, y_test, model_config):
        """Train model and return metrics with optional scaling."""
        model = model_config['model']
        needs_scaling = model_config['needs_scaling']
        
        if needs_scaling:
            # Scale data for models that need it (MLP)
            scaler = StandardScaler()
            X_train = scaler.fit_transform(X_train)
            X_test = scaler.transform(X_test)
        
        # Train
        model.fit(X_train, y_train)
        
        # Predict
        y_pred = model.predict(X_test)
        
        # Calculate metrics
        metrics = {
            'Precision': precision_score(y_test, y_pred, zero_division=0),
            'Illicit Recall': recall_score(y_test, y_pred, zero_division=0),
            'F1': f1_score(y_test, y_pred, zero_division=0),
            'Accuracy': accuracy_score(y_test, y_pred),
            'MicroF1': f1_score(y_test, y_pred, average='micro')
        }
        
        return metrics
    
    # ============================================
    # STEP 5: Train All Configurations
    # ============================================
    
    print("\n3. Training baseline models with optimized hyperparameters...")
    print("-" * 60)
    
    all_results = []
    detailed_results = {}
    
    # For each baseline model
    for model_name, model_config in get_baseline_models().items():
        print(f"\n{model_name}:")
        
        # 1. Train baseline (features only - AF)
        print(f"  Training baseline (AF only)...")
        X_train = X[train_mask]
        y_train = y[train_mask]
        X_test = X[test_mask]
        y_test = y[test_mask]
        
        metrics = train_and_evaluate(X_train, y_train, X_test, y_test, model_config)
        
        result_row = {
            'Method': f"{model_name}^AF",
            **metrics
        }
        all_results.append(result_row)
        detailed_results[f"{model_name}_AF"] = metrics
        
        print(f"    F1: {metrics['F1']:.4f} | Precision: {metrics['Precision']:.4f} | Recall: {metrics['Illicit Recall']:.4f}")
        
        # 2. Train with each GNN's embeddings (AF+NE)
        for gnn_name, embeddings in embeddings_dict.items():
            gnn_display = gnn_name.upper().replace('_', '-')
            print(f"  Training with {gnn_display} embeddings (AF+NE)...")
            
            # Concatenate features with embeddings
            X_enhanced = np.concatenate([X, embeddings], axis=1)
            
            X_train_enh = X_enhanced[train_mask]
            X_test_enh = X_enhanced[test_mask]
            
            # Get fresh model instance
            model_config_fresh = get_baseline_models()[model_name]
            
            metrics = train_and_evaluate(X_train_enh, y_train, X_test_enh, y_test, 
                                        model_config_fresh)
            
            result_row = {
                'Method': f"{model_name}^AF+NE ({gnn_display})",
                **metrics
            }
            all_results.append(result_row)
            detailed_results[f"{model_name}_AF+NE_{gnn_name}"] = metrics
            
            print(f"    F1: {metrics['F1']:.4f} | Precision: {metrics['Precision']:.4f} | Recall: {metrics['Illicit Recall']:.4f}")
    
    # ============================================
    # STEP 6: Create Results DataFrame
    # ============================================
    
    df_results = pd.DataFrame(all_results)
    
    # Round numeric columns
    numeric_cols = ['Precision', 'Illicit Recall', 'F1', 'Accuracy', 'MicroF1']
    for col in numeric_cols:
        df_results[col] = df_results[col].round(4)
    
    # ============================================
    # STEP 7: Display Results in Paper Format
    # ============================================
    
    print("\n" + "="*80)
    print("TABLE 1.1 FORMAT - ILLICIT CLASSIFICATION RESULTS")
    print("="*80)
    
    # Display columns matching the paper
    display_cols = ['Method', 'Precision', 'Illicit Recall', 'F1']
    df_display = df_results[display_cols].copy()
    
    print("\nLogistic Regression:")
    print("-" * 60)
    lr_results = df_display[df_display['Method'].str.startswith('Logistic')]
    print(lr_results.to_string(index=False))
    
    print("\nRandom Forest:")
    print("-" * 60)
    rf_results = df_display[df_display['Method'].str.startswith('Random')]
    print(rf_results.to_string(index=False))
    
    print("\nMLP:")
    print("-" * 60)
    mlp_results = df_display[df_display['Method'].str.startswith('MLP')]
    print(mlp_results.to_string(index=False))
    
    print("\nXGBoost:")
    print("-" * 60)
    xgb_results = df_display[df_display['Method'].str.startswith('XGBoost')]
    print(xgb_results.to_string(index=False))
    
    # ============================================
    # STEP 8: Summary of Improvements
    # ============================================
    
    print("\n" + "="*80)
    print("SUMMARY - ENHANCEMENT FROM GNN EMBEDDINGS")
    print("="*80)
    
    baseline_models = ['Logistic Regression', 'Random Forest', 'MLP', 'XGBoost']
    
    for baseline in baseline_models:
        # Get baseline F1
        baseline_f1 = df_results[df_results['Method'] == f"{baseline}^AF"]['F1'].values[0]
        
        # Get best enhanced F1
        enhanced_results = df_results[df_results['Method'].str.startswith(f"{baseline}^AF+NE")]
        if not enhanced_results.empty:
            best_idx = enhanced_results['F1'].idxmax()
            best_f1 = enhanced_results.loc[best_idx, 'F1']
            best_method = enhanced_results.loc[best_idx, 'Method']
            best_gnn = best_method.split('(')[1].replace(')', '')
            
            improvement = ((best_f1 - baseline_f1) / baseline_f1) * 100
            print(f"{baseline:20} | Baseline: {baseline_f1:.4f} → Best: {best_f1:.4f} ({best_gnn}) | +{improvement:.1f}%")
    
    # ============================================
    # STEP 9: Complete Table (All Results)
    # ============================================
    
    print("\n" + "="*80)
    print("COMPLETE RESULTS TABLE")
    print("="*80)
    print(df_display.to_string(index=False))
    
    if save_detailed_results:
        df_results.to_csv('baseline_enhancement_results_optimized.csv', index=False)
        print("\n✓ Results saved to 'baseline_enhancement_results_optimized.csv'")
    
    return df_results, detailed_results


# ============================================
# USAGE
# ============================================

if __name__ == "__main__":
    # Assumes GNN embeddings are already available as global variables
    # from: run_elliptic_kfold_with_embeddings()
    
    # Train baseline models with optimized hyperparameters
    df_results, detailed_results = train_baselines_with_embeddings_optimized(
        n_splits=5,
        fold_id=1,  # Use same fold as GNN training
        val_ratio=0.10,
        random_state=42
    )
    
    print("\n" + "="*80)
    print("Expected Performance (based on your K-fold baseline results):")
    print("  Random Forest: ~0.93-0.94 F1 baseline")
    print("  XGBoost: ~0.92-0.93 F1 baseline")
    print("  MLP: ~0.89-0.90 F1 baseline")
    print("  Logistic Regression: ~0.61-0.62 F1 baseline")
    print("\nWith GNN embeddings, expect 5-70% relative improvement depending on model.")
    print("="*80)

BASELINE MODELS ENHANCEMENT WITH GNN EMBEDDINGS (OPTIMIZED)

1. Loading Elliptic dataset and creating K-fold splits...
  Train: 33,525 | Val: 3,726 | Test: 9,313
  Features shape: (203769, 165)
  Class balance - Illicit: 3,272 / 33,525 (9.8%)
  XGBoost scale_pos_weight: 9.25

2. Collecting GNN embeddings from global variables...
  ✓ Found gcn_embeddings: shape (203769, 64)
  ✓ Found skip_gcn_embeddings: shape (203769, 64)
  ✓ Found gat_embeddings: shape (203769, 64)
  ✓ Found gatv2_embeddings: shape (203769, 64)
  ✓ Found sage_embeddings: shape (203769, 64)

3. Training baseline models with optimized hyperparameters...
------------------------------------------------------------

Logistic Regression:
  Training baseline (AF only)...
    F1: 0.6146 | Precision: 0.4577 | Recall: 0.9351
  Training with GCN embeddings (AF+NE)...
    F1: 0.8434 | Precision: 0.7734 | Recall: 0.9274
  Training with SKIP-GCN embeddings (AF+NE)...
    F1: 0.8944 | Precision: 0.8724 | Recall: 0.9175
  Training w