# Custom DQN Training for Distance Heuristic

This notebook tests the custom DQN implementation on a tiny gridworld environment.

## Goals
1. Collect training data (state pairs from planning graph)
2. Train distance heuristic using transparent custom DQN
3. Evaluate learned distances vs true distances
4. Debug any issues with full visibility into training

In [1]:
%load_ext autoreload
%autoreload 2

## Setup

In [36]:
import sys
import numpy as np
import torch
from pathlib import Path

# Add src to path if needed
sys.path.insert(0, str(Path.cwd().parent / "src"))

from tamp_improv.benchmarks.gridworld import GridworldTAMPSystem
from tamp_improv.approaches.improvisational.base import ImprovisationalTAMPApproach
from tamp_improv.approaches.improvisational.policies.multi_rl import MultiRLPolicy
from tamp_improv.approaches.improvisational.policies.rl import RLConfig
from tamp_improv.approaches.improvisational.collection import collect_total_shortcuts
from tamp_improv.approaches.improvisational.distance_heuristic_v3 import (
    DistanceHeuristicV3,
    DistanceHeuristicV3Config,
)
from tamp_improv.approaches.improvisational.graph_training import compute_graph_distances

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

PyTorch version: 2.9.0+cu128
CUDA available: False
Using device: cpu


## Configuration

## Step 1: Create Gridworld Environment

In [58]:
# Create tiny gridworld
system = GridworldTAMPSystem.create_default(
    num_cells=2,
    num_states_per_cell=5,
    num_teleporters=0,
    seed=42,
    max_episode_steps=200,
)

print(f"✓ Created GridworldTAMPSystem")
print(f"  System name: {system.name}")
print(f"  Action space: {system.env.action_space}")
print(f"  Observation space: {system.env.observation_space}")

✓ Created GridworldTAMPSystem
  System name: TAMPSystem
  Action space: Discrete(5)
  Observation space: Graph(Box(-inf, inf, (6,), float32), None)


## Step 2: Collect Training Data

This builds the planning graph and collects state pairs.

In [59]:
# Create approach for collection
rl_config = RLConfig(device=device)
policy = MultiRLPolicy(seed=42, config=rl_config)
approach = ImprovisationalTAMPApproach(system, policy, seed=42)
approach.training_mode = True

print("✓ Created approach")

✓ Created approach


In [60]:
# Collection config
config_dict = {
    "seed": 42,
    "collect_episodes": 5,
    "use_random_rollouts": True,
    "num_rollouts_per_node": 5,
    "max_steps_per_rollout": 100,
    "shortcut_success_threshold": 0.5,
}

# Collect data
print("\nCollecting training data...")
rng = np.random.default_rng(42)
training_data = collect_total_shortcuts(system, approach, config_dict, rng=rng)

print(f"\n✓ Collection complete!")
print(f"  Nodes with states: {len(training_data.node_states)}")
print(f"  Total shortcuts: {len(training_data.valid_shortcuts)}")
print(f"  Planning graph: {len(training_data.graph.nodes)} nodes, {len(training_data.graph.edges)} edges")


Collecting training data...

Collecting total planning graph from 5 episodes
Building total planning graph from 5 episodes...
Planning graph with 5 nodes and 5 edges
  Step 1/1: 0 -> 1
  Step 1/1: 0 -> 2
  Step 1/2: 0 -> 1
  Step 2/2: 1 -> 3
  Step 1/3: 0 -> 1
  Step 2/3: 1 -> 3
  Step 3/3: 3 -> 4
Planning graph with 5 nodes and 5 edges
  Step 1/1: 0 -> 1
  Step 1/1: 0 -> 2
  Step 1/2: 0 -> 1
  Step 2/2: 1 -> 3
  Step 1/3: 0 -> 1
  Step 2/3: 1 -> 3
  Step 3/3: 3 -> 4
Planning graph with 5 nodes and 5 edges
  Step 1/1: 0 -> 1
  Step 1/1: 0 -> 2
  Step 1/2: 0 -> 1
  Step 2/2: 1 -> 3
  Step 1/3: 0 -> 1
  Step 2/3: 1 -> 3
  Step 1/3: 0 -> 1
  Step 2/3: 1 -> 3
  Step 3/3: 3 -> 4
Planning graph with 5 nodes and 5 edges
  Step 1/1: 0 -> 1
  Step 1/1: 0 -> 2
  Step 1/2: 0 -> 1
  Step 2/2: 1 -> 3
  Step 1/2: 0 -> 1
  Step 2/2: 1 -> 3
  Step 1/2: 0 -> 1
  Step 2/2: 1 -> 3
  Step 1/3: 0 -> 1
  Step 2/3: 1 -> 3
  Step 3/3: 3 -> 4
Planning graph with 5 nodes and 5 edges
  Step 1/1: 0 -> 1
  Step 1

## Step 3: Prepare State Pairs for Training

In [61]:
# Extract state pairs from training data
planning_graph = training_data.graph
all_node_states = training_data.node_states

# Collect all (source_state, target_state) pairs
state_pairs = []
for source_node in planning_graph.nodes:
    for target_node in planning_graph.nodes:
        if source_node.id == target_node.id:
            continue
        
        if source_node.id not in all_node_states or target_node.id not in all_node_states:
            continue
        
        source_states = all_node_states[source_node.id]
        target_states = all_node_states[target_node.id]
        
        if not source_states or not target_states:
            continue
        
        # Take first state from each node (for simplicity)
        # Could also take all combinations: for s in source_states for t in target_states
        for i in range(len(source_states)):
            for j in range(len(target_states)):
                state_pairs.append((source_states[i], target_states[j]))

print(f"\n✓ Prepared {len(state_pairs)} state pairs for training")

# Sample a few pairs to inspect
print("\nSample state pairs:")
for i, (source, target) in enumerate(state_pairs[:3]):
    source_atoms = system.perceiver.step(source)
    target_atoms = system.perceiver.step(target)
    print(f"  Pair {i}: {len(source_atoms)} source atoms -> {len(target_atoms)} target atoms")



✓ Prepared 500 state pairs for training

Sample state pairs:
  Pair 0: 2 source atoms -> 2 target atoms
  Pair 1: 2 source atoms -> 2 target atoms
  Pair 2: 2 source atoms -> 2 target atoms


## Step 4: Train Distance Heuristic with Custom DQN

This is where the magic happens! The custom DQN will print detailed logs showing:
- Episode statistics (reward, length, success rate)
- Training metrics (loss, Q-values, TD errors)
- Replay buffer size
- Epsilon decay
- Periodic evaluation

In [62]:
# Create distance heuristic config with custom DQN enabled
heuristic_config = DistanceHeuristicV3Config(
    latent_dim=4,
    hidden_dims=[64, 64],
    learning_rate=5e-3,
    batch_size=128,  # Smaller batch for tiny dataset
    buffer_size=10000,  # Smaller buffer
    c_target=1,
    gamma=0.9,
    iters_per_epoch=1,
    device='cuda',
    repetition_factor=1,
    policy_temperature=1,
    learn_frequency=1
)

In [63]:
# Create and train heuristic
heuristic = DistanceHeuristicV3(config=heuristic_config, seed=42)

print("\nStarting training...\n")
heuristic.train(
    env=system.env,
    state_pairs=state_pairs,
    perceiver=system.perceiver,
    num_epochs=300,
    trajectories_per_epoch=10,
    max_episode_steps=50
)

print("\n✓ Training complete!")


Starting training...


Training distance heuristic V3 on 500 state pairs...
Device: cpu
State dimension: 12
Latent dimension: 4
Computed observation statistics: mean=2.323, std=0.694

Starting training for 300 epochs...
Collecting 10 trajectories per epoch
Using cosine annealing LR scheduler: 0.005 -> 5e-05

[Epoch 12/300]
  Buffer size: 130
  Learning rate: 0.004980
  --- Policy Performance ---
  Success rate: 0.00% (0/10)
  Avg trajectory length: 51.0
  --- Training Metrics ---
  Total loss: 5.0047
  Alignment loss: 0.0017
  Uniformity loss: 4.8368
  Accuracy: 53.12%
  L2 norm: 0.1661
  Lambda: 1.0000

[Epoch 13/300]
  Buffer size: 140
  Learning rate: 0.004977
  --- Policy Performance ---
  Success rate: 0.00% (0/10)
  Avg trajectory length: 51.0
  --- Training Metrics ---
  Total loss: 4.8449
  Alignment loss: 0.0023
  Uniformity loss: 4.8353
  Accuracy: 46.09%
  L2 norm: 0.0123
  Lambda: 0.9950

[Epoch 14/300]
  Buffer size: 150
  Learning rate: 0.004973
  --- Policy Performance 

In [None]:
heuristic.collect_trajectories(env, )

In [64]:
heuristic.replay_buffer.trajectories[-1]

[array([0., 5., 0., 1., 0., 0., 1., 9., 7., 1., 1., 1.], dtype=float32),
 array([0., 5., 1., 1., 0., 0., 1., 9., 7., 1., 1., 1.], dtype=float32),
 array([0., 5., 2., 1., 0., 0., 1., 9., 7., 1., 1., 1.], dtype=float32),
 array([0., 5., 3., 1., 0., 0., 1., 9., 7., 1., 1., 1.], dtype=float32),
 array([0., 5., 4., 1., 0., 0., 1., 9., 7., 1., 1., 1.], dtype=float32),
 array([0., 6., 4., 1., 0., 0., 1., 9., 7., 1., 1., 1.], dtype=float32),
 array([0., 7., 4., 1., 0., 0., 1., 9., 7., 1., 1., 1.], dtype=float32),
 array([0., 8., 4., 1., 0., 0., 1., 9., 7., 1., 1., 1.], dtype=float32),
 array([0., 8., 5., 1., 1., 0., 1., 9., 7., 1., 1., 1.], dtype=float32),
 array([0., 8., 5., 1., 1., 0., 1., 9., 7., 1., 1., 1.], dtype=float32),
 array([0., 8., 5., 1., 1., 0., 1., 9., 7., 1., 1., 1.], dtype=float32),
 array([0., 8., 6., 1., 1., 0., 1., 9., 7., 1., 1., 1.], dtype=float32),
 array([0., 8., 5., 1., 1., 0., 1., 9., 7., 1., 1., 1.], dtype=float32),
 array([0., 8., 5., 1., 1., 0., 1., 9., 7., 1., 1.,

In [65]:
heuristic.replay_buffer.trajectory_metadata[-1]

{'start_state': array([0., 5., 0., 1., 0., 0., 1., 9., 7., 1., 1., 1.], dtype=float32),
 'goal_state': array([0., 9., 6., 1., 1., 0., 1., 9., 6., 1., 1., 1.], dtype=float32)}

In [47]:
encodings = []
for s, g in state_pairs:
    encodings.append(heuristic.s_encode(s))

In [48]:
print(np.array(encodings).std(axis=0))

[1.3767556 1.0953312 1.1957561 1.1943041]


## Step 5: Evaluate Distance Heuristic

Compare learned distances f(s,s') with:
1. True optimal distances (from rollouts)
2. Graph distances (from planning graph)

In [49]:
# Compute graph distances
print("Computing graph distances...")
graph_distances = compute_graph_distances(planning_graph, exclude_shortcuts=True)
print(f"✓ Computed {len(graph_distances)} pairwise distances")

Computing graph distances...
✓ Computed 46 pairwise distances


In [50]:
a, b = state_pairs[350]
d_sg = heuristic.latent_dist(a, b)
d_gg = heuristic.latent_dist(b, b)
print(np.sqrt(1 / (2 * np.log(heuristic.config.gamma)) * (d_gg**2 - d_sg**2)))

8.175434184641874


In [51]:
heuristic.estimate_distance(a, b)

16.35086836928375

In [52]:
heuristic._flatten_state(a)

array([10.,  0.], dtype=float32)

In [53]:
heuristic._flatten_state(b)

array([20., 10.], dtype=float32)

In [31]:
from tamp_improv.approaches.improvisational.analyze import compute_true_distance


compute_true_distance(system, a, system.perceiver.step(b))

0.0

In [54]:
# Evaluate on all state pairs
import random

print("\nEvaluating learned distances...")
results = []

sampled_state_pairs = [random.choice(state_pairs) for i in range(100)]
for i, (source_state, target_state) in enumerate(sampled_state_pairs):
    # Get learned distance
    learned_dist = heuristic.estimate_distance(source_state, target_state) / 2
    
    # Get true distance (from rollouts)
    target_atoms = system.perceiver.step(target_state)
    true_dist = compute_true_distance(system, source_state, target_atoms)
    
    # Get graph distance (approximation)
    # Find which nodes these states belong to
    source_atoms = system.perceiver.step(source_state)
    source_node_id = None
    target_node_id = None
    
    for node in planning_graph.nodes:
        if node.atoms == source_atoms:
            source_node_id = node.id
        if node.atoms == target_atoms:
            target_node_id = node.id
    
    if source_node_id is not None and target_node_id is not None:
        graph_dist = graph_distances.get((source_node_id, target_node_id), float('inf'))
    else:
        graph_dist = float('inf')
    
    results.append({
        'source_idx': i,
        'learned_distance': learned_dist,
        'true_distance': true_dist,
        'graph_distance': graph_dist,
    })
    
    if (i + 1) % 10 == 0:
        print(f"  Evaluated {i + 1}/{len(state_pairs)} pairs...")

print(f"\n✓ Evaluated {len(results)} state pairs")


Evaluating learned distances...
  Evaluated 10/2072 pairs...
  Evaluated 20/2072 pairs...
  Evaluated 30/2072 pairs...
  Evaluated 40/2072 pairs...
  Evaluated 50/2072 pairs...
  Evaluated 60/2072 pairs...
  Evaluated 70/2072 pairs...
  Evaluated 80/2072 pairs...
  Evaluated 90/2072 pairs...
  Evaluated 100/2072 pairs...

✓ Evaluated 100 state pairs


## Step 6: Analyze Results

In [55]:
# Filter finite results
finite_results = [r for r in results if r['graph_distance'] != float('inf')]
infinite_results = [r for r in results if r['graph_distance'] == float('inf')]

print(f"\nResults Summary:")
print(f"  Total pairs: {len(results)}")
print(f"  Finite distance pairs: {len(finite_results)}")
print(f"  Infinite distance pairs: {len(infinite_results)}")


Results Summary:
  Total pairs: 100
  Finite distance pairs: 42
  Infinite distance pairs: 58


In [56]:
# Compute statistics for finite distance pairs
if finite_results:
    true_dists = np.array([r['true_distance'] for r in finite_results])
    learned_dists = np.array([r['learned_distance'] for r in finite_results])
    graph_dists = np.array([r['graph_distance'] for r in finite_results])
    
    # vs True distance
    mae_true = np.mean(np.abs(true_dists - learned_dists))
    rmse_true = np.sqrt(np.mean((true_dists - learned_dists) ** 2))
    correlation_true = np.corrcoef(true_dists, learned_dists)[0, 1]
    
    # vs Graph distance
    mae_graph = np.mean(np.abs(graph_dists - learned_dists))
    correlation_graph = np.corrcoef(graph_dists, learned_dists)[0, 1]
    
    print(f"\nStatistics (vs True Distance):")
    print(f"  MAE: {mae_true:.2f}")
    print(f"  RMSE: {rmse_true:.2f}")
    print(f"  Correlation: {correlation_true:.3f}")
    
    print(f"\nStatistics (vs Graph Distance):")
    print(f"  MAE: {mae_graph:.2f}")
    print(f"  Correlation: {correlation_graph:.3f}")
    
    print(f"\nDistance Ranges:")
    print(f"  True:    [{true_dists.min():.1f}, {true_dists.max():.1f}]")
    print(f"  Graph:   [{graph_dists.min():.1f}, {graph_dists.max():.1f}]")
    print(f"  Learned: [{learned_dists.min():.1f}, {learned_dists.max():.1f}]")
else:
    print("\nNo finite distance pairs to analyze!")


Statistics (vs True Distance):
  MAE: 10.60
  RMSE: 13.90
  Correlation: 0.692

Statistics (vs Graph Distance):
  MAE: 9.95
  Correlation: 0.636

Distance Ranges:
  True:    [1.0, 47.0]
  Graph:   [2.6, 37.0]
  Learned: [1.0, 12.2]


In [57]:
# Show sample comparisons
if finite_results:
    print("\nSample Comparisons (sorted by true distance):")
    print(f"{'Idx':>4} | {'True':>6} | {'Graph':>6} | {'Learned':>8} | {'Error':>6}")
    print("-" * 45)
    
    sorted_results = sorted(finite_results, key=lambda r: r['true_distance'])
    for r in sorted_results:  # Show first 20
        error = abs(r['true_distance'] - r['learned_distance'])
        print(
            f"{r['source_idx']:>4} | "
            f"{r['true_distance']:>6.1f} | "
            f"{r['graph_distance']:>6.1f} | "
            f"{r['learned_distance']:>8.1f} | "
            f"{error:>6.1f}"
        )


Sample Comparisons (sorted by true distance):
 Idx |   True |  Graph |  Learned |  Error
---------------------------------------------
  46 |    1.0 |    3.8 |      1.0 |    0.0
  39 |    3.0 |    2.6 |      2.9 |    0.1
  95 |    4.0 |    6.2 |      5.6 |    1.6
   3 |    7.0 |    8.0 |      5.3 |    1.7
   7 |    7.0 |    6.4 |      4.9 |    2.1
  35 |    7.0 |    8.0 |      2.3 |    4.7
  52 |    9.0 |    5.6 |      7.9 |    1.1
   6 |   10.0 |   10.0 |      5.8 |    4.2
  26 |   10.0 |   10.0 |      7.9 |    2.1
  34 |   10.0 |   10.0 |      7.5 |    2.5
  44 |   10.0 |   10.0 |      7.5 |    2.5
  50 |   10.0 |   10.0 |      5.8 |    4.2
  69 |   11.0 |   12.6 |      6.4 |    4.6
  71 |   11.0 |   12.6 |      7.5 |    3.5
  81 |   11.0 |   16.4 |      7.6 |    3.4
  92 |   11.0 |   12.6 |      7.5 |    3.5
  49 |   13.0 |   16.8 |      7.9 |    5.1
  72 |   14.0 |   16.2 |      8.0 |    6.0
  97 |   14.0 |   16.8 |      7.9 |    6.1
  82 |   15.0 |   13.8 |      7.7 |    7.3
  83

In [220]:
sampled_state_pairs[68]

(GraphInstance(nodes=array([[0., 5., 1., 1., 0., 0.],
        [1., 5., 5., 1., 1., 1.]], dtype=float32), edges=None, edge_links=None),
 GraphInstance(nodes=array([[0., 5., 9., 1., 1., 0.],
        [1., 5., 9., 1., 1., 1.]], dtype=float32), edges=None, edge_links=None))

## Step 7: Visualize Results

In [None]:
import matplotlib.pyplot as plt

if finite_results:
    true_dists = np.array([r['true_distance'] for r in finite_results])
    learned_dists = np.array([r['learned_distance'] for r in finite_results])
    
    plt.figure(figsize=(10, 5))
    
    # Scatter plot
    plt.subplot(1, 2, 1)
    plt.scatter(true_dists, learned_dists, alpha=0.6)
    plt.plot([true_dists.min(), true_dists.max()], 
             [true_dists.min(), true_dists.max()], 
             'r--', label='Perfect correlation')
    plt.xlabel('True Distance')
    plt.ylabel('Learned Distance')
    plt.title(f'Learned vs True Distance\n(Correlation: {correlation_true:.3f})')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Error distribution
    plt.subplot(1, 2, 2)
    errors = true_dists - learned_dists
    plt.hist(errors, bins=20, alpha=0.7, edgecolor='black')
    plt.xlabel('Error (True - Learned)')
    plt.ylabel('Count')
    plt.title(f'Error Distribution\n(MAE: {mae_true:.2f}, RMSE: {rmse_true:.2f})')
    plt.axvline(0, color='r', linestyle='--', label='Zero error')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
else:
    print("No results to visualize!")

## Step 8: Save Heuristic (Optional)

In [None]:
# Save trained heuristic
save_path = Path.cwd() / "outputs" / "custom_dqn_test"
save_path.mkdir(parents=True, exist_ok=True)

heuristic.save(str(save_path / "distance_heuristic"))
print(f"\n✓ Saved heuristic to {save_path}")

## Next Steps

### If the heuristic is learning well:
- Increase training steps
- Try larger environments (more cells, more teleporters)
- Test on real benchmarks (obstacle2d, tower, etc.)

### If the heuristic is not learning:
- Check the training logs above - are Q-values decreasing?
- Check success rate - is it improving over time?
- Try different hyperparameters:
  - Increase `custom_dqn_her_k` (more hindsight goals)
  - Decrease `learning_starts` (start learning sooner)
  - Increase `buffer_size` (more experience)
  - Adjust `custom_dqn_target_update_freq`
- Check state pairs - are they feasible?
- Add more logging in custom_dqn.py to debug specific issues