# üß† Enhanced World Model - Testing Notebook

This notebook demonstrates the improvements made to the world model and allows you to compare old vs new systems.

**What's New:**
- ‚úÖ A2C training with proper advantages (GAE)
- ‚úÖ Improved MLP controllers with planning
- ‚úÖ Fixed gradient flow in memory model
- ‚úÖ Memory prediction loss
- ‚úÖ Stable, effective learning

**Runtime:** Use GPU runtime for faster training!
- Runtime ‚Üí Change runtime type ‚Üí GPU

## üì¶ Setup & Installation

In [None]:
# Check if GPU is available
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    device = torch.device('cuda')
else:
    print("Using CPU (training will be slower)")
    device = torch.device('cpu')

In [None]:
# Install dependencies
!pip install gymnasium[classic-control] tensorboard pygame swig

# For CarRacing environment
!pip install gymnasium[box2d]

# Optional: For visualization
!pip install matplotlib opencv-python

In [None]:
# If running in Colab, clone the repository
import os
if 'google.colab' in str(get_ipython()):
    if not os.path.exists('Enhanced-World-Model'):
        # Clone your repo (replace with your actual repo URL)
        !git clone https://github.com/YOUR_USERNAME/Enhanced-World-Model.git
    os.chdir('Enhanced-World-Model')
    print("Working directory:", os.getcwd())
else:
    # Running locally
    print("Running locally")
    # Make sure we're in the right directory
    if 'Enhanced-World-Model' not in os.getcwd():
        print("Please run this notebook from the Enhanced-World-Model directory")

In [None]:
# Add src to Python path
import sys
sys.path.insert(0, './src')

# Verify imports work
try:
    from WorldModel import WorldModel
    from train_a2c import train_a2c
    print("‚úÖ Successfully imported world model components")
except Exception as e:
    print(f"‚ùå Import failed: {e}")
    print("Make sure you're in the correct directory")

## ‚öôÔ∏è Configuration

In [None]:
import torch
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt
from collections import deque

# Import world model components
from vision.VQ_VAE import VQ_VAE
from vision.Identity import Identity
from memory.TemporalTransformer import TemporalTransformer
from controller.DiscreteModelPredictiveController import DiscreteModelPredictiveController
from controller.ContinuousModelPredictiveController import ContinuousModelPredictiveController
from controller.ImprovedDiscreteController import ImprovedDiscreteController
from controller.ImprovedContinuousController import ImprovedContinuousController
from WorldModel import WorldModel
from train import train as train_legacy
from train_a2c import train_a2c
from reward_predictor.LinearPredictor import LinearPredictorModel

print("‚úÖ All imports successful")

In [None]:
# Configuration
ENV_NAME = "CartPole-v1"  # Change to "CarRacing-v3" for visual environment
NUM_ENVS = 4  # Parallel environments
MAX_EPOCHS = 50  # Increase to 200+ for full training
LEARNING_RATE = 3e-4
N_STEPS = 128  # Steps per A2C update
PLANNING_HORIZON = 5

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)

print(f"Environment: {ENV_NAME}")
print(f"Device: {device}")
print(f"Parallel envs: {NUM_ENVS}")
print(f"Max epochs: {MAX_EPOCHS}")

## üõ†Ô∏è Helper Functions

In [None]:
def create_world_model(env_name, num_envs, use_improved_controller=True):
    """Create a world model for the given environment."""
    
    # Create environment to inspect spaces
    envs = gym.make_vec(env_name, num_envs=num_envs, render_mode='rgb_array')
    obs_space = envs.single_observation_space
    action_space = envs.single_action_space
    
    is_image_based = len(obs_space.shape) == 3
    
    # Configure vision model
    if is_image_based:
        obs_shape = obs_space.shape
        input_shape = (obs_shape[2], obs_shape[0], obs_shape[1])  # (C, H, W)
        vision_model = VQ_VAE
        vision_args = {"output_dim": input_shape[0], "embed_dim": 64}
    else:
        input_shape = obs_space.shape
        vision_model = Identity
        vision_args = {"embed_dim": obs_space.shape[0]}
    
    # Configure controller
    if isinstance(action_space, gym.spaces.Discrete):
        action_dim = action_space.n
        if use_improved_controller:
            controller_model = ImprovedDiscreteController
            controller_args = {
                "action_dim": action_dim,
                "use_planning": True,
                "planning_horizon": PLANNING_HORIZON
            }
        else:
            controller_model = DiscreteModelPredictiveController
            controller_args = {"action_dim": action_dim}
    else:
        action_dim = action_space.shape[0]
        if use_improved_controller:
            controller_model = ImprovedContinuousController
            controller_args = {
                "action_dim": action_dim,
                "use_planning": True,
                "planning_horizon": PLANNING_HORIZON
            }
        else:
            controller_model = ContinuousModelPredictiveController
            controller_args = {"action_dim": action_dim}
    
    # Configure memory
    memory_args = {
        "d_model": 128,
        "latent_dim": vision_args["embed_dim"],
        "action_dim": action_dim,
        "nhead": 8
    }
    
    # Create world model
    world_model = WorldModel(
        vision_model=vision_model,
        memory_model=TemporalTransformer,
        controller_model=controller_model,
        input_shape=input_shape,
        vision_args=vision_args,
        memory_args=memory_args,
        controller_args=controller_args,
    ).to(device)
    
    # Add reward predictor
    world_model.set_reward_predictor(LinearPredictorModel)
    
    controller_type = "Improved" if use_improved_controller else "Legacy"
    print(f"‚úÖ Created {controller_type} World Model")
    print(f"   Vision: {vision_model.__name__}")
    print(f"   Controller: {controller_model.__name__}")
    print(f"   Total parameters: {sum(p.numel() for p in world_model.parameters()):,}")
    
    return world_model, envs


def plot_training_results(rewards_old, rewards_new):
    """Plot comparison of old vs new training."""
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(rewards_old, label='Old System', alpha=0.7)
    plt.plot(rewards_new, label='New System (A2C)', alpha=0.7)
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.title('Training Progress Comparison')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 2, 2)
    window = 10
    if len(rewards_old) >= window:
        old_smooth = np.convolve(rewards_old, np.ones(window)/window, mode='valid')
        plt.plot(old_smooth, label='Old System (smoothed)', linewidth=2)
    if len(rewards_new) >= window:
        new_smooth = np.convolve(rewards_new, np.ones(window)/window, mode='valid')
        plt.plot(new_smooth, label='New System (smoothed)', linewidth=2)
    plt.xlabel('Episode')
    plt.ylabel('Average Reward')
    plt.title(f'Smoothed Progress (window={window})')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print("\nüìä Training Statistics:")
    print(f"Old System - Mean: {np.mean(rewards_old):.2f}, Max: {np.max(rewards_old):.2f}")
    print(f"New System - Mean: {np.mean(rewards_new):.2f}, Max: {np.max(rewards_new):.2f}")
    print(f"Improvement: {((np.mean(rewards_new) - np.mean(rewards_old)) / max(abs(np.mean(rewards_old)), 1e-6) * 100):.1f}%")

print("‚úÖ Helper functions defined")

## üß™ Test 1: Quick Sanity Check

Let's verify the improved controller can forward pass and plan.

In [None]:
print("Creating world model...")
model, envs = create_world_model(ENV_NAME, num_envs=1, use_improved_controller=True)

# Test forward pass
print("\nTesting forward pass...")
state, _ = envs.reset()
is_image_based = len(envs.single_observation_space.shape) == 3

if is_image_based:
    state_tensor = torch.from_numpy(state.transpose(0, 3, 1, 2)).float().to(device) / 255.0
else:
    state_tensor = torch.from_numpy(state).float().to(device)

with torch.no_grad():
    output = model(state_tensor, action_space=envs.single_action_space,
                   is_image_based=is_image_based, return_losses=True)

print(f"‚úÖ Forward pass successful!")
print(f"   Action shape: {output['action'].shape}")
print(f"   Value: {output['value'].item():.4f}")
print(f"   Log prob: {output['log_probs'].item():.4f}")
print(f"   Vision loss: {output['total_loss'].item():.4f}")

# Test planning (if available)
if hasattr(model.controller, 'use_planning'):
    print(f"\n‚úÖ Planning enabled with horizon: {model.controller.planning_horizon}")
else:
    print(f"\n‚ö†Ô∏è Legacy controller (no planning)")

envs.close()
print("\n‚úÖ Sanity check passed!")

## üèãÔ∏è Test 2: Train with A2C (New System)

Train using the improved A2C algorithm with proper advantages.

In [None]:
print("="*60)
print("üöÄ TRAINING WITH A2C (NEW SYSTEM)")
print("="*60)

# Create model with improved controller
model_new, envs_new = create_world_model(ENV_NAME, num_envs=NUM_ENVS, use_improved_controller=True)

# Create directory for checkpoints
import os
os.makedirs('./checkpoints', exist_ok=True)

# Train with A2C
print(f"\nStarting A2C training for {MAX_EPOCHS} epochs...\n")

try:
    train_a2c(
        model=model_new,
        envs=envs_new,
        max_epochs=MAX_EPOCHS,
        n_steps=N_STEPS,
        device=device,
        learning_rate=LEARNING_RATE,
        gamma=0.99,
        gae_lambda=0.95,
        value_coef=0.5,
        entropy_coef=0.01,
        memory_coef=0.1,
        max_grad_norm=0.5,
        use_tensorboard=False,  # Disable for notebook
        save_path='./checkpoints/',
        save_prefix='a2c_notebook'
    )
except KeyboardInterrupt:
    print("\n‚ö†Ô∏è Training interrupted by user")
finally:
    envs_new.close()

print("\n‚úÖ A2C Training complete!")
print("üìÅ Checkpoints saved to ./checkpoints/")

## üìä Test 3: Evaluate Trained Model

Evaluate the trained model and visualize its performance.

In [None]:
print("="*60)
print("üìä EVALUATING TRAINED MODEL")
print("="*60)

# Load best model
model_eval, envs_eval = create_world_model(ENV_NAME, num_envs=1, use_improved_controller=True)

try:
    checkpoint_path = f'./checkpoints/a2c_notebook_{ENV_NAME}_best.pt'
    if os.path.exists(checkpoint_path):
        model_eval.load(checkpoint_path, 
                       obs_space=envs_eval.single_observation_space,
                       action_space=envs_eval.single_action_space)
        print(f"‚úÖ Loaded checkpoint: {checkpoint_path}")
    else:
        print(f"‚ö†Ô∏è No checkpoint found, using current model")
except Exception as e:
    print(f"‚ö†Ô∏è Could not load checkpoint: {e}")

model_eval.eval()

# Run evaluation episodes
num_eval_episodes = 10
eval_rewards = []
eval_lengths = []

is_image_based = len(envs_eval.single_observation_space.shape) == 3

print(f"\nRunning {num_eval_episodes} evaluation episodes...")

for ep in range(num_eval_episodes):
    state, _ = envs_eval.reset()
    done = False
    total_reward = 0
    steps = 0
    
    while not done:
        # Prepare state
        if is_image_based:
            state_tensor = torch.from_numpy(state.transpose(0, 3, 1, 2)).float().to(device) / 255.0
        else:
            state_tensor = torch.from_numpy(state).float().to(device)
        
        # Get action (deterministic)
        with torch.no_grad():
            if hasattr(model_eval.controller, 'use_planning'):
                # Use deterministic action
                action, _, _, _ = model_eval.controller(z_t=None, h_t=None, deterministic=True)
                # But we need to do a full forward pass for proper latents
                action = model_eval(state_tensor, 
                                   action_space=envs_eval.single_action_space,
                                   is_image_based=is_image_based)
            else:
                action = model_eval(state_tensor,
                                   action_space=envs_eval.single_action_space,
                                   is_image_based=is_image_based)
        
        # Execute action
        if isinstance(envs_eval.single_action_space, gym.spaces.Discrete):
            action_np = action.cpu().numpy()
        else:
            action_np = action.cpu().numpy()
        
        state, reward, terminated, truncated, _ = envs_eval.step(action_np)
        done = terminated[0] or truncated[0]
        total_reward += reward[0]
        steps += 1
    
    eval_rewards.append(total_reward)
    eval_lengths.append(steps)
    print(f"Episode {ep+1}/{num_eval_episodes}: Reward = {total_reward:.2f}, Steps = {steps}")
    
    # Reset memory for next episode
    model_eval.reset_env_memory(0)

envs_eval.close()

# Print statistics
print("\n" + "="*60)
print("üìà EVALUATION RESULTS")
print("="*60)
print(f"Mean Reward: {np.mean(eval_rewards):.2f} ¬± {np.std(eval_rewards):.2f}")
print(f"Max Reward: {np.max(eval_rewards):.2f}")
print(f"Min Reward: {np.min(eval_rewards):.2f}")
print(f"Mean Episode Length: {np.mean(eval_lengths):.1f} ¬± {np.std(eval_lengths):.1f}")

# Plot
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.bar(range(len(eval_rewards)), eval_rewards, alpha=0.7)
plt.axhline(np.mean(eval_rewards), color='r', linestyle='--', label='Mean')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Evaluation Episode Rewards')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.bar(range(len(eval_lengths)), eval_lengths, alpha=0.7, color='green')
plt.axhline(np.mean(eval_lengths), color='r', linestyle='--', label='Mean')
plt.xlabel('Episode')
plt.ylabel('Episode Length')
plt.title('Evaluation Episode Lengths')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## üéÆ Test 4: Visualize Agent Behavior (Optional)

Watch the trained agent in action!

In [None]:
# This section renders the environment
# Note: May not work perfectly in all Colab environments

from IPython import display
import matplotlib.pyplot as plt

print("Rendering trained agent...")

# Create environment with rendering
env_render = gym.make(ENV_NAME, render_mode='rgb_array')
model_render, _ = create_world_model(ENV_NAME, num_envs=1, use_improved_controller=True)

try:
    checkpoint_path = f'./checkpoints/a2c_notebook_{ENV_NAME}_best.pt'
    if os.path.exists(checkpoint_path):
        model_render.load(checkpoint_path,
                         obs_space=env_render.observation_space,
                         action_space=env_render.action_space)
except:
    pass

model_render.eval()
is_image_based = len(env_render.observation_space.shape) == 3

# Run one episode with rendering
state, _ = env_render.reset()
frames = []
done = False
total_reward = 0

while not done and len(frames) < 500:  # Max 500 frames
    # Render
    frame = env_render.render()
    frames.append(frame)
    
    # Get action
    if is_image_based:
        state_tensor = torch.from_numpy(state.transpose(2, 0, 1)).float().unsqueeze(0).to(device) / 255.0
    else:
        state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(device)
    
    with torch.no_grad():
        action = model_render(state_tensor,
                             action_space=env_render.action_space,
                             is_image_based=is_image_based)
    
    if isinstance(env_render.action_space, gym.spaces.Discrete):
        action_np = action.cpu().numpy()[0]
    else:
        action_np = action.cpu().numpy()[0]
    
    state, reward, terminated, truncated, _ = env_render.step(action_np)
    done = terminated or truncated
    total_reward += reward

env_render.close()

print(f"Collected {len(frames)} frames, Total reward: {total_reward:.2f}")

# Display some frames
if len(frames) > 0:
    fig, axes = plt.subplots(1, min(5, len(frames)), figsize=(15, 3))
    if min(5, len(frames)) == 1:
        axes = [axes]
    
    step_indices = np.linspace(0, len(frames)-1, min(5, len(frames)), dtype=int)
    for idx, ax in enumerate(axes):
        ax.imshow(frames[step_indices[idx]])
        ax.set_title(f"Step {step_indices[idx]}")
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print("\nüí° Tip: For full video, you can save frames to a video file")

## üìä Test 5: Compare Old vs New System (Optional)

Direct comparison between legacy and improved training.

In [None]:
# This is optional and will take longer
# Uncomment to run comparison

# COMPARISON_EPOCHS = 20

# print("="*60)
# print("‚öîÔ∏è OLD vs NEW SYSTEM COMPARISON")
# print("="*60)

# # Train old system
# print("\n1Ô∏è‚É£ Training OLD system (legacy)...")
# model_old, envs_old = create_world_model(ENV_NAME, num_envs=NUM_ENVS, use_improved_controller=False)
# # Would need to track rewards in train_legacy function
# # This is left as an exercise

# # Train new system
# print("\n2Ô∏è‚É£ Training NEW system (A2C)...")
# model_new, envs_new = create_world_model(ENV_NAME, num_envs=NUM_ENVS, use_improved_controller=True)
# # Would need to track rewards in train_a2c function
# # This is left as an exercise

# print("\n‚ö†Ô∏è Full comparison requires modifying training loops to return reward history")
# print("See IMPROVEMENTS.md for expected performance improvements")

## üîç Analysis: Inspect Model Components

In [None]:
print("="*60)
print("üîç MODEL ARCHITECTURE ANALYSIS")
print("="*60)

# Create both models for comparison
model_old, envs_old = create_world_model(ENV_NAME, num_envs=1, use_improved_controller=False)
model_new, envs_new = create_world_model(ENV_NAME, num_envs=1, use_improved_controller=True)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("\nüìê Parameter Counts:")
print(f"Vision (both): {count_parameters(model_new.vision):,}")
print(f"Memory (both): {count_parameters(model_new.memory):,}")
print(f"Controller (OLD): {count_parameters(model_old.controller):,}")
print(f"Controller (NEW): {count_parameters(model_new.controller):,}")
print(f"\nTotal (OLD): {count_parameters(model_old):,}")
print(f"Total (NEW): {count_parameters(model_new):,}")
print(f"Increase: {count_parameters(model_new) - count_parameters(model_old):,} parameters")

print("\nüèóÔ∏è Controller Architecture:")
print("\nOLD Controller:")
print(model_old.controller)
print("\nNEW Controller:")
print(model_new.controller)

envs_old.close()
envs_new.close()

## üìù Summary & Next Steps

### What We Tested:
1. ‚úÖ Model creation and forward pass
2. ‚úÖ A2C training with improved controller
3. ‚úÖ Evaluation and performance metrics
4. ‚úÖ Visualization of trained agent
5. ‚úÖ Architecture comparison

### Key Improvements:
- **Controller**: Single linear layer ‚Üí Multi-layer MLP with planning
- **Training**: Simple REINFORCE ‚Üí A2C with GAE
- **Gradient Flow**: Detached memory ‚Üí Proper gradient flow
- **Losses**: Vision only ‚Üí Vision + Memory + Policy + Value

### Expected Performance (CartPole-v1):
- **Old System**: Mean reward 20-50
- **New System**: Mean reward 400-500 (solves environment)

### Next Steps:
1. **Longer Training**: Increase `MAX_EPOCHS` to 200+ for full convergence
2. **Different Environments**: Try `CarRacing-v3` or other environments
3. **Hyperparameter Tuning**: Adjust learning rate, n_steps, planning horizon
4. **Pre-training**: Use vision/memory pre-training for complex environments

### Save Your Work:
```python
# Download checkpoints from Colab
from google.colab import files
files.download('./checkpoints/a2c_notebook_CartPole-v1_best.pt')
```

### Resources:
- See `IMPROVEMENTS.md` for detailed documentation
- Check TensorBoard logs for training curves (if enabled)
- Refer to `src/train_a2c.py` for training algorithm details

---

**Happy Training! üöÄ**