In [None]:
# ============================================================
# run_universal_constraints.py â€” FULL WORKING SCRIPT
# ============================================================
import os
import sys
module_path = os.path.abspath(os.path.join('..'))

if module_path not in sys.path:
    sys.path.append(module_path)

from gridworld_env_layout import GridWorldMDPFromLayoutEnv
from gridworld_env import NoisyLinearRewardFeaturizedGridWorldEnv
import numpy as np
from agent.q_learning_agent import ValueIteration, PolicyEvaluation
from scipy.optimize import linprog
from utils import generate_random_gridworld_envs

from utils import simulate_all_feedback
from utils import (
    compute_successor_features_family,
    derive_constraints_from_q_family,
    derive_constraints_from_atoms,
    generate_candidate_atoms_for_scot
)

In [None]:
# 1) Generate envs + solve with Value Iteration
W_TRUE = np.array([-10, -2]) / np.linalg.norm([-10, -2])

envs, meta = generate_random_gridworld_envs(
    n_envs=50,
    rows=3, cols=3,
    color_to_feature_map={"red":[1.0,0.0], "blue":[0.0,1.0]},
    palette=("red","blue"),
    p_color_range={"red":(0.2,0.6), "blue":(0.4,0.8)},
    terminal_policy=dict(kind="random_k", k_min=0, k_max=1, p_no_terminal=0.1),
    gamma_range=(0.98, 0.995),
    noise_prob_range=(0.0, 0.0),
    w_mode="fixed",
    W_fixed=W_TRUE,
    seed=45,
    GridEnvClass=GridWorldMDPFromLayoutEnv,
)

vis = [ValueIteration(e) for e in envs]
for v in vis:
    v.run_value_iteration(epsilon=1e-10)
Q_list = [v.get_q_values() for v in vis]


# 2) Successor features
SFs = compute_successor_features_family(
    envs,
    Q_list,
    convention="entering",
    zero_terminal_features=True,
    tol=1e-10,
    max_iters=10000,
)


# 3) Q-only constraints
## I probably need to make this parallel
U_q_per_env, U_q_global = derive_constraints_from_q_family(
    SFs,
    Q_list,
    envs,
    tie_eps=1e-10,
    skip_terminals=True,
    normalize=True,
    tol=1e-12,
    precision=1e-3,
    lp_epsilon=1e-4,
)



## I probably need to make this parallel
# 4) Simulate feedback atoms (pairwise, estop, improvement, demo)
atoms_per_env = simulate_all_feedback(envs, Q_list, n_base_trajs=200, n_improvements=200, n_pairwise=200, n_estops=200)


## I probably need to make this parallel
# 5) Atom-based constraints
U_atoms_per_env, U_atoms_global = derive_constraints_from_atoms(
    atoms_per_env,
    SFs,
    envs,
    precision=1e-3,
    lp_epsilon=1e-4,
)


# 6) Final Universal Set = union of Q-only + atom constraints
import numpy as np
from utils import remove_redundant_constraints

all_global = []
if len(U_q_global) > 0:
    all_global.append(U_q_global)
if len(U_atoms_global) > 0:
    all_global.append(U_atoms_global)

if all_global:
    stacked = np.vstack(all_global)
    U_universal = remove_redundant_constraints(stacked, epsilon=1e-4)
else:
    d = SFs[0][0].shape[-1]
    U_universal = np.zeros((0, d))

print("Universal constraint set size:", len(U_universal))

In [None]:
def scot_greedy_family_atoms_tracked(
    U_global,
    atoms_per_env,
    SFs,
    envs,
    *,
    normalize=True,
    round_decimals=12,
):
    """
    SCOT greedy selection over atoms with full environment tracking.

    Returns:
        chosen_atoms: list of (env_idx, Atom)
        env_stats: {
            env_idx: {
                'atoms': [Atom, Atom, ...],
                'indices': [0, 5, 9, ...],   # SCOT iteration numbers
                'coverage_counts': [12, 4, ...],  # new constraints covered each time
                'total_coverage': int
            }
        }
    """

    # ---------- Utility ----------
    def key_for(v):
        n = np.linalg.norm(v)
        if n == 0.0 or not np.isfinite(n):
            return ("ZERO",)
        vv = v / n if normalize else v
        return tuple(np.round(vv, round_decimals))

    # ---------- Build U_global dictionary ----------
    key_to_uix = {}
    for idx, v in enumerate(U_global):
        key_to_uix.setdefault(key_for(v), []).append(idx)

    universe = set(range(len(U_global)))
    covered  = set()
    chosen   = []

    # ---------- Tracking state ----------
    env_stats = {
        i: {
            "atoms": [],
            "indices": [],
            "coverage_counts": [],
            "total_coverage": 0,
        }
        for i in range(len(atoms_per_env))
    }

    # ---------- Precompute coverage for each atom ----------

    cov = []
    mu_sa_list = [sf[0] for sf in SFs]

    for env_idx, (atoms, sf, env) in enumerate(zip(atoms_per_env, SFs, envs)):
        mu_sa = sf[0]
        cov_i = []

        for atom in atoms:
            constraints = atom_to_constraints(atom, mu_sa, env)

            covered_set = set()
            for v in constraints:
                k = key_for(v)
                if k in key_to_uix:
                    covered_set.update(key_to_uix[k])

            cov_i.append(covered_set)

        cov.append(cov_i)

    # ---------- Greedy Loop ----------
    iter_count = 0

    while True:
        uncovered = universe - covered
        if not uncovered:
            break

        best_gain = 0
        best_atom = None
        best_new  = None

        for i in range(len(atoms_per_env)):
            for j, covered_by_atom in enumerate(cov[i]):
                if not covered_by_atom:
                    continue

                new_cover = uncovered & covered_by_atom
                gain = len(new_cover)

                if gain > best_gain:
                    best_gain = gain
                    best_atom = (i, j)
                    best_new = new_cover

        if best_atom is None:
            break

        i, j = best_atom
        atom = atoms_per_env[i][j]

        # Add to chosen list
        chosen.append((i, atom))
        covered |= best_new

        # ---------- Update env_stats ----------
        env_stats[i]["atoms"].append(atom)
        env_stats[i]["indices"].append(iter_count)
        env_stats[i]["coverage_counts"].append(len(best_new))
        env_stats[i]["total_coverage"] += len(best_new)

        iter_count += 1

    return chosen, env_stats

## how many envs got activated
## each env 
## we can provide some examples whereE-stop is more informative than others. specifically env without terminals
## but in expectation story is different

In [None]:
# generate candidate atoms
# pass them to SCOT and found the solution
candidates_per_env = generate_candidate_atoms_for_scot(
    envs,
    Q_list,
    use_q_demos=True,
    use_pairwise=True,
    use_estop=True,
    use_improvement=True,
    n_pairwise=200,
    n_estops=200,
    n_improvements=200
)



# Make random selection instead of SCOT
## Modify BIRL to support all types of feedback atoms
# compute the regret and visualize that

## Run the whole loop with 

In [None]:
chosen_atoms, stats = scot_greedy_family_atoms_tracked(
    U_universal,
    candidates_per_env,
    SFs,
    envs
)