In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))

if module_path not in sys.path:
    sys.path.append(module_path)
#from utils import remove_redundant_constraints

#from __future__ import annotations
import numpy as np
from multiprocessing import Pool, cpu_count

from minigrid.core.world_object import Goal, Wall, Lava
from minigrid.manual_control import ManualControl
from minigrid.minigrid_env import MiniGridEnv
from minigrid.core.mission import MissionSpace

# test_lavaworld_manual.py
from utils import (generate_lavaworld, 
                    generate_demos_from_policies_multi, 
                    constraints_from_demos_next_state_multi,
                    generate_trajectory_pools_multi,
                    generate_feedback_multi,
                    value_iteration_next_state_multi,
                    policy_evaluation_next_state_multi,
                    compute_successor_features_multi,
                    value_iteration_next_state,
                    policy_evaluation_next_state,
                    compute_successor_features_from_q_next_state,
                    l2_normalize,
                    enumerate_states,
                    constraints_from_demos_next_state,
                    constraints_from_demos_next_state_multi,
                    generate_state_action_demos,
                    trajectory_return)

from minigrid.manual_control import ManualControl

In [2]:
ACT_LEFT = 0
ACT_RIGHT = 1
ACT_FORWARD = 2
ACTIONS = [ACT_LEFT, ACT_RIGHT, ACT_FORWARD]
envs, mdps, meta = generate_lavaworld(
    n_envs=1,
    size=8,
    seed=420,
)

print("Generated", len(envs), "envs")

#env = envs[9]
#env.reset()

# print("Goal at:", meta["goals"][0])
# print("Lava cells:")
# print(meta["lava_masks"][0].astype(int))

# manual = ManualControl(env, seed=0)
# manual.start()

Generated 1 envs


In [3]:
mdps[0]['true_w']

array([-0.157952  , -0.019744  , -0.98720002, -0.009872  ])

### Testing VI, Policy Eval, Succesor Feature

In [4]:
# ------------------------------------------------------------
# Assumptions:
# - All functions are already defined in the notebook
# - generate_lavaworld is available
# ------------------------------------------------------------

print("Running DP + SF tests in notebook...")

# ------------------------------------------------------------
# Setup
# ------------------------------------------------------------

envs, mdps, meta = generate_lavaworld(
    n_envs=5,
    size=6,
    seed=123,
)
mdp = mdps[0]

gamma = 0.99

D = mdp["Phi"].shape[1]
rng = np.random.default_rng(0)
theta = l2_normalize(rng.normal(size=D))

S, A, _ = mdp["T"].shape

# ------------------------------------------------------------
# Test 1: Value Iteration shapes
# ------------------------------------------------------------

V, Q, pi = value_iteration_next_state(
    mdp=mdp,
    theta=theta,
    gamma=gamma,
)

assert V.shape == (S,)
assert Q.shape == (S, A)
assert pi.shape == (S,)

print("✓ Test 1 passed: value iteration shapes")

# ------------------------------------------------------------
# Test 2: Terminal states have zero value
# ------------------------------------------------------------

terminal = mdp["terminal"]
assert np.allclose(V[terminal], 0.0)

print("✓ Test 2 passed: terminal states have zero value")

# ------------------------------------------------------------
# Test 3: Policy evaluation matches value iteration
# ------------------------------------------------------------

V_pe = policy_evaluation_next_state(
    mdp=mdp,
    theta=theta,
    policy=pi,
    gamma=gamma,
)

assert np.allclose(V_pe, V, atol=1e-6)

print("✓ Test 3 passed: policy eval matches value iteration")

# ------------------------------------------------------------
# Test 4: Successor features reconstruct value
# ------------------------------------------------------------

Psi_sa, Psi_s = compute_successor_features_from_q_next_state(
    T=mdp["T"],
    Phi=mdp["Phi"],
    Q=Q,
    terminal_mask=mdp["terminal"],
    gamma=gamma,
)

V_hat = Psi_s @ theta
assert np.allclose(V_hat, V, atol=1e-5)

print("✓ Test 4 passed: successor features reconstruct V")

# ------------------------------------------------------------
# Test 5: Multi-env value iteration matches single-env
# ------------------------------------------------------------

V_list, Q_list, pi_list = value_iteration_next_state_multi(
    mdps=[mdp],
    theta=theta,
    gamma=gamma,
    n_jobs=1,   # IMPORTANT for notebooks
)

assert np.allclose(V_list[0], V)
assert np.allclose(Q_list[0], Q)
assert np.all(pi_list[0] == pi)

print("✓ Test 5 passed: multi-env value iteration")

# ------------------------------------------------------------
# Test 6: Multi-env policy evaluation matches single
# ------------------------------------------------------------

V_multi_pe = policy_evaluation_next_state_multi(
    mdps=[mdp],
    theta=theta,
    policy_list=[pi],
    gamma=gamma,
    n_jobs=1,
)[0]

assert np.allclose(V_multi_pe, V_pe)

print("✓ Test 6 passed: multi-env policy evaluation")

# ------------------------------------------------------------
# Test 7: Multi-env successor features
# ------------------------------------------------------------

Psi_sa_list, Psi_s_list = compute_successor_features_multi(
    mdps=[mdp],
    Q_list=[Q],
    gamma=gamma,
    n_jobs=1,
)

V_hat_multi = Psi_s_list[0] @ theta
assert np.allclose(V_hat_multi, V, atol=1e-5)

print("✓ Test 7 passed: multi-env successor features")

# ------------------------------------------------------------
# Done
# ------------------------------------------------------------

print("\nALL TESTS PASSED ✅")


Running DP + SF tests in notebook...
✓ Test 1 passed: value iteration shapes
✓ Test 2 passed: terminal states have zero value
✓ Test 3 passed: policy eval matches value iteration
✓ Test 4 passed: successor features reconstruct V
✓ Test 5 passed: multi-env value iteration
✓ Test 6 passed: multi-env policy evaluation
✓ Test 7 passed: multi-env successor features

ALL TESTS PASSED ✅


In [5]:
import numpy as np

print("Running demo + constraint pipeline tests...")

# ============================================================
# Setup environment and prerequisites
# ============================================================

envs, mdps, meta = generate_lavaworld(
    n_envs=5,
    size=6,
    seed=123,
)
mdp = mdps[0]

gamma = 0.99
S, A, _ = mdp["T"].shape
D = mdp["Phi"].shape[1]

rng = np.random.default_rng(0)
theta = l2_normalize(rng.normal(size=D))

# Compute optimal policy and successor features
V, Q, pi = value_iteration_next_state(
    mdp=mdp,
    theta=theta,
    gamma=gamma,
)

Psi_sa, Psi_s = compute_successor_features_from_q_next_state(
    T=mdp["T"],
    Phi=mdp["Phi"],
    Q=Q,
    terminal_mask=mdp["terminal"],
    gamma=gamma,
)

# ============================================================
# Test 1: generate_state_action_demos (single env)
# ============================================================

states = enumerate_states(mdp["size"], mdp["wall_mask"])

demos = generate_state_action_demos(
    states=states,
    pi=pi,
    terminal_mask=mdp["terminal"],
    idx_of=mdp["idx_of"],
)

# 1a. demos are valid indices
for s, a in demos:
    assert 0 <= s < S
    assert 0 <= a < A

# 1b. no terminal states
for s, _ in demos:
    assert not mdp["terminal"][s]

# 1c. action matches policy
for s, a in demos:
    assert a == pi[s]

# 1d. every nonterminal state in states appears
for s_tuple in states:
    i = mdp["idx_of"][s_tuple]
    if not mdp["terminal"][i]:
        assert (i, int(pi[i])) in demos

print("✓ Test 1 passed: generate_state_action_demos")

# ============================================================
# Test 2: _generate_demos_only_worker
# ============================================================

# demos_worker = _generate_demos_only_worker((mdp, pi))
# assert demos_worker == demos

# print("✓ Test 2 passed: _generate_demos_only_worker")

# ============================================================
# Test 3: generate_demos_from_policies_multi
# ============================================================

demos_list = generate_demos_from_policies_multi(
    mdps=[mdp],
    pi_list=[pi],
    n_jobs=1,   # notebook-safe
)

assert len(demos_list) == 1
assert demos_list[0] == demos

print("✓ Test 3 passed: generate_demos_from_policies_multi")

# ============================================================
# Test 4: constraints_from_demos_next_state (single env)
# ============================================================

constraints = constraints_from_demos_next_state(
    demos=demos,
    Psi_sa=Psi_sa,
    terminal_mask=mdp["terminal"],
    normalize=True,
)

# 4a. constraints exist
assert len(constraints) > 0

# 4b. correct shape
for c in constraints:
    assert c.shape == (D,)

# 4c. normalized
for c in constraints:
    assert np.isclose(np.linalg.norm(c), 1.0, atol=1e-6)

print("✓ Test 4 passed: constraints_from_demos_next_state")

# ============================================================
# Test 5: constraint semantic validity
# (ψ(s,a*) − ψ(s,a)) · θ ≥ 0
# ============================================================

for v in constraints:
    assert np.dot(v, theta) >= -1e-8

print("✓ Test 5 passed: constraint semantic correctness")

# ============================================================
# Test 6: constraints_from_demos_next_state_multi
# ============================================================

constraints_multi = constraints_from_demos_next_state_multi(
    demos_list=[demos],
    Psi_sa_list=[Psi_sa],
    terminal_mask_list=[mdp["terminal"]],
    normalize=True,
    n_jobs=1,
)

assert len(constraints_multi) == 1
assert len(constraints_multi[0]) == len(constraints)

for c1, c2 in zip(constraints_multi[0], constraints):
    assert np.allclose(c1, c2)

print("✓ Test 6 passed: constraints_from_demos_next_state_multi")

# ============================================================
# Final confirmation
# ============================================================

print("\nALL DEMO + CONSTRAINT TESTS PASSED ✅")


Running demo + constraint pipeline tests...
✓ Test 1 passed: generate_state_action_demos
✓ Test 3 passed: generate_demos_from_policies_multi
✓ Test 4 passed: constraints_from_demos_next_state
✓ Test 5 passed: constraint semantic correctness
✓ Test 6 passed: constraints_from_demos_next_state_multi

ALL DEMO + CONSTRAINT TESTS PASSED ✅


In [6]:
traj_pools = generate_trajectory_pools_multi(
    mdps=mdps,
    n_trajs_per_state=1,
    max_horizon=25,
    n_jobs=4,
)

pairwise_list, correction_list, estop_list = generate_feedback_multi(
    traj_pools=traj_pools,
    mdps=mdps,
    gamma=0.99,
    n_pairs=200,
    num_random_trajs=8,
    estop_beta=10.0,
    n_jobs=4,
)

In [7]:
import numpy as np

print("Running trajectory + feedback validity tests...")

# ============================================================
# Setup environment
# ============================================================

envs, mdps, meta = generate_lavaworld(
    n_envs=1,
    size=6,
    seed=123,
)
mdp = mdps[0]

gamma = 0.99
theta_true = mdp["true_w"]

# ============================================================
# Test 1: generate_trajectory_pools_multi
# ============================================================

traj_pools = generate_trajectory_pools_multi(
    mdps=[mdp],
    n_trajs_per_state=3,
    max_horizon=20,
    n_jobs=1,   # notebook-safe
)

assert len(traj_pools) == 1
trajectories = traj_pools[0]
assert len(trajectories) > 0

idx_of = mdp["idx_of"]

# --- structural checks ---
for traj in trajectories:
    if len(traj) == 0:
        continue

    # start state valid and non-terminal
    s0 = traj[0][0]
    assert s0 in idx_of
    assert not mdp["terminal"][idx_of[s0]]

    terminal_reached = False

    for (s, a, sp) in traj:
        assert s in idx_of
        assert sp in idx_of
        assert a in ACTIONS

        if terminal_reached:
            raise AssertionError("Transition after terminal state")

        if mdp["terminal"][idx_of[sp]]:
            terminal_reached = True

print("✓ Test 1 passed: trajectory pool structural validity")

# --- reward sanity ---
returns = [
    trajectory_return(traj, mdp, theta_true, gamma)
    for traj in trajectories
]

assert all(np.isfinite(r) for r in returns)

print("✓ Test 1b passed: trajectory returns finite")

# ============================================================
# Test 2: generate_feedback_multi
# ============================================================

pairwise_list, correction_list, estop_list = generate_feedback_multi(
    traj_pools=traj_pools,
    mdps=[mdp],
    gamma=gamma,
    n_pairs=50,
    num_random_trajs=5,
    estop_beta=5.0,
    n_jobs=1,
)

pairwise = pairwise_list[0]
corrections = correction_list[0]
estops = estop_list[0]

# ============================================================
# Test 2a: pairwise preference validity
# ============================================================

for tau_good, tau_bad in pairwise:
    R_good = trajectory_return(tau_good, mdp, theta_true, gamma)
    R_bad  = trajectory_return(tau_bad,  mdp, theta_true, gamma)

    assert R_good >= R_bad - 1e-8

print("✓ Test 2a passed: pairwise preferences reward-consistent")

# ============================================================
# Test 2b: correction feedback validity
# ============================================================

for tau_improved, tau_orig in corrections:
    # same start state
    assert tau_improved[0][0] == tau_orig[0][0]

    R_new = trajectory_return(tau_improved, mdp, theta_true, gamma)
    R_old = trajectory_return(tau_orig,     mdp, theta_true, gamma)

    assert R_new > R_old + 1e-8

print("✓ Test 2b passed: correction feedback improves reward")

# ============================================================
# Correct Test 2c: E-stop feedback validity
# Prefix reward >= full trajectory reward
# ============================================================

for traj, t_stop in estops:
    assert isinstance(t_stop, int)
    assert 0 <= t_stop < len(traj)

    # compute prefix reward up to stop
    prefix_reward = 0.0
    for i in range(t_stop + 1):
        _, _, sp = traj[i]
        prefix_reward += mdp["Phi"][mdp["idx_of"][sp]] @ theta_true

    # compute full reward
    full_reward = sum(
        mdp["Phi"][mdp["idx_of"][sp]] @ theta_true
        for (_, _, sp) in traj
    )

    # E-stop semantics:
    # prefix is preferred to full
    assert prefix_reward > full_reward - 1e-8, (
        "E-stop prefix reward is worse than full trajectory"
    )

print("✓ Test 2c passed: E-stop prefix reward ≥ full reward")


# ============================================================
# Done
# ============================================================

print("\nALL TRAJECTORY + FEEDBACK TESTS PASSED ✅")


Running trajectory + feedback validity tests...
✓ Test 1 passed: trajectory pool structural validity
✓ Test 1b passed: trajectory returns finite
✓ Test 2a passed: pairwise preferences reward-consistent
✓ Test 2b passed: correction feedback improves reward
✓ Test 2c passed: E-stop prefix reward ≥ full reward

ALL TRAJECTORY + FEEDBACK TESTS PASSED ✅


In [8]:
from utils import GenerationSpec, DemoSpec, FeedbackSpec
from utils import *

# --- you already have these ---
# mdps
# pi_list
# enumerate_states
# generate_trajectory_pools_multi
# generate_pairwise_preferences
# simulate_human_estop_one_mdp
# generate_correction_feedback

spec = GenerationSpec(
    seed=123,
    demo=DemoSpec(
        enabled=True,
        env_fraction=0.6,
        state_fraction=0.4,
    ),
    pairwise=FeedbackSpec(
        enabled=True,
        total_budget=10,
        alloc_method="dirichlet",
        alloc_params={"alpha": 0.3},
    ),
    estop=FeedbackSpec(
        enabled=True,
        total_budget=200,
        alloc_method="sparse_poisson",
        alloc_params={"p_active": 0.4, "mean": 400},
    ),
    improvement=FeedbackSpec(
        enabled=True,
        total_budget=300,
        alloc_method="dirichlet",
        alloc_params={"alpha": 0.5},
    ),
)

atoms_per_env = generate_candidate_atoms_for_scot_minigrid(
    mdps=mdps,
    pi_list=pi_list,
    spec=spec,
    enumerate_states=enumerate_states,
    generate_trajectory_pools_multi=generate_trajectory_pools_multi,
    pairwise_fn=generate_pairwise_preferences,
    estop_fn=simulate_human_estop_one_mdp,
    improvement_fn=generate_correction_feedback,
)

print(sum(len(a) for a in atoms_per_env), "atoms generated")


310 atoms generated
