In [1]:
%matplotlib tk

In [2]:
from acorl.envs.seeker.seeker_exploration import SeekerExplorationEnvConfig
import gymnasium as gym

from acorl.env_wrapper.adaption_fn import ConditionalAdaptionEnvWrapper
from acorl.acrl_algos.alpha_projection.mapping import alpha_projection_interface_fn
from acorl.envs.constraints.seeker import SeekerInputSetPolytopeCalculator
from rl_competition.competition.environment import create_exploration_seeker  # for env_config

from plot_utils import plot_seeker_obs, decode_obs, plot_mcts_tree_xy_limited

import numpy as np
from dataclasses import dataclass
from network import SeekerAlphaZeroNet
import torch
import torch.optim as optim
from MCTS_AC import MCTSPlanner_AC
from utils import (
    env_set_state,
)
import random


In [3]:
# --- Reproducibility ---
def set_global_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


RNG_SEED = 42
set_global_seeds(RNG_SEED)

In [4]:
# --- Device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [5]:
# === Config ===

@dataclass
class Config:
    # ========================
    # Environment
    # ========================
    max_episode_steps: int = 300  # max steps in environment before cut off (goal not reached, obstacle not crashed into --> prevent forever stepping)

    # ========================
    # MCTS core
    # ========================
    num_simulations: int = 400    # Number of MCTS simulations per real environment step
    cpuct: float = 4            # Exploration vs exploitation tradeoff in PUCT; Higher -> more exploration guided by policy prior
    max_depth: int = 64           # Safety cap on tree depth during a simulation

    # For root action selection / Action sampling temperature at root
    # >1.0 = more stochastic, 1.0 = proportional to visits, ~0 = greedy
    temperature: float = 1.0

    # ========================
    # Progressive Widening
    # ========================
    pw_k: float = 1.5
    # Controls how many actions are allowed per node:
    #   K_max = pw_k * N(s)^pw_alpha
    pw_alpha: float = 0.5
    # Growth rate of branching factor
    # 0.5 is common; smaller = more conservative expansion

    # ========================
    # Action sampling (baseline, non-fancy, but no duplicates)
    # ========================
    # --- Uniform warmstart ---
    # No uniform warmstart, no diversity scoring
    K_uniform_per_node: int = 8
    # First K children per node are sampled uniformly in [-1,1]^2
    # Set to 0 to disable
    warmstart_iters: int = 20
    # Number of *training iterations* during which ALL nodes use uniform sampling
    # 0 disables global warmstart; use this if you want uniform sampling only early in training

    # --- Novelty reject (hard deduplication) ---
    # Deduplicate actions (keep this ON to satisfy “no duplicate actions”)
    novelty_eps: float = 1e-3      # small but > 0
    # Minimum distance between actions to be considered "new"
    # In [-1,1]^2, values around 0.05–0.15 are reasonable
    # Set <=0 to disable
    novelty_metric: str = "l2"
    # Distance metric for novelty check:
    # "linf" = max(|dx|, |dy|)  (good for box action spaces)
    # "l2"   = Euclidean distance

    # --- Diversity scoring (soft repulsion) ---
    # Disable candidate scoring / diversity
    num_candidates: int = 1
    # Number of candidate actions sampled before choosing the best
    # <=1 disables diversity scoring
    diversity_lambda: float = 0.0
    # Strength of diversity penalty
    # Higher -> stronger push away from already-sampled actions
    # Set <=0 to disable
    diversity_sigma: float = 0.25  # unused
    # Length scale for diversity penalty
    # Roughly: how far actions must be before they stop "repelling" each other
    policy_beta: float = 1.0       # unused
    # Weight of policy log-probability in candidate scoring
    # Higher -> follow policy more closely
    # Lower -> prioritize diversity more

    # --- Resampling control ---
    max_resample_attempts: int = 16
    # How many times expansion may retry to find a novel action
    # If all fail, expansion is declined and MCTS falls back to selection
    
    # ========================
    # Training
    # ========================
    batch_size: int = 8
    learning_rate: float = 3e-4
    weight_decay: float = 1e-4
    train_steps_per_iter: int = 50    # Gradient updates per outer iteration

    # (Only used by our baseline loss function)
    value_loss_weight: float = 1.0
    policy_loss_weight: float = 1.0  # applies to mu/log_std regression

    # ========================
    # Data collection
    # ========================
    collect_episodes_per_iter: int = 10     # Number of real env episodes collected per training iteration
    replay_buffer_capacity: int = batch_size
    gamma_mcts: float = 0.85     # Discount factor for return backup in MCTS
    gamma_mc = 0.9

    # ========================
    # Logging / evaluation
    # ========================
    eval_every: int = 25
    eval_episodes: int = 10   # use 10 fixed seeds for smoother eval curves


cfg = Config()

In [6]:
def action_towards_closest_obstacle(agent, obstacles, *, normalize=True):
    """
    Compute an action vector pointing from the agent directly toward
    the closest obstacle center.

    Parameters
    ----------
    agent : array-like, shape (dim,)
        Agent position.
    obstacles : array-like, shape (N, dim+1)
        Obstacle positions + radius. Only positions are used.
    normalize : bool
        If True, return a unit direction vector.
        If False, return the raw displacement vector.

    Returns
    -------
    action : np.ndarray, shape (dim,)
        Action vector pointing toward the closest obstacle.
    idx : int
        Index of the closest obstacle.
    dist : float
        Distance to the closest obstacle center.
    """
    agent = np.asarray(agent)
    obstacles = np.asarray(obstacles)

    centers = obstacles[:, :-1]  # drop radius
    diffs = centers - agent      # vectors agent -> obstacle
    dists = np.linalg.norm(diffs, axis=1)

    idx = np.argmin(dists)
    direction = diffs[idx]
    dist = dists[idx]

    if normalize:
        if dist > 0:
            direction = direction / dist
        else:
            direction = np.zeros_like(direction)

    return direction, idx, dist

# ENVS

In [7]:
from rl_competition.competition.environment import create_exploration_seeker
from acorl.envs.constraints.seeker import SeekerInputSetPolytopeCalculator
from acorl.env_wrapper.adaption_fn import ConditionalAdaptionEnvWrapper
from acorl.acrl_algos.alpha_projection.mapping import alpha_projection_interface_fn

# --- Real environment for rollouts / data collection ---
env_config = SeekerExplorationEnvConfig(
        randomize=True,
        num_obstacles=10,
        dim=2,
        log=False,
    )
env_real = gym.make(env_config.id, **env_config.model_dump(exclude={'id'}))
env_real_AC = gym.make(env_config.id, **env_config.model_dump(exclude={'id'}))


obs_dim = env_real.observation_space.shape[0]
action_dim = env_real.action_space.shape[0]

print("obs_dim:", obs_dim, "action_dim:", action_dim)
print("action_space:", env_real.action_space)

# --- Simulation environment for MCTS step_fn ---
env_sim = gym.make(env_config.id, **env_config.model_dump(exclude={'id'}))
env_sim = env_sim.unwrapped

constraint_calculator = SeekerInputSetPolytopeCalculator(env_config=env_config)
env_sim_AC = ConditionalAdaptionEnvWrapper(env_sim, 
                                        constraint_calculator.compute_relevant_input_set,
                                        constraint_calculator.compute_fail_safe_input,
                                        constraint_calculator.get_set_representation(),
                                        alpha_projection_interface_fn)

env_real_AC = ConditionalAdaptionEnvWrapper(env_real_AC, 
                                        constraint_calculator.compute_relevant_input_set,
                                        constraint_calculator.compute_fail_safe_input,
                                        constraint_calculator.get_set_representation(),
                                        alpha_projection_interface_fn)

def sync_conditional_adaption_wrapper(
    env_wrapped,
    obs,
    *,
    constraint_calculator,
):
    """
    Sync ConditionalAdaptionEnvWrapper cache after env_set_state().
    """
    info = {
        "boundary_size": float(getattr(env_wrapped.unwrapped, "_size", 10.0)),  # SeekerEnv uses _size for boundary :contentReference[oaicite:2]{index=2}
    }

    # Compute constraint info for THIS obs
    info["relevant_input_set"] = constraint_calculator.compute_relevant_input_set(obs, info)
    info["fail_safe_input"] = constraint_calculator.compute_fail_safe_input(obs, info)

    # Optional but harmless
    if hasattr(env_wrapped.unwrapped, "_boundary_size"):
        info["boundary_size"] = env_wrapped.unwrapped._boundary_size

    env_wrapped._previous_obs = obs
    env_wrapped._previous_info = info

def step_fn(node, action):
    """
    MCTS transition function: set env_sim to `state`, take `action`, return next_state/reward/done/info.
    USES ACTION CONSTRAINED ENVIRONMENT
    Returns: next_state, reward, done, info  (matching MCTSPlanner expectations)
    """
    # 1) teleport base env
    env_set_state(env_sim_AC, node, num_obstacles=env_config.num_obstacles)

    # 2) sync wrapper cache
    obs = np.asarray(node.state, dtype=env_sim.unwrapped._dtype)
    sync_conditional_adaption_wrapper(
        env_sim_AC,
        obs,
        constraint_calculator=constraint_calculator,
    )


    # 3) step the WRAPPED env
    action = np.asarray(action, dtype=env_sim_AC.unwrapped._dtype)
    next_obs, reward, terminated, truncated, info = env_sim_AC.step(action)

    next_obs = np.array(next_obs, copy=True) #break reference to internal env buffer (??)
    
    done = bool(terminated or truncated)
    next_coin_collected = bool(getattr(env_sim_AC.unwrapped, "_coin_collected", False))
    
    return next_obs, next_coin_collected, reward, done, info 

obs_dim: 36 action_dim: 2
action_space: Box(-1.0, 1.0, (2,), float64)


In [8]:
np.random.seed(42)
obs0, info0 = env_real.reset(seed=42)

# NETWORK

In [9]:
# --- Network ---
net = SeekerAlphaZeroNet(obs_dim=obs_dim, action_dim=action_dim).to(device)

# Optional: print one forward pass sanity
obs_t = torch.from_numpy(obs0).float().unsqueeze(0).to(device)
with torch.no_grad():
    mu_t, log_std_t, v_t = net(obs_t)

print("mu:", mu_t.cpu().numpy())
print("log_std:", log_std_t.cpu().numpy())
print("v:", v_t.item())

# --- Optimizer (we'll use later) ---
optimizer = optim.AdamW(net.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)

mu: [[0.3417295  0.20982334]]
log_std: [[ 0.27098858 -0.38152447]]
v: 0.8045704364776611


# PLANNER

In [10]:
# === PLANNER === 
planner = MCTSPlanner_AC(
    net=net,
    device=str(device),
    step_fn=step_fn,
    num_simulations=cfg.num_simulations,
    cpuct=cfg.cpuct,
    gamma=cfg.gamma_mcts,
    pw_k=cfg.pw_k,
    pw_alpha=cfg.pw_alpha,
    max_depth=cfg.max_depth,
    temperature=cfg.temperature,
    rng=np.random.default_rng(RNG_SEED),
    
    K_uniform_per_node=cfg.K_uniform_per_node,
    warmstart_iters=cfg.warmstart_iters,
    novelty_eps=cfg.novelty_eps,
    novelty_metric=cfg.novelty_metric,
    num_candidates=cfg.num_candidates,
    diversity_lambda=cfg.diversity_lambda,
    diversity_sigma=cfg.diversity_sigma,
    policy_beta=cfg.policy_beta,
    max_resample_attempts=cfg.max_resample_attempts,
)

In [11]:
env_real

<TimeLimit<OrderEnforcing<PassiveEnvChecker<SeekerExplorationEnv<acorl-envs/SeekerExplorationEnv>>>>>

In [12]:
np.random.seed(42)
obs, info = env_real_AC.reset(seed=42)
np.random.seed(42)
obs, info = env_real.reset(seed=42)

In [13]:
plot_seeker_obs(obs, info, num_obstacles=10, env=env_real_AC, title="test")

In [14]:
action = [0, 1]
obs, reward, terminated, truncated, info = env_real_AC.step(action)
obs, reward, terminated, truncated, info = env_real_AC.step(action)
print("obs: ", obs)
print("reward: ", reward)
print("terminated: ", terminated)
print("truncated: ", truncated)
print("info: ", info)

obs:  [-2.26430661  9.99999056  4.17589095  1.77585272 -6.87962719 -6.88010959
  0.41783754 -8.83832776  7.32352292  1.60190286  2.02230023  4.16145156
  0.73762034 -9.58831011  9.39819704  1.11814671  6.64885282 -5.75321779
  2.55870768 -6.36350066 -6.3319098   0.2221866  -3.91515514  0.49512863
  1.81570856 -1.36109963 -4.1754172   0.65304704  2.23705789 -7.21012279
  0.45900391 -4.15710703 -2.67276313  2.72099101  9.31264066  6.16794696]
reward:  0
terminated:  False
truncated:  False
info:  {'distance': 10.445697103260219, 'boundary_size': 10, 'require_fail_safe': False, 'relevant_input_set': <acorl.convexsets.hpolytope.HPolytope object at 0x000002A044B713D0>, 'fail_safe_input': array([None, None], dtype=object), 'adapted_action': array([-0.00602875,  0.88713305])}


In [15]:
action = [0, 1]
obs, reward, terminated, truncated, info = env_real.step(action)
obs, reward, terminated, truncated, info = env_real.step(action)
print("obs: ", obs)
print("reward: ", reward)
print("terminated: ", terminated)
print("truncated: ", truncated)
print("info: ", info)

obs:  [-2.25827786 10.11285752  4.17589095  1.77585272 -6.87962719 -6.88010959
  0.41783754 -8.83832776  7.32352292  1.60190286  2.02230023  4.16145156
  0.73762034 -9.58831011  9.39819704  1.11814671  6.64885282 -5.75321779
  2.55870768 -6.36350066 -6.3319098   0.2221866  -3.91515514  0.49512863
  1.81570856 -1.36109963 -4.1754172   0.65304704  2.23705789 -7.21012279
  0.45900391 -4.15710703 -2.67276313  2.72099101  9.31264066  6.16794696]
reward:  -100.0
terminated:  True
truncated:  False
info:  {'distance': 10.531105229326139, 'boundary_size': 10}


In [16]:
np.random.seed(42)
obs, info = env_real.reset(seed=42)
for i in range(3):
    agent, goal, obstacles, coin, dim = decode_obs(obs, num_obstacles=10)
    direction, idx, dist = action_towards_closest_obstacle(agent, obstacles, normalize=False)
    action = np.clip(direction, -1.0, 1.0)
    obs, reward, terminated, truncated, info = env_real.step(action)
    print(info)

{'distance': 7.616679783569992, 'boundary_size': 10}
{'distance': 6.202536876075355, 'boundary_size': 10}
{'distance': 4.788435702054005, 'boundary_size': 10}


In [17]:
agent, goal, obstacles, coin, dim = decode_obs(obs, num_obstacles=10)
direction, idx, dist = action_towards_closest_obstacle(agent, obstacles, normalize=False)
action = 0.5*np.clip(direction, -1.0, 1.0)
obs, reward, terminated, truncated, info = env_real.step(action)
print(info)

{'distance': 4.0983404849984035, 'boundary_size': 10}


In [22]:
plot_seeker_obs(obs, info, num_obstacles=10, env=env_real, title="test")

In [23]:
root = planner.search(obs, coin_collected=bool(getattr(env_real.unwrapped, "_coin_collected", False)))

In [24]:
plot_mcts_tree_xy_limited(
    root,
    num_obstacles=10,
    title="MCTS tree near obstacle (2D)",
    max_depth=1,
    top_k_per_node=99,
    L=None,
)

<Axes: title={'center': 'MCTS tree near obstacle (2D)'}>

In [21]:
unsafe = [ch for ch in root.children if ch.child_node.is_terminal]
print(len(unsafe), "terminal children")

8 terminal children
