# Blocksworld PPO Benchmark - Pointer-Over-Heads Transformer

This notebook runs the Blocksworld PPO benchmark with PoT iterative refinement.

**Training modes:**
- `supervised`: Cross-entropy on good trajectories
- `ppo`: PPO with good/bad trajectory contrastive learning

**Augmentations:** Sub-trajectory extraction from training trajectories


In [None]:
# Clone repository and install dependencies
!git clone https://github.com/Eran-BA/PoT.git
%cd PoT
!pip install -q torch numpy tqdm datasets wandb

# Login to W&B (optional - for experiment tracking)
import wandb
wandb.login()


In [None]:
# Check GPU
import torch
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")


## 1. SimplePoT PPO WITH Augmentations


In [None]:
# SimplePoT PPO WITH sub-trajectory augmentations
!python experiments/blocksworld_ppo_benchmark.py \
    --mode ppo \
    --epochs 5 \
    --batch-size 32 \
    --max-blocks 6 \
    --model-type simple \
    --R 4 \
    --max-depth 32 \
    --d-model 128 \
    --n-heads 4 \
    --n-layers 2 \
    --controller-type transformer \
    --good-bad-ratio 1.0 \
    --eval-interval 1 \
    --wandb \
    --project blocksworld-ppo \
    --run-name simple-with-aug \
    --output-dir experiments/results/blocksworld_simple_aug


## 2. SimplePoT PPO WITHOUT Augmentations


In [None]:
# SimplePoT PPO WITHOUT augmentations
!python experiments/blocksworld_ppo_benchmark.py \
    --mode ppo \
    --epochs 5 \
    --batch-size 32 \
    --max-blocks 6 \
    --model-type simple \
    --R 4 \
    --max-depth 32 \
    --d-model 128 \
    --n-heads 4 \
    --n-layers 2 \
    --controller-type transformer \
    --no-augmentation \
    --good-bad-ratio 1.0 \
    --eval-interval 1 \
    --wandb \
    --project blocksworld-ppo \
    --run-name simple-no-aug \
    --output-dir experiments/results/blocksworld_simple_no_aug


## 3. HybridPoT PPO WITH Augmentations


In [None]:
# HybridPoT PPO WITH augmentations
!python experiments/blocksworld_ppo_benchmark.py \
    --mode ppo \
    --epochs 5 \
    --batch-size 32 \
    --max-blocks 6 \
    --model-type hybrid \
    --controller-type transformer \
    --d-ctrl 128 \
    --max-depth 128 \
    --d-model 256 \
    --n-heads 8 \
    --H-cycles 2 \
    --L-cycles 6 \
    --H-layers 2 \
    --L-layers 2 \
    --halt-max-steps 2 \
    --good-bad-ratio 1.0 \
    --eval-interval 1 \
    --wandb \
    --project blocksworld-ppo \
    --run-name hybrid-with-aug \
    --output-dir experiments/results/blocksworld_hybrid_aug


## 4. HybridPoT PPO WITHOUT Augmentations


In [None]:
# HybridPoT PPO WITHOUT augmentations
!python experiments/blocksworld_ppo_benchmark.py \
    --mode ppo \
    --epochs 5 \
    --batch-size 32 \
    --max-blocks 6 \
    --model-type hybrid \
    --controller-type transformer \
    --d-ctrl 128 \
    --max-depth 128 \
    --d-model 256 \
    --n-heads 8 \
    --H-cycles 2 \
    --L-cycles 6 \
    --H-layers 2 \
    --L-layers 2 \
    --halt-max-steps 2 \
    --no-augmentation \
    --good-bad-ratio 1.0 \
    --eval-interval 1 \
    --wandb \
    --project blocksworld-ppo \
    --run-name hybrid-no-aug \
    --output-dir experiments/results/blocksworld_hybrid_no_aug


## 5. Display Results


In [None]:
import json
from pathlib import Path

result_dirs = [
    ('SimplePoT + Aug', 'experiments/results/blocksworld_simple_aug'),
    ('SimplePoT - No Aug', 'experiments/results/blocksworld_simple_no_aug'),
    ('HybridPoT + Aug', 'experiments/results/blocksworld_hybrid_aug'),
    ('HybridPoT - No Aug', 'experiments/results/blocksworld_hybrid_no_aug'),
]

print(f"{'Model':<25} {'Slot Acc':<12} {'Exact Match':<12}")
print("-" * 50)

for name, d in result_dirs:
    results_file = Path(d) / 'results.json'
    if results_file.exists():
        with open(results_file) as f:
            results = json.load(f)
        slot_acc = results.get('final_metrics', {}).get('val_slot_acc', 0)
        exact = results.get('final_metrics', {}).get('val_exact_match', 0)
        print(f"{name:<25} {slot_acc:<12.2%} {exact:<12.2%}")
    else:
        print(f"{name:<25} (not run yet)")
