# Van der Pol Oscillator Training Notebook

This notebook provides a clean interface for training and evaluating models on the Van der Pol oscillator system.

**Sections:**
1. Setup & Data Generation
2. SFT Training (Optional)
3. GRPO Training (Optional) 
4. SFT + GRPO Training (Combined)
5. Model Evaluation
6. Results Analysis

## 1. Setup & Configuration

In [None]:
import sys
import os
from pathlib import Path

# Add parent directory to path
sys.path.append('..')

from config import ALL_CONFIG, AVAILABLE_SYSTEMS
from core.model_manager import UniversalModelManager
from core.data_pipeline import UniversalDataGenerator
from training.sft_training import train_sft_model, setup_universal_chat_template, save_sft_model
from training.grpo_training import train_grpo_model, save_grpo_model
from evaluation.inference import run_batch_inference
from evaluation.metrics import compute_batch_metrics
from evaluation.visualization import plot_comparison, plot_metrics_comparison
from environments import get_system
from data_utils import load_train_eval_datasets, list_available_datasets
from gpu_utils import auto_gpu_config
import matplotlib.pyplot as plt
import numpy as np

print("✅ All modules loaded successfully!")
print(f"Available systems: {AVAILABLE_SYSTEMS}")

In [None]:
# Configuration
SYSTEM_NAME = "van_der_pol"
DATASET_NAME = "vdp"  # Simple clean name
LORA_RANK = 8
MAX_SEQ_LENGTH = 1024

print(f"🎯 Training system: {SYSTEM_NAME}")
print(f"📊 Dataset: {DATASET_NAME}")
print(f"🔧 LoRA rank: {LORA_RANK}")
print(f"📏 Max sequence length: {MAX_SEQ_LENGTH}")

## 2. Data Generation & Loading

In [None]:
# Check available datasets
print("📂 Available datasets:")
datasets = list_available_datasets("../datasets")
if datasets:
    for dataset in datasets:
        print(f"   • {dataset}")
else:
    print("   No datasets found")

In [None]:
# Generate VDP dataset (run this cell if you don't have the dataset)
GENERATE_NEW_DATA = False  # Set to True to generate new data

if GENERATE_NEW_DATA:
    print("🔄 Generating new Van der Pol dataset...")
    
    generator = UniversalDataGenerator(
        systems=[SYSTEM_NAME],
        dt=ALL_CONFIG["system"]["dt"],
        steps=ALL_CONFIG["system"]["steps"],
        reasoning_start=ALL_CONFIG["system"]["reasoning_start"],
        reasoning_end=ALL_CONFIG["system"]["reasoning_end"],
        solution_start=ALL_CONFIG["system"]["solution_start"],
        solution_end=ALL_CONFIG["system"]["solution_end"]
    )
    
    # Generate 2000 samples (1800 train + 200 eval)
    data = generator.generate_single_system_dataset(SYSTEM_NAME, 2000)
    train_data, eval_data = generator.split_dataset(data, 0.9)
    
    # Save dataset
    import pickle
    os.makedirs("../datasets", exist_ok=True)
    
    with open(f"../datasets/{DATASET_NAME}_train.pkl", 'wb') as f:
        pickle.dump(train_data, f)
    with open(f"../datasets/{DATASET_NAME}_eval.pkl", 'wb') as f:
        pickle.dump(eval_data, f)
    
    print(f"✅ Generated and saved dataset: {DATASET_NAME}")
    print(f"   📈 Train samples: {len(train_data)}")
    print(f"   📊 Eval samples: {len(eval_data)}")
else:
    print("⏭️  Skipping data generation (set GENERATE_NEW_DATA=True to generate)")

In [None]:
# Load existing dataset
try:
    train_data, eval_data, dataset_info = load_train_eval_datasets(
        DATASET_NAME, "../datasets", SYSTEM_NAME
    )
    print(f"✅ Loaded dataset: {DATASET_NAME}")
    print(f"   📈 Train samples: {len(train_data)}")
    print(f"   📊 Eval samples: {len(eval_data)}")
    print(f"   ℹ️  Dataset info: {dataset_info.get('config', {})}")
except Exception as e:
    print(f"❌ Failed to load dataset: {e}")
    print("💡 Set GENERATE_NEW_DATA=True in the cell above to generate the dataset")

## 3. Model Setup

In [None]:
# Setup GPU and model manager
print("🎯 Setting up GPU and model...")

# Auto-select best GPU
gpu_config = auto_gpu_config()
print(f"🖥️  Selected GPU: {gpu_config['gpu_id']}")

# Create model manager
manager = UniversalModelManager(ALL_CONFIG["model"]["base_model_name"])

# Setup model
model, tokenizer = manager.setup_model(
    max_seq_length=MAX_SEQ_LENGTH,
    lora_rank=LORA_RANK,
    gpu_id=gpu_config['gpu_id'],
    auto_select_gpu=False
)

print("✅ Model setup complete!")

In [None]:
# Setup chat template for single system
setup_universal_chat_template(
    manager, [SYSTEM_NAME],
    ALL_CONFIG["system"]["reasoning_start"],
    ALL_CONFIG["system"]["reasoning_end"],
    ALL_CONFIG["system"]["solution_start"],
    ALL_CONFIG["system"]["solution_end"]
)

print("✅ Chat template configured for Van der Pol Oscillator")

## 4. SFT Training (Optional)

Run this section to train only the SFT model.

In [None]:
# SFT Training
RUN_SFT_ONLY = False  # Set to True to run SFT training only

if RUN_SFT_ONLY:
    print("🚀 Starting SFT Training...")
    
    # Update SFT config
    sft_config = ALL_CONFIG["sft"].copy()
    sft_config["output_dir"] = f"../temp_training/{SYSTEM_NAME}/sft"
    
    # Train SFT
    sft_result = train_sft_model(
        manager, train_data, eval_data, sft_config
    )
    
    # Save SFT model
    sft_save_path = save_sft_model(
        manager, [SYSTEM_NAME], sft_result["metrics"]
    )
    
    print(f"✅ SFT model saved to: {sft_save_path}")
else:
    print("⏭️  Skipping SFT-only training (set RUN_SFT_ONLY=True to run)")

## 5. SFT + GRPO Training (Combined)

Run this section to train both SFT and GRPO models in sequence.

In [None]:
# Combined SFT + GRPO Training
RUN_COMBINED_TRAINING = True  # Set to True to run full training pipeline

if RUN_COMBINED_TRAINING:
    print("🚀 Starting Combined SFT + GRPO Training...")
    
    # === SFT Phase ===
    print("\n" + "="*50)
    print("📚 SFT TRAINING PHASE")
    print("="*50)
    
    sft_config = ALL_CONFIG["sft"].copy()
    sft_config["output_dir"] = f"../temp_training/{SYSTEM_NAME}/sft"
    
    sft_result = train_sft_model(
        manager, train_data, eval_data, sft_config
    )
    
    sft_save_path = save_sft_model(
        manager, [SYSTEM_NAME], sft_result["metrics"]
    )
    
    print(f"✅ SFT model saved to: {sft_save_path}")
    
    # === GRPO Phase ===
    print("\n" + "="*50)
    print("🎮 GRPO TRAINING PHASE")
    print("="*50)
    
    grpo_config = ALL_CONFIG["grpo"].copy()
    grpo_config["output_dir"] = f"../temp_training/{SYSTEM_NAME}/grpo"
    
    grpo_result = train_grpo_model(
        manager, train_data, eval_data, grpo_config,
        ALL_CONFIG["system"]["reasoning_start"],
        ALL_CONFIG["system"]["reasoning_end"],
        ALL_CONFIG["system"]["solution_start"],
        ALL_CONFIG["system"]["solution_end"]
    )
    
    grpo_save_path = save_grpo_model(
        manager, [SYSTEM_NAME], grpo_result["metrics"]
    )
    
    print(f"✅ GRPO model saved to: {grpo_save_path}")
    
    print("\n" + "="*50)
    print("🎉 TRAINING COMPLETED")
    print("="*50)
    print(f"📍 SFT model: {sft_save_path}")
    print(f"📍 GRPO model: {grpo_save_path}")
    
else:
    print("⏭️  Skipping combined training (set RUN_COMBINED_TRAINING=True to run)")

## 6. Model Evaluation

Evaluate your trained models on the test dataset.

In [None]:
# Model Evaluation
RUN_EVALUATION = True  # Set to True to run evaluation
EVALUATE_SFT = True   # Set to True to evaluate SFT model
EVALUATE_GRPO = True  # Set to True to evaluate GRPO model

if RUN_EVALUATION:
    print("📊 Starting Model Evaluation...")
    
    # Load models for evaluation
    eval_manager = UniversalModelManager()
    
    if EVALUATE_SFT:
        print("\n🔍 Evaluating SFT Model...")
        try:
            sft_model, sft_tokenizer, sft_lora, sft_metadata = eval_manager.load_single_system_model(
                SYSTEM_NAME, model_type="sft"
            )
            
            # Generate test cases
            system = get_system(SYSTEM_NAME)()
            test_cases = []
            for _ in range(10):  # 10 test cases
                initial_state = system.generate_random_initial_state()
                test_cases.append(tuple(initial_state))
            
            # Run inference
            from vllm import SamplingParams
            sampling_params = SamplingParams(
                temperature=0.7,
                top_k=50,
                max_tokens=1024
            )
            
            sft_results = run_batch_inference(
                sft_model, sft_tokenizer, SYSTEM_NAME, test_cases,
                lora_request=sft_lora,
                sampling_params=sampling_params
            )
            
            # Compute metrics
            sft_metrics = compute_batch_metrics(sft_results)
            
            print(f"✅ SFT Evaluation Results:")
            print(f"   Success rate: {sft_metrics['success_rate']:.2%}")
            print(f"   Mean performance: {sft_metrics['mean_performance_score']:.4f}")
            
        except Exception as e:
            print(f"❌ SFT evaluation failed: {e}")
    
    if EVALUATE_GRPO:
        print("\n🔍 Evaluating GRPO Model...")
        try:
            grpo_model, grpo_tokenizer, grpo_lora, grpo_metadata = eval_manager.load_single_system_model(
                SYSTEM_NAME, model_type="grpo"
            )
            
            # Run inference
            grpo_results = run_batch_inference(
                grpo_model, grpo_tokenizer, SYSTEM_NAME, test_cases,
                lora_request=grpo_lora,
                sampling_params=sampling_params
            )
            
            # Compute metrics
            grpo_metrics = compute_batch_metrics(grpo_results)
            
            print(f"✅ GRPO Evaluation Results:")
            print(f"   Success rate: {grpo_metrics['success_rate']:.2%}")
            print(f"   Mean performance: {grpo_metrics['mean_performance_score']:.4f}")
            
        except Exception as e:
            print(f"❌ GRPO evaluation failed: {e}")
            
else:
    print("⏭️  Skipping evaluation (set RUN_EVALUATION=True to run)")

## 7. Visualization & Analysis

In [None]:
# Plot results if evaluation was run
if RUN_EVALUATION and EVALUATE_GRPO and 'grpo_results' in locals():
    print("📈 Generating visualizations...")
    
    # Plot trajectory comparison
    fig1 = plot_comparison(grpo_results)
    if fig1:
        plt.figure(fig1.number)
        plt.suptitle(f"Van der Pol GRPO Model Results", fontsize=16)
        plt.tight_layout()
        plt.show()
    
    # Plot metrics comparison
    fig2 = plot_metrics_comparison(grpo_results)
    if fig2:
        plt.figure(fig2.number)
        plt.suptitle(f"Van der Pol Performance Metrics", fontsize=16)
        plt.tight_layout()
        plt.show()
        
    print("✅ Visualizations complete!")
else:
    print("⏭️  No results to visualize (run evaluation first)")

## 8. Summary

This notebook provides a complete workflow for training and evaluating Van der Pol oscillator control models:

- **Data Generation**: Create clean VDP dataset (1800 train + 200 eval)
- **SFT Training**: Supervised fine-tuning for basic control knowledge
- **GRPO Training**: Reinforcement learning for optimal control
- **Evaluation**: Test model performance on unseen data
- **Visualization**: Plot trajectories and performance metrics

**Model Outputs:**
- SFT model: `models/single_system/van_der_pol/sft/latest/`
- GRPO model: `models/single_system/van_der_pol/grpo/latest/`

**Next Steps:**
- Use the DI training notebook for Double Integrator
- Use the universal training notebook for multi-system models
- Load trained models in other notebooks for further analysis