In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))

if module_path not in sys.path:
    sys.path.append(module_path)
#from utils import remove_redundant_constraints

#from __future__ import annotations
import numpy as np
from multiprocessing import Pool, cpu_count

from minigrid.core.world_object import Goal, Wall, Lava
from minigrid.manual_control import ManualControl
from minigrid.minigrid_env import MiniGridEnv
from minigrid.core.mission import MissionSpace

# test_lavaworld_manual.py
from utils import (generate_lavaworld, 
                    generate_demos_from_policies_multi, 
                    constraints_from_demos_next_state_multi,
                    generate_trajectory_pools_multi,
                    generate_feedback_multi)

from minigrid.manual_control import ManualControl

In [None]:
ACT_LEFT = 0
ACT_RIGHT = 1
ACT_FORWARD = 2
ACTIONS = [ACT_LEFT, ACT_RIGHT, ACT_FORWARD]
envs, mdps, meta = generate_lavaworld(
    n_envs=5,
    size=8,
    seed=42,
)

print("Generated", len(envs), "envs")

#env = envs[9]
#env.reset()

# print("Goal at:", meta["goals"][0])
# print("Lava cells:")
# print(meta["lava_masks"][0].astype(int))

# manual = ManualControl(env, seed=0)
# manual.start()

In [None]:
import numpy as np

theta_true_list = []
for mdp in mdps:
    D = mdp["Phi"].shape[1]
    theta_true_list.append(np.random.randn(D))

In [None]:
traj_pools = generate_trajectory_pools_multi(
    mdps=mdps,
    n_trajs_per_state=3,
    max_horizon=25,
    n_jobs=4,
)

pairwise_list, correction_list, estop_list = generate_feedback_multi(
    traj_pools=traj_pools,
    mdps=mdps,
    theta_true_list=theta_true_list,
    gamma=0.99,
    n_pairs=200,
    num_random_trajs=8,
    estop_beta=10.0,
    n_jobs=4,
)

In [None]:
def policy_evaluation_next_state(
    T: np.ndarray,
    r_next: np.ndarray,
    policy: np.ndarray,
    terminal_mask: np.ndarray,
    gamma: float,
    theta: float = 1e-8,
    max_iters: int = 200000,
) -> np.ndarray:
    """
    Evaluate a fixed policy with NEXT-state reward:
      V(s) = Σ_{s'} T[s,a,s'] * ( r_next[s'] + gamma * 1[~terminal(s')] * V(s') )
    Terminal states are kept at V=0 (consistent with your VI done-cutoff).
    """
    S, A, S2 = T.shape
    assert S == S2
    V = np.zeros(S, dtype=float)

    cont = (~terminal_mask).astype(float)  # 1 if nonterminal, 0 if terminal

    for _ in range(max_iters):
        delta = 0.0
        for s in range(S):
            if terminal_mask[s]:
                continue
            a = int(policy[s])
            v_new = float(np.sum(T[s, a] * (r_next + gamma * (cont * V))))
            delta = max(delta, abs(v_new - V[s]))
            V[s] = v_new
        if delta < theta:
            break
    return V

def value_iteration_next_state(
    T: np.ndarray,
    r_next: np.ndarray,
    terminal_mask: np.ndarray,
    gamma: float,
    theta: float = 1e-8,
    max_iters: int = 200000,
):
    """
    NEXT-state reward value iteration:
      Q(s,a) = Σ_{s'} T[s,a,s'] * ( r_next[s'] + gamma * 1[~terminal(s')] * V(s') )
      V(s) = max_a Q(s,a)
    Terminal states fixed at V=0.
    Returns: V, Q, pi
    """
    S, A, S2 = T.shape
    assert S == S2
    V = np.zeros(S, dtype=float)
    Q = np.zeros((S, A), dtype=float)

    cont = (~terminal_mask).astype(float)

    for _ in range(max_iters):
        delta = 0.0
        for s in range(S):
            if terminal_mask[s]:
                continue

            # compute Q(s,a) for all a
            for a in range(A):
                Q[s, a] = float(np.sum(T[s, a] * (r_next + gamma * (cont * V))))

            v_new = float(np.max(Q[s]))
            delta = max(delta, abs(v_new - V[s]))
            V[s] = v_new

        if delta < theta:
            break

    # greedy policy
    pi = np.zeros(S, dtype=int)
    for s in range(S):
        if terminal_mask[s]:
            pi[s] = ACT_FORWARD
        else:
            pi[s] = int(np.argmax(Q[s]))

    return V, Q, pi

def compute_successor_features_from_q_next_state(
    T: np.ndarray,
    Phi: np.ndarray,
    Q: np.ndarray,
    terminal_mask: np.ndarray,
    gamma: float,
    tol: float = 1e-10,
    max_iters: int = 100000,
):
    """
    Successor Features with NEXT-STATE (entering) convention, consistent with your code.

    Definitions:
      π(s)      = argmax_a Q(s,a)
      ψ(s)      = E_π [ sum_t γ^t φ(s_{t+1}) | s0 = s ]
      ψ(s,a)    = E [ φ(s1) + γ ψ(s1) | s0=s, a0=a ]

    Bellman equation:
      ψ(s) = Σ_{s'} P_π(s,s') [ φ(s') + γ * 1[~terminal(s')] * ψ(s') ]

    Inputs:
      T             : (S,A,S) transition matrix
      Phi           : (S,D) state feature matrix (φ(s))
      Q             : (S,A) Q-values (used to extract greedy policy)
      terminal_mask : (S,) boolean
      gamma         : discount factor

    Returns:
      Psi_sa : (S,A,D) successor features for state-action
      Psi_s  : (S,D)   successor features for state
    """
    S, A, S2 = T.shape
    assert S == S2
    D = Phi.shape[1]

    # -----------------------------
    # Greedy policy from Q
    # -----------------------------
    Pi = np.zeros((S, A), dtype=float)
    for s in range(S):
        if terminal_mask[s]:
            continue
        Pi[s, np.argmax(Q[s])] = 1.0

    # -----------------------------
    # Policy transition matrix
    # P_pi[s,s'] = Σ_a π(a|s) T[s,a,s']
    # -----------------------------
    P_pi = np.zeros((S, S), dtype=float)
    for s in range(S):
        for a in range(A):
            if Pi[s, a] > 0:
                P_pi[s] += Pi[s, a] * T[s, a]

        # absorbing fallback (safety)
        if P_pi[s].sum() == 0:
            P_pi[s, s] = 1.0

    cont = (~terminal_mask).astype(float)

    # -----------------------------
    # Iterative policy SFs ψ(s)
    # -----------------------------
    Psi_s = np.zeros((S, D), dtype=float)

    for _ in range(max_iters):
        Psi_old = Psi_s.copy()

        for s in range(S):
            if terminal_mask[s]:
                continue

            exp_phi_next = P_pi[s] @ Phi
            exp_psi_next = P_pi[s] @ Psi_old

            Psi_s[s] = exp_phi_next + gamma * cont[s] * exp_psi_next

        if np.max(np.abs(Psi_s - Psi_old)) < tol:
            break

    # -----------------------------
    # State–action successor features ψ(s,a)
    # -----------------------------
    Psi_sa = np.zeros((S, A, D), dtype=float)
    for s in range(S):
        for a in range(A):
            p_next = T[s, a]
            exp_phi_next = p_next @ Phi
            exp_psi_next = p_next @ Psi_s
            Psi_sa[s, a] = exp_phi_next + gamma * cont[s] * exp_psi_next

    return Psi_sa, Psi_s

In [None]:
def _policy_eval_worker(args):
    T, r_next, policy, terminal_mask, gamma, theta, max_iters = args
    return policy_evaluation_next_state(
        T=T,
        r_next=r_next,
        policy=policy,
        terminal_mask=terminal_mask,
        gamma=gamma,
        theta=theta,
        max_iters=max_iters,
    )

def policy_evaluation_next_state_multi(
    mdps,
    r_next_list,
    policy_list,
    gamma,
    theta=1e-8,
    max_iters=200000,
    n_jobs=None,
):
    """
    Parallel policy evaluation over multiple envs.

    mdps        : list of mdp dicts
    r_next_list : list of r_next vectors (one per env)
    policy_list : list of policies (one per env)
    """
    if n_jobs is None:
        n_jobs = cpu_count()

    args = [
        (
            mdp["T"],
            r_next,
            policy,
            mdp["terminal"],
            gamma,
            theta,
            max_iters,
        )
        for mdp, r_next, policy in zip(mdps, r_next_list, policy_list)
    ]

    with Pool(n_jobs) as pool:
        Vs = pool.map(_policy_eval_worker, args)

    return Vs

def _vi_worker(args):
    T, r_next, terminal_mask, gamma, theta, max_iters = args
    return value_iteration_next_state(
        T=T,
        r_next=r_next,
        terminal_mask=terminal_mask,
        gamma=gamma,
        theta=theta,
        max_iters=max_iters,
    )

def value_iteration_next_state_multi(
    mdps,
    r_next_list,
    gamma,
    theta=1e-8,
    max_iters=200000,
    n_jobs=None,
):
    """
    Parallel value iteration over multiple envs.

    Returns:
        V_list, Q_list, pi_list
    """
    if n_jobs is None:
        n_jobs = cpu_count()

    args = [
        (
            mdp["T"],
            r_next,
            mdp["terminal"],
            gamma,
            theta,
            max_iters,
        )
        for mdp, r_next in zip(mdps, r_next_list)
    ]

    with Pool(n_jobs) as pool:
        results = pool.map(_vi_worker, args)

    V_list, Q_list, pi_list = zip(*results)
    return list(V_list), list(Q_list), list(pi_list)

def _sf_worker(args):
    T, Phi, Q, terminal_mask, gamma, tol, max_iters = args
    return compute_successor_features_from_q_next_state(
        T=T,
        Phi=Phi,
        Q=Q,
        terminal_mask=terminal_mask,
        gamma=gamma,
        tol=tol,
        max_iters=max_iters,
    )

def compute_successor_features_multi(
    mdps,
    Q_list,
    gamma,
    tol=1e-10,
    max_iters=100000,
    n_jobs=None,
):
    """
    Parallel successor feature computation.
    """
    if n_jobs is None:
        n_jobs = cpu_count()

    args = [
        (
            mdp["T"],
            mdp["Phi"],
            Q,
            mdp["terminal"],
            gamma,
            tol,
            max_iters,
        )
        for mdp, Q in zip(mdps, Q_list)
    ]

    with Pool(n_jobs) as pool:
        results = pool.map(_sf_worker, args)

    Psi_sa_list, Psi_s_list = zip(*results)
    return list(Psi_sa_list), list(Psi_s_list)

In [None]:
def test_pipeline_end_to_end(mdps, r_next_list, gamma=0.99, tol=1e-8):
    """
    End-to-end test:
    VI → SF → demos → constraints
    Passing outputs strictly forward.
    """

    # --------------------------------------------------
    # 1) Value Iteration
    # --------------------------------------------------
    V_list, Q_list, pi_list = value_iteration_next_state_multi(
        mdps=mdps,
        r_next_list=r_next_list,
        gamma=gamma,
        n_jobs=4,
    )

    assert len(V_list) == len(mdps)
    assert len(Q_list) == len(mdps)
    assert len(pi_list) == len(mdps)

    # --------------------------------------------------
    # 2) Successor Features (from Q)
    # --------------------------------------------------
    Psi_sa_list, Psi_s_list = compute_successor_features_multi(
        mdps=mdps,
        Q_list=Q_list,
        gamma=gamma,
        n_jobs=4,
    )

    assert len(Psi_sa_list) == len(mdps)
    assert len(Psi_s_list) == len(mdps)

    # --------------------------------------------------
    # 3) Demos (from policy)
    # --------------------------------------------------
    demos_list = generate_demos_from_policies_multi(
        mdps=mdps,
        pi_list=pi_list,
        n_jobs=4,
    )

    assert len(demos_list) == len(mdps)

    # Each demo must agree with π and avoid terminals
    for mdp, demos, pi in zip(mdps, demos_list, pi_list):
        terminal = mdp["terminal"]
        for s, a in demos:
            assert not terminal[s]
            assert a == pi[s]

    # --------------------------------------------------
    # 4) Constraints (from demos + SFs)
    # --------------------------------------------------
    terminal_mask_list = [mdp["terminal"] for mdp in mdps]

    constraints_per_env = constraints_from_demos_next_state_multi(
        demos_list=demos_list,
        Psi_sa_list=Psi_sa_list,
        terminal_mask_list=terminal_mask_list,
        normalize=True,
        n_jobs=4,
    )

    assert len(constraints_per_env) == len(mdps)

    # --------------------------------------------------
    # 5) Constraint sanity checks
    # --------------------------------------------------
    for i, (mdp, constraints) in enumerate(zip(mdps, constraints_per_env)):
        D = mdp["Phi"].shape[1]

        for c in constraints:
            assert c.shape == (D,)
            assert np.linalg.norm(c) > tol

    print("✅ END-TO-END PIPELINE TEST PASSED")

r_next_list = [
    np.random.randn(mdp["T"].shape[0])
    for mdp in mdps
]

test_pipeline_end_to_end(
    mdps=mdps,
    r_next_list=r_next_list,
    gamma=0.99,
)

### testing above function

In [None]:
import numpy as np

def max_abs(x):
    return float(np.max(np.abs(x)))

def assert_close(a, b, tol, name=""):
    err = max_abs(a - b)
    print(f"[{name}] max error = {err:.3e}")
    assert err < tol, f"{name} failed (err={err})"

In [None]:
def test_vi_vs_policy_eval(mdps, gamma, tol=1e-6):
    print("\n=== Test: VI vs Policy Evaluation ===")

    for i, mdp in enumerate(mdps):
        T = mdp["T"]
        terminal = mdp["terminal"]

        # random NEXT-state reward
        r_next = np.random.randn(T.shape[0])

        V, Q, pi = value_iteration_next_state(
            T=T,
            r_next=r_next,
            terminal_mask=terminal,
            gamma=gamma,
        )

        V_eval = policy_evaluation_next_state(
            T=T,
            r_next=r_next,
            policy=pi,
            terminal_mask=terminal,
            gamma=gamma,
        )

        assert_close(V, V_eval, tol, name=f"env {i}")

    print("✅ Passed: VI and policy evaluation match")

def test_bellman_optimality(mdps, gamma, tol=1e-6):
    print("\n=== Test: Bellman Optimality ===")

    for i, mdp in enumerate(mdps):
        T = mdp["T"]
        terminal = mdp["terminal"]
        r_next = np.random.randn(T.shape[0])

        V, Q, pi = value_iteration_next_state(
            T=T,
            r_next=r_next,
            terminal_mask=terminal,
            gamma=gamma,
        )

        cont = (~terminal).astype(float)

        for s in range(T.shape[0]):
            if terminal[s]:
                continue
            q_star = np.max([
                np.sum(T[s, a] * (r_next + gamma * cont * V))
                for a in range(T.shape[1])
            ])
            assert abs(V[s] - q_star) < tol

    print("✅ Passed: Bellman optimality")

def test_successor_features_value_reconstruction(mdps, gamma, tol=1e-5):
    print("\n=== Test: Successor Features → Value ===")

    for i, mdp in enumerate(mdps):
        T = mdp["T"]
        Phi = mdp["Phi"]
        terminal = mdp["terminal"]

        D = Phi.shape[1]
        w = np.random.randn(D)

        r_next = Phi @ w

        V, Q, pi = value_iteration_next_state(
            T=T,
            r_next=r_next,
            terminal_mask=terminal,
            gamma=gamma,
        )

        Psi_sa, Psi_s = compute_successor_features_from_q_next_state(
            T=T,
            Phi=Phi,
            Q=Q,
            terminal_mask=terminal,
            gamma=gamma,
        )

        V_hat = Psi_s @ w

        assert_close(V, V_hat, tol, name=f"env {i}")

    print("✅ Passed: SF value reconstruction")

def test_terminal_states(mdps):
    print("\n=== Test: Terminal States ===")

    for i, mdp in enumerate(mdps):
        terminal = mdp["terminal"]
        T = mdp["T"]
        r_next = np.random.randn(T.shape[0])

        V, Q, pi = value_iteration_next_state(
            T=T,
            r_next=r_next,
            terminal_mask=terminal,
            gamma=0.99,
        )

        assert np.all(V[terminal] == 0.0)
        assert np.all(Q[terminal] == 0.0)

    print("✅ Passed: terminal states fixed at zero")

def test_multi_vs_single(mdps, gamma, tol=1e-6):
    print("\n=== Test: Multi-env vs Single-env ===")

    r_list = [np.random.randn(mdp["T"].shape[0]) for mdp in mdps]

    V_single, Q_single, pi_single = [], [], []
    for mdp, r in zip(mdps, r_list):
        V, Q, pi = value_iteration_next_state(
            T=mdp["T"],
            r_next=r,
            terminal_mask=mdp["terminal"],
            gamma=gamma,
        )
        V_single.append(V)
        Q_single.append(Q)
        pi_single.append(pi)

    V_multi, Q_multi, pi_multi = value_iteration_next_state_multi(
        mdps=mdps,
        r_next_list=r_list,
        gamma=gamma,
        n_jobs=2,
    )

    for i in range(len(mdps)):
        assert_close(V_single[i], V_multi[i], tol, f"V env {i}")
        assert_close(Q_single[i], Q_multi[i], tol, f"Q env {i}")
        assert np.all(pi_single[i] == pi_multi[i])

    print("✅ Passed: multi-env consistency")

In [None]:
def run_all_tests(mdps, gamma=0.99):
    test_vi_vs_policy_eval(mdps, gamma)
    test_bellman_optimality(mdps, gamma)
    test_successor_features_value_reconstruction(mdps, gamma)
    test_terminal_states(mdps)
    test_multi_vs_single(mdps, gamma)

In [None]:
envs, mdps, meta = generate_lavaworld(
    n_envs=10,
    size=8,
    seed=42,
)
run_all_tests(mdps, gamma=0.99)