# VQ-VAE World Model: Learning Rules from Pixels

This notebook runs the VQ-VAE world model experiments that learn to:
1. **Distinguish deterministic rules from stochastic chance** (without being told!)
2. **Extract interpretable rules** from the learned representations

## Key Insight

VQ-VAE with categorical transition predictions = **entropy learned from data**:
- Same (code, action) → same next_code consistently = **RULE** (low entropy)
- Same (code, action) → varied next_codes = **CHANCE** (high entropy)

## Experiments

1. **2048**: Should show ~14x higher entropy for positions that change (stochastic tile spawns)
2. **Othello**: Should show near-zero entropy everywhere (fully deterministic)
3. **Rule Extraction**: Visualize what the model learned as interpretable rules

---
## Setup

In [None]:
# Check GPU
!nvidia-smi -L || echo 'No GPU detected'
import torch
print(f'PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}')

In [None]:
# Mount Google Drive
import os, sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    REPO_DIR = '/content/drive/MyDrive/Colab_Notebooks/tg_smn'
    OUT_DIR = '/content/drive/MyDrive/Colab_Notebooks/vq_world_model_outputs'
else:
    REPO_DIR = os.getcwd()
    OUT_DIR = os.path.join(REPO_DIR, 'outputs')

os.makedirs(OUT_DIR, exist_ok=True)
print(f'REPO_DIR: {REPO_DIR}')
print(f'OUT_DIR: {OUT_DIR}')

In [None]:
# Clone or update repo
import pathlib

if IN_COLAB:
    parent = pathlib.Path(REPO_DIR).parent
    parent.mkdir(parents=True, exist_ok=True)
    %cd {parent}

    if not os.path.exists(REPO_DIR):
        !git clone https://github.com/RespectableGlioma/tg_smn.git
    else:
        %cd {REPO_DIR}
        !git pull

%cd {REPO_DIR}
!ls -la world_models/stoch_muzero/

In [None]:
# Install dependencies
!pip install -q tqdm numpy matplotlib

---
## Experiment 1: 2048 (Stochastic Game)

2048 has:
- **Deterministic** slide/merge mechanics (THE RULES)
- **Stochastic** tile spawns (THE CHANCE)

The model should learn BOTH - showing bimodal entropy distribution.

In [None]:
# Train VQ World Model on 2048
%cd {REPO_DIR}

!python -m world_models.stoch_muzero.train_vq_v2 \
    --game 2048 \
    --train_steps 20000 \
    --batch_size 32 \
    --codebook_size 512 \
    --n_trajectories 2000

---
## Experiment 2: Othello (Deterministic Game)

Othello has:
- **Only deterministic** transitions
- No randomness at all

The model should show near-zero entropy EVERYWHERE.

In [None]:
# Train VQ World Model on Othello
%cd {REPO_DIR}

!python -m world_models.stoch_muzero.train_vq_v2 \
    --game othello \
    --train_steps 20000 \
    --batch_size 32 \
    --codebook_size 512 \
    --n_trajectories 2000

---
## Rule Extraction & Visualization

Now we extract and visualize what the models learned:
1. **Codebook Gallery**: What does each discrete code represent?
2. **Entropy Distribution**: How many rules vs chance transitions?
3. **Transition Graph**: Visual map of the learned dynamics
4. **Rule Summary**: List of discovered deterministic rules

In [None]:
# Setup for rule extraction
%cd {REPO_DIR}

import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# Import our modules
import sys
sys.path.insert(0, REPO_DIR)

from world_models.stoch_muzero.vq_model_v2 import VQWorldModel, VQWorldModelConfig
from world_models.stoch_muzero.rule_extraction import RuleExtractor, analyze_model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
# Generate fresh data for analysis
from world_models.stoch_muzero.train_vq_v2 import generate_trajectories

print("Generating 2048 data for analysis...")
obs_2048, actions_2048 = generate_trajectories('2048', n_trajectories=500, max_steps=30, img_size=64)
obs_2048_t = torch.from_numpy(obs_2048).to(device)
actions_2048_t = torch.from_numpy(actions_2048).to(device)
print(f"  2048: {obs_2048.shape[0]} trajectories, {obs_2048.shape[1]-1} steps")

print("\nGenerating Othello data for analysis...")
obs_othello, actions_othello = generate_trajectories('othello', n_trajectories=500, max_steps=30, img_size=64)
obs_othello_t = torch.from_numpy(obs_othello).to(device)
actions_othello_t = torch.from_numpy(actions_othello).to(device)
print(f"  Othello: {obs_othello.shape[0]} trajectories, {obs_othello.shape[1]-1} steps")

In [None]:
# Train and analyze 2048 model
print("="*60)
print("TRAINING AND ANALYZING 2048 MODEL")
print("="*60)

# Create and train model
cfg_2048 = VQWorldModelConfig(
    img_size=64,
    n_actions=4,
    codebook_size=512,
    code_dim=64,
    ema_decay=0.95,
    reset_threshold=2,
    reset_every=100,
)
model_2048 = VQWorldModel(cfg_2048).to(device)
optimizer = torch.optim.AdamW(model_2048.parameters(), lr=3e-4)

# Training
model_2048.train()
train_steps = 15000

pbar = tqdm(range(1, train_steps + 1), desc="Training 2048")
for step in pbar:
    idx = torch.randint(0, obs_2048_t.shape[0], (32,))
    obs_batch = obs_2048_t[idx]
    action_batch = actions_2048_t[idx]
    
    losses = model_2048.compute_loss(obs_batch, action_batch, unroll_steps=5)
    
    optimizer.zero_grad()
    losses['total_loss'].backward()
    torch.nn.utils.clip_grad_norm_(model_2048.parameters(), 1.0)
    optimizer.step()
    
    if step % 500 == 0:
        pbar.set_postfix({
            'loss': f"{losses['total_loss'].item():.3f}",
            'ent': f"{losses['entropy'].item():.2f}",
            'codes': f"{losses['unique_codes']}",
        })

print("\nTraining complete. Running rule extraction...")

In [None]:
# Analyze 2048 model
os.makedirs(f"{OUT_DIR}/rule_analysis_2048", exist_ok=True)

extractor_2048 = analyze_model(
    model_2048,
    obs_2048_t,
    actions_2048_t,
    device,
    game_name='2048',
    save_dir=f"{OUT_DIR}/rule_analysis_2048"
)

In [None]:
# Train and analyze Othello model
print("="*60)
print("TRAINING AND ANALYZING OTHELLO MODEL")
print("="*60)

# Create and train model
cfg_othello = VQWorldModelConfig(
    img_size=64,
    n_actions=64,
    codebook_size=512,
    code_dim=64,
    ema_decay=0.95,
    reset_threshold=2,
    reset_every=100,
)
model_othello = VQWorldModel(cfg_othello).to(device)
optimizer = torch.optim.AdamW(model_othello.parameters(), lr=3e-4)

# Training
model_othello.train()
train_steps = 15000

pbar = tqdm(range(1, train_steps + 1), desc="Training Othello")
for step in pbar:
    idx = torch.randint(0, obs_othello_t.shape[0], (32,))
    obs_batch = obs_othello_t[idx]
    action_batch = actions_othello_t[idx]
    
    losses = model_othello.compute_loss(obs_batch, action_batch, unroll_steps=5)
    
    optimizer.zero_grad()
    losses['total_loss'].backward()
    torch.nn.utils.clip_grad_norm_(model_othello.parameters(), 1.0)
    optimizer.step()
    
    if step % 500 == 0:
        pbar.set_postfix({
            'loss': f"{losses['total_loss'].item():.3f}",
            'ent': f"{losses['entropy'].item():.2f}",
            'codes': f"{losses['unique_codes']}",
        })

print("\nTraining complete. Running rule extraction...")

In [None]:
# Analyze Othello model
os.makedirs(f"{OUT_DIR}/rule_analysis_othello", exist_ok=True)

extractor_othello = analyze_model(
    model_othello,
    obs_othello_t,
    actions_othello_t,
    device,
    game_name='othello',
    save_dir=f"{OUT_DIR}/rule_analysis_othello"
)

---
## Side-by-Side Comparison

In [None]:
# Compare entropy distributions
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# 2048
ax = axes[0]
if extractor_2048.transition_entropy is not None:
    ent_2048 = []
    for code in range(cfg_2048.codebook_size):
        for action in range(cfg_2048.n_actions):
            count = extractor_2048.transition_counts[code, action].sum()
            if count >= 5:
                ent_2048.append(extractor_2048.transition_entropy[code, action])
    ent_2048 = np.array(ent_2048)
    
    ax.hist(ent_2048, bins=50, edgecolor='black', alpha=0.7, color='steelblue')
    ax.axvline(x=0.1, color='r', linestyle='--', label='Deterministic threshold')
    ax.set_xlabel('Entropy (bits)')
    ax.set_ylabel('Count')
    ax.set_title(f'2048: {(ent_2048 < 0.1).sum()}/{len(ent_2048)} deterministic ({100*(ent_2048 < 0.1).mean():.1f}%)')
    ax.legend()

# Othello
ax = axes[1]
if extractor_othello.transition_entropy is not None:
    ent_othello = []
    for code in range(cfg_othello.codebook_size):
        for action in range(cfg_othello.n_actions):
            count = extractor_othello.transition_counts[code, action].sum()
            if count >= 5:
                ent_othello.append(extractor_othello.transition_entropy[code, action])
    ent_othello = np.array(ent_othello)
    
    ax.hist(ent_othello, bins=50, edgecolor='black', alpha=0.7, color='coral')
    ax.axvline(x=0.1, color='r', linestyle='--', label='Deterministic threshold')
    ax.set_xlabel('Entropy (bits)')
    ax.set_ylabel('Count')
    ax.set_title(f'Othello: {(ent_othello < 0.1).sum()}/{len(ent_othello)} deterministic ({100*(ent_othello < 0.1).mean():.1f}%)')
    ax.legend()

plt.suptitle('Entropy Distribution Comparison: Rules vs Chance', fontsize=14)
plt.tight_layout()
plt.savefig(f'{OUT_DIR}/entropy_comparison.png', dpi=150)
plt.show()

print(f"\nSaved comparison to: {OUT_DIR}/entropy_comparison.png")

In [None]:
# Summary table
print("\n" + "="*70)
print("SUMMARY: LEARNED RULES VS CHANCE")
print("="*70)
print(f"{'Game':<15} {'Rules':<12} {'Chance':<12} {'% Rules':<12} {'Codebook':<15}")
print("-"*70)

# 2048
n_rules_2048 = len([r for r in extractor_2048.rules if r.is_deterministic])
n_chance_2048 = len([r for r in extractor_2048.rules if not r.is_deterministic])
pct_2048 = 100 * n_rules_2048 / max(1, n_rules_2048 + n_chance_2048)
codes_2048 = int((extractor_2048.codebook_usage > 0).sum()) if extractor_2048.codebook_usage is not None else 'N/A'
print(f"{'2048':<15} {n_rules_2048:<12} {n_chance_2048:<12} {pct_2048:<12.1f} {codes_2048}/{cfg_2048.codebook_size}")

# Othello
n_rules_othello = len([r for r in extractor_othello.rules if r.is_deterministic])
n_chance_othello = len([r for r in extractor_othello.rules if not r.is_deterministic])
pct_othello = 100 * n_rules_othello / max(1, n_rules_othello + n_chance_othello)
codes_othello = int((extractor_othello.codebook_usage > 0).sum()) if extractor_othello.codebook_usage is not None else 'N/A'
print(f"{'Othello':<15} {n_rules_othello:<12} {n_chance_othello:<12} {pct_othello:<12.1f} {codes_othello}/{cfg_othello.codebook_size}")

print("\n" + "="*70)
print("INTERPRETATION:")
print("-"*70)
print("• 2048 should have BOTH rules (slides) AND chance (tile spawns)")
print("• Othello should be ~100% rules (fully deterministic)")
print("• Higher codebook usage = richer representation")
print("="*70)

---
## Next Steps

If the experiments above show the expected patterns:
- **2048**: Mix of rules and chance
- **Othello**: Nearly 100% rules

Then we're ready for:

### 1. Planning Benchmark (VQ-MCTS)
- Implement tree search that only branches on high-entropy transitions
- Compare tree size vs standard MCTS

### 2. Transfer Experiment
- Train on 2048 variant A (one visual style)
- Fine-tune encoder only on variant B (different style)
- Test if dynamics transfer (they should!)

### 3. Connect to TG-SMN
- Rules = stable structure, compress well
- Randomness = noise, don't memorize

In [None]:
# Save trained models for later use
torch.save({
    'model_state': model_2048.state_dict(),
    'config': cfg_2048,
}, f'{OUT_DIR}/model_2048.pt')

torch.save({
    'model_state': model_othello.state_dict(),
    'config': cfg_othello,
}, f'{OUT_DIR}/model_othello.pt')

print(f"Models saved to {OUT_DIR}/")