In [1]:
from __future__ import annotations
import numpy as np

from minigrid.core.grid import Grid
from minigrid.core.mission import MissionSpace
from minigrid.core.world_object import Goal, Wall, Lava
from minigrid.minigrid_env import MiniGridEnv
from minigrid.manual_control import ManualControl

# ======================================================
# Utils
# ======================================================
def l2_normalize(w, eps=1e-8):
    n = np.linalg.norm(w)
    return w if n < eps else w / n
# ======================================================
# Environment
# ======================================================
class SimpleEnv(MiniGridEnv):
    def __init__(
        self,
        size=5,
        agent_start_pos=(1, 1),
        agent_start_dir=0,
        max_steps=None,
        **kwargs,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir

        if max_steps is None:
            max_steps = 4 * size**2

        super().__init__(
            mission_space=MissionSpace(lambda: "grand mission"),
            grid_size=size,
            see_through_walls=True,
            max_steps=max_steps,
            **kwargs,
        )

    def _gen_grid(self, width, height):
        self.grid = Grid(width, height)
        self.grid.wall_rect(0, 0, width, height)

        # vertical lava column at x=2, y=1..height-2
        for y in range(1, height - 1):
            self.grid.set(2, y, Lava())

        # goal
        self.put_obj(Goal(), width - 2, height - 2)

        # agent
        self.agent_pos = self.agent_start_pos  # (x,y)
        self.agent_dir = self.agent_start_dir
        self.mission = "grand mission"

# ======================================================
# Feature extraction (state-based, canonical)
# ======================================================

DIR_TO_VEC = {
    0: (1, 0),   # right
    1: (0, 1),   # down
    2: (-1, 0),  # left
    3: (0, -1),  # up
}

W_MAP = {
    "L1.2": np.array([-0.05, -2.0, -0.01]),                       # [dist, on_lava, step]
    "L1.3": np.array([-0.8, -0.1, -5.0, -0.05]),                 # [dist, lava_ahead, on_lava, step]
    "L2.1": np.array([-0.05, -0.3, -2.0, -0.01]),                 # [dist, num_lava4, on_lava, step]
    "L2.3": np.array([-0.05, -0.4, -2.0, -0.01]),                 # [dist, lava_nearby8, on_lava, step]
    "L3.1": np.array([-0.02, -0.02, -0.5, -2.0, -0.01]),          # [dx, dy, lava_ahead, on_lava, step]
    "L4.1": np.array([-0.05,  0.10, -2.0, -0.01]),                # [dist, min_dist_lava, on_lava, step]
}

FEATURE_SET = "L1.3"   # <<< CHANGE HERE

def manhattan(p, q):
    return abs(p[0] - q[0]) + abs(p[1] - q[1])

def lava_ahead_state(lava_mask: np.ndarray, y: int, x: int, direction: int) -> int:
    dx, dy = DIR_TO_VEC[direction]
    ny, nx = y + dy, x + dx
    if 0 <= ny < lava_mask.shape[0] and 0 <= nx < lava_mask.shape[1]:
        return int(lava_mask[ny, nx])
    return 0

def on_lava_state(lava_mask: np.ndarray, y: int, x: int) -> int:
    return int(lava_mask[y, x])

def num_lava_4_state(lava_mask: np.ndarray, y: int, x: int) -> int:
    cnt = 0
    for dy, dx in [(-1,0),(1,0),(0,-1),(0,1)]:
        ny, nx = y+dy, x+dx
        if 0 <= ny < lava_mask.shape[0] and 0 <= nx < lava_mask.shape[1]:
            cnt += int(lava_mask[ny, nx])
    return cnt

def lava_nearby_8_state(lava_mask: np.ndarray, y: int, x: int) -> int:
    for dy in [-1,0,1]:
        for dx in [-1,0,1]:
            if dy == 0 and dx == 0:
                continue
            ny, nx = y+dy, x+dx
            if 0 <= ny < lava_mask.shape[0] and 0 <= nx < lava_mask.shape[1]:
                if lava_mask[ny, nx]:
                    return 1
    return 0

def min_dist_to_lava_state(lava_cells: np.ndarray, y: int, x: int, size: int) -> float:
    if lava_cells.size == 0:
        return float(2 * size)
    return float(np.min(np.abs(lava_cells[:, 0] - y) + np.abs(lava_cells[:, 1] - x)))

def phi_from_state(state, goal_yx, lava_mask, lava_cells, size):
    """
    state = (y, x, dir)
    goal_yx = (gy, gx)
    """
    y, x, direction = state
    gy, gx = goal_yx

    dist = manhattan((y, x), (gy, gx))
    step = 1.0

    if FEATURE_SET == "L1.2":
        return np.array([dist, on_lava_state(lava_mask, y, x), step], dtype=float)

    if FEATURE_SET == "L1.3":
        return np.array([dist,
                         lava_ahead_state(lava_mask, y, x, direction),
                         on_lava_state(lava_mask, y, x),
                         step], dtype=float)

    if FEATURE_SET == "L2.1":
        return np.array([dist,
                         num_lava_4_state(lava_mask, y, x),
                         on_lava_state(lava_mask, y, x),
                         step], dtype=float)

    if FEATURE_SET == "L2.3":
        return np.array([dist,
                         lava_nearby_8_state(lava_mask, y, x),
                         on_lava_state(lava_mask, y, x),
                         step], dtype=float)

    if FEATURE_SET == "L3.1":
        dx = gx - x
        dy = gy - y
        return np.array([dx, dy,
                         lava_ahead_state(lava_mask, y, x, direction),
                         on_lava_state(lava_mask, y, x),
                         step], dtype=float)

    if FEATURE_SET == "L4.1":
        return np.array([dist,
                         min_dist_to_lava_state(lava_cells, y, x, size),
                         on_lava_state(lava_mask, y, x),
                         step], dtype=float)

    raise ValueError(f"Unknown FEATURE_SET: {FEATURE_SET}")

def reward_from_state(state, goal_yx, lava_mask, lava_cells, size) -> float:
    phi = phi_from_state(state, goal_yx, lava_mask, lava_cells, size)
    W = l2_normalize(W_MAP[FEATURE_SET])
    return float(W @ phi)

# ======================================================
# Planning dynamics model (left/right/forward)
# ======================================================
ACT_LEFT = 0
ACT_RIGHT = 1
ACT_FORWARD = 2
ACTIONS = [ACT_LEFT, ACT_RIGHT, ACT_FORWARD]

def build_static_maps(env: SimpleEnv):
    """
    Returns everything in NumPy (y,x) convention.
    """
    size = env.width
    wall_mask = np.zeros((size, size), dtype=bool)
    lava_mask = np.zeros((size, size), dtype=bool)
    goal_yx = None

    for y in range(size):
        for x in range(size):
            obj = env.grid.get(x, y)  # MiniGrid uses (x,y)
            if obj is None:
                continue
            if isinstance(obj, Wall):
                wall_mask[y, x] = True
            elif isinstance(obj, Lava):
                lava_mask[y, x] = True
            elif isinstance(obj, Goal):
                goal_yx = (y, x)

    if goal_yx is None:
        raise RuntimeError("Goal not found in grid")

    lava_cells = np.argwhere(lava_mask)
    return size, wall_mask, lava_mask, lava_cells, goal_yx

def is_terminal_state(state, goal_yx, lava_mask) -> bool:
    y, x, _ = state
    return (y, x) == goal_yx or bool(lava_mask[y, x])
    #return False

def step_model(state, action, wall_mask, goal_yx, lava_mask):
    """
    Deterministic transition for planning.
    state = (y, x, dir)
    """
    y, x, direction = state

    if is_terminal_state(state, goal_yx, lava_mask):
        return state, True

    if action == ACT_LEFT:
        nstate = (y, x, (direction - 1) % 4)
        return nstate, is_terminal_state(nstate, goal_yx, lava_mask)

    if action == ACT_RIGHT:
        nstate = (y, x, (direction + 1) % 4)
        return nstate, is_terminal_state(nstate, goal_yx, lava_mask)

    if action == ACT_FORWARD:
        dx, dy = DIR_TO_VEC[direction]
        ny, nx = y + dy, x + dx

        # bounds/wall => no move
        if ny < 0 or ny >= wall_mask.shape[0] or nx < 0 or nx >= wall_mask.shape[1]:
            nstate = (y, x, direction)
            return nstate, is_terminal_state(nstate, goal_yx, lava_mask)

        if wall_mask[ny, nx]:
            nstate = (y, x, direction)
            return nstate, is_terminal_state(nstate, goal_yx, lava_mask)

        nstate = (ny, nx, direction)
        return nstate, is_terminal_state(nstate, goal_yx, lava_mask)

    raise ValueError(f"Unknown action: {action}")

def enumerate_states(size, wall_mask):
    """
    All non-wall cells for all 4 directions.
    """
    states = []
    for y in range(size):
        for x in range(size):
            if wall_mask[y, x]:
                continue
            for d in range(4):
                states.append((y, x, d))
    return states

def build_tabular_mdp(states, wall_mask, goal_yx, lava_mask, lava_cells, size, gamma=0.99):
    """
    Tabular MDP for BIRL.
    """
    S = len(states)
    A = len(ACTIONS)

    idx_of = {s: i for i, s in enumerate(states)}
    T = np.zeros((S, A, S))
    terminal_mask = np.zeros(S, dtype=bool)
    Phi = np.zeros((S, len(W_MAP[FEATURE_SET])))

    for i, s in enumerate(states):
        terminal_mask[i] = is_terminal_state(s, goal_yx, lava_mask)
        Phi[i] = phi_from_state(s, goal_yx, lava_mask, lava_cells, size)

        for a_idx, a in enumerate(ACTIONS):
            sp, _ = step_model(s, a, wall_mask, goal_yx, lava_mask)
            T[i, a_idx, idx_of[sp]] = 1.0

    return {
        "states": states,
        "idx_of": idx_of,
        "T": T,
        "Phi": Phi,
        "terminal": terminal_mask,
        "gamma": gamma,
        "goal_yx": goal_yx,
        "lava_mask": lava_mask,
        "wall_mask": wall_mask,
        "lava_cells": lava_cells,
        "size": size,
    }

# ======================================================
# Value Iteration
# ======================================================

def value_iteration(
    states,
    wall_mask,
    goal_yx,
    lava_mask,
    lava_cells,
    size,
    gamma=0.99,
    theta=1e-8,
    max_iters=20000,
):
    """
    Reward computed from next state s' (as in your earlier code).
    """
    idx_of = {s: i for i, s in enumerate(states)}
    V = np.zeros(len(states), dtype=float)

    for _ in range(max_iters):
        delta = 0.0
        for i, s in enumerate(states):
            if is_terminal_state(s, goal_yx, lava_mask):
                continue

            best = -1e18
            for a in ACTIONS:
                sp, done = step_model(s, a, wall_mask, goal_yx, lava_mask)
                r = reward_from_state(sp, goal_yx, lava_mask, lava_cells, size)
                j = idx_of[sp]
                q = r + (0.0 if done else gamma * V[j])
                if q > best:
                    best = q

            delta = max(delta, abs(best - V[i]))
            V[i] = best

        if delta < theta:
            break

    pi = np.zeros(len(states), dtype=int)
    for i, s in enumerate(states):
        if is_terminal_state(s, goal_yx, lava_mask):
            pi[i] = ACT_FORWARD
            continue

        best_a = ACTIONS[0]
        best_q = -1e18
        for a in ACTIONS:
            sp, done = step_model(s, a, wall_mask, goal_yx, lava_mask)
            r = reward_from_state(sp, goal_yx, lava_mask, lava_cells, size)
            j = idx_of[sp]
            q = r + (0.0 if done else gamma * V[j])
            if q > best_q:
                best_q = q
                best_a = a
        pi[i] = best_a

    return V, pi, idx_of

def generate_state_action_demos(states, pi, terminal_mask):
    demos = []
    for i, _s in enumerate(states):
        if terminal_mask[i]:
            continue
        demos.append((i, int(pi[i])))
    return demos

# ======================================================
# BIRL (demo-only, MCMC over theta)
# ======================================================
class DemoOnlyBIRL:
    def __init__(self, T, Phi, gamma, demos, beta=1.0):
        self.T = T
        self.Phi = Phi
        self.gamma = gamma
        self.demos = demos
        self.beta = beta
        self.S, self.A, _ = T.shape
        self.D = Phi.shape[1]

    def _q_values(self, theta):
        """
        CONSISTENT with next-state reward value iteration:
        Q(s,a) = sum_{s'} T[s,a,s'] * ( r[s'] + gamma * V[s'] )
        """
        r = self.Phi @ theta            # r[s]
        Q = np.zeros((self.S, self.A))
        V = np.zeros(self.S)

        # --- Value iteration for V under next-state reward ---
        while True:
            delta = 0.0
            for s in range(self.S):
                q_vals = []
                for a in range(self.A):
                    q_sa = np.sum(self.T[s, a] * (r + self.gamma * V))
                    q_vals.append(q_sa)

                v_new = max(q_vals)
                delta = max(delta, abs(V[s] - v_new))
                V[s] = v_new

            if delta < 1e-8:
                break

        # --- Compute Q using converged V ---
        for s in range(self.S):
            for a in range(self.A):
                Q[s, a] = np.sum(self.T[s, a] * (r + self.gamma * V))

        return Q


    def log_likelihood(self, theta):
        theta = l2_normalize(theta)
        Q = self._q_values(theta)

        ll = 0.0
        for s, a in self.demos:
            logits = self.beta * Q[s]
            # stable log-sum-exp would be nicer, but keep simple:
            ll += self.beta * Q[s, a] - np.log(np.sum(np.exp(logits)))
        return ll

    def run(self, num_samples=1000, step_size=0.3, seed=0):
        rng = np.random.default_rng(seed)
        theta = l2_normalize(rng.normal(0, 1, self.D))
        ll = self.log_likelihood(theta)

        best_theta = theta.copy()
        best_ll = ll

        for _ in range(num_samples):
            prop = l2_normalize(theta + rng.normal(0, step_size, self.D))
            prop_ll = self.log_likelihood(prop)

            if prop_ll > ll or rng.random() < np.exp(prop_ll - ll):
                theta, ll = prop, prop_ll

            if ll > best_ll:
                best_theta, best_ll = theta.copy(), ll

        return best_theta

# ======================================================
# Evaluation
# ======================================================
# ======================================================
# Evaluation (CONSISTENT with your Value Iteration)
# ======================================================
def expected_value_difference(T, gamma, r_true, pi_true, pi_eval, terminal_mask):
    """
    Make policy evaluation consistent with YOUR value_iteration:
      - reward is on NEXT state s' (like reward_from_state(sp))
      - terminal transitions cut off future (0 if done)
      - terminal states are absorbing in step_model (already), but we also skip updating them
    Inputs:
      T: (S,A,S) deterministic/stochastic transitions
      r_true: length S, interpreted as r_next[sp] (reward of NEXT state)
      terminal_mask: length S boolean, True if state is terminal
    """
    S = T.shape[0]
    def policy_eval(policy):
        V = np.zeros(S, dtype=float)
        while True:
            delta = 0.0
            for s in range(S):
                if terminal_mask[s]:
                    # consistent with VI: terminal value is 0 and no future
                    continue

                a = int(policy[s])

                # Bellman expectation with NEXT-state reward:
                # V(s) = sum_{s'} T[s,a,s'] * ( r_true[s'] + gamma * V(s') )
                v_new = float(np.sum(T[s, a] * (r_true + gamma * V)))

                delta = max(delta, abs(V[s] - v_new))
                V[s] = v_new

            if delta < 1e-8:
                break
        return V

    V_opt  = policy_eval(pi_true)
    V_eval = policy_eval(pi_eval)
    return float(np.mean(V_opt) - np.mean(V_eval))


# ======================================================
# Debug printing (state-based)
# =====================================================
def debug_print_values_and_demos(
    states,
    idx_of,
    V,
    pi,
    wall_mask,
    goal_yx,
    lava_mask,
    lava_cells,
    size,
    gamma=0.99,
):
    print("\n===== DEBUG: V(s), Q(s,a), DEMOS =====")

    for i, s in enumerate(states):
        y, x, d = s

        if wall_mask[y, x]:
            cell_type = "WALL"
        elif (y, x) == goal_yx:
            cell_type = "GOAL"
        elif lava_mask[y, x]:
            cell_type = "LAVA"
        else:
            cell_type = "EMPTY"

        term = is_terminal_state(s, goal_yx, lava_mask)
        term_str = "TERMINAL" if term else ""

        print(f"\nState {i:03d} (y={y}, x={x}, d={d}) [{cell_type}] {term_str}")
        print(f"  V(s) = {V[i]: .6f}")

        if cell_type == "WALL":
            print("  ⚠️ ERROR: WALL state present in state space!")
            continue
        if term:
            print("  (terminal state, no actions)")
            continue

        for a in ACTIONS:
            sp, done = step_model(s, a, wall_mask, goal_yx, lava_mask)
            r = reward_from_state(sp, goal_yx, lava_mask, lava_cells, size)
            j = idx_of[sp]
            q = r + (0.0 if done else gamma * V[j])
            print(f"  Q(s,a={a}) = {q: .6f}")

        print(f"  OPTIMAL ACTION (demo) = {pi[i]}")

# ======================================================
# Optional manual control (no SymbolicObsWrapper)
# ======================================================

class ManualControlStateReward(ManualControl):
    def __init__(self, env, wall_mask, lava_mask, lava_cells, goal_yx, size):
        super().__init__(env)
        self.wall_mask = wall_mask
        self.lava_mask = lava_mask
        self.lava_cells = lava_cells
        self.goal_yx = goal_yx
        self.size = size

    def step(self, action):
        _obs, _rew, terminated, truncated, _info = self.env.step(action)

        x, y = self.env.unwrapped.agent_pos  # env uses (x,y)
        d = int(self.env.unwrapped.agent_dir)
        s = (y, x, d)

        if self.wall_mask[y, x]:
            cell_type = "WALL"
        elif (y, x) == self.goal_yx:
            cell_type = "GOAL"
        elif self.lava_mask[y, x]:
            cell_type = "LAVA"
        else:
            cell_type = "EMPTY"

        phi = phi_from_state(s, self.goal_yx, self.lava_mask, self.lava_cells, self.size)
        R = reward_from_state(s, self.goal_yx, self.lava_mask, self.lava_cells, self.size)

        print(f"STATE {s} [{cell_type}]  φ(s)={phi}  R(s)={R:.4f}")

        return _obs, 0.0, terminated, truncated, _info

# ======================================================
# Main (BIRL experiment)
# ======================================================

def birl_main():
    # --- build env once for map extraction / planning
    planning_env = SimpleEnv(render_mode=None)
    planning_env.reset(seed=0)

    size, wall_mask, lava_mask, lava_cells, goal_yx = build_static_maps(planning_env)
    states = enumerate_states(size, wall_mask)

    # --- ground truth planning under your FEATURE_SET and W_MAP
    V_true, pi_true, idx_of = value_iteration(
        states, wall_mask, goal_yx, lava_mask, lava_cells, size
    )

    # --- tabular mdp
    mdp = build_tabular_mdp(states, wall_mask, goal_yx, lava_mask, lava_cells, size)

    # --- demos from optimal policy
    demos = generate_state_action_demos(states, pi_true, mdp["terminal"])

    # --- run BIRL
    birl = DemoOnlyBIRL(mdp["T"], mdp["Phi"], mdp["gamma"], demos, beta=1.0)
    theta_hat = birl.run(num_samples=1000, step_size=0.25, seed=0)

    # --- learned reward and greedy policy (simple one-step greedy rollout using VI-style Q)
    r_hat = mdp["Phi"] @ theta_hat

    # compute a policy by running VI with learned weights (cleanest)
    # but you can also compute greedy wrt V_hat; keep simple: reuse value_iteration by temporarily swapping W_MAP
    # We'll do a small helper instead:
    V_hat = np.zeros(len(states))
    pi_hat = np.zeros(len(states), dtype=int)

    # Evaluate greedy using Bellman backup with fixed V_hat iteration (NEXT-STATE reward)
    for _ in range(500):
        delta = 0.0
        for i, s in enumerate(states):
            if mdp["terminal"][i]:
                continue

            best = -1e18
            for a in ACTIONS:
                sp, done = step_model(s, a, wall_mask, goal_yx, lava_mask)
                j = idx_of[sp]

                q = r_hat[j] + (0.0 if done else mdp["gamma"] * V_hat[j])

                if q > best:
                    best = q

            delta = max(delta, abs(best - V_hat[i]))
            V_hat[i] = best

        if delta < 1e-10:
            break

    for i, s in enumerate(states):
        if mdp["terminal"][i]:
            pi_hat[i] = ACT_FORWARD
            continue

        best_a = ACTIONS[0]
        best_q = -1e18
        for a in ACTIONS:
            sp, done = step_model(s, a, wall_mask, goal_yx, lava_mask)
            j = idx_of[sp]

            q = r_hat[j] + (0.0 if done else mdp["gamma"] * V_hat[j])

            if q > best_q:
                best_q = q
                best_a = a

        pi_hat[i] = best_a

    # --- true reward vector over states (consistent with normalized theta used in reward_from_state)
    r_true = mdp["Phi"] @ l2_normalize(W_MAP[FEATURE_SET])

    # --- EVD
    evd = expected_value_difference(mdp["T"], mdp["gamma"], r_true, pi_true, pi_hat, mdp["terminal"])
    print("\n===== BIRL RESULTS (SYMBOLIC-FREE) =====")
    print("FEATURE_SET :", FEATURE_SET)
    print("True θ      :", l2_normalize(W_MAP[FEATURE_SET]))
    print("Learned θ   :", theta_hat)
    print("EVD         :", evd)

    # Optional: debug state-wise V/Q/demos
    debug_print_values_and_demos(
        states=states,
        idx_of=idx_of,
        V=V_true,
        pi=pi_true,
        wall_mask=wall_mask,
        goal_yx=goal_yx,
        lava_mask=lava_mask,
        lava_cells=lava_cells,
        size=size,
        gamma=mdp["gamma"],
    )

    # Optional: interactive manual play with state-based printing
    # env = SimpleEnv(render_mode="human")
    # env.reset(seed=0)
    # manual = ManualControlStateReward(env, wall_mask, lava_mask, lava_cells, goal_yx, size)
    # manual.start()

if __name__ == "__main__":
    birl_main()


===== BIRL RESULTS (SYMBOLIC-FREE) =====
FEATURE_SET : L1.3
True θ      : [-0.157952   -0.019744   -0.98720002 -0.009872  ]
Learned θ   : [-1.16070895e-04 -9.99483190e-01  2.91536387e-02 -1.35427352e-02]
EVD         : 0.05865421972337925

===== DEBUG: V(s), Q(s,a), DEMOS =====

State 000 (y=1, x=1, d=0) [EMPTY] 
  V(s) = -1.470928
  Q(s,a=0) = -2.738146
  Q(s,a=1) = -2.738146
  Q(s,a=2) = -1.470928
  OPTIMAL ACTION (demo) = 2

State 001 (y=1, x=1, d=1) [EMPTY] 
  V(s) = -2.117643
  Q(s,a=0) = -2.117643
  Q(s,a=1) = -3.352445
  Q(s,a=2) = -2.265609
  OPTIMAL ACTION (demo) = 0

State 002 (y=1, x=1, d=2) [EMPTY] 
  V(s) = -2.738146
  Q(s,a=0) = -2.738146
  Q(s,a=1) = -2.738146
  Q(s,a=2) = -3.352445
  OPTIMAL ACTION (demo) = 0

State 003 (y=1, x=1, d=3) [EMPTY] 
  V(s) = -2.117643
  Q(s,a=0) = -3.352445
  Q(s,a=1) = -2.117643
  Q(s,a=2) = -2.738146
  OPTIMAL ACTION (demo) = 1

State 004 (y=1, x=2, d=0) [LAVA] TERMINAL
  V(s) =  0.000000
  (terminal state, no actions)

State 005 (y=1, x=2

In [None]:
# ======================================================
# TEST: Value Iteration Debug Print
# (Run this in a NEW notebook cell)
# ======================================================

def test_value_iteration_debug(
    size=5,
    gamma=0.99,
    max_states=None,   # set to int to limit printing
):
    # --- Build env (no render)
    env = SimpleEnv(render_mode=None, size=size)
    env.reset(seed=0)

    # --- Extract static maps
    size, wall_mask, lava_mask, lava_cells, goal_yx = build_static_maps(env)

    # --- Enumerate state space
    states = enumerate_states(size, wall_mask)

    # --- Run Value Iteration
    V, pi, idx_of = value_iteration(
        states,
        wall_mask,
        goal_yx,
        lava_mask,
        lava_cells,
        size,
        gamma=gamma,
    )

    print("\n=====================================================")
    print(" VALUE ITERATION DEBUG OUTPUT")
    print(" FEATURE_SET =", FEATURE_SET)
    print(" gamma =", gamma)
    print(" #states =", len(states))
    print("=====================================================")

    printed = 0

    for i, s in enumerate(states):
        if max_states is not None and printed >= max_states:
            break

        y, x, d = s

        # --- semantic cell type
        if wall_mask[y, x]:
            cell = "WALL"
        elif (y, x) == goal_yx:
            cell = "GOAL"
        elif lava_mask[y, x]:
            cell = "LAVA"
        else:
            cell = "EMPTY"

        terminal = is_terminal_state(s, goal_yx, lava_mask)

        print("\n---------------------------------------------")
        print(f"State {i:03d} : (y={y}, x={x}, dir={d})")
        print(f"Cell Type   : {cell}")
        print(f"Terminal    : {terminal}")
        print(f"V(s)        : {V[i]: .6f}")

        if terminal:
            print("  (terminal state — no Q-values)")
            printed += 1
            continue

        # --- Q(s,a)
        for a in ACTIONS:
            sp, done = step_model(s, a, wall_mask, goal_yx, lava_mask)
            j = idx_of[sp]
            r = reward_from_state(sp, goal_yx, lava_mask, lava_cells, size)
            q = r + (0.0 if done else gamma * V[j])

            print(
                f"  Q(s,a={a}) = {q: .6f}   "
                f"-> s'={sp}, R(s')={r: .4f}, done={done}"
            )

        print(f"  Optimal a* : {pi[i]}")
        printed += 1

    print("\n=====================================================")
    print(" END VALUE ITERATION DEBUG")
    print("=====================================================")


# ======================================================
# Run the test
# ======================================================

test_value_iteration_debug(
    size=5,
    gamma=0.99,
    max_states=None,
)

In [None]:
from __future__ import annotations
import numpy as np

from minigrid.core.grid import Grid
from minigrid.core.mission import MissionSpace
from minigrid.core.world_object import Goal, Wall, Lava
from minigrid.minigrid_env import MiniGridEnv
from minigrid.wrappers import SymbolicObsWrapper
from minigrid.manual_control import ManualControl


def l2_normalize(w, eps=1e-8):
    n = np.linalg.norm(w)
    return w if n < eps else w / n

DIR_TO_VEC = {
    0: (1, 0),    # right
    1: (0, 1),    # down
    2: (-1, 0),   # left
    3: (0, -1),   # up
}


FEATURE_SET = "L1.3"

W_MAP = {
    "L1.3": np.array([   # [dist, lava_ahead, on_lava, step]
        -0.05,
        -0.5,
        -2.0,
        -0.01
    ])
}

class LinearRewardEnv(MiniGridEnv):
    def __init__(
        self,
        size=5,
        agent_start_pos=(1, 1),
        agent_start_dir=0,
        **kwargs,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir

        super().__init__(
            mission_space=MissionSpace(lambda: "Reach goal, avoid lava"),
            grid_size=size,
            see_through_walls=True,
            max_steps=4 * size * size,
            **kwargs,
        )

    def _gen_grid(self, width, height):
        self.grid = Grid(width, height)
        self.grid.wall_rect(0, 0, width, height)

        # vertical lava column
        for i in range(height - 2):
            if i == 0:
                continue
            else:
                self.grid.set(2, i, Lava())

        self.put_obj(Goal(), width - 2, height - 2)

        self.agent_pos = self.agent_start_pos
        self.agent_dir = self.agent_start_dir
        self.mission = "Linear reward MiniGrid"



In [None]:
env = LinearRewardEnv()

env.reset(seed=0)

In [None]:
print(dir(env))

In [None]:

def build_static_maps(env: SimpleEnv):
    """
    Extract:
      - size
      - wall mask
      - lava mask
      - goal position
    """
    size = env.width  # square
    wall_mask = np.zeros((size, size), dtype=bool)
    lava_mask = np.zeros((size, size), dtype=bool)
    goal_yx = None

    for y in range(size):
        for x in range(size):
            obj = env.grid.get(x, y)
            if obj is None:
                continue
            if isinstance(obj, Wall):
                wall_mask[y, x] = True
            elif isinstance(obj, Lava):
                lava_mask[y, x] = True
            elif isinstance(obj, Goal):
                goal_yx = (y, x)

    if goal_yx is None:
        raise RuntimeError("Goal not found in grid")

    lava_cells = np.argwhere(lava_mask)
    return size, wall_mask, lava_mask, lava_cells, goal_yx

In [None]:
size, wall_mask, lava_mask, lava_cells, goal_yx = build_static_maps(env)

In [None]:
goal_yx