# üß† Enhanced World Model - Testing Notebook

This notebook demonstrates the improvements made to the world model.

**What's New:**
- ‚úÖ A2C training with proper advantages (GAE)
- ‚úÖ Improved MLP controllers with planning
- ‚úÖ Fixed gradient flow in memory model
- ‚úÖ Memory prediction loss
- ‚úÖ Debug output enabled

**Runtime:** Use GPU runtime for faster training!

## üì¶ Setup & Installation

In [None]:
# Check if GPU is available
import torch
import sys

# Enable immediate output flushing
import functools
print = functools.partial(print, flush=True)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    device = torch.device('cuda')
else:
    print("Using CPU (training will be slower)")
    device = torch.device('cpu')

print(f"Python version: {sys.version}")

In [None]:
# Install dependencies (uncomment if needed)
# !pip install gymnasium[classic-control] tensorboard pygame swig
# !pip install gymnasium[box2d]
# !pip install matplotlib opencv-python

print("Dependencies ready")

In [None]:
# Add src to Python path
import os
import sys

# Get current directory
current_dir = os.getcwd()
print(f"Current directory: {current_dir}")

# Add src to path if not already there
src_path = os.path.join(current_dir, 'src')
if src_path not in sys.path:
    sys.path.insert(0, src_path)
    print(f"Added {src_path} to Python path")

# Verify imports work
try:
    from WorldModel import WorldModel
    from train_a2c import train_a2c
    print("‚úÖ Successfully imported world model components")
except Exception as e:
    print(f"‚ùå Import failed: {e}")
    import traceback
    traceback.print_exc()

## ‚öôÔ∏è Configuration & Imports

In [None]:
import torch
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt
from collections import deque
import sys

# Enable immediate output flushing for debug prints
import functools
print = functools.partial(print, flush=True)

# Import world model components
from vision.VQ_VAE import VQ_VAE
from vision.Identity import Identity
from memory.TemporalTransformer import TemporalTransformer
from controller.DiscreteModelPredictiveController import DiscreteModelPredictiveController
from controller.ContinuousModelPredictiveController import ContinuousModelPredictiveController
from controller.ImprovedDiscreteController import ImprovedDiscreteController
from controller.ImprovedContinuousController import ImprovedContinuousController
from WorldModel import WorldModel
from train_a2c import train_a2c
from reward_predictor.LinearPredictor import LinearPredictorModel

print("‚úÖ All imports successful")
print("Debug output is ENABLED - you will see [DEBUG] messages")

In [None]:
# Configuration
ENV_NAME = "CartPole-v1"  # Change to "CarRacing-v3" for visual environment
NUM_ENVS = 4  # Parallel environments
MAX_EPOCHS = 50  # Increase to 200+ for full training
LEARNING_RATE = 3e-4
N_STEPS = 128  # Steps per A2C update
PLANNING_HORIZON = 5

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)

print("=" * 60)
print("CONFIGURATION")
print("=" * 60)
print(f"Environment: {ENV_NAME}")
print(f"Device: {device}")
print(f"Parallel envs: {NUM_ENVS}")
print(f"Max epochs: {MAX_EPOCHS}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Planning horizon: {PLANNING_HORIZON}")
print("=" * 60)

## üõ†Ô∏è Helper Functions

In [None]:
def create_world_model(env_name, num_envs, use_improved_controller=True):
    """Create a world model for the given environment."""
    print(f"\n{'='*60}")
    print(f"Creating World Model")
    print(f"{'='*60}")
    
    # Create environment to inspect spaces
    envs = gym.make_vec(env_name, num_envs=num_envs, render_mode='rgb_array')
    obs_space = envs.single_observation_space
    action_space = envs.single_action_space
    
    print(f"Environment: {env_name}")
    print(f"Observation space: {obs_space}")
    print(f"Action space: {action_space}")
    
    is_image_based = len(obs_space.shape) == 3
    
    # Configure vision model
    if is_image_based:
        obs_shape = obs_space.shape
        input_shape = (obs_shape[2], obs_shape[0], obs_shape[1])  # (C, H, W)
        vision_model = VQ_VAE
        vision_args = {"output_dim": input_shape[0], "embed_dim": 64}
        print(f"Vision: VQ_VAE (image-based)")
    else:
        input_shape = obs_space.shape
        vision_model = Identity
        vision_args = {"embed_dim": obs_space.shape[0]}
        print(f"Vision: Identity (state-based)")
    
    # Configure controller
    if isinstance(action_space, gym.spaces.Discrete):
        action_dim = action_space.n
        if use_improved_controller:
            controller_model = ImprovedDiscreteController
            controller_args = {
                "action_dim": action_dim,
                "use_planning": True,
                "planning_horizon": PLANNING_HORIZON
            }
            print(f"Controller: ImprovedDiscreteController (with planning)")
        else:
            controller_model = DiscreteModelPredictiveController
            controller_args = {"action_dim": action_dim}
            print(f"Controller: DiscreteModelPredictiveController (legacy)")
    else:
        action_dim = action_space.shape[0]
        if use_improved_controller:
            controller_model = ImprovedContinuousController
            controller_args = {
                "action_dim": action_dim,
                "use_planning": True,
                "planning_horizon": PLANNING_HORIZON
            }
            print(f"Controller: ImprovedContinuousController (with planning)")
        else:
            controller_model = ContinuousModelPredictiveController
            controller_args = {"action_dim": action_dim}
            print(f"Controller: ContinuousModelPredictiveController (legacy)")
    
    print(f"Action dimension: {action_dim}")
    
    # Configure memory
    memory_args = {
        "d_model": 128,
        "latent_dim": vision_args["embed_dim"],
        "action_dim": action_dim,
        "nhead": 8
    }
    print(f"Memory: TemporalTransformer (d_model={memory_args['d_model']})")
    
    # Create world model
    print("\nInitializing world model...")
    world_model = WorldModel(
        vision_model=vision_model,
        memory_model=TemporalTransformer,
        controller_model=controller_model,
        input_shape=input_shape,
        vision_args=vision_args,
        memory_args=memory_args,
        controller_args=controller_args,
    ).to(device)
    
    # Add reward predictor
    world_model.set_reward_predictor(LinearPredictorModel)
    print("Reward predictor: LinearPredictorModel")
    
    total_params = sum(p.numel() for p in world_model.parameters())
    trainable_params = sum(p.numel() for p in world_model.parameters() if p.requires_grad)
    
    print(f"\n{'='*60}")
    print(f"‚úÖ World Model Created Successfully")
    print(f"{'='*60}")
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"{'='*60}\n")
    
    return world_model, envs


def plot_training_results(rewards, title="Training Progress"):
    """Plot training results."""
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(rewards, alpha=0.7)
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.title(title)
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 2, 2)
    window = 10
    if len(rewards) >= window:
        smooth = np.convolve(rewards, np.ones(window)/window, mode='valid')
        plt.plot(smooth, linewidth=2)
        plt.xlabel('Episode')
        plt.ylabel('Average Reward')
        plt.title(f'Smoothed Progress (window={window})')
        plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print("\nüìä Training Statistics:")
    print(f"Mean Reward: {np.mean(rewards):.2f}")
    print(f"Max Reward: {np.max(rewards):.2f}")
    print(f"Min Reward: {np.min(rewards):.2f}")
    print(f"Std Reward: {np.std(rewards):.2f}")

print("‚úÖ Helper functions defined")

## üß™ Test 1: Quick Sanity Check

Verify the improved controller can forward pass and plan. **You should see DEBUG output here!**

In [None]:
print("\n" + "="*60)
print("SANITY CHECK - Testing Forward Pass")
print("="*60 + "\n")

print("Creating world model...")
model, envs = create_world_model(ENV_NAME, num_envs=1, use_improved_controller=True)

# Test forward pass
print("\nTesting forward pass...")
state, _ = envs.reset()
is_image_based = len(envs.single_observation_space.shape) == 3

print(f"State shape: {state.shape}")
print(f"Is image based: {is_image_based}")

if is_image_based:
    state_tensor = torch.from_numpy(state.transpose(0, 3, 1, 2)).float().to(device) / 255.0
else:
    state_tensor = torch.from_numpy(state).float().to(device)

print(f"State tensor shape: {state_tensor.shape}")
print("\nPerforming forward pass (watch for DEBUG output)...\n")

with torch.no_grad():
    output = model(state_tensor, action_space=envs.single_action_space,
                   is_image_based=is_image_based, return_losses=True)

print("\n" + "="*60)
print("‚úÖ Forward pass successful!")
print("="*60)
print(f"Action shape: {output['action'].shape}")
print(f"Action value: {output['action']}")
print(f"Value estimate: {output['value'].item():.4f}")
print(f"Log probability: {output['log_probs'].item():.4f}")
print(f"Reconstruction loss: {output['recon_loss'].mean().item():.4f}")
print(f"VQ loss: {output['vq_loss'].mean().item():.4f}")
print(f"Total loss: {output['total_loss'].item():.4f}")

# Test planning (if available)
if hasattr(model.controller, 'use_planning') and model.controller.use_planning:
    print(f"\n‚úÖ Planning enabled with horizon: {model.controller.planning_horizon}")
else:
    print(f"\n‚ö†Ô∏è Legacy controller (no planning)")

envs.close()
print("\n‚úÖ Sanity check passed!")

## üèãÔ∏è Test 2: Train with A2C (New System)

Train using the improved A2C algorithm with proper advantages.

**Note:** Debug output will be visible during training!

In [None]:
print("\n" + "="*60)
print("üöÄ TRAINING WITH A2C (NEW SYSTEM)")
print("="*60 + "\n")

# Create model with improved controller
model_new, envs_new = create_world_model(ENV_NAME, num_envs=NUM_ENVS, use_improved_controller=True)

# Create directory for checkpoints
import os
os.makedirs('./checkpoints', exist_ok=True)
print("Checkpoints directory: ./checkpoints/\n")

# Train with A2C
print(f"Starting A2C training for {MAX_EPOCHS} epochs...")
print("Watch for DEBUG output during training!\n")

try:
    train_a2c(
        model=model_new,
        envs=envs_new,
        max_epochs=MAX_EPOCHS,
        n_steps=N_STEPS,
        device=device,
        learning_rate=LEARNING_RATE,
        gamma=0.99,
        gae_lambda=0.95,
        value_coef=0.5,
        entropy_coef=0.01,
        memory_coef=0.1,
        max_grad_norm=0.5,
        use_tensorboard=False,  # Disable for notebook
        save_path='./checkpoints/',
        save_prefix='a2c_notebook'
    )
except KeyboardInterrupt:
    print("\n‚ö†Ô∏è Training interrupted by user")
finally:
    envs_new.close()

print("\n" + "="*60)
print("‚úÖ A2C Training complete!")
print("="*60)
print("üìÅ Checkpoints saved to ./checkpoints/")

## üìä Test 3: Evaluate Trained Model

Evaluate the trained model and visualize its performance.

In [None]:
print("\n" + "="*60)
print("üìä EVALUATING TRAINED MODEL")
print("="*60 + "\n")

# Load best model
model_eval, envs_eval = create_world_model(ENV_NAME, num_envs=1, use_improved_controller=True)

try:
    checkpoint_path = f'./checkpoints/a2c_notebook_{ENV_NAME}_best.pt'
    if os.path.exists(checkpoint_path):
        model_eval.load(checkpoint_path, 
                       obs_space=envs_eval.single_observation_space,
                       action_space=envs_eval.single_action_space)
        print(f"‚úÖ Loaded checkpoint: {checkpoint_path}\n")
    else:
        print(f"‚ö†Ô∏è No checkpoint found, using current model\n")
except Exception as e:
    print(f"‚ö†Ô∏è Could not load checkpoint: {e}\n")

model_eval.eval()

# Run evaluation episodes
num_eval_episodes = 10
eval_rewards = []
eval_lengths = []

is_image_based = len(envs_eval.single_observation_space.shape) == 3

print(f"Running {num_eval_episodes} evaluation episodes...\n")

for ep in range(num_eval_episodes):
    state, _ = envs_eval.reset()
    done = False
    total_reward = 0
    steps = 0
    
    while not done:
        # Prepare state
        if is_image_based:
            state_tensor = torch.from_numpy(state.transpose(0, 3, 1, 2)).float().to(device) / 255.0
        else:
            state_tensor = torch.from_numpy(state).float().to(device)
        
        # Get action (deterministic)
        with torch.no_grad():
            action = model_eval(state_tensor, 
                               action_space=envs_eval.single_action_space,
                               is_image_based=is_image_based)
        
        # Execute action
        if isinstance(envs_eval.single_action_space, gym.spaces.Discrete):
            action_np = action.cpu().numpy()
        else:
            action_np = action.cpu().numpy()
        
        state, reward, terminated, truncated, _ = envs_eval.step(action_np)
        done = terminated[0] or truncated[0]
        total_reward += reward[0]
        steps += 1
    
    eval_rewards.append(total_reward)
    eval_lengths.append(steps)
    print(f"Episode {ep+1}/{num_eval_episodes}: Reward = {total_reward:.2f}, Steps = {steps}")
    
    # Reset memory for next episode
    model_eval.reset_env_memory(0)

envs_eval.close()

# Print statistics
print("\n" + "="*60)
print("üìà EVALUATION RESULTS")
print("="*60)
print(f"Mean Reward: {np.mean(eval_rewards):.2f} ¬± {np.std(eval_rewards):.2f}")
print(f"Max Reward: {np.max(eval_rewards):.2f}")
print(f"Min Reward: {np.min(eval_rewards):.2f}")
print(f"Mean Episode Length: {np.mean(eval_lengths):.1f} ¬± {np.std(eval_lengths):.1f}")
print("="*60)

# Plot
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.bar(range(len(eval_rewards)), eval_rewards, alpha=0.7)
plt.axhline(np.mean(eval_rewards), color='r', linestyle='--', label='Mean')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Evaluation Episode Rewards')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.bar(range(len(eval_lengths)), eval_lengths, alpha=0.7, color='green')
plt.axhline(np.mean(eval_lengths), color='r', linestyle='--', label='Mean')
plt.xlabel('Episode')
plt.ylabel('Episode Length')
plt.title('Evaluation Episode Lengths')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## üîç Test 4: Inspect Model Architecture

Detailed inspection of the model components and their dimensions.

In [None]:
print("\n" + "="*60)
print("üîç MODEL ARCHITECTURE INSPECTION")
print("="*60 + "\n")

# Create a fresh model for inspection
model_inspect, envs_inspect = create_world_model(ENV_NAME, num_envs=1, use_improved_controller=True)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("\nüìê Parameter Counts by Component:")
print(f"Vision: {count_parameters(model_inspect.vision):,}")
print(f"Memory: {count_parameters(model_inspect.memory):,}")
print(f"Controller: {count_parameters(model_inspect.controller):,}")
if model_inspect.reward_predictor:
    print(f"Reward Predictor: {count_parameters(model_inspect.reward_predictor):,}")
print(f"\nTotal: {count_parameters(model_inspect):,}")

print("\nüèóÔ∏è Controller Architecture:")
print(model_inspect.controller)

print("\nüß† Memory Architecture:")
print(model_inspect.memory)

print("\nüëÅÔ∏è Vision Architecture:")
print(model_inspect.vision)

if model_inspect.reward_predictor:
    print("\nüéÅ Reward Predictor Architecture:")
    print(model_inspect.reward_predictor)

envs_inspect.close()
print("\n" + "="*60)

## üìù Summary

### What We Tested:
1. ‚úÖ Model creation and forward pass with DEBUG output
2. ‚úÖ A2C training with improved controller
3. ‚úÖ Evaluation and performance metrics
4. ‚úÖ Architecture inspection

### Key Improvements:
- **Controller**: Single linear layer ‚Üí Multi-layer MLP with planning
- **Training**: Simple REINFORCE ‚Üí A2C with GAE
- **Gradient Flow**: Detached memory ‚Üí Proper gradient flow
- **Losses**: Vision only ‚Üí Vision + Memory + Policy + Value
- **Debug Output**: Now properly visible in notebook!

### Expected Performance (CartPole-v1):
- **Goal**: Mean reward 400-500 (solves environment)

### Next Steps:
1. **Longer Training**: Increase `MAX_EPOCHS` to 200+ for full convergence
2. **Different Environments**: Try `CarRacing-v3` or other environments
3. **Hyperparameter Tuning**: Adjust learning rate, n_steps, planning horizon

---

**Happy Training! üöÄ**