In [None]:
from __future__ import annotations
import numpy as np

from minigrid.core.grid import Grid
from minigrid.core.mission import MissionSpace
from minigrid.core.world_object import Goal, Wall, Lava
from minigrid.minigrid_env import MiniGridEnv
from minigrid.wrappers import SymbolicObsWrapper
from minigrid.manual_control import ManualControl


def l2_normalize(w, eps=1e-8):
    n = np.linalg.norm(w)
    if n < eps:
        return w
    return w / n


# ======================================================
# Environment
# ======================================================

class SimpleEnv(MiniGridEnv):
    def __init__(
        self,
        size=5,
        agent_start_pos=(1, 1),
        agent_start_dir=0,
        max_steps=None,
        **kwargs,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir

        if max_steps is None:
            max_steps = 4 * size**2

        super().__init__(
            mission_space=MissionSpace(lambda: "grand mission"),
            grid_size=size,
            see_through_walls=True,
            max_steps=max_steps,
            **kwargs,
        )

    def _gen_grid(self, width, height):
        self.grid = Grid(width, height)
        self.grid.wall_rect(0, 0, width, height)

        # "vertical wall" (actually lava column at x=2, y=1..height-3)
        for i in range(height - 2):
            if i == 0:
                continue
            else:
                self.grid.set(2, i, Lava())

        # goal
        self.put_obj(Goal(), width - 2, height - 2)

        # agent
        self.agent_pos = self.agent_start_pos
        self.agent_dir = self.agent_start_dir
        self.mission = "grand mission"


# ======================================================
# Feature extraction (state-based, NumPy-only)
# ======================================================

# direction → (dx, dy)
DIR_TO_VEC = {
    0: (1, 0),   # right
    1: (0, 1),   # down
    2: (-1, 0),  # left
    3: (0, -1),  # up
}

W_MAP = {
    "L1.2": np.array([   # [dist, on_lava, step]
        -0.05,
        -2.0,
        -0.01
    ]),

    "L1.3": np.array([   # [dist, lava_ahead, on_lava, step]
        -0.05,
        -0.5,
        -2.0,
        -0.01
    ]),

    "L2.1": np.array([   # [dist, num_lava4, on_lava, step]
        -0.05,
        -0.3,
        -2.0,
        -0.01
    ]),

    "L2.3": np.array([   # [dist, lava_nearby8, on_lava, step]
        -0.05,
        -0.4,
        -2.0,
        -0.01
    ]),

    "L3.1": np.array([   # [dx, dy, lava_ahead, on_lava, step]
        -0.02,
        -0.02,
        -0.5,
        -2.0,
        -0.01
    ]),

    "L4.1": np.array([   # [dist, min_dist_lava, on_lava, step]
        -0.05,
         0.10,   # farther from lava = good
        -2.0,
        -0.01
    ]),
}

# ======================================================
# Feature set selector
# ======================================================

FEATURE_SET = "L1.3"   # <<< CHANGE HERE

# ======================================================
# Utilities for reward/features from (x,y,dir)
# ======================================================

def manhattan(p, q):
    return abs(p[0] - q[0]) + abs(p[1] - q[1])

def lava_ahead_state(lava_mask: np.ndarray, y: int, x: int, direction: int) -> int:
    dx, dy = DIR_TO_VEC[direction]
    ny, nx = y + dy, x + dx
    if 0 <= ny < lava_mask.shape[0] and 0 <= nx < lava_mask.shape[1]:
        return int(lava_mask[ny, nx])
    return 0

def on_lava_state(lava_mask: np.ndarray, y: int, x: int) -> int:
    return int(lava_mask[y, x])

def num_lava_4_state(lava_mask: np.ndarray, y: int, x: int) -> int:
    cnt = 0
    for dy, dx in [(-1,0),(1,0),(0,-1),(0,1)]:
        ny, nx = y+dy, x+dx
        if 0 <= ny < lava_mask.shape[0] and 0 <= nx < lava_mask.shape[1]:
            cnt += int(lava_mask[ny, nx])
    return cnt

def lava_nearby_8_state(lava_mask: np.ndarray, y: int, x: int) -> int:
    for dy in [-1,0,1]:
        for dx in [-1,0,1]:
            if dy == 0 and dx == 0:
                continue
            ny, nx = y+dy, x+dx
            if 0 <= ny < lava_mask.shape[0] and 0 <= nx < lava_mask.shape[1]:
                if lava_mask[ny, nx]:
                    return 1
    return 0

def min_dist_to_lava_state(lava_cells: np.ndarray, y: int, x: int, size: int) -> float:
    if lava_cells.size == 0:
        return float(2 * size)
    # lava_cells is array of (y,x)
    return float(np.min(np.abs(lava_cells[:, 0] - y) + np.abs(lava_cells[:, 1] - x)))

def phi_from_state(state, goal_yx, lava_mask, lava_cells, size):
    """
    state = (y, x, dir)
    goal_yx = (gy, gx)
    """
    y, x, direction = state
    gy, gx = goal_yx

    dist = manhattan((y, x), (gy, gx))
    step = 1.0

    if FEATURE_SET == "L1.2":
        return np.array([dist, on_lava_state(lava_mask, y, x), step], dtype=float)

    if FEATURE_SET == "L1.3":
        return np.array([
            dist,
            lava_ahead_state(lava_mask, y, x, direction),
            on_lava_state(lava_mask, y, x),
            step
        ], dtype=float)

    if FEATURE_SET == "L2.1":
        return np.array([
            dist,
            num_lava_4_state(lava_mask, y, x),
            on_lava_state(lava_mask, y, x),
            step
        ], dtype=float)

    if FEATURE_SET == "L2.3":
        return np.array([
            dist,
            lava_nearby_8_state(lava_mask, y, x),
            on_lava_state(lava_mask, y, x),
            step
        ], dtype=float)

    if FEATURE_SET == "L3.1":
        dx = gx - x
        dy = gy - y
        return np.array([
            dx,
            dy,
            lava_ahead_state(lava_mask, y, x, direction),
            on_lava_state(lava_mask, y, x),
            step
        ], dtype=float)

    if FEATURE_SET == "L4.1":
        return np.array([
            dist,
            min_dist_to_lava_state(lava_cells, y, x, size),
            on_lava_state(lava_mask, y, x),
            step
        ], dtype=float)

    raise ValueError(f"Unknown FEATURE_SET: {FEATURE_SET}")


def reward_from_state(state, goal_yx, lava_mask, lava_cells, size) -> float:
    phi = phi_from_state(state, goal_yx, lava_mask, lava_cells, size)
    W = l2_normalize(W_MAP[FEATURE_SET])
    return float(W @ phi)



# ======================================================
# MiniGrid dynamics model for planning (left/right/forward)
# ======================================================

# MiniGrid actions (common): left=0, right=1, forward=2
ACT_LEFT = 0
ACT_RIGHT = 1
ACT_FORWARD = 2
ACTIONS = [ACT_LEFT, ACT_RIGHT, ACT_FORWARD]


def build_static_maps(env: SimpleEnv):
    """
    Extract:
      - size
      - wall mask
      - lava mask
      - goal position
    """
    size = env.width  # square
    wall_mask = np.zeros((size, size), dtype=bool)
    lava_mask = np.zeros((size, size), dtype=bool)
    goal_yx = None

    for y in range(size):
        for x in range(size):
            obj = env.grid.get(x, y)
            if obj is None:
                continue
            if isinstance(obj, Wall):
                wall_mask[y, x] = True
            elif isinstance(obj, Lava):
                lava_mask[y, x] = True
            elif isinstance(obj, Goal):
                goal_yx = (y, x)

    if goal_yx is None:
        raise RuntimeError("Goal not found in grid")

    lava_cells = np.argwhere(lava_mask)
    return size, wall_mask, lava_mask, lava_cells, goal_yx

def is_terminal_state(state, goal_yx, lava_mask) -> bool:
    y, x, _ = state
    if (y, x) == goal_yx:
        return True
    if lava_mask[y, x]:
        return True
    return False

def step_model(state, action, wall_mask, goal_yx, lava_mask):
    """
    Deterministic transition for planning.
    state = (y, x, dir)
    """
    y, x, direction = state

    # If already terminal, stay terminal (absorbing)
    if is_terminal_state(state, goal_yx, lava_mask):
        return state, True

    if action == ACT_LEFT:
        ndir = (direction - 1) % 4
        nstate = (y, x, ndir)
        return nstate, is_terminal_state(nstate, goal_yx, lava_mask)

    if action == ACT_RIGHT:
        ndir = (direction + 1) % 4
        nstate = (y, x, ndir)
        return nstate, is_terminal_state(nstate, goal_yx, lava_mask)

    if action == ACT_FORWARD:
        dx, dy = DIR_TO_VEC[direction]
        ny, nx = y + dy, x + dx

        # wall/bounds -> no move
        if ny < 0 or ny >= wall_mask.shape[0] or nx < 0 or nx >= wall_mask.shape[1]:
            nstate = (y, x, direction)
            return nstate, is_terminal_state(nstate, goal_yx, lava_mask)

        if wall_mask[ny, nx]:
            nstate = (y, x, direction)
            return nstate, is_terminal_state(nstate, goal_yx, lava_mask)

        nstate = (ny, nx, direction)
        return nstate, is_terminal_state(nstate, goal_yx, lava_mask)

    raise ValueError(f"Unknown action: {action}")

def enumerate_states(size, wall_mask):
    """
    Include all non-wall cells for all 4 directions.
    """
    states = []
    for y in range(size):
        for x in range(size):
            if wall_mask[y, x]:
                continue
            for d in range(4):
                states.append((y, x, d))
    return states

def build_tabular_mdp(states, wall_mask, goal_yx, lava_mask, lava_cells, size, gamma=0.99):
    """
    Converts your planning model into a tabular MDP for BIRL.
    """
    S = len(states)
    A = len(ACTIONS)

    idx_of = {s: i for i, s in enumerate(states)}
    T = np.zeros((S, A, S))
    terminal_mask = np.zeros(S, dtype=bool)
    Phi = np.zeros((S, len(W_MAP[FEATURE_SET])))

    for i, s in enumerate(states):
        terminal_mask[i] = is_terminal_state(s, goal_yx, lava_mask)
        Phi[i] = phi_from_state(s, goal_yx, lava_mask, lava_cells, size)

        for a_idx, a in enumerate(ACTIONS):
            sp, _ = step_model(s, a, wall_mask, goal_yx, lava_mask)
            j = idx_of[sp]
            T[i, a_idx, j] = 1.0

    return {
        "states": states,
        "idx_of": idx_of,
        "T": T,
        "Phi": Phi,
        "terminal": terminal_mask,
        "gamma": gamma,
    }

# ======================================================
# Value Iteration
# ======================================================

def value_iteration(
    states,
    wall_mask,
    goal_yx,
    lava_mask,
    lava_cells,
    size,
    gamma=0.99,
    theta=1e-8,
    max_iters=20000,
):
    """
    Reward is computed from next state (s') like your ManualControl prints after env.step().
    That matches: R = reward_numpy(obs_next).
    """
    # index states for arrays
    idx_of = {s: i for i, s in enumerate(states)}
    V = np.zeros(len(states), dtype=float)

    for it in range(max_iters):
        delta = 0.0
        for i, s in enumerate(states):
            if is_terminal_state(s, goal_yx, lava_mask):
                # terminal value = 0 (absorbing, no future)
                continue

            best = -1e18
            for a in ACTIONS:
                sp, done = step_model(s, a, wall_mask, goal_yx, lava_mask)
                r = reward_from_state(sp, goal_yx, lava_mask, lava_cells, size)

                j = idx_of[sp]
                q = r + (0.0 if done else gamma * V[j])
                if q > best:
                    best = q

            new_v = best
            delta = max(delta, abs(new_v - V[i]))
            V[i] = new_v

        if delta < theta:
            # converged
            # print(f"Value iteration converged at iter={it}, delta={delta}")
            break

    # greedy policy
    pi = np.zeros(len(states), dtype=int)
    for i, s in enumerate(states):
        if is_terminal_state(s, goal_yx, lava_mask):
            pi[i] = ACT_FORWARD  # arbitrary
            continue

        best_a = ACTIONS[0]
        best_q = -1e18
        for a in ACTIONS:
            sp, done = step_model(s, a, wall_mask, goal_yx, lava_mask)
            r = reward_from_state(sp, goal_yx, lava_mask, lava_cells, size)
            j = idx_of[sp]
            q = r + (0.0 if done else gamma * V[j])
            if q > best_q:
                best_q = q
                best_a = a
        pi[i] = best_a

    return V, pi, idx_of

def generate_state_action_demos(states, pi, terminal_mask):
    """
    Demo = (state_id, optimal_action)
    No trajectories.
    """
    demos = []
    for i, s in enumerate(states):
        if terminal_mask[i]:
            continue
        demos.append((i, int(pi[i])))
    return demos

# ======================================================
# Manual control with live feature + reward printing (unchanged)
# ======================================================

class ManualControlWithReward(ManualControl):
    def step(self, action):
        obs, _, terminated, truncated, info = self.env.step(action)

        # If you still want obs-based features/reward prints, keep your original pipeline:
        # (SymbolicObsWrapper obs contains "image" and "direction".)
        phi = phi_state_numpy(obs)
        R = reward_numpy(obs)

        print("φ(s) =", phi, "R(s) =", R)

        return obs, 0.0, terminated, truncated, info

# ======================================================
# Your original obs-based feature pipeline (kept for printing)
# ======================================================

LAVA_IDX = 9
AGENT_IDX = 10
GOAL_IDX = 8

def get_agent_goal(idx):
    agent = np.argwhere(idx == AGENT_IDX)
    goal  = np.argwhere(idx == GOAL_IDX)

    ay, ax = agent[0]

    if goal.size == 0:
        return (ay, ax), None
    else:
        gy, gx = goal[0]
        return (ay, ax), (gy, gx)

def lava_ahead(idx, ay, ax, direction):
    dx, dy = DIR_TO_VEC[direction]
    fy, fx = ay + dy, ax + dx
    if 0 <= fy < idx.shape[0] and 0 <= fx < idx.shape[1]:
        return int(idx[fy, fx] == LAVA_IDX)
    return 0

def on_lava(idx, ay, ax):
    return int(idx[ay, ax] == LAVA_IDX)

def num_lava_4(idx, ay, ax):
    cnt = 0
    for dy, dx in [(-1,0),(1,0),(0,-1),(0,1)]:
        y, x = ay+dy, ax+dx
        if 0 <= y < idx.shape[0] and 0 <= x < idx.shape[1]:
            cnt += int(idx[y, x] == LAVA_IDX)
    return cnt

def lava_nearby_8(idx, ay, ax):
    for dy in [-1,0,1]:
        for dx in [-1,0,1]:
            if dy == 0 and dx == 0:
                continue
            y, x = ay+dy, ax+dx
            if 0 <= y < idx.shape[0] and 0 <= x < idx.shape[1]:
                if idx[y, x] == LAVA_IDX:
                    return 1
    return 0

def min_dist_to_lava(idx, ay, ax):
    lava_cells = np.argwhere(idx == LAVA_IDX)
    if lava_cells.size == 0:
        return float(idx.shape[0] + idx.shape[1])
    return min(abs(ay-y) + abs(ax-x) for y, x in lava_cells)

def phi_L12(obs):
    idx = obs["image"][:, :, 2]
    (ay, ax), goal = get_agent_goal(idx)
    dist = 0 if goal is None else manhattan((ay, ax), goal)
    return np.array([dist, on_lava(idx, ay, ax), 1.0])

def phi_L13(obs):
    idx = obs["image"][:, :, 2]
    (ay, ax), goal = get_agent_goal(idx)
    dist = 0 if goal is None else manhattan((ay, ax), goal)
    return np.array([
        dist,
        lava_ahead(idx, ay, ax, obs["direction"]),
        on_lava(idx, ay, ax),
        1.0
    ])

def phi_L21(obs):
    idx = obs["image"][:, :, 2]
    (ay, ax), goal = get_agent_goal(idx)
    dist = 0 if goal is None else manhattan((ay, ax), goal)
    return np.array([
        dist,
        num_lava_4(idx, ay, ax),
        on_lava(idx, ay, ax),
        1.0
    ])

def phi_L23(obs):
    idx = obs["image"][:, :, 2]
    (ay, ax), goal = get_agent_goal(idx)
    dist = 0 if goal is None else manhattan((ay, ax), goal)
    return np.array([
        dist,
        lava_nearby_8(idx, ay, ax),
        on_lava(idx, ay, ax),
        1.0
    ])

def phi_L31(obs):
    idx = obs["image"][:, :, 2]
    (ay, ax), goal = get_agent_goal(idx)

    if goal is None:
        dx = dy = 0
    else:
        gy, gx = goal
        dy, dx = gy - ay, gx - ax

    return np.array([
        dx,
        dy,
        lava_ahead(idx, ay, ax, obs["direction"]),
        on_lava(idx, ay, ax),
        1.0
    ])

def phi_L41(obs):
    idx = obs["image"][:, :, 2]
    (ay, ax), goal = get_agent_goal(idx)
    dist = 0 if goal is None else manhattan((ay, ax), goal)
    return np.array([
        dist,
        min_dist_to_lava(idx, ay, ax),
        on_lava(idx, ay, ax),
        1.0
    ])

PHI_MAP = {
    "L1.2": phi_L12,
    "L1.3": phi_L13,
    "L2.1": phi_L21,
    "L2.3": phi_L23,
    "L3.1": phi_L31,
    "L4.1": phi_L41,
}

def phi_state_numpy(obs):
    return PHI_MAP[FEATURE_SET](obs)

def reward_numpy(obs):
    phi = phi_state_numpy(obs)
    W = W_MAP[FEATURE_SET]
    assert len(W) == len(phi), "Weight/feature size mismatch"
    return float(W @ phi)

# ======================================================
# Run optimal policy on the env (rendered)
# ======================================================

def run_optimal_policy(env_wrapped, pi, idx_of, wall_mask, goal_yx, lava_mask, lava_cells, size, max_steps=1000):
    """
    env_wrapped: SymbolicObsWrapper(SimpleEnv(render_mode="human"))
    """
    obs, info = env_wrapped.reset(seed=42)
    terminated = truncated = False

    # the wrapper keeps direction in obs; but we will read from env.unwrapped for safety
    for t in range(max_steps):
        uy, ux = env_wrapped.unwrapped.agent_pos[1], env_wrapped.unwrapped.agent_pos[0]  # careful: env uses (x,y)
        # Actually MiniGrid stores agent_pos as (x,y). Your features used (y,x).
        x, y = env_wrapped.unwrapped.agent_pos
        d = int(env_wrapped.unwrapped.agent_dir)

        s = (y, x, d)
        a = int(pi[idx_of[s]])

        obs, _, terminated, truncated, info = env_wrapped.step(a)

        # print same style as manual
        phi = phi_state_numpy(obs)
        R = reward_numpy(obs)
        print(f"t={t:03d} a={a}  φ(s')={phi}  R(s')={R:.4f}")

        if terminated or truncated:
            break

# ======================================================
# Run BIRL to learn reward 
# ======================================================

class DemoOnlyBIRL:
    def __init__(self, T, Phi, gamma, demos, beta=10.0):
        self.T = T
        self.Phi = Phi
        self.gamma = gamma
        self.demos = demos
        self.beta = beta
        self.S, self.A, _ = T.shape
        self.D = Phi.shape[1]

    def _q_values(self, theta):
        r = self.Phi @ theta
        Q = np.zeros((self.S, self.A))
        V = np.zeros(self.S)

        while True:
            delta = 0.0
            for s in range(self.S):
                q_vals = []
                for a in range(self.A):
                    q_vals.append(r[s] + self.gamma * np.sum(self.T[s, a] * V))
                v_new = max(q_vals)
                delta = max(delta, abs(V[s] - v_new))
                V[s] = v_new
            if delta < 1e-8:
                break

        for s in range(self.S):
            for a in range(self.A):
                Q[s, a] = r[s] + self.gamma * np.sum(self.T[s, a] * V)

        return Q

    def log_likelihood(self, theta):
        theta = l2_normalize(theta)
        Q = self._q_values(theta)

        ll = 0.0
        for s, a in self.demos:
            logits = self.beta * Q[s]
            ll += self.beta * Q[s, a] - np.log(np.sum(np.exp(logits)))
        return ll

    def run(self, num_samples=400, step_size=0.3):
        rng = np.random.default_rng(0)
        theta = rng.normal(0, 1, self.D)
        ll = self.log_likelihood(theta)

        best_theta = theta.copy()
        best_ll = ll

        for _ in range(num_samples):
            #prop = theta + rng.normal(0, step_size, self.D)
            prop = theta + rng.normal(0, step_size, self.D)
            prop = l2_normalize(prop)

            prop_ll = self.log_likelihood(prop)

            if prop_ll > ll or rng.random() < np.exp(prop_ll - ll):
                theta, ll = prop, prop_ll

            if ll > best_ll:
                best_theta, best_ll = theta.copy(), ll

        return best_theta

def expected_value_difference(T, gamma, r_true, pi_true, pi_eval):
    def policy_eval(policy):
        V = np.zeros(len(r_true))
        while True:
            delta = 0.0
            for s in range(len(V)):
                v_new = r_true[s] + gamma * np.sum(T[s, policy[s]] * V)
                delta = max(delta, abs(V[s] - v_new))
                V[s] = v_new
            if delta < 1e-8:
                break
        return V

    V_opt = policy_eval(pi_true)
    V_eval = policy_eval(pi_eval)
    return np.mean(V_opt) - np.mean(V_eval)

def birl_main():
    planning_env = SimpleEnv(render_mode=None)
    planning_env.reset(seed=0)

    size, wall_mask, lava_mask, lava_cells, goal_yx = build_static_maps(planning_env)
    states = enumerate_states(size, wall_mask)

    # Ground-truth planning
    V_true, pi_true, idx_of = value_iteration(
        states, wall_mask, goal_yx, lava_mask, lava_cells, size
    )

    mdp = build_tabular_mdp(
        states, wall_mask, goal_yx, lava_mask, lava_cells, size
    )

    demos = generate_state_action_demos(
        states, pi_true, mdp["terminal"]
    )

    # Run BIRL
    birl = DemoOnlyBIRL(
        mdp["T"], mdp["Phi"], mdp["gamma"], demos
    )
    theta_hat = birl.run()

    # Learned policy
    r_hat = mdp["Phi"] @ theta_hat
    V_hat = np.zeros(len(states))
    pi_hat = np.zeros(len(states), dtype=int)

    for i, s in enumerate(states):
        best_q = -1e18
        for a in ACTIONS:
            sp, _ = step_model(s, a, wall_mask, goal_yx, lava_mask)
            j = idx_of[sp]
            q = r_hat[i] + mdp["gamma"] * V_hat[j]
            if q > best_q:
                best_q = q
                pi_hat[i] = a

    # Evaluation
    #r_true = mdp["Phi"] @ W_MAP[FEATURE_SET]
    r_true = mdp["Phi"] @ l2_normalize(W_MAP[FEATURE_SET])

    evd = expected_value_difference(
        mdp["T"], mdp["gamma"], r_true, pi_true, pi_hat
    )

    print("\n===== BIRL RESULTS =====")
    print("True θ     :", W_MAP[FEATURE_SET])
    print("Learned θ  :", theta_hat)
    print("EVD        :", evd)

# ======================================================
# Main
# ======================================================

def main():
    # Create env once for planning map extraction (no render needed)
    planning_env = SimpleEnv(render_mode=None)
    planning_env.reset(seed=42)

    size, wall_mask, lava_mask, lava_cells, goal_yx = build_static_maps(planning_env)

    states = enumerate_states(size, wall_mask)
    V, pi, idx_of = value_iteration(
        states,
        wall_mask,
        goal_yx,
        lava_mask,
        lava_cells,
        size,
        gamma=0.99,
        theta=1e-8,
    )

    # Pretty-print a tiny summary
    start_state = (planning_env.agent_start_pos[1], planning_env.agent_start_pos[0], planning_env.agent_start_dir)  # (y,x,dir)
    print("FEATURE_SET =", FEATURE_SET)
    print("Start state =", start_state, "V(start) =", V[idx_of[start_state]])

    # Now run in rendered env using the computed policy
    env = SimpleEnv(render_mode="human")
    env = SymbolicObsWrapper(env)

    run_optimal_policy(
        env_wrapped=env,
        pi=pi,
        idx_of=idx_of,
        wall_mask=wall_mask,
        goal_yx=goal_yx,
        lava_mask=lava_mask,
        lava_cells=lava_cells,
        size=size,
        max_steps=4 * size * size,
    )

    # If you want manual mode instead, comment out run_optimal_policy and use:
    # manual_control = ManualControlWithReward(env, seed=42)
    # manual_control.start()

if __name__ == "__main__":
    birl_main()

  ll += self.beta * Q[s, a] - np.log(np.sum(np.exp(logits)))
  if prop_ll > ll or rng.random() < np.exp(prop_ll - ll):
  ll += self.beta * Q[s, a] - np.log(np.sum(np.exp(logits)))



===== BIRL RESULTS =====
True θ     : [-0.05 -0.5  -2.   -0.01]
Learned θ  : [-0.31953372 -0.69286049  0.43314024 -0.75588681]
EVD        : 13.872706296171962
