In [2]:
# =============================================================================
# Two-Stage SCOT vs Random (GLOBAL POOL) — FULL EXPERIMENT
# =============================================================================

import argparse
import json
import os
import sys
import time
import numpy as np
from concurrent.futures import ProcessPoolExecutor

# -----------------------------------------------------------------------------
# Path setup
# -----------------------------------------------------------------------------
module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

# -----------------------------------------------------------------------------
# Imports (UNCHANGED CORE PIPELINE)
# -----------------------------------------------------------------------------
from utils import (
    generate_random_gridworld_envs,
    compute_successor_features_family,
    derive_constraints_from_q_family,
    derive_constraints_from_atoms,
    compute_Q_from_weights_with_VI,
    remove_redundant_constraints,
    parallel_value_iteration,
    GenerationSpec,
    DemoSpec,
    FeedbackSpec,
)

from utils.successor_features import max_q_sa_pairs
from utils.common_helper import calculate_expected_value_difference
from utils.feedback_budgeting import generate_candidate_atoms_for_scot
from reward_learning.multi_env_atomic_birl import MultiEnvAtomicBIRL
from gridworld_env_layout import GridWorldMDPFromLayoutEnv

from teaching.two_stage_scot import two_stage_scot


# =============================================================================
# Ground-truth reward generator
# =============================================================================
def generate_w_true(d, seed=None):
    rng = np.random.default_rng(seed)
    w = rng.normal(size=d)
    return w / np.linalg.norm(w)

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.


In [3]:
seed=11000
feature_dim=3
n_envs=20
mdp_size=3

# --------------------------------------------------
W_TRUE = generate_w_true(feature_dim, seed=seed)

# --------------------------------------------------
# 2. Environments
# --------------------------------------------------
color_to_feature_map = {
    f"f{i}": [1 if j == i else 0 for j in range(feature_dim)]
    for i in range(feature_dim)
}

envs, _ = generate_random_gridworld_envs(
    n_envs=n_envs,
    rows=mdp_size,
    cols=mdp_size,
    color_to_feature_map=color_to_feature_map,
    palette=list(color_to_feature_map.keys()),
    p_color_range={c: (0.3, 0.8) for c in color_to_feature_map},
    terminal_policy=dict(kind="random_k", k_min=1, k_max=1),
    gamma_range=(0.99, 0.99),
    noise_prob_range=(0.0, 0.0),
    w_mode="fixed",
    W_fixed=W_TRUE,
    seed=seed,
    GridEnvClass=GridWorldMDPFromLayoutEnv,
)

envs = np.array(envs)[[0,4]]

In [4]:
for env in envs:
    env.print_mdp_info()


Grid size           : 3 x 3
Num states          : 9
Num actions         : 4
Discount factor γ   : 0.99
Noise probability  : 0.0
Num features        : 3
Terminal states     : [3]
Start location      : (0, 0)

Feature weights (normalized):
[-0.0944  0.5151  0.8519]

Layout (colors):
f0 f0 f1
f1 f0 f2
f0 f1 f0

Sample transition check (state 0):
  Action 0: [(0, np.float64(1.0))]
  Action 1: [(3, np.float64(1.0))]
  Action 2: [(0, np.float64(1.0))]
  Action 3: [(1, np.float64(1.0))]


Grid size           : 3 x 3
Num states          : 9
Num actions         : 4
Discount factor γ   : 0.99
Noise probability  : 0.0
Num features        : 3
Terminal states     : [8]
Start location      : (0, 0)

Feature weights (normalized):
[-0.0944  0.5151  0.8519]

Layout (colors):
f1 f0 f1
f1 f2 f1
f1 f1 f1

Sample transition check (state 0):
  Action 0: [(0, np.float64(1.0))]
  Action 1: [(3, np.float64(1.0))]
  Action 2: [(0, np.float64(1.0))]
  Action 3: [(1, np.float64(1.0))]



In [5]:
# --------------------------------------------------
# 3. Optimal Q
# --------------------------------------------------
Q_list = parallel_value_iteration(envs, epsilon=1e-10)

# --------------------------------------------------
# 4. Successor features
# --------------------------------------------------
SFs = compute_successor_features_family(
    envs,
    Q_list,
    convention="entering",
    zero_terminal_features=True,
)

# --------------------------------------------------
# 5. CONSTRAINT + ATOM GENERATION (SPEC-BASED)
# --------------------------------------------------
print("GENERATING CONSTRAINTS")

# Q-based constraints
U_per_env_q, U_q = derive_constraints_from_q_family(
    SFs,
    Q_list,
    envs,
    skip_terminals=False,
    normalize=True,
)

[3/12] Running Value Iteration on all MDPs... (parallel)
       VI progress: 1/2 MDPs solved...
       VI progress: 2/2 MDPs solved...
       ✔ VI completed in 0.25s

GENERATING CONSTRAINTS


In [None]:
def derive_constraints_from_q_ties(
    mu_sa,
    q_values,
    env,
    tie_eps=1e-10,
    skip_terminals=True,
    normalize=True,
    tol=1e-12,
    mode="optimal",   # "optimal" (old behavior) | "all"
):
    """
    Derive linear reward constraints of the form:
        w · (mu(s,a) - mu(s,b)) >= 0

    Returns:
        list of (v, s, a, b)
    """

    S, A, d = mu_sa.shape
    q = np.asarray(q_values, float)

    # Identify optimal actions per state
    m = np.max(q, axis=1, keepdims=True)
    argmax_mask = np.abs(q - m) <= tie_eps

    if skip_terminals and getattr(env, "terminal_states", None) is not None:
        terms = np.array(env.terminal_states, dtype=int)
        argmax_mask[terms] = False

    constraints = []

    for s in range(S):

        # Skip terminal states
        if skip_terminals and getattr(env, "terminal_states", None) is not None:
            if s in env.terminal_states:
                continue

        psi_s = mu_sa[s]

        if mode == "optimal":
            A_src = np.where(argmax_mask[s])[0]
            B_fn = lambda a: np.where(~argmax_mask[s])[0]

        elif mode == "all":
            A_src = np.arange(A)
            B_fn = lambda a: [b for b in range(A) if b != a]

        else:
            raise ValueError(f"Unknown mode: {mode}")

        for a in A_src:
            B = B_fn(a)
            if len(B) == 0:
                continue

            diffs = psi_s[a][None, :] - psi_s[B]
            norms = np.linalg.norm(diffs, axis=1)

            for i, b in enumerate(B):
                if norms[i] <= tol:
                    continue

                v = diffs[i] / norms[i] if normalize else diffs[i]
                constraints.append((v, s, a, b))

    return constraints

In [7]:
for i in SFs[0][0]:
    print(i)
    print()

[[ 1.99        0.9801     97.02989999]
 [ 0.          0.          0.        ]
 [ 1.99        0.9801     97.02989999]
 [ 1.          0.99       98.00999999]]

[[ 1.          0.99       98.00999999]
 [ 1.          0.         98.99999999]
 [ 1.99        0.9801     97.02989999]
 [ 0.          1.         98.99999999]]

[[ 0.          1.         98.99999999]
 [ 0.          0.         99.99999999]
 [ 1.          0.99       98.00999999]
 [ 0.          1.         98.99999999]]

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]

[[ 1.          0.99       98.00999999]
 [ 0.99        1.         98.00999999]
 [ 0.          0.          0.        ]
 [ 0.          0.         99.99999999]]

[[ 0.          1.         98.99999999]
 [ 1.          0.         98.99999999]
 [ 1.          0.         98.99999999]
 [ 0.          0.         99.99999999]]

[[ 0.          0.          0.        ]
 [ 1.9801      0.99       97.02989999]
 [ 1.9801      0.99       97.02989999]
 [ 0.99        1.         98.00999999]]

[[

In [10]:
for i in remove_redundant_constraints(U_q):
    print(i)

[0.01020199 0.01009997 0.99989695]
[-0.70710678  0.70710678  0.        ]
[ 0.         -0.70710678  0.70710678]
[0.         0.7035446  0.71065111]


In [9]:
for u in U_per_env_q:
    print(len(remove_redundant_constraints(u)))
    for i in remove_redundant_constraints(u):
        print(i)
    #print(remove_redundant_constraints(u))
    print()

3
[0.01020199 0.01009997 0.99989695]
[-0.70710678  0.70710678  0.        ]
[ 0.00710633 -0.71063316  0.70352683]

3
[ 0.         -0.70710678  0.70710678]
[-0.70710678  0.70710678  0.        ]
[0.         0.7035446  0.71065111]



In [17]:
for i in U_per_env_q[0]:
    print(i)

[-0.71063316  0.00710633  0.70352683]
[0.01020199 0.01009997 0.99989695]
[-0.71063316  0.00710633  0.70352683]
[-0.71063316  0.00710633  0.70352683]
[-0.70710678  0.70710678  0.        ]
[-0.71063316  0.00710633  0.70352683]
[ 0.         -0.70710678  0.70710678]
[-0.41029806 -0.40619508  0.81649314]
[ 0.         -0.70710678  0.70710678]
[-0.41029806 -0.40619508  0.81649314]
[-0.40619508 -0.41029806  0.81649314]
[0. 0. 1.]
[ 0.         -0.70710678  0.70710678]
[-0.70710678  0.          0.70710678]
[-0.70710678  0.          0.70710678]
[0.01009997 0.01020199 0.99989695]
[-0.71066833  0.00717774  0.70349059]
[-0.71066833  0.00717774  0.70349059]
[ 0.00710633 -0.71063316  0.70352683]
[-0.40619508 -0.41029806  0.81649314]
[ 0.00710633 -0.71063316  0.70352683]
[-0.40619508 -0.41029806  0.81649314]
[-0.70710678  0.          0.70710678]
[-0.40619508 -0.41029806  0.81649314]
[-0.70710678  0.          0.70710678]


In [43]:
const_env_0 = derive_constraints_from_q_ties(SFs[0][0], Q_list[0], envs[0])
const_env_1 = derive_constraints_from_q_ties(SFs[1][0], Q_list[1], envs[1])


In [44]:
for i in const_env_0:
    print(i)

(array([-0.71063316,  0.00710633,  0.70352683]), 0, np.int64(3), np.int64(0))
(array([0.01020199, 0.01009997, 0.99989695]), 0, np.int64(3), np.int64(1))
(array([-0.71063316,  0.00710633,  0.70352683]), 0, np.int64(3), np.int64(2))
(array([-0.71063316,  0.00710633,  0.70352683]), 1, np.int64(3), np.int64(0))
(array([-0.70710678,  0.70710678,  0.        ]), 1, np.int64(3), np.int64(1))
(array([-0.71063316,  0.00710633,  0.70352683]), 1, np.int64(3), np.int64(2))
(array([ 0.        , -0.70710678,  0.70710678]), 2, np.int64(1), np.int64(0))
(array([-0.41029806, -0.40619508,  0.81649314]), 2, np.int64(1), np.int64(2))
(array([ 0.        , -0.70710678,  0.70710678]), 2, np.int64(1), np.int64(3))
(array([-0.41029806, -0.40619508,  0.81649314]), 4, np.int64(3), np.int64(0))
(array([-0.40619508, -0.41029806,  0.81649314]), 4, np.int64(3), np.int64(1))
(array([0., 0., 1.]), 4, np.int64(3), np.int64(2))
(array([ 0.        , -0.70710678,  0.70710678]), 5, np.int64(3), np.int64(0))
(array([-0.70710

In [45]:
for i in const_env_1:
    print(i)

(array([ 0.        , -0.70710678,  0.70710678]), 0, np.int64(1), np.int64(0))
(array([ 0.        , -0.70710678,  0.70710678]), 0, np.int64(1), np.int64(2))
(array([-0.70710678,  0.70710678,  0.        ]), 0, np.int64(1), np.int64(3))
(array([-0.81649314,  0.40619508,  0.41029806]), 1, np.int64(1), np.int64(0))
(array([ 0.        , -0.70710678,  0.70710678]), 1, np.int64(1), np.int64(2))
(array([ 0.        , -0.70710678,  0.70710678]), 1, np.int64(1), np.int64(3))
(array([ 0.        , -0.70710678,  0.70710678]), 2, np.int64(1), np.int64(0))
(array([-0.70710678,  0.70710678,  0.        ]), 2, np.int64(1), np.int64(2))
(array([ 0.        , -0.70710678,  0.70710678]), 2, np.int64(1), np.int64(3))
(array([ 0.        , -0.70710678,  0.70710678]), 3, np.int64(3), np.int64(0))
(array([ 0.        , -0.70710678,  0.70710678]), 3, np.int64(3), np.int64(1))
(array([ 0.        , -0.70710678,  0.70710678]), 3, np.int64(3), np.int64(2))
(array([-0.70710678,  0.70710678,  0.        ]), 4, np.int64(1),

In [None]:
spec = GenerationSpec(
    seed=seed,

    demo=DemoSpec(
        enabled=True,
        env_fraction=1.0,
        max_steps=1,
        state_fraction=1,
        alloc_method="uniform",
    ),

    # pairwise=FeedbackSpec(
    #     enabled=("pairwise" in enabled),
    #     total_budget=total_budget if "pairwise" in enabled else 0,
    #     alloc_method="uniform",
    # ),

    # estop=FeedbackSpec(
    #     enabled=("estop" in enabled),
    #     total_budget=total_budget if "estop" in enabled else 0,
    #     alloc_method="uniform",
    # ),

    # improvement=FeedbackSpec(
    #     enabled=("improvement" in enabled),
    #     total_budget=total_budget if "improvement" in enabled else 0,
    #     alloc_method="uniform",
    # ),
)

candidates_per_env = generate_candidate_atoms_for_scot(
    envs,
    Q_list,
    spec=spec,
)

In [14]:
for i in candidates_per_env[0]:
    print(i.data)


[(5, 3)]
[(7, 3)]
[(1, 3)]
[(6, 3)]
[(0, 3)]
[(4, 3)]
[(2, 1)]
[(8, 0)]


In [15]:
for i in candidates_per_env[1]:
    print(i.data)

[(2, 1)]
[(1, 1)]
[(5, 2)]
[(6, 0)]
[(3, 3)]
[(4, 3)]
[(0, 1)]
[(7, 0)]
