# GR00T Probe Evaluation Notebook

This notebook provides a convenient interface to evaluate the trained GR00T probe model.

## Features:
- **Detailed Metrics**: MSE, RMSE, MAE, correlation analysis
- **Visualizations**: Training curves and prediction plots
- **Feature Type Support**: Evaluate `mean_pooled` or `last_vector` models
- **Performance Analysis**: Per-dimension and overall performance

## Requirements:
- Trained model: `probe/best_probe_model.pth` (from training)
- Processed data: `probe_training_data_150k_processed.parquet`
- Training history: `probe/training_history.pkl` (for plots)

## 🔧 Configuration

**Important**: Make sure `FEATURE_TYPE` matches the one used during training!

In [None]:
# Evaluation Configuration
FEATURE_TYPE = "mean_pooled"  # Options: "mean_pooled" or "last_vector" (MUST match training!)
DATA_PATH = "/content/drive/MyDrive/probe_training_data/probe_training_data_150k_processed.parquet"  # Path to processed data

# Use mounted drive structure
PROBE_OUTPUT_DIR = f"/content/drive/MyDrive/probes/{FEATURE_TYPE}"
MODEL_PATH = f"{PROBE_OUTPUT_DIR}/best_probe_model.pth"  # Path to trained model

print(f"📊 Configuration:")
print(f"   • Feature Type: {FEATURE_TYPE}")
print(f"   • Data Path: {DATA_PATH}")
print(f"   • Model Path: {MODEL_PATH}")
print(f"\n⚠️  Make sure FEATURE_TYPE matches the one used during training!")

## ✅ Check Required Files

Let's verify all required files exist before evaluation:

In [None]:
import os

# Check required files  
required_files = [
    (DATA_PATH, "Processed training data"),
    (MODEL_PATH, "Trained model"),
    (f"{PROBE_OUTPUT_DIR}/training_history.pkl", "Training history (for plots)")
]

print("📁 Required Files Check:")
all_files_exist = True

for file_path, description in required_files:
    if os.path.exists(file_path):
        size_mb = os.path.getsize(file_path) / (1024 * 1024)
        print(f"   ✅ {description}: {file_path} ({size_mb:.2f} MB)")
    else:
        print(f"   ❌ {description}: {file_path} (not found)")
        all_files_exist = False

if all_files_exist:
    print("\n🎉 All files found! Ready for evaluation.")
else:
    print("\n⚠️  Some files are missing. Please:")
    print("   1. Run the data extraction notebook if data is missing")
    print("   2. Run train_probe.ipynb if model is missing")

## 🔍 Run Evaluation

Execute the evaluation with the configured parameters:

In [None]:
# Import and run evaluation
import sys
import os

# Add current directory to path
sys.path.append(os.getcwd())

# Import evaluation function
from evaluate_probe import main as evaluate_main

print("🔍 Starting probe evaluation...")
print("=" * 60)

# Run evaluation with specified parameters
evaluate_main(
    feature_type=FEATURE_TYPE,
    data_path=DATA_PATH,
    model_path=MODEL_PATH
)

print("=" * 60)
print("🎉 Evaluation completed!")

## 📊 Check Evaluation Results

After evaluation, check the generated output files:

In [None]:
# Check evaluation output files
output_files = [
    f"{PROBE_OUTPUT_DIR}/evaluation_metrics.pkl",
    f"{PROBE_OUTPUT_DIR}/training_curves.png", 
    f"{PROBE_OUTPUT_DIR}/predictions_vs_targets.png"
]

print("📁 Evaluation Output Files:")
for file_path in output_files:
    if os.path.exists(file_path):
        size_kb = os.path.getsize(file_path) / 1024
        print(f"   ✅ {file_path} ({size_kb:.1f} KB)")
    else:
        print(f"   ❌ {file_path} (not found)")

print(f"\n🎯 Feature type evaluated: {FEATURE_TYPE}")
print("\n📈 Generated visualizations:")
print("   • Training curves show loss over epochs")
print("   • Prediction plots show model accuracy")

## 🔄 Compare Feature Types

If you have models trained with both feature types, compare them:

In [None]:
# Compare both feature types (if models exist)
COMPARE_BOTH = False  # Set to True to compare both feature types

if COMPARE_BOTH:
    print("🔄 Comparing both feature types...")
    
    feature_types = ["mean_pooled", "last_vector"]
    results = {}
    
    for ft in feature_types:
        model_path = f"probe/best_probe_model_{ft}.pth"
        
        if os.path.exists(model_path):
            print(f"\n🔍 Evaluating {ft} model...")
            print("=" * 40)
            
            # Run evaluation
            evaluate_main(
                feature_type=ft,
                data_path=DATA_PATH,
                model_path=model_path
            )
            
            # Rename output files
            import shutil
            outputs_to_rename = [
                ("probe/evaluation_metrics.pkl", f"probe/evaluation_metrics_{ft}.pkl"),
                ("probe/training_curves.png", f"probe/training_curves_{ft}.png"),
                ("probe/predictions_vs_targets.png", f"probe/predictions_vs_targets_{ft}.png")
            ]
            
            for src, dst in outputs_to_rename:
                if os.path.exists(src):
                    shutil.move(src, dst)
                    
        else:
            print(f"⚠️  Model for {ft} not found: {model_path}")
            
    print("\n🎉 Comparison complete! Check probe/ directory for outputs.")
    print("📊 Files saved with _mean_pooled and _last_vector suffixes")
    
else:
    print("ℹ️  Set COMPARE_BOTH = True to compare both feature types")
    print("   (Requires models trained with both feature types)")

## 📈 Quick Results Summary

Load and display key metrics from the evaluation:

In [None]:
# Quick summary of results
try:
    import pickle
    
    metrics_file = f"{PROBE_OUTPUT_DIR}/evaluation_metrics.pkl"
    if os.path.exists(metrics_file):
        with open(metrics_file, "rb") as f:
            metrics = pickle.load(f)
        
        print("📊 Quick Results Summary:")
        print(f"   • MSE: {metrics['mse']:.6f}")
        print(f"   • RMSE: {metrics['rmse']:.6f}")
        print(f"   • MAE: {metrics['mae']:.6f}")
        print(f"   • Avg Correlation: {np.mean(metrics['correlations']):.4f}")
        
        # Quality assessment
        avg_corr = np.mean(metrics['correlations'])
        if avg_corr > 0.8:
            quality = "Excellent 🌟"
        elif avg_corr > 0.6:
            quality = "Good ✅"
        elif avg_corr > 0.4:
            quality = "Moderate ⚠️"
        else:
            quality = "Poor ❌"
        
        print(f"   • Quality: {quality}")
        print(f"   • Feature Type: {FEATURE_TYPE}")
        
    else:
        print("📊 No metrics file found. Run evaluation first.")
        
except Exception as e:
    print(f"⚠️  Error loading metrics: {e}")