In [1]:
"""
=================================================================================
INDIVIDUAL METHOD TESTER
=================================================================================
Quick testing script for individual method-dataset combinations.
Provides detailed output for debugging and analysis.
=================================================================================
"""

import sys
from pathlib import Path
import numpy as np
import pandas as pd
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Setup
PROJECT_ROOT = Path.cwd().parent if 'notebooks' in str(Path.cwd()) else Path.cwd()
sys.path.insert(0, str(PROJECT_ROOT))

from src.methods.method_runner import run_talent_method
from sklearn.metrics import (
    roc_auc_score, accuracy_score, f1_score, confusion_matrix,
    mean_squared_error, r2_score, mean_absolute_error
)

# =============================================================================
# CONFIGURATION FUNCTION
# =============================================================================

def test_method(
    task='pd',
    dataset='0001.gmsc',
    method='xgboost',
    tune=False,
    # Data params
    test_size=0.2,
    val_size=0.2,
    cv_splits=3,
    seed=42,
    row_limit=None,
    # Training params
    max_epoch=50,
    batch_size=256,
    n_trials=20,
    # Display params
    n_preview_rows=10,
    verbose=False,
):
    """
    Test a single method on a dataset with detailed output.
    
    Args:
        task: 'pd' (classification) or 'lgd' (regression)
        dataset: Dataset name (e.g., '0001.gmsc', '0001.heloc')
        method: TALENT method name (e.g., 'xgboost', 'mlp', 'tabpfn')
        tune: Whether to use hyperparameter optimization
        test_size: Test set fraction
        val_size: Validation set fraction
        cv_splits: Number of cross-validation folds
        seed: Random seed
        row_limit: Optional row limit for quick testing
        max_epoch: Max epochs for deep methods
        batch_size: Batch size for deep methods
        n_trials: HPO trials if tune=True
        n_preview_rows: Number of prediction rows to show
        verbose: Print detailed progress
        
    Returns:
        results: Dictionary with all fold results
    """
    
    print("="*80)
    print(f" TESTING: {method.upper()} on {dataset} ({task.upper()})")
    print("="*80)
    
    # Configuration summary
    print(f"\nðŸ“‹ Configuration:")
    print(f"  Dataset: {dataset}")
    print(f"  Method: {method}")
    print(f"  HPO: {'Yes' if tune else 'No'} {f'({n_trials} trials)' if tune else ''}")
    print(f"  CV Splits: {cv_splits}")
    print(f"  Row Limit: {row_limit if row_limit else 'None (full dataset)'}")
    print(f"  Seed: {seed}")
    
    # Run method
    print(f"\nðŸ”„ Running {method}...")
    start_time = pd.Timestamp.now()
    
    try:
        results = run_talent_method(
            task=task,
            dataset=dataset,
            test_size=test_size,
            val_size=val_size,
            cv_splits=cv_splits,
            seed=seed,
            row_limit=row_limit,
            method=method,
            max_epoch=max_epoch,
            batch_size=batch_size,
            tune=tune,
            n_trials=n_trials,
            verbose=verbose,
        )
        
        elapsed = (pd.Timestamp.now() - start_time).total_seconds()
        
        print(f"âœ“ Completed in {elapsed:.1f}s")
        
    except Exception as e:
        print(f"âœ— Error: {e}")
        import traceback
        traceback.print_exc()
        return None
    
    # =============================================================================
    # AGGREGATE RESULTS
    # =============================================================================
    
    print(f"\n{'='*80}")
    print(" RESULTS SUMMARY")
    print(f"{'='*80}")
    
    # Collect all predictions
    all_y_true = []
    all_y_pred = []
    all_fold_ids = []
    
    for fold_id, fold_result in results.items():
        y_true = fold_result['y_true']
        y_pred = fold_result['y_pred']
        
        all_y_true.extend(y_true)
        all_y_pred.extend(y_pred)
        all_fold_ids.extend([fold_id] * len(y_true))
    
    all_y_true = np.array(all_y_true)
    all_y_pred = np.array(all_y_pred)
    
    # =============================================================================
    # COMPUTE METRICS
    # =============================================================================
    
    if task == 'pd':  # Classification
        # Handle probability predictions
        if len(all_y_pred.shape) > 1 and all_y_pred.shape[1] > 1:
            y_pred_proba = all_y_pred[:, 1]
            y_pred_class = np.argmax(all_y_pred, axis=1)
        else:
            y_pred_proba = all_y_pred
            y_pred_class = (all_y_pred > 0.5).astype(int)
        
        # Metrics
        auc = roc_auc_score(all_y_true, y_pred_proba)
        acc = accuracy_score(all_y_true, y_pred_class)
        f1 = f1_score(all_y_true, y_pred_class, average='binary')
        cm = confusion_matrix(all_y_true, y_pred_class)
        
        print(f"\nðŸ“Š Classification Metrics (All Folds):")
        print(f"  AUC:      {auc:.4f}")
        print(f"  Accuracy: {acc:.4f}")
        print(f"  F1 Score: {f1:.4f}")
        
        print(f"\n  Confusion Matrix:")
        print(f"    [[TN={cm[0,0]:4d}  FP={cm[0,1]:4d}]")
        print(f"     [FN={cm[1,0]:4d}  TP={cm[1,1]:4d}]]")
        
        # Class distribution
        n_positive = (all_y_true == 1).sum()
        n_negative = (all_y_true == 0).sum()
        print(f"\n  Class Distribution:")
        print(f"    Negative (0): {n_negative:5d} ({n_negative/len(all_y_true)*100:.1f}%)")
        print(f"    Positive (1): {n_positive:5d} ({n_positive/len(all_y_true)*100:.1f}%)")
        
    else:  # Regression
        mse = mean_squared_error(all_y_true, all_y_pred)
        rmse = np.sqrt(mse)
        mae = mean_absolute_error(all_y_true, all_y_pred)
        r2 = r2_score(all_y_true, all_y_pred)
        
        print(f"\nðŸ“Š Regression Metrics (All Folds):")
        print(f"  RÂ²:   {r2:.4f}")
        print(f"  RMSE: {rmse:.4f}")
        print(f"  MAE:  {mae:.4f}")
        print(f"  MSE:  {mse:.4f}")
        
        # Distribution stats
        print(f"\n  Target Distribution:")
        print(f"    Mean: {all_y_true.mean():.4f}")
        print(f"    Std:  {all_y_true.std():.4f}")
        print(f"    Min:  {all_y_true.min():.4f}")
        print(f"    Max:  {all_y_true.max():.4f}")
    
    # =============================================================================
    # PER-FOLD BREAKDOWN
    # =============================================================================
    
    print(f"\n{'='*80}")
    print(" PER-FOLD BREAKDOWN")
    print(f"{'='*80}")
    
    fold_metrics = []
    
    for fold_id, fold_result in results.items():
        y_true = fold_result['y_true']
        y_pred = fold_result['y_pred']
        train_time = fold_result['train_time']
        
        if task == 'pd':
            if len(y_pred.shape) > 1 and y_pred.shape[1] > 1:
                y_pred_proba = y_pred[:, 1]
            else:
                y_pred_proba = y_pred
            
            fold_auc = roc_auc_score(y_true, y_pred_proba)
            fold_metrics.append({
                'Fold': fold_id,
                'AUC': fold_auc,
                'Samples': len(y_true),
                'Time (s)': train_time,
            })
        else:
            fold_r2 = r2_score(y_true, y_pred)
            fold_rmse = np.sqrt(mean_squared_error(y_true, y_pred))
            fold_metrics.append({
                'Fold': fold_id,
                'RÂ²': fold_r2,
                'RMSE': fold_rmse,
                'Samples': len(y_true),
                'Time (s)': train_time,
            })
    
    fold_df = pd.DataFrame(fold_metrics)
    print(f"\n{fold_df.to_string(index=False)}")
    
    # Summary statistics
    print(f"\n  Average training time: {fold_df['Time (s)'].mean():.2f}s per fold")
    print(f"  Total training time: {fold_df['Time (s)'].sum():.2f}s")
    
    # =============================================================================
    # PREDICTION PREVIEW
    # =============================================================================
    
    print(f"\n{'='*80}")
    print(f" PREDICTION PREVIEW (First {n_preview_rows} samples)")
    print(f"{'='*80}")
    
    if task == 'pd':
        if len(all_y_pred.shape) > 1 and all_y_pred.shape[1] > 1:
            preview_pred = all_y_pred[:n_preview_rows, 1]  # Probability of positive class
        else:
            preview_pred = all_y_pred[:n_preview_rows]
    else:
        preview_pred = all_y_pred[:n_preview_rows]
    
    preview_df = pd.DataFrame({
        'Fold': all_fold_ids[:n_preview_rows],
        'True': all_y_true[:n_preview_rows],
        'Predicted': preview_pred,
        'Error': np.abs(all_y_true[:n_preview_rows] - preview_pred),
    })
    
    if task == 'pd':
        preview_df['Pred_Class'] = (preview_pred > 0.5).astype(int)
        preview_df['Correct'] = (preview_df['True'] == preview_df['Pred_Class']).map({True: 'âœ“', False: 'âœ—'})
    
    print(f"\n{preview_df.to_string(index=False)}")
    
    # =============================================================================
    # DATASET INFO
    # =============================================================================
    
    print(f"\n{'='*80}")
    print(" DATASET INFO")
    print(f"{'='*80}")
    
    first_fold = results[list(results.keys())[0]]
    info = first_fold['info']
    
    print(f"\n  Dataset: {dataset}")
    print(f"  Task: {info.get('task_type', 'N/A')}")
    print(f"  Total samples: {len(all_y_true)}")
    print(f"  Numerical features: {info.get('n_num_features', 'N/A')}")
    print(f"  Categorical features: {info.get('n_cat_features', 'N/A')}")
    print(f"  Total features: {info.get('n_num_features', 0) + info.get('n_cat_features', 0)}")
    if task == 'pd':
        print(f"  Number of classes: {info.get('n_classes', 'N/A')}")
    
    # =============================================================================
    # FINAL SUMMARY
    # =============================================================================
    
    print(f"\n{'='*80}")
    print(" SUMMARY")
    print(f"{'='*80}")
    
    if task == 'pd':
        print(f"\nâœ… {method} achieved AUC = {auc:.4f} on {dataset}")
    else:
        print(f"\nâœ… {method} achieved RÂ² = {r2:.4f} on {dataset}")
    
    print(f"   Total time: {elapsed:.1f}s")
    print(f"   HPO: {'Yes' if tune else 'No'}")
    
    print(f"\n{'='*80}\n")
    print(results)
    return results


# =============================================================================
# EXAMPLE USAGE
# =============================================================================

# Example 1: Quick test without HPO
if __name__ == "__main__":
    # Classification test
    results_pd = test_method(
        task='pd',
        dataset='0001.gmsc',
        method='tabpfn',
        tune=False,
        row_limit=1000 ,  # Quick test
        cv_splits=3,
    )


 TESTING: TABPFN on 0001.gmsc (PD)

ðŸ“‹ Configuration:
  Dataset: 0001.gmsc
  Method: tabpfn
  HPO: No 
  CV Splits: 3
  Row Limit: 1000
  Seed: 42

ðŸ”„ Running tabpfn...
âœ“ Completed in 2.8s

 RESULTS SUMMARY

ðŸ“Š Classification Metrics (All Folds):
  AUC:      0.8081
  Accuracy: 0.9410
  F1 Score: 0.0000

  Confusion Matrix:
    [[TN= 941  FP=   2]
     [FN=  57  TP=   0]]

  Class Distribution:
    Negative (0):   943 (94.3%)
    Positive (1):    57 (5.7%)

 PER-FOLD BREAKDOWN

 Fold      AUC  Samples  Time (s)
    1 0.828404      334       0.0
    2 0.791988      333       0.0
    3 0.822159      333       0.0

  Average training time: 0.00s per fold
  Total training time: 0.00s

 PREDICTION PREVIEW (First 10 samples)

 Fold  True  Predicted    Error  Pred_Class Correct
    1     1   0.170336 0.829664           0       âœ—
    1     0   0.080200 0.080200           0       âœ“
    1     0   0.033210 0.033210           0       âœ“
    1     0   0.032777 0.032777           0      