# Universal Control Model Training Notebook

This notebook trains a universal model that can control multiple systems. It follows the progressive training approach:
1. Train on Double Integrator first
2. Extend the model to Van der Pol oscillator
3. Create a universal model that controls both systems

**Sections:**
1. Setup & Data Loading
2. Double Integrator Training
3. Van der Pol Extension Training
4. Universal Model Training
5. Cross-System Evaluation
6. Comparative Analysis

## 1. Setup & Configuration

In [None]:
import sys
import os
from pathlib import Path

# Add parent directory to path
sys.path.append('..')

from config import ALL_CONFIG, AVAILABLE_SYSTEMS
from core.model_manager import UniversalModelManager
from core.data_pipeline import UniversalDataGenerator
from training.sft_training import train_sft_model, setup_universal_chat_template, save_sft_model
from training.grpo_training import train_grpo_model, save_grpo_model
from evaluation.inference import run_batch_inference
from evaluation.metrics import compute_batch_metrics
from evaluation.visualization import plot_comparison, plot_metrics_comparison
from environments import get_system
from data_utils import load_train_eval_datasets, list_available_datasets
from gpu_utils import auto_gpu_config
import matplotlib.pyplot as plt
import numpy as np

print("‚úÖ All modules loaded successfully!")
print(f"Available systems: {AVAILABLE_SYSTEMS}")

In [None]:
# Configuration
SYSTEMS = ["double_integrator", "van_der_pol"]
DATASET_NAMES = ["di", "vdp"]  # Simple clean names
LORA_RANK = 8
MAX_SEQ_LENGTH = 1024

print(f"üéØ Training systems: {', '.join(SYSTEMS)}")
print(f"üìä Datasets: {', '.join(DATASET_NAMES)}")
print(f"üîß LoRA rank: {LORA_RANK}")
print(f"üìè Max sequence length: {MAX_SEQ_LENGTH}")

## 2. Data Loading & Verification

In [None]:
# Check available datasets
print("üìÇ Available datasets:")
datasets = list_available_datasets("../datasets")
if datasets:
    for dataset in datasets:
        print(f"   ‚Ä¢ {dataset}")
else:
    print("   No datasets found")
    
# Check if we have both required datasets
missing_datasets = []
for dataset_name in DATASET_NAMES:
    if dataset_name not in datasets:
        missing_datasets.append(dataset_name)

if missing_datasets:
    print(f"\n‚ö†Ô∏è  Missing datasets: {', '.join(missing_datasets)}")
    print("üí° Run the individual training notebooks first to generate datasets")
else:
    print("\n‚úÖ All required datasets are available!")

In [None]:
# Load all datasets
all_datasets = {}

for system_name, dataset_name in zip(SYSTEMS, DATASET_NAMES):
    try:
        train_data, eval_data, dataset_info = load_train_eval_datasets(
            dataset_name, "../datasets", system_name
        )
        all_datasets[system_name] = {
            'train': train_data,
            'eval': eval_data,
            'info': dataset_info
        }
        print(f"‚úÖ {system_name}: {len(train_data)} train + {len(eval_data)} eval samples")
    except Exception as e:
        print(f"‚ùå Failed to load {system_name} dataset: {e}")

if len(all_datasets) == len(SYSTEMS):
    print("\nüéâ All datasets loaded successfully!")
else:
    print("\n‚ùå Some datasets failed to load. Please check the individual training notebooks.")

## 3. Model Setup

In [None]:
# Setup GPU and model manager
print("üéØ Setting up GPU and model...")

# Auto-select best GPU
gpu_config = auto_gpu_config()
print(f"üñ•Ô∏è  Selected GPU: {gpu_config['gpu_id']}")

# Create model manager
manager = UniversalModelManager(ALL_CONFIG["model"]["base_model_name"])

# Setup model
model, tokenizer = manager.setup_model(
    max_seq_length=MAX_SEQ_LENGTH,
    lora_rank=LORA_RANK,
    gpu_id=gpu_config['gpu_id'],
    auto_select_gpu=False
)

print("‚úÖ Model setup complete!")

## 4. Progressive Training: Double Integrator First

Start by training the base model on Double Integrator system.

In [None]:
# Phase 1: Train on Double Integrator
TRAIN_DI_PHASE = True  # Set to True to train DI phase

if TRAIN_DI_PHASE and 'double_integrator' in all_datasets:
    print("üöÄ Phase 1: Training on Double Integrator...")
    
    # Setup chat template for DI only first
    setup_universal_chat_template(
        manager, ["double_integrator"],
        ALL_CONFIG["system"]["reasoning_start"],
        ALL_CONFIG["system"]["reasoning_end"],
        ALL_CONFIG["system"]["solution_start"],
        ALL_CONFIG["system"]["solution_end"]
    )
    
    # Get DI data
    di_train = all_datasets['double_integrator']['train']
    di_eval = all_datasets['double_integrator']['eval']
    
    # === DI SFT Phase ===
    print("\n" + "="*60)
    print("üìö DOUBLE INTEGRATOR SFT TRAINING")
    print("="*60)
    
    sft_config = ALL_CONFIG["sft"].copy()
    sft_config["output_dir"] = "../temp_training/universal/di_sft"
    
    di_sft_result = train_sft_model(
        manager, di_train, di_eval, sft_config
    )
    
    di_sft_path = save_sft_model(
        manager, ["double_integrator"], di_sft_result["metrics"], "di_base_sft"
    )
    
    print(f"‚úÖ DI SFT model saved to: {di_sft_path}")
    
    # === DI GRPO Phase ===
    print("\n" + "="*60)
    print("üéÆ DOUBLE INTEGRATOR GRPO TRAINING")
    print("="*60)
    
    grpo_config = ALL_CONFIG["grpo"].copy()
    grpo_config["output_dir"] = "../temp_training/universal/di_grpo"
    
    di_grpo_result = train_grpo_model(
        manager, di_train, di_eval, grpo_config,
        ALL_CONFIG["system"]["reasoning_start"],
        ALL_CONFIG["system"]["reasoning_end"],
        ALL_CONFIG["system"]["solution_start"],
        ALL_CONFIG["system"]["solution_end"]
    )
    
    di_grpo_path = save_grpo_model(
        manager, ["double_integrator"], di_grpo_result["metrics"], "di_base_grpo"
    )
    
    print(f"‚úÖ DI GRPO model saved to: {di_grpo_path}")
    print("\nüéâ Phase 1 Complete: Double Integrator model ready!")
    
else:
    print("‚è≠Ô∏è  Skipping DI phase (set TRAIN_DI_PHASE=True and ensure DI dataset is loaded)")

## 5. Progressive Training: Extend to Van der Pol

Now extend the trained DI model to handle Van der Pol oscillator as well.

In [None]:
# Phase 2: Extend to Van der Pol
TRAIN_VDP_EXTENSION = True  # Set to True to train VDP extension

if TRAIN_VDP_EXTENSION and len(all_datasets) == 2:
    print("üöÄ Phase 2: Extending to Van der Pol...")
    
    # Setup universal chat template for both systems
    setup_universal_chat_template(
        manager, SYSTEMS,
        ALL_CONFIG["system"]["reasoning_start"],
        ALL_CONFIG["system"]["reasoning_end"],
        ALL_CONFIG["system"]["solution_start"],
        ALL_CONFIG["system"]["solution_end"]
    )
    
    # Combine datasets for universal training
    print("üîÑ Combining datasets...")
    
    # Mix training data from both systems
    combined_train = []
    combined_eval = []
    
    for system_name in SYSTEMS:
        combined_train.extend(all_datasets[system_name]['train'])
        combined_eval.extend(all_datasets[system_name]['eval'])
    
    # Shuffle the combined data
    import random
    random.shuffle(combined_train)
    random.shuffle(combined_eval)
    
    print(f"üìä Combined training data: {len(combined_train)} samples")
    print(f"üìä Combined eval data: {len(combined_eval)} samples")
    
    # === Universal SFT Phase ===
    print("\n" + "="*60)
    print("üìö UNIVERSAL SFT TRAINING (DI + VDP)")
    print("="*60)
    
    universal_sft_config = ALL_CONFIG["sft"].copy()
    universal_sft_config["output_dir"] = "../temp_training/universal/combined_sft"
    universal_sft_config["num_train_epochs"] = 2  # Less epochs since we're extending
    
    universal_sft_result = train_sft_model(
        manager, combined_train, combined_eval, universal_sft_config
    )
    
    universal_sft_path = save_sft_model(
        manager, SYSTEMS, universal_sft_result["metrics"], "universal_sft"
    )
    
    print(f"‚úÖ Universal SFT model saved to: {universal_sft_path}")
    
    # === Universal GRPO Phase ===
    print("\n" + "="*60)
    print("üéÆ UNIVERSAL GRPO TRAINING (DI + VDP)")
    print("="*60)
    
    universal_grpo_config = ALL_CONFIG["grpo"].copy()
    universal_grpo_config["output_dir"] = "../temp_training/universal/combined_grpo"
    universal_grpo_config["max_steps"] = 150  # More steps for universal training
    
    universal_grpo_result = train_grpo_model(
        manager, combined_train, combined_eval, universal_grpo_config,
        ALL_CONFIG["system"]["reasoning_start"],
        ALL_CONFIG["system"]["reasoning_end"],
        ALL_CONFIG["system"]["solution_start"],
        ALL_CONFIG["system"]["solution_end"]
    )
    
    universal_grpo_path = save_grpo_model(
        manager, SYSTEMS, universal_grpo_result["metrics"], "universal_grpo"
    )
    
    print(f"‚úÖ Universal GRPO model saved to: {universal_grpo_path}")
    print("\nüéâ Phase 2 Complete: Universal model ready!")
    
else:
    print("‚è≠Ô∏è  Skipping VDP extension (set TRAIN_VDP_EXTENSION=True and ensure both datasets are loaded)")

## 6. Universal Model Evaluation

Test the universal model on both systems to verify it can control both.

In [None]:
# Universal Model Evaluation
RUN_UNIVERSAL_EVALUATION = True  # Set to True to run evaluation

if RUN_UNIVERSAL_EVALUATION:
    print("üìä Starting Universal Model Evaluation...")
    
    # Load universal model for evaluation
    eval_manager = UniversalModelManager()
    
    try:
        universal_model, universal_tokenizer, universal_lora, universal_metadata = eval_manager.load_universal_model()
        
        print(f"‚úÖ Loaded universal model trained on: {universal_metadata.get('trained_systems', SYSTEMS)}")
        
        # Evaluation results storage
        evaluation_results = {}
        
        # Test on each system
        for system_name in SYSTEMS:
            print(f"\nüîç Testing universal model on {system_name.upper()}...")
            
            # Generate test cases
            system = get_system(system_name)()
            test_cases = []
            for _ in range(10):  # 10 test cases per system
                initial_state = system.generate_random_initial_state()
                test_cases.append(tuple(initial_state))
            
            # Run inference
            from vllm import SamplingParams
            sampling_params = SamplingParams(
                temperature=0.7,
                top_k=50,
                max_tokens=1024
            )
            
            results = run_batch_inference(
                universal_model, universal_tokenizer, system_name, test_cases,
                lora_request=universal_lora,
                sampling_params=sampling_params
            )
            
            # Compute metrics
            metrics = compute_batch_metrics(results)
            evaluation_results[system_name] = {
                'results': results,
                'metrics': metrics
            }
            
            print(f"   Success rate: {metrics['success_rate']:.2%}")
            print(f"   Mean performance: {metrics['mean_performance_score']:.4f}")
            
        print("\n‚úÖ Universal model evaluation complete!")
        
    except Exception as e:
        print(f"‚ùå Universal model evaluation failed: {e}")
        print("üí° Make sure the universal training phase completed successfully")
        
else:
    print("‚è≠Ô∏è  Skipping universal evaluation (set RUN_UNIVERSAL_EVALUATION=True to run)")

## 7. Comparative Analysis & Visualization

In [None]:
# Comparative Analysis
if RUN_UNIVERSAL_EVALUATION and 'evaluation_results' in locals():
    print("üìà Generating comparative analysis...")
    
    # Performance summary
    print("\n" + "="*70)
    print("üèÜ UNIVERSAL MODEL PERFORMANCE SUMMARY")
    print("="*70)
    
    for system_name in SYSTEMS:
        if system_name in evaluation_results:
            metrics = evaluation_results[system_name]['metrics']
            print(f"\nüìä {system_name.upper().replace('_', ' ')}:")
            print(f"   ‚úì Success Rate: {metrics['success_rate']:.1%}")
            print(f"   ‚ö° Mean Performance: {metrics['mean_performance_score']:.4f}")
            if 'mean_final_error' in metrics:
                print(f"   üéØ Mean Final Error: {metrics['mean_final_error']:.6f}")
    
    # Visualization for each system
    for system_name in SYSTEMS:
        if system_name in evaluation_results:
            print(f"\nüìà Generating plots for {system_name}...")
            
            results = evaluation_results[system_name]['results']
            
            # Plot trajectory comparison
            fig1 = plot_comparison(results)
            if fig1:
                plt.figure(fig1.number)
                plt.suptitle(f"Universal Model on {system_name.replace('_', ' ').title()}", fontsize=16)
                plt.tight_layout()
                plt.show()
            
            # Plot metrics comparison
            fig2 = plot_metrics_comparison(results)
            if fig2:
                plt.figure(fig2.number)
                plt.suptitle(f"{system_name.replace('_', ' ').title()} Performance Metrics", fontsize=16)
                plt.tight_layout()
                plt.show()
    
    print("\n‚úÖ Comparative analysis complete!")
    
else:
    print("‚è≠Ô∏è  No evaluation results to analyze")

## 8. Model Loading for Further Use

In [None]:
# Load trained models for interactive use
print("üîß Loading models for interactive use...")

# Load universal model
try:
    final_manager = UniversalModelManager()
    universal_model, universal_tokenizer, universal_lora, universal_metadata = final_manager.load_universal_model()
    
    print("‚úÖ Universal model loaded and ready for use!")
    print(f"   üìç Trained on systems: {universal_metadata.get('trained_systems', SYSTEMS)}")
    print(f"   üìÖ Training date: {universal_metadata.get('timestamp', 'Unknown')}")
    
    # Example usage
    print("\nüí° Example usage:")
    print("```python")
    print("# Generate a test case for double integrator")
    print("di_system = get_system('double_integrator')()")
    print("initial_state = di_system.generate_random_initial_state()")
    print("")
    print("# Run inference")
    print("results = run_batch_inference(")
    print("    universal_model, universal_tokenizer, 'double_integrator', [initial_state],")
    print("    lora_request=universal_lora")
    print(")")
    print("```")
    
except Exception as e:
    print(f"‚ùå Failed to load universal model: {e}")
    print("üí° Make sure the training completed successfully")

## 9. Summary

This notebook implements the complete universal control model training pipeline:

### üéØ **Training Phases:**
1. **Phase 1**: Train on Double Integrator (base knowledge)
2. **Phase 2**: Extend to Van der Pol (universal capability)

### üìä **Key Features:**
- **Progressive Learning**: Build from simple to complex systems
- **Knowledge Transfer**: Preserve DI knowledge while learning VDP
- **Universal Control**: Single model controls multiple systems
- **Cross-System Evaluation**: Test performance on both systems

### üèÜ **Model Outputs:**
- **DI Specialist**: `models/single_system/double_integrator/grpo/latest/`
- **Universal Model**: `models/universal/grpo/latest/`

### üî¨ **Research Benefits:**
- Compare specialist vs universal performance
- Study knowledge transfer in control tasks
- Evaluate generalization across system types
- Test scalability to additional systems

### üöÄ **Next Steps:**
- Add more control systems (pendulum, cartpole, etc.)
- Experiment with different training orders
- Test few-shot adaptation to new systems
- Implement continual learning strategies

The universal model is now ready for control research! üéõÔ∏èü§ñ