# Imports, device and seeding

In [1]:
# --- Standard libs ---
import os
import math
import random
from dataclasses import dataclass
from typing import Dict, Any

# --- Scientific stack ---
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

# --- ACORL and RL ---
from acorl.envs.seeker.seeker import SeekerEnv, SeekerEnvConfig

# --- Own Code ---
from MCTS import MCTSPlanner
from network import SeekerAlphaZeroNet

# --- Device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


# --- Reproducibility ---
def set_global_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


SEED = 42
set_global_seeds(SEED)

Using device: cpu


# Config

In [2]:
@dataclass
class Config:
    # ========================
    # Environment
    # ========================
    max_episode_steps: int = 200
    # Maximum steps per real rollout episode


    # ========================
    # MCTS core
    # ========================
    num_simulations: int = 200
    # Number of MCTS simulations per real environment step

    cpuct: float = 1.5
    # Exploration vs exploitation tradeoff in PUCT
    # Higher -> more exploration guided by policy prior

    gamma: float = 0.99
    # Discount factor for return backup in MCTS

    max_depth: int = 64
    # Safety cap on tree depth during a simulation

    temperature: float = 1.0
    # Action sampling temperature at root
    # >1.0 = more stochastic, 1.0 = proportional to visits, ~0 = greedy


    # ========================
    # Progressive Widening
    # ========================
    pw_k: float = 2.0
    # Controls how many actions are allowed per node:
    #   K_max = pw_k * N(s)^pw_alpha

    pw_alpha: float = 0.5
    # Growth rate of branching factor
    # 0.5 is common; smaller = more conservative expansion


    # ========================
    # Action Sampling Improvements
    # ========================

    # --- Uniform warmstart ---
    K_uniform_per_node: int = 5
    # First K children per node are sampled uniformly in [-1,1]^2
    # Set to 0 to disable

    warmstart_iters: int = 0
    # Number of *training iterations* during which ALL nodes use uniform sampling
    # 0 disables global warmstart; use this if you want uniform sampling only early in training


    # --- Novelty reject (hard deduplication) ---
    novelty_eps: float = 0.1
    # Minimum distance between actions to be considered "new"
    # In [-1,1]^2, values around 0.05â€“0.15 are reasonable
    # Set <=0 to disable

    novelty_metric: str = "linf"
    # Distance metric for novelty check:
    # "linf" = max(|dx|, |dy|)  (good for box action spaces)
    # "l2"   = Euclidean distance


    # --- Diversity scoring (soft repulsion) ---
    num_candidates: int = 16
    # Number of candidate actions sampled before choosing the best
    # <=1 disables diversity scoring

    diversity_lambda: float = 1.0
    # Strength of diversity penalty
    # Higher -> stronger push away from already-sampled actions
    # Set <=0 to disable

    diversity_sigma: float = 0.25
    # Length scale for diversity penalty
    # Roughly: how far actions must be before they stop "repelling" each other

    policy_beta: float = 1.0
    # Weight of policy log-probability in candidate scoring
    # Higher -> follow policy more closely
    # Lower -> prioritize diversity more


    # --- Resampling control ---
    max_resample_attempts: int = 8
    # How many times expansion may retry to find a novel action
    # If all fail, expansion is declined and MCTS falls back to selection


    # ========================
    # Training
    # ========================
    batch_size: int = 128
    learning_rate: float = 3e-4
    weight_decay: float = 1e-4

    train_steps_per_iter: int = 200
    # Gradient updates per outer iteration

    value_loss_weight: float = 1.0
    policy_loss_weight: float = 1.0


    # ========================
    # Data collection
    # ========================
    episodes_per_iter: int = 10
    # Number of real env episodes collected per training iteration

    replay_buffer_capacity: int = 50_000


    # ========================
    # Logging / evaluation
    # ========================
    eval_every: int = 5
    eval_episodes: int = 5


cfg = Config()

## Sanity test

In [3]:
from pprint import pprint
pprint(cfg)

Config(max_episode_steps=200,
       num_simulations=200,
       cpuct=1.5,
       gamma=0.99,
       max_depth=64,
       temperature=1.0,
       pw_k=2.0,
       pw_alpha=0.5,
       K_uniform_per_node=5,
       warmstart_iters=0,
       novelty_eps=0.1,
       novelty_metric='linf',
       num_candidates=16,
       diversity_lambda=1.0,
       diversity_sigma=0.25,
       policy_beta=1.0,
       max_resample_attempts=8,
       batch_size=128,
       learning_rate=0.0003,
       weight_decay=0.0001,
       train_steps_per_iter=200,
       value_loss_weight=1.0,
       policy_loss_weight=1.0,
       episodes_per_iter=10,
       replay_buffer_capacity=50000,
       eval_every=5,
       eval_episodes=5)


# Create env_real, env_sim, dims (for network), and step_fn (for MCTSPlanner)

In [4]:
# --- Env config ---
env_config = SeekerEnvConfig(randomize=False, num_obstacles=1)

# --- Real environment for rollouts / data collection ---
env_real = SeekerEnv(**env_config.model_dump(exclude={"id"}))
obs0, info0 = env_real.reset()

obs_dim = env_real.observation_space.shape[0]
action_dim = env_real.action_space.shape[0]

print("obs_dim:", obs_dim, "action_dim:", action_dim)
print("action_space:", env_real.action_space)

# --- Simulation environment for MCTS step_fn ---
env_sim = SeekerEnv(**env_config.model_dump(exclude={"id"}))


def set_env_state_from_obs(sim_env: SeekerEnv, obs: np.ndarray):
    """
    Overwrite SeekerEnv internal state to match the flat observation vector.

    obs layout (from your old notebook):
      [agent_x, agent_y, goal_x, goal_y, (obs_x, obs_y, obs_r)*N]
    """
    obs = np.asarray(obs, dtype=sim_env._dtype)

    # agent and goal
    sim_env._agent_position = obs[0:2].copy()
    sim_env._goal_position = obs[2:4].copy()

    # obstacles
    obstacles = obs[4:].reshape(-1, 3)
    sim_env._obstacle_position = obstacles[:, 0:2].copy()
    sim_env._obstacle_radius = obstacles[:, 2].copy()

def step_fn(state: np.ndarray, action: np.ndarray):
    """
    MCTS transition function: set env_sim to `state`, take `action`, return next_state/reward/done/info.
    Returns: next_state, reward, done, info  (matching MCTSPlanner expectations)
    """
    set_env_state_from_obs(env_sim, state)

    action = np.asarray(action, dtype=env_sim._dtype)
    next_obs, reward, terminated, truncated, info = env_sim.step(action)
    done = bool(terminated or truncated)

    return next_obs, float(reward), done, info

obs_dim: 7 action_dim: 2
action_space: Box(-1.0, 1.0, (2,), float64)


# Instantiate neural network and optimizer

In [5]:
# --- Network ---
net = SeekerAlphaZeroNet(obs_dim=obs_dim, action_dim=action_dim).to(device)

# Optional: print one forward pass sanity
obs_t = torch.from_numpy(obs0).float().unsqueeze(0).to(device)
with torch.no_grad():
    mu_t, log_std_t, v_t = net(obs_t)

print("mu:", mu_t.cpu().numpy())
print("log_std:", log_std_t.cpu().numpy())
print("v:", v_t.item())

# --- Optimizer (we'll use later) ---
optimizer = optim.AdamW(net.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)

mu: [[ 0.05101037 -0.15922213]]
log_std: [[0.1659213 0.4990744]]
v: 0.1396380364894867


# Instantiate MCTSPlanner

In [6]:
planner = MCTSPlanner(
    net=net,
    device=str(device),
    step_fn=step_fn,
    num_simulations=cfg.num_simulations,
    cpuct=cfg.cpuct,
    gamma=cfg.gamma,
    pw_k=cfg.pw_k,
    pw_alpha=cfg.pw_alpha,
    max_depth=cfg.max_depth,
    temperature=cfg.temperature,
    rng=np.random.default_rng(SEED),
    K_uniform_per_node=cfg.K_uniform_per_node,
    warmstart_iters=cfg.warmstart_iters,
    novelty_eps=cfg.novelty_eps,
    novelty_metric=cfg.novelty_metric,
    num_candidates=cfg.num_candidates,
    diversity_lambda=cfg.diversity_lambda,
    diversity_sigma=cfg.diversity_sigma,
    policy_beta=cfg.policy_beta,
    max_resample_attempts=cfg.max_resample_attempts,
)

# Smoke test: one MCTS search, inspect root, pick action, step env_real

In [7]:
# Reset real env
obs, info = env_real.reset()

# Run one MCTS search from the current observation
root = planner.search(obs)

print("Root visit count N:", root.N)
print("Root children K:", len(root.children))

# Show a few children stats
for i, ch in enumerate(root.children[:5]):
    print(
        f"[{i}] N_sa={ch.N_sa:4d}  Q_sa={ch.Q_sa:+.4f}  "
        f"P_raw={ch.P_sa_raw:.3e}  P={ch.P_sa:.3f}  action={ch.action}"
    )

# Pick an action from MCTS policy (training=True samples from visit counts)
action = planner.act(root, training=True)
print("Chosen action:", action)

# Step the real environment once
next_obs, reward, terminated, truncated, info = env_real.step(action)
done = bool(terminated or truncated)

print("Step result -> reward:", reward, "done:", done)
print("Obs delta L2:", np.linalg.norm(next_obs - obs))


Root visit count N: 200
Root children K: 29
[0] N_sa=   2  Q_sa=-1.8590  P_raw=7.835e-02  P=0.043  action=[-0.1131716 -0.5455226]
[1] N_sa=   2  Q_sa=-1.4264  P_raw=7.958e-02  P=0.044  action=[-0.08216845  0.13748239]
[2] N_sa=   2  Q_sa=-1.3088  P_raw=7.689e-02  P=0.042  action=[ 0.4447187  -0.07624554]
[3] N_sa=   6  Q_sa=-1.0696  P_raw=7.059e-02  P=0.039  action=[0.51703906 0.43892592]
[4] N_sa=   3  Q_sa=-1.2001  P_raw=6.342e-02  P=0.035  action=[-0.12980588  0.9847511 ]
Chosen action: [1. 1.]
Step result -> reward: 0.4142126045006318 done: False
Obs delta L2: 1.4142135623730951


In [9]:
# Show all children stats
for i, ch in enumerate(root.children):
    print(
        f"[{i}] N_sa={ch.N_sa:4d}  Q_sa={ch.Q_sa:+.4f}  "
        f"P_raw={ch.P_sa_raw:.3e}  P={ch.P_sa:.3f}  action={ch.action}"
    )

actions = np.stack([ch.action for ch in root.children], axis=0)
print("unique rows:", np.unique(actions, axis=0).shape[0], " / ", actions.shape[0])

[0] N_sa=   2  Q_sa=-1.8590  P_raw=7.835e-02  P=0.043  action=[-0.1131716 -0.5455226]
[1] N_sa=   2  Q_sa=-1.4264  P_raw=7.958e-02  P=0.044  action=[-0.08216845  0.13748239]
[2] N_sa=   2  Q_sa=-1.3088  P_raw=7.689e-02  P=0.042  action=[ 0.4447187  -0.07624554]
[3] N_sa=   6  Q_sa=-1.0696  P_raw=7.059e-02  P=0.039  action=[0.51703906 0.43892592]
[4] N_sa=   3  Q_sa=-1.2001  P_raw=6.342e-02  P=0.035  action=[-0.12980588  0.9847511 ]
[5] N_sa=   1  Q_sa=-1.4354  P_raw=7.030e-02  P=0.038  action=[ 0.2602494 -1.       ]
[6] N_sa=   1  Q_sa=-1.6558  P_raw=5.476e-02  P=0.030  action=[-1.         -0.10073842]
[7] N_sa=   1  Q_sa=-2.0700  P_raw=5.946e-02  P=0.033  action=[-0.6640612 -1.       ]
[8] N_sa=   4  Q_sa=-1.2808  P_raw=5.875e-02  P=0.032  action=[ 1.        -0.2990462]
[9] N_sa=  18  Q_sa=-1.0355  P_raw=5.998e-02  P=0.033  action=[0.46458045 1.        ]
[10] N_sa=   2  Q_sa=-1.5904  P_raw=5.176e-02  P=0.028  action=[ 1. -1.]
[11] N_sa=   2  Q_sa=-1.7778  P_raw=4.965e-02  P=0.027  act