# Custom DQN Training for Distance Heuristic

This notebook tests the custom DQN implementation on a tiny gridworld environment.

## Goals
1. Collect training data (state pairs from planning graph)
2. Train distance heuristic using transparent custom DQN
3. Evaluate learned distances vs true distances
4. Debug any issues with full visibility into training

In [1]:
%load_ext autoreload
%autoreload 2

## Setup

In [180]:
import sys
import numpy as np
import torch
from pathlib import Path

# Add src to path if needed
sys.path.insert(0, str(Path.cwd().parent / "src"))

from tamp_improv.benchmarks.gridworld import GridworldTAMPSystem
from tamp_improv.approaches.improvisational.base import ImprovisationalTAMPApproach
from tamp_improv.approaches.improvisational.policies.multi_rl import MultiRLPolicy
from tamp_improv.approaches.improvisational.policies.rl import RLConfig
from tamp_improv.approaches.improvisational.collection import collect_total_shortcuts
from tamp_improv.approaches.improvisational.distance_heuristic_v2 import (
    GoalConditionedDistanceHeuristicV2,
    DistanceHeuristicConfig,
    compute_reward
)
from tamp_improv.approaches.improvisational.graph_training import compute_graph_distances

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

PyTorch version: 2.9.0+cu128
CUDA available: False
Using device: cpu


## Configuration

In [170]:
# Environment configuration
SEED = 42
NUM_CELLS = 2  # 2x2 grid of cells
NUM_STATES_PER_CELL = 5  # 5x5 states per cell
NUM_TELEPORTERS = 0  # 1 portal pair

# Collection configuration
COLLECT_EPISODES = 10  # Very small for testing

# Training configuration
TRAINING_STEPS = 500000  # Short training for testing
MAX_EPISODE_STEPS = 50  # Max steps per episode
LEARNING_STARTS = 500  # Start learning after 500 steps
LOG_FREQ = 500  # Log every 500 steps
EVAL_FREQ = 2500  # Evaluate every 2500 steps
TRAIN_FREQ = 1000 # update every 1000 steps
TARGET_UPDATE_FREQ=2000

print("Configuration:")
print(f"  Environment: {NUM_CELLS}x{NUM_CELLS} cells, {NUM_STATES_PER_CELL}x{NUM_STATES_PER_CELL} states/cell")
print(f"  Total grid size: {NUM_CELLS * NUM_STATES_PER_CELL}x{NUM_CELLS * NUM_STATES_PER_CELL}")
print(f"  Teleporters: {NUM_TELEPORTERS}")
print(f"  Collection: {COLLECT_EPISODES} episode(s)")
print(f"  Training: {TRAINING_STEPS} steps")

Configuration:
  Environment: 2x2 cells, 5x5 states/cell
  Total grid size: 10x10
  Teleporters: 0
  Collection: 10 episode(s)
  Training: 500000 steps


## Step 1: Create Gridworld Environment

In [166]:
# Create tiny gridworld
system = GridworldTAMPSystem.create_default(
    num_cells=NUM_CELLS,
    num_states_per_cell=NUM_STATES_PER_CELL,
    num_teleporters=NUM_TELEPORTERS,
    seed=SEED,
    max_episode_steps=200,
)

print(f"✓ Created GridworldTAMPSystem")
print(f"  System name: {system.name}")
print(f"  Action space: {system.env.action_space}")
print(f"  Observation space: {system.env.observation_space}")

✓ Created GridworldTAMPSystem
  System name: TAMPSystem
  Action space: Discrete(5)
  Observation space: Graph(Box(-inf, inf, (6,), float32), None)


## Step 2: Collect Training Data

This builds the planning graph and collects state pairs.

In [167]:
# Create approach for collection
rl_config = RLConfig(device=device)
policy = MultiRLPolicy(seed=SEED, config=rl_config)
approach = ImprovisationalTAMPApproach(system, policy, seed=SEED)
approach.training_mode = True

print("✓ Created approach")

✓ Created approach


In [168]:
# Collection config
config_dict = {
    "seed": SEED,
    "collect_episodes": COLLECT_EPISODES,
    "use_random_rollouts": True,
    "num_rollouts_per_node": 5,
    "max_steps_per_rollout": 100,
    "shortcut_success_threshold": 0.5,
}

# Collect data
print("\nCollecting training data...")
rng = np.random.default_rng(SEED)
training_data = collect_total_shortcuts(system, approach, config_dict, rng=rng)

print(f"\n✓ Collection complete!")
print(f"  Nodes with states: {len(training_data.node_states)}")
print(f"  Total shortcuts: {len(training_data.valid_shortcuts)}")
print(f"  Planning graph: {len(training_data.graph.nodes)} nodes, {len(training_data.graph.edges)} edges")


Collecting training data...

Collecting total planning graph from 10 episodes
Building total planning graph from 10 episodes...
Planning graph with 5 nodes and 5 edges
  Step 1/1: 0 -> 1
  Step 1/1: 0 -> 2
  Step 1/2: 0 -> 1
  Step 2/2: 1 -> 3
  Step 1/3: 0 -> 1
  Step 2/3: 1 -> 3
  Step 3/3: 3 -> 4
Planning graph with 5 nodes and 5 edges
  Step 1/1: 0 -> 1
  Step 1/1: 0 -> 2
  Step 1/2: 0 -> 1
  Step 2/2: 1 -> 3
  Step 1/3: 0 -> 1
  Step 2/3: 1 -> 3
  Step 3/3: 3 -> 4
Planning graph with 5 nodes and 5 edges
  Step 1/1: 0 -> 1
  Step 1/1: 0 -> 2
  Step 1/2: 0 -> 1
  Step 2/2: 1 -> 3
  Step 1/3: 0 -> 1
  Step 2/3: 1 -> 3
  Step 1/3: 0 -> 1
  Step 2/3: 1 -> 3
  Step 3/3: 3 -> 4
Planning graph with 5 nodes and 5 edges
  Step 1/1: 0 -> 1
  Step 1/1: 0 -> 2
  Step 1/2: 0 -> 1
  Step 2/2: 1 -> 3
  Step 1/2: 0 -> 1
  Step 2/2: 1 -> 3
  Step 1/2: 0 -> 1
  Step 2/2: 1 -> 3
  Step 1/3: 0 -> 1
  Step 2/3: 1 -> 3
  Step 3/3: 3 -> 4
Planning graph with 5 nodes and 5 edges
  Step 1/1: 0 -> 1
  Step

## Step 3: Prepare State Pairs for Training

In [169]:
# Extract state pairs from training data
planning_graph = training_data.graph
all_node_states = training_data.node_states

# Collect all (source_state, target_state) pairs
state_pairs = []
for source_node in planning_graph.nodes:
    for target_node in planning_graph.nodes:
        if source_node.id == target_node.id:
            continue
        
        if source_node.id not in all_node_states or target_node.id not in all_node_states:
            continue
        
        source_states = all_node_states[source_node.id]
        target_states = all_node_states[target_node.id]
        
        if not source_states or not target_states:
            continue
        
        # Take first state from each node (for simplicity)
        # Could also take all combinations: for s in source_states for t in target_states
        for i in range(len(source_states)):
            for j in range(len(target_states)):
                state_pairs.append((source_states[i], target_states[j]))

print(f"\n✓ Prepared {len(state_pairs)} state pairs for training")

# Sample a few pairs to inspect
print("\nSample state pairs:")
for i, (source, target) in enumerate(state_pairs[:3]):
    source_atoms = system.perceiver.step(source)
    target_atoms = system.perceiver.step(target)
    print(f"  Pair {i}: {len(source_atoms)} source atoms -> {len(target_atoms)} target atoms")



✓ Prepared 1920 state pairs for training

Sample state pairs:
  Pair 0: 2 source atoms -> 2 target atoms
  Pair 1: 2 source atoms -> 2 target atoms
  Pair 2: 2 source atoms -> 2 target atoms


## Step 4: Train Distance Heuristic with Custom DQN

This is where the magic happens! The custom DQN will print detailed logs showing:
- Episode statistics (reward, length, success rate)
- Training metrics (loss, Q-values, TD errors)
- Replay buffer size
- Epsilon decay
- Periodic evaluation

In [324]:
# Create distance heuristic config with custom DQN enabled
heuristic_config = DistanceHeuristicConfig(
    learning_rate=1e-3,
    batch_size=128,  # Smaller batch for tiny dataset
    buffer_size=50000,  # Smaller buffer
    max_episode_steps=MAX_EPISODE_STEPS,
    learning_starts=LEARNING_STARTS,
    device=device,
    # Custom DQN specific
    log_freq=LOG_FREQ,
    train_freq=TRAIN_FREQ,
    eval_freq=EVAL_FREQ,
    target_update_freq=TARGET_UPDATE_FREQ,
    her_k=4,
    her_strategy="future"
    
)

print("✓ Created heuristic config")
print(f"  Hidden sizes: {heuristic_config.hidden_sizes}")
print(f"  HER k: {heuristic_config.her_k}")

✓ Created heuristic config
  Hidden sizes: None
  HER k: 4


In [325]:
# Create and train heuristic
heuristic = GoalConditionedDistanceHeuristicV2(config=heuristic_config, seed=SEED)

print("\nStarting training...\n")
heuristic.train(
    env=system.env,
    state_pairs=state_pairs,
    perceiver=system.perceiver,
    max_training_steps=TRAINING_STEPS,
)

print("\n✓ Training complete!")


Starting training...


Training distance heuristic V2 on 1920 state pairs...
Device: cpu
Computed observation statistics: mean=2.305, std=0.707

Using V2 self-contained DQN implementation
Observation dim: 12
Goal dim: 12
Number of actions: 5

Starting training for 500000 steps...
[DEBUG REWARD STEP 1] reward=-1.0
[DEBUG REWARD STEP 2] reward=-1.0
[DEBUG REWARD STEP 3] reward=-1.0
[DEBUG REWARD STEP 4] reward=-1.0
[DEBUG REWARD STEP 5] reward=-1.0
[DEBUG REWARD STEP 6] reward=-1.0
[DEBUG REWARD STEP 7] reward=-1.0
[DEBUG REWARD STEP 8] reward=-1.0
[DEBUG REWARD STEP 9] reward=-1.0
[DEBUG REWARD STEP 10] reward=-1.0
[DEBUG EPISODE 1] length=50, total_reward=-50.0, success=False
[DEBUG EPISODE 2] length=1, total_reward=-1.0, success=False
[DEBUG EPISODE 3] length=50, total_reward=-50.0, success=False

[Step 1500]
  Epsilon: 0.858
  Replay buffer size: 1644
  Avg Q-value: 0.637
Loss: 0.8171060085296631

[Step 22000]
  Epsilon: 0.050
  Replay buffer size: 24180
  Avg Q-value: 0.097

[Step 

In [326]:
t = heuristic.replay_buffer.transitions[77]

In [327]:
print(np.linalg.norm(t['obs'] - t['desired_goal']))

5.8309517


In [328]:
t['reward']

-1.0

In [329]:
for i in range(100):
    t = heuristic.replay_buffer.transitions[-i]
    if np.linalg.norm(t['obs'] - t['desired_goal']) == 0:
        print(i)

3
55
56
58


In [330]:
compute_reward(t['achieved_goal'], t['desired_goal'])

array([-1.], dtype=float32)

In [331]:
print(np.mean([t['reward'] for t in heuristic.replay_buffer.transitions]))

-0.94774


In [332]:
a = t['achieved_goal'].reshape(1, -1)
d = t['desired_goal'].reshape(1, -1)

In [333]:
compute_reward(a, d)

array([-1.], dtype=float32)

In [334]:
heuristic.model.replay_buffer

AttributeError: 'GoalConditionedDistanceHeuristicV2' object has no attribute 'model'

## Step 5: Evaluate Distance Heuristic

Compare learned distances f(s,s') with:
1. True optimal distances (from rollouts)
2. Graph distances (from planning graph)

In [335]:
# Compute graph distances
print("Computing graph distances...")
graph_distances = compute_graph_distances(planning_graph, exclude_shortcuts=True)
print(f"✓ Computed {len(graph_distances)} pairwise distances")

Computing graph distances...
✓ Computed 14 pairwise distances


In [336]:
a, b = state_pairs[3]
heuristic.estimate_distance(a, b)

0.0

In [337]:
a, system.perceiver.step(b)

(GraphInstance(nodes=array([[0., 0., 0., 0., 0., 0.],
        [1., 6., 8., 1., 1., 1.]], dtype=float32), edges=None, edge_links=None),
 {(InCol0 robot0), (InRow1 robot0)})

In [319]:
compute_true_distance(system, a, system.perceiver.step(b))

5.0

In [320]:
# Evaluate on all state pairs
from tamp_improv.approaches.improvisational.analyze import compute_true_distance
import random

print("\nEvaluating learned distances...")
results = []

sampled_state_pairs = [random.choice(state_pairs) for i in range(100)]
for i, (source_state, target_state) in enumerate(sampled_state_pairs):
    # Get learned distance
    learned_dist = heuristic.estimate_distance(source_state, target_state)
    
    # Get true distance (from rollouts)
    target_atoms = system.perceiver.step(target_state)
    true_dist = compute_true_distance(system, source_state, target_atoms)
    
    # Get graph distance (approximation)
    # Find which nodes these states belong to
    source_atoms = system.perceiver.step(source_state)
    source_node_id = None
    target_node_id = None
    
    for node in planning_graph.nodes:
        if node.atoms == source_atoms:
            source_node_id = node.id
        if node.atoms == target_atoms:
            target_node_id = node.id
    
    if source_node_id is not None and target_node_id is not None:
        graph_dist = graph_distances.get((source_node_id, target_node_id), float('inf'))
    else:
        graph_dist = float('inf')
    
    results.append({
        'source_idx': i,
        'learned_distance': learned_dist,
        'true_distance': true_dist,
        'graph_distance': graph_dist,
    })
    
    if (i + 1) % 10 == 0:
        print(f"  Evaluated {i + 1}/{len(state_pairs)} pairs...")

print(f"\n✓ Evaluated {len(results)} state pairs")


Evaluating learned distances...
  Evaluated 10/1920 pairs...
  Evaluated 20/1920 pairs...
  Evaluated 30/1920 pairs...
  Evaluated 40/1920 pairs...
  Evaluated 50/1920 pairs...
  Evaluated 60/1920 pairs...
  Evaluated 70/1920 pairs...
  Evaluated 80/1920 pairs...
  Evaluated 90/1920 pairs...
  Evaluated 100/1920 pairs...

✓ Evaluated 100 state pairs


## Step 6: Analyze Results

In [321]:
# Filter finite results
finite_results = [r for r in results if r['graph_distance'] != float('inf')]
infinite_results = [r for r in results if r['graph_distance'] == float('inf')]

print(f"\nResults Summary:")
print(f"  Total pairs: {len(results)}")
print(f"  Finite distance pairs: {len(finite_results)}")
print(f"  Infinite distance pairs: {len(infinite_results)}")


Results Summary:
  Total pairs: 100
  Finite distance pairs: 51
  Infinite distance pairs: 49


In [322]:
# Compute statistics for finite distance pairs
if finite_results:
    true_dists = np.array([r['true_distance'] for r in finite_results])
    learned_dists = np.array([r['learned_distance'] for r in finite_results])
    graph_dists = np.array([r['graph_distance'] for r in finite_results])
    
    # vs True distance
    mae_true = np.mean(np.abs(true_dists - learned_dists))
    rmse_true = np.sqrt(np.mean((true_dists - learned_dists) ** 2))
    correlation_true = np.corrcoef(true_dists, learned_dists)[0, 1]
    
    # vs Graph distance
    mae_graph = np.mean(np.abs(graph_dists - learned_dists))
    correlation_graph = np.corrcoef(graph_dists, learned_dists)[0, 1]
    
    print(f"\nStatistics (vs True Distance):")
    print(f"  MAE: {mae_true:.2f}")
    print(f"  RMSE: {rmse_true:.2f}")
    print(f"  Correlation: {correlation_true:.3f}")
    
    print(f"\nStatistics (vs Graph Distance):")
    print(f"  MAE: {mae_graph:.2f}")
    print(f"  Correlation: {correlation_graph:.3f}")
    
    print(f"\nDistance Ranges:")
    print(f"  True:    [{true_dists.min():.1f}, {true_dists.max():.1f}]")
    print(f"  Graph:   [{graph_dists.min():.1f}, {graph_dists.max():.1f}]")
    print(f"  Learned: [{learned_dists.min():.1f}, {learned_dists.max():.1f}]")
else:
    print("\nNo finite distance pairs to analyze!")


Statistics (vs True Distance):
  MAE: 4.24
  RMSE: 5.38
  Correlation: 0.174

Statistics (vs Graph Distance):
  MAE: 5.76
  Correlation: 0.075

Distance Ranges:
  True:    [1.0, 13.0]
  Graph:   [2.5, 17.2]
  Learned: [1.4, 2.0]


In [323]:
# Show sample comparisons
if finite_results:
    print("\nSample Comparisons (sorted by true distance):")
    print(f"{'Idx':>4} | {'True':>6} | {'Graph':>6} | {'Learned':>8} | {'Error':>6}")
    print("-" * 45)
    
    sorted_results = sorted(finite_results, key=lambda r: r['true_distance'])
    for r in sorted_results[:20]:  # Show first 20
        error = abs(r['true_distance'] - r['learned_distance'])
        print(
            f"{r['source_idx']:>4} | "
            f"{r['true_distance']:>6.1f} | "
            f"{r['graph_distance']:>6.1f} | "
            f"{r['learned_distance']:>8.1f} | "
            f"{error:>6.1f}"
        )


Sample Comparisons (sorted by true distance):
 Idx |   True |  Graph |  Learned |  Error
---------------------------------------------
  15 |    1.0 |    4.0 |      1.7 |    0.7
  19 |    1.0 |    2.5 |      1.7 |    0.7
  39 |    1.0 |    2.5 |      1.7 |    0.7
  63 |    1.0 |    3.7 |      1.7 |    0.7
  85 |    1.0 |    2.5 |      1.8 |    0.8
  90 |    1.0 |    3.3 |      1.8 |    0.8
  94 |    1.0 |    2.5 |      1.8 |    0.8
  11 |    2.0 |    3.3 |      1.7 |    0.3
  29 |    2.0 |   13.2 |      1.7 |    0.3
  52 |    2.0 |    2.5 |      1.6 |    0.4
  54 |    2.0 |    3.3 |      1.6 |    0.4
  95 |    2.0 |    3.3 |      1.6 |    0.4
  28 |    3.0 |    3.7 |      1.6 |    1.4
  48 |    3.0 |    3.7 |      1.7 |    1.3
  53 |    3.0 |    3.7 |      1.5 |    1.5
  59 |    3.0 |    2.5 |      1.9 |    1.1
   9 |    4.0 |    2.5 |      1.7 |    2.3
  24 |    4.0 |    3.3 |      1.4 |    2.6
  30 |    4.0 |   13.2 |      1.7 |    2.3
  66 |    4.0 |    3.7 |      1.7 |    2.3


In [220]:
sampled_state_pairs[68]

(GraphInstance(nodes=array([[0., 5., 1., 1., 0., 0.],
        [1., 5., 5., 1., 1., 1.]], dtype=float32), edges=None, edge_links=None),
 GraphInstance(nodes=array([[0., 5., 9., 1., 1., 0.],
        [1., 5., 9., 1., 1., 1.]], dtype=float32), edges=None, edge_links=None))

## Step 7: Visualize Results

In [None]:
import matplotlib.pyplot as plt

if finite_results:
    true_dists = np.array([r['true_distance'] for r in finite_results])
    learned_dists = np.array([r['learned_distance'] for r in finite_results])
    
    plt.figure(figsize=(10, 5))
    
    # Scatter plot
    plt.subplot(1, 2, 1)
    plt.scatter(true_dists, learned_dists, alpha=0.6)
    plt.plot([true_dists.min(), true_dists.max()], 
             [true_dists.min(), true_dists.max()], 
             'r--', label='Perfect correlation')
    plt.xlabel('True Distance')
    plt.ylabel('Learned Distance')
    plt.title(f'Learned vs True Distance\n(Correlation: {correlation_true:.3f})')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Error distribution
    plt.subplot(1, 2, 2)
    errors = true_dists - learned_dists
    plt.hist(errors, bins=20, alpha=0.7, edgecolor='black')
    plt.xlabel('Error (True - Learned)')
    plt.ylabel('Count')
    plt.title(f'Error Distribution\n(MAE: {mae_true:.2f}, RMSE: {rmse_true:.2f})')
    plt.axvline(0, color='r', linestyle='--', label='Zero error')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
else:
    print("No results to visualize!")

## Step 8: Save Heuristic (Optional)

In [None]:
# Save trained heuristic
save_path = Path.cwd() / "outputs" / "custom_dqn_test"
save_path.mkdir(parents=True, exist_ok=True)

heuristic.save(str(save_path / "distance_heuristic"))
print(f"\n✓ Saved heuristic to {save_path}")

## Next Steps

### If the heuristic is learning well:
- Increase training steps
- Try larger environments (more cells, more teleporters)
- Test on real benchmarks (obstacle2d, tower, etc.)

### If the heuristic is not learning:
- Check the training logs above - are Q-values decreasing?
- Check success rate - is it improving over time?
- Try different hyperparameters:
  - Increase `custom_dqn_her_k` (more hindsight goals)
  - Decrease `learning_starts` (start learning sooner)
  - Increase `buffer_size` (more experience)
  - Adjust `custom_dqn_target_update_freq`
- Check state pairs - are they feasible?
- Add more logging in custom_dqn.py to debug specific issues