# GR00T Probe Training Notebook

This notebook provides a convenient interface to train the GR00T probe model.

## Features:
- **Linear Regression**: Simple probe with no hidden layers
- **Feature Types**: Choose between `mean_pooled` or `last_vector` features
- **Input Shape**: [2048] dimensional features
- **Easy Configuration**: Modify parameters in the cell below

## Requirements:
- Processed training data: `probe_training_data_150k_processed.parquet`
- Generated by running `getting_started/extract_probe_training_data.ipynb`

## 🔧 Configuration

Modify these parameters to customize your training:

In [None]:
# Training Configuration
FEATURE_TYPE = "mean_pooled"  # Options: "mean_pooled" or "last_vector"
DATA_PATH = "probe_training_data_150k_processed.parquet"  # Path to processed data
BATCH_SIZE = 32
NUM_EPOCHS = 100

print(f"📊 Configuration:")
print(f"   • Feature Type: {FEATURE_TYPE}")
print(f"   • Data Path: {DATA_PATH}")
print(f"   • Batch Size: {BATCH_SIZE}")
print(f"   • Epochs: {NUM_EPOCHS}")

## 🚀 Start Training

Run the cell below to start training the probe model:

In [None]:
# Import and run training
import sys
import os

# Add current directory to path
sys.path.append(os.getcwd())

# Import training function
from train_probe import main as train_main

print("🏁 Starting probe training...")
print("=" * 60)

# Run training with specified parameters
train_main(
    feature_type=FEATURE_TYPE,
    data_path=DATA_PATH,
    batch_size=BATCH_SIZE,
    num_epochs=NUM_EPOCHS
)

print("=" * 60)
print("🎉 Training completed!")

## 📊 Check Training Results

After training, you can check if the output files were created:

In [None]:
import os

# Check output files
output_files = [
    "probe/best_probe_model.pth",
    "probe/training_history.pkl"
]

print("📁 Output Files:")
for file_path in output_files:
    if os.path.exists(file_path):
        size_mb = os.path.getsize(file_path) / (1024 * 1024)
        print(f"   ✅ {file_path} ({size_mb:.2f} MB)")
    else:
        print(f"   ❌ {file_path} (not found)")

print(f"\n🎯 Feature type used: {FEATURE_TYPE}")
print("\n📝 Next steps:")
print("   1. Run evaluate_probe.ipynb to evaluate the trained model")
print("   2. Make sure to use the same feature type for evaluation")

## 🔄 Quick Feature Type Comparison

Want to compare both feature types? Run this cell to train both:

In [None]:
# Train both feature types for comparison
COMPARE_BOTH = False  # Set to True to train both feature types

if COMPARE_BOTH:
    print("🔄 Training both feature types for comparison...")
    
    feature_types = ["mean_pooled", "last_vector"]
    
    for ft in feature_types:
        print(f"\n🚀 Training with {ft} features...")
        print("=" * 50)
        
        # Train model
        train_main(
            feature_type=ft,
            data_path=DATA_PATH,
            batch_size=BATCH_SIZE,
            num_epochs=NUM_EPOCHS
        )
        
        # Rename output files to avoid overwriting
        import shutil
        if os.path.exists("probe/best_probe_model.pth"):
            shutil.move("probe/best_probe_model.pth", f"probe/best_probe_model_{ft}.pth")
        if os.path.exists("probe/training_history.pkl"):
            shutil.move("probe/training_history.pkl", f"probe/training_history_{ft}.pkl")
            
    print("\n🎉 Both models trained! Check probe/ directory for outputs.")
else:
    print("ℹ️  Set COMPARE_BOTH = True to train both feature types")