# Imports

In [1]:
import numpy as np
import torch
import torch.nn as nn

from acorl.envs.seeker.seeker import SeekerEnv, SeekerEnvConfig


# Dimensions for observations and actions

In [2]:
# Config: simple deterministic 2D, 1 obstacle
env_config = SeekerEnvConfig(randomize=False, num_obstacles=1)

# Create env directly from config
env = SeekerEnv(**env_config.model_dump(exclude={"id"}))

obs, info = env.reset()
obs_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

print("obs_dim:", obs_dim)
print("action_dim:", action_dim)
print("example obs:", obs)
print("action_space:", env.action_space)

obs_dim: 7
action_dim: 2
example obs: [-5.96127573 -6.03124677  6.          6.          0.          0.
  4.        ]
action_space: Box(-1.0, 1.0, (2,), float64)


# Neural Net

## Define

In [3]:
class SeekerAlphaZeroNet(nn.Module):
    def __init__(self, obs_dim: int, action_dim: int, hidden_sizes=(128, 128)):
        super().__init__()
        layers = []
        last_dim = obs_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(last_dim, h))
            layers.append(nn.ReLU())
            last_dim = h
        self.body = nn.Sequential(*layers)

        # Policy head: mean + log_std for Gaussian over actions
        self.mu_head = nn.Linear(last_dim, action_dim)
        self.log_std_head = nn.Linear(last_dim, action_dim)

        # Value head: scalar
        self.v_head = nn.Linear(last_dim, 1)

    def forward(self, obs: torch.Tensor):
        # obs: [batch, obs_dim]
        x = self.body(obs)
        mu = self.mu_head(x)                         # [batch, action_dim]
        log_std = self.log_std_head(x).clamp(-5, 2)  # keep std in a sane range
        v = self.v_head(x).squeeze(-1)               # [batch]
        return mu, log_std, v

## Instantiate

In [4]:
net = SeekerAlphaZeroNet(obs_dim=obs_dim, action_dim=action_dim)

# quick sanity check
obs_tensor = torch.from_numpy(obs).float().unsqueeze(0)  # [1, obs_dim]
mu, log_std, v = net(obs_tensor)

print("mu:", mu)
print("log_std:", log_std)
print("v:", v)

mu: tensor([[-0.3429, -0.2585]], grad_fn=<AddmmBackward0>)
log_std: tensor([[ 0.0247, -0.0717]], grad_fn=<ClampBackward1>)
v: tensor([-0.5850], grad_fn=<SqueezeBackward1>)


## Distribution

### Helper functions for using the network output as parameters for normal distribution and sampling from that

In [5]:
from torch.distributions import Normal, Independent

def policy_dist(mu: torch.Tensor, log_std: torch.Tensor):
    """Return a diagonal Gaussian policy distribution."""
    std = log_std.exp()
    return Independent(Normal(mu, std), 1)

In [6]:
def eval_policy(net, obs_np):
    """
    obs_np: numpy array, shape (obs_dim,)
    
    Returns:
        dist  -- PyTorch distribution Ï€(a|s)
        v     -- value estimate (float)
        mu    -- mean vector (np)
        log_std -- log_std vector (np)
    """
    obs_t = torch.from_numpy(obs_np).float().unsqueeze(0)  # [1, obs_dim]

    with torch.no_grad():
        mu, log_std, v = net(obs_t)

    dist = policy_dist(mu, log_std)

    return dist, float(v.item()), mu.squeeze(0).numpy(), log_std.squeeze(0).numpy()

In [7]:
def sample_action(net, obs_np):
    """
    Samples a raw action from the policy network.
    Returns a numpy array of shape (action_dim,).
    """
    dist, _, _, _ = eval_policy(net, obs_np)
    a = dist.sample()        # tensor [1, action_dim]
    a = a.squeeze(0).detach().cpu().numpy()
    return a


In [8]:
def mean_action(net, obs_np):
    mu, log_std, v = net(torch.from_numpy(obs_np).float().unsqueeze(0))
    return mu.squeeze(0).detach().numpy()


### Test

In [9]:
obs, info = env.reset()
a_raw = sample_action(net, obs)
print("sampled action:", a_raw)

dist, v, mu, log_std = eval_policy(net, obs)
print("policy mean:", mu)
print("value estimate:", v)


sampled action: [-0.64780056 -1.1480635 ]
policy mean: [-0.34804317 -0.25869226]
value estimate: -0.5856422185897827


# MCTS

## Child Class (Edge stats for PUCT) and Node Class

In [4]:
from __future__ import annotations
from dataclasses import dataclass, field
import numpy as np
from typing import Optional, List, Dict, Any


@dataclass
class Child:
    action: np.ndarray
    child_node: MCTSNode
    N_sa: int = 0
    Q_sa: float = 0.0
    P_sa: float = 0.0

@dataclass
class MCTSNode:
    state: np.ndarray
    parent: Optional[MCTSNode] = None
    parent_action: Optional[np.ndarray] = None

    N: int = 0
    children: List[Child] = field(default_factory=list)

    mu: Optional[np.ndarray] = None
    log_std: Optional[np.ndarray] = None
    v: Optional[float] = None

    is_terminal: bool = False
    terminal_value: Optional[float] = None

    def is_fully_expanded(self, k: float, alpha: float) -> bool:
        K_max = k * (self.N ** alpha)
        return len(self.children) >= K_max

    def add_child(self, action: np.ndarray, child_state: np.ndarray, prior: float) -> "MCTSNode":
        child_node = MCTSNode(state=child_state, parent=self, parent_action=action)
        self.children.append(Child(action=action, child_node=child_node, P_sa=prior))
        return child_node
