# PLDM for MiniGrid - Demo Notebook

This notebook demonstrates all the functionality of the PLDM repository.

Contents:
1. Setup and Imports
2. Using the Python API directly
3. Training via command line
4. Evaluation via command line
5. Visualization via command line
6. Working with pre-trained models

## 1. Setup and Imports

In [None]:
# Make sure we're in the right directory
import os
import sys

# Add repo to path if needed
repo_path = os.path.dirname(os.path.abspath('.'))
if repo_path not in sys.path:
    sys.path.insert(0, '.')

print(f"Working directory: {os.getcwd()}")


In [None]:
# Core imports
import torch
import numpy as np
import matplotlib.pyplot as plt

# Import from our repo
from models import PLDM, FlexibleEncoder, Predictor
from utils import (
    make_env, 
    get_full_obs, 
    bfs_solve, 
    collect_trajectory,
    collect_dataset,
    TrajectoryDataset,
    vicreg_loss,
    CEMPlanner,
    make_custom_doorkey,
    bfs_solve_custom_env,
    generate_custom_configs
)

print("All imports successful!")


In [None]:
# Setup device
device = torch.device('mps' if torch.backends.mps.is_available() else 
                      'cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


## 2. Using the Python API Directly

This section shows how to use the modules directly in Python code.

### 2.1 Environment Utilities

In [None]:
# Create environment
env = make_env("MiniGrid-DoorKey-5x5-v0")
env.reset(seed=42)

# Get full observation
obs = get_full_obs(env)
print(f"Observation shape: {obs.shape}")

# Visualize
plt.figure(figsize=(4, 4))
plt.imshow(obs)
plt.title("MiniGrid-DoorKey-5x5-v0")
plt.axis('off')
plt.show()

env.close()


In [None]:
# Test BFS solver
actions = bfs_solve("MiniGrid-DoorKey-5x5-v0", seed=1)
print(f"BFS found solution with {len(actions)} actions: {actions}")


### 2.2 Data Collection

In [None]:
# Collect a single trajectory
traj = collect_trajectory("MiniGrid-DoorKey-5x5-v0", seed=1, actions=actions)
print(f"Trajectory length: {len(traj)} steps")
print(f"Keys in each step: {traj[0].keys()}")


In [None]:
# Collect a small dataset (for demo - use more trajectories for real training)
trajectories = collect_dataset(
    env_name="MiniGrid-DoorKey-5x5-v0",
    num_trajectories=50,  # Small for demo
    bfs_ratio=0.8
)


In [None]:
# Create PyTorch dataset
dataset = TrajectoryDataset(trajectories, sequence_length=8)
print(f"Dataset size: {len(dataset)} sequences")

# Get a sample
obs_seq, actions_seq, next_obs_seq = dataset[0]
print(f"Observation sequence shape: {obs_seq.shape}")
print(f"Actions sequence shape: {actions_seq.shape}")
print(f"Next observation sequence shape: {next_obs_seq.shape}")


### 2.3 Model Architecture

In [None]:
# Create PLDM model
model = PLDM(latent_dim=128, action_dim=7).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")


In [None]:
# Test forward pass
obs_batch = obs_seq.unsqueeze(0).to(device)  # Add batch dimension
actions_batch = actions_seq.unsqueeze(0).to(device)
next_obs_batch = next_obs_seq.unsqueeze(0).to(device)

z, z_next, z_next_pred = model(obs_batch, actions_batch, next_obs_batch)
print(f"Latent z shape: {z.shape}")
print(f"Predicted z_next shape: {z_next_pred.shape}")


In [None]:
# Test VICReg loss
total_loss, sim_loss, std_loss, cov_loss = vicreg_loss(z_next_pred, z_next)
print(f"Total loss: {total_loss.item():.4f}")
print(f"Similarity loss: {sim_loss.item():.4f}")
print(f"Std loss: {std_loss.item():.4f}")
print(f"Cov loss: {cov_loss.item():.4f}")


### 2.4 CEM Planner


In [None]:
# Create planner
planner = CEMPlanner(
    model, 
    action_dim=7, 
    horizon=10,
    num_iterations=5,
    num_samples=100,
    num_elites=20
)
print("Planner created successfully")


In [None]:
# Test planning (with random untrained model - just to verify it works)
start_obs = trajectories[0][0]['obs']
goal_obs = trajectories[0][-1]['obs']

start_tensor = torch.FloatTensor(start_obs).permute(2, 0, 1).unsqueeze(0) / 255.0
goal_tensor = torch.FloatTensor(goal_obs).permute(2, 0, 1).unsqueeze(0) / 255.0

with torch.no_grad():
    z_start = model.encode(start_tensor.to(device)).squeeze(0)
    z_goal = model.encode(goal_tensor.to(device)).squeeze(0)
    
    planned_actions = planner.plan(z_start, z_goal, verbose=True)

print(f"\nPlanned actions: {planned_actions.cpu().numpy()}")


### 2.5 Custom Environment

In [None]:
# Generate custom configurations
custom_configs = generate_custom_configs(exclude_standard=True)
print(f"Generated {len(custom_configs)} custom configurations")


In [None]:
# Create and visualize a custom environment
config = custom_configs[0]
print(f"Config: Key={config['key_pos']}, Door={config['door_pos']}, Goal={config['goal_pos']}")

custom_env = make_custom_doorkey(
    key_pos=config['key_pos'],
    door_pos=config['door_pos'],
    goal_pos=config['goal_pos'],
    agent_start=config['agent_start'],
    agent_dir=0
)
custom_env.reset()

custom_obs = get_full_obs(custom_env)
plt.figure(figsize=(4, 4))
plt.imshow(custom_obs)
plt.title(f"Custom DoorKey: K={config['key_pos']}, D={config['door_pos']}, G={config['goal_pos']}")
plt.axis('off')
plt.show()

# Test BFS solver on custom env
custom_actions = bfs_solve_custom_env(custom_env)
print(f"BFS solution: {len(custom_actions)} actions")

custom_env.close()


### 2.6 Environment Comparison (Generalization Environments)

In [None]:
# Visualize different MiniGrid environments used for generalization testing
import gymnasium as gym

env_names = [
    "MiniGrid-Empty-5x5-v0",
    "MiniGrid-Empty-Random-5x5-v0",
    "MiniGrid-Dynamic-Obstacles-5x5-v0",
    "MiniGrid-DoorKey-5x5-v0",
]

fig, axes = plt.subplots(1, 4, figsize=(16, 4))
for i, env_name in enumerate(env_names):
    env = gym.make(env_name, render_mode="rgb_array")
    env.reset()
    frame = env.render()
    env.close()
    
    axes[i].imshow(frame)
    short_name = env_name.replace("MiniGrid-", "").replace("-v0", "")
    axes[i].set_title(short_name, fontsize=12, fontweight='bold')
    axes[i].axis('off')

plt.suptitle('MiniGrid 5x5 Environments', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()


## 3. Training via Command Line

The following cells demonstrate how to train a model using the command line interface.

In [None]:
# View training help
!python train.py --help


In [None]:
# Train a model with minimal settings (quick demo)
# For real training, use more trajectories and epochs
!python train.py \
    --output_dir outputs/demo_run \
    --num_trajectories 100 \
    --epochs 10 \
    --batch_size 32 \
    --val_every 5


In [None]:
# Check what was saved
import os
for root, dirs, files in os.walk('outputs/demo_run'):
    level = root.replace('outputs/demo_run', '').count(os.sep)
    indent = ' ' * 2 * level
    print(f'{indent}{os.path.basename(root)}/')
    subindent = ' ' * 2 * (level + 1)
    for file in files:
        print(f'{subindent}{file}')


## 4. Evaluation via Command Line

The following cells demonstrate how to evaluate a trained model.

In [None]:
# View evaluation help
!python evaluate.py --help

In [None]:
# Basic evaluation
!python evaluate.py \
    --model_path outputs/demo_run/checkpoints/best_model.pt \
    --num_episodes 10 \
    --output_dir outputs/demo_run/evaluation
    

In [None]:
# Replanning frequency analysis
!python evaluate.py \
    --model_path outputs/demo_run/checkpoints/best_model.pt \
    --replan_analysis \
    --replan_values 1 3 6 \
    --num_episodes 10 \
    --output_dir outputs/demo_run/replan_analysis


In [None]:
# Test generalization to simpler environment
!python evaluate.py \
    --model_path outputs/demo_run/checkpoints/best_model.pt \
    --env_name MiniGrid-Empty-5x5-v0 \
    --num_episodes 10 \
    --output_dir outputs/demo_run/generalization


In [None]:
# Evaluate on custom DoorKey configurations
!python evaluate.py \
    --model_path outputs/demo_run/checkpoints/best_model.pt \
    --custom_configs \
    --num_trials_per_config 2 \
    --output_dir outputs/demo_run/custom_eval


## 5. Visualization via Command Line

The following cells demonstrate how to generate visualizations.

In [None]:
# View visualization help
!python visualize.py --help


In [None]:
# Generate training curves
!python visualize.py \
    --mode training_curves \
    --history_path outputs/demo_run/training_history.json \
    --output_dir outputs/demo_run/visualizations


In [None]:
# Generate latent space visualization
!python visualize.py \
    --mode latent_space \
    --model_path outputs/demo_run/checkpoints/best_model.pt \
    --num_trajectories 50 \
    --output_dir outputs/demo_run/visualizations


In [None]:
# Generate BFS trajectory visualizations
!python visualize.py \
    --mode bfs_trajectories \
    --seeds 1 2 \
    --output_dir outputs/demo_run/visualizations


In [None]:
# Generate planning episode visualization
!python visualize.py \
    --mode planning \
    --model_path outputs/demo_run/checkpoints/best_model.pt \
    --output_dir outputs/demo_run/visualizations


In [None]:
# Generate trajectory step-by-step visualization
!python visualize.py \
    --mode trajectory \
    --output_dir outputs/demo_run/visualizations


In [None]:
# Generate environment comparison visualization (shows different MiniGrid environments)
!python visualize.py \
    --mode env_comparison \
    --output_dir outputs/demo_run/visualizations


In [None]:
# Generate episode execution with distance trajectory plots
# This shows both the step-by-step images AND the distance-to-goal plot
!python visualize.py \
    --mode episode_distances \
    --model_path outputs/demo_run/checkpoints/best_model.pt \
    --output_dir outputs/demo_run/visualizations


In [None]:
# List all generated visualizations
import os
viz_dir = 'outputs/demo_run/visualizations'
if os.path.exists(viz_dir):
    print("Generated visualizations:")
    for f in os.listdir(viz_dir):
        print(f"  - {f}")
else:
    print("No visualizations directory found")


## 6. Working with Pre-trained Models

This section shows how to load and use a pre-trained model.

In [None]:
# Load a trained model
model_path = 'outputs/demo_run/checkpoints/best_model.pt'

# Create model and initialize encoder
loaded_model = PLDM(latent_dim=128, action_dim=7).to(device)
loaded_model.init_encoder_fc((3, 40, 40), device)  # Initialize FC layers

# Load weights
loaded_model.load_state_dict(torch.load(model_path, map_location=device))
loaded_model.eval()

print("Model loaded successfully!")


In [None]:
# Use the loaded model for planning
planner_loaded = CEMPlanner(
    loaded_model,
    action_dim=7,
    horizon=15,
    num_iterations=10,
    num_samples=200,
    num_elites=30
)

# Get start and goal from a trajectory
test_traj = collect_trajectory("MiniGrid-DoorKey-5x5-v0", seed=5, 
                                actions=bfs_solve("MiniGrid-DoorKey-5x5-v0", seed=5))

start_obs = test_traj[0]['obs']
goal_obs = test_traj[-1]['obs']

# Visualize start and goal
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].imshow(start_obs)
axes[0].set_title('Start State')
axes[0].axis('off')
axes[1].imshow(goal_obs)
axes[1].set_title('Goal State')
axes[1].axis('off')
plt.tight_layout()
plt.show()


In [None]:
# Plan with the loaded model
start_tensor = torch.FloatTensor(start_obs).permute(2, 0, 1).unsqueeze(0) / 255.0
goal_tensor = torch.FloatTensor(goal_obs).permute(2, 0, 1).unsqueeze(0) / 255.0

with torch.no_grad():
    z_start = loaded_model.encode(start_tensor.to(device)).squeeze(0)
    z_goal = loaded_model.encode(goal_tensor.to(device)).squeeze(0)
    
    print("Planning...")
    planned_actions = planner_loaded.plan(z_start, z_goal, verbose=True)

print(f"\nPlanned actions: {planned_actions.cpu().numpy()}")


In [None]:
# Execute planned actions in environment
env = make_env("MiniGrid-DoorKey-5x5-v0")
env.reset(seed=5)

execution_images = [get_full_obs(env).copy()]
action_names = ['Left', 'Right', 'Fwd', 'Pick', 'Drop', 'Tog', 'Done']

print("Executing planned actions:")
for i, action in enumerate(planned_actions.cpu().numpy()[:10]):  # First 10 actions
    obs, reward, done, truncated, _ = env.step(int(action))
    execution_images.append(get_full_obs(env).copy())
    print(f"  Step {i+1}: {action_names[action]}, reward={reward}, done={done}")
    if done:
        print("  Goal reached!")
        break

env.close()

# Visualize execution
n_show = min(len(execution_images), 8)
fig, axes = plt.subplots(1, n_show, figsize=(2*n_show, 2))
for i in range(n_show):
    axes[i].imshow(execution_images[i])
    axes[i].set_title(f'Step {i}')
    axes[i].axis('off')
plt.suptitle('Planned Action Execution')
plt.tight_layout()
plt.show()


## 7. Full Training Run (Optional)

Run this section for a more complete training run with better results.

In [None]:
# Full training run (takes longer but gives better results)
# Uncomment to run

# !python train.py \
#     --output_dir outputs/full_run \
#     --num_trajectories 1200 \
#     --epochs 100 \
#     --batch_size 64 \
#     --lr 3e-4


In [None]:
# Full evaluation after full training
# Uncomment to run

# !python evaluate.py \
#     --model_path outputs/full_run/checkpoints/best_model.pt \
#     --replan_analysis \
#     --replan_values 1 3 6 9 12 15 \
#     --num_episodes 50 \
#     --output_dir outputs/full_run/evaluation


In [None]:
# Generate all visualizations after full training
# Uncomment to run

# !python visualize.py \
#     --mode all \
#     --model_path outputs/full_run/checkpoints/best_model.pt \
#     --history_path outputs/full_run/training_history.json \
#     --output_dir outputs/full_run/visualizations


## Summary

This notebook demonstrated:

1. **Python API**: Direct usage of models, utils, planners, and custom environments
2. **Training**: `python train.py --output_dir <dir> [options]`
3. **Evaluation**: `python evaluate.py --model_path <path> [options]`
4. **Visualization**: `python visualize.py --mode <mode> [options]`
5. **Loading models**: How to load and use pre-trained models

Key arguments:
- `--output_dir`: Where to save outputs (models, results, visualizations)
- `--model_path`: Path to a saved model checkpoint
- `--history_path`: Path to training history JSON
- `--replan_analysis`: Run replanning frequency analysis
- `--custom_configs`: Evaluate on custom DoorKey configurations
