In [70]:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from torch.serialization import add_safe_globals

OBS_DIM = 126
ACT_DIM = 2
DEFAULT_PPO_DIMS = (512, 256, 128)
DEFAULT_SAC_DIMS = (512, 256, 128)
DEFAULT_ACTION_SET = np.asarray(
    [
        [-0.35, 0.90],
        [-0.15, 0.80],
        [0.00, 0.80],
        [0.15, 0.80],
        [0.35, 0.90],
        [-0.20, 0.30],
        [0.00, 0.30],
        [0.20, 0.30],
        [0.00, 0.00],
        [0.00, -0.50],
    ],
    dtype=np.float32,
)


def _register_pickle_safe_globals() -> None:
    try:
        add_safe_globals([np.core.multiarray._reconstruct])
    except Exception:
        pass


_register_pickle_safe_globals()


def make_dummy_state(obs_dim: int = OBS_DIM) -> torch.Tensor:
    return torch.zeros(1, obs_dim, dtype=torch.float32)


def _infer_ppo_dims(state_dict: dict) -> tuple[int, ...]:
    dims = []
    layer_idx = 0
    while True:
        key = f"body.{layer_idx}.weight"
        weight = state_dict.get(key)
        if weight is None:
            break
        dims.append(int(weight.shape[0]))
        layer_idx += 2
    return tuple(dims)


def _infer_sac_dims(state_dict: dict) -> tuple[tuple[int, ...], int]:
    dims = []
    layer_idx = 0
    while True:
        key = f"net.{layer_idx}.weight"
        weight = state_dict.get(key)
        if weight is None:
            break
        dims.append(int(weight.shape[0]))
        layer_idx += 2
    if not dims:
        raise RuntimeError("Unable to infer SAC architecture from checkpoint")
    final_out = dims[-1]
    if final_out % 2 != 0:
        raise RuntimeError(f"Unexpected SAC final layer size {final_out}; expected even number for mean/log_std pairs")
    act_dim = final_out // 2
    hidden_dims = tuple(dims[:-1])
    return hidden_dims, act_dim


def load_rainbow_actor(ckpt_path: Path, *, device: str = "cpu"):
    ckpt = torch.load(str(ckpt_path), map_location=device, weights_only=False)
    action_set = np.asarray(ckpt.get("action_set", DEFAULT_ACTION_SET), dtype=np.float32)
    obs_dim = int(ckpt.get("obs_dim", OBS_DIM))
    atoms = int(ckpt.get("atoms", 51))
    v_min = float(ckpt.get("v_min", -50.0))
    v_max = float(ckpt.get("v_max", 50.0))
    noisy = bool(ckpt.get("use_noisy", True))
    sigma0 = float(ckpt.get("noisy_sigma0", 0.4))
    model = RainbowQNetwork(
        obs_dim,
        action_set.shape[0],
        hidden_dims=DEFAULT_PPO_DIMS,
        atoms=atoms,
        v_min=v_min,
        v_max=v_max,
        noisy=noisy,
        sigma0=sigma0,
    ).to(device).eval()
    model.load_state_dict(ckpt["q_net"], strict=True)
    for module in model.modules():
        if isinstance(module, NoisyLinear):
            module.weight_epsilon.zero_()
            module.bias_epsilon.zero_()
    return model, action_set, ckpt


def load_ppo_actor(ckpt_path: Path, *, device: str = "cpu") -> "PPOActor":
    ckpt = torch.load(str(ckpt_path), map_location=device)
    if isinstance(ckpt, nn.Module):
        state = ckpt.state_dict()
    elif isinstance(ckpt, dict):
        actor_state = ckpt.get("actor")
        if isinstance(actor_state, dict):
            state = actor_state
        elif all(isinstance(k, str) for k in ckpt):
            state = ckpt
        else:
            state = None
    else:
        state = None
    if state is None:
        raise RuntimeError(f"Unsupported PPO checkpoint format: keys={list(ckpt.keys())}")
    hidden_dims = _infer_ppo_dims(state) or DEFAULT_PPO_DIMS
    actor = PPOActor(hidden_dims=hidden_dims).to(device).eval()
    actor.load_state_dict(state, strict=False)
    return actor


def load_td3_actor(ckpt_path: Path, *, device: str = "cpu") -> "TD3Actor":
    ckpt = torch.load(str(ckpt_path), map_location=device)
    state = None
    if isinstance(ckpt, nn.Module):
        state = ckpt.state_dict()
    elif isinstance(ckpt, dict):
        for key in ("actor", "policy", "model", "actor_state_dict", "policy_state_dict"):
            payload = ckpt.get(key)
            if isinstance(payload, dict):
                state = payload
                break
        if state is None and all(isinstance(k, str) for k in ckpt):
            state = ckpt
    if state is None:
        raise RuntimeError(f"Unsupported TD3 checkpoint format: keys={list(ckpt.keys())}")
    actor = TD3Actor().to(device).eval()
    actor.load_state_dict(state, strict=False)
    return actor


def load_sac_actor(ckpt_path: Path, *, device: str = "cpu") -> "GaussianPolicy":
    ckpt = torch.load(str(ckpt_path), map_location=device)
    actor_state = ckpt.get("actor") if isinstance(ckpt, dict) else None
    if not isinstance(actor_state, dict):
        raise RuntimeError(f"Unsupported SAC checkpoint format: keys={list(ckpt.keys())}")
    hidden_dims, act_dim = _infer_sac_dims(actor_state)
    policy = GaussianPolicy(hidden_dims=hidden_dims or DEFAULT_SAC_DIMS, act_dim=act_dim).to(device).eval()
    policy.load_state_dict(actor_state, strict=True)
    return policy



In [71]:
class NoisyLinear(nn.Module):
    def __init__(self, in_features: int, out_features: int, sigma0: float = 0.5) -> None:
        super().__init__()
        self.in_features = int(in_features)
        self.out_features = int(out_features)
        self.sigma0 = float(sigma0)
        weight_shape = (self.out_features, self.in_features)
        self.weight_mu = nn.Parameter(torch.empty(weight_shape))
        self.weight_sigma = nn.Parameter(torch.empty(weight_shape))
        self.register_buffer("weight_epsilon", torch.zeros(weight_shape))
        self.bias_mu = nn.Parameter(torch.empty(self.out_features))
        self.bias_sigma = nn.Parameter(torch.empty(self.out_features))
        self.register_buffer("bias_epsilon", torch.zeros(self.out_features))
        self.reset_parameters()
        self.reset_noise()

    def reset_parameters(self) -> None:
        bound = 1.0 / np.sqrt(self.in_features)
        self.weight_mu.data.uniform_(-bound, bound)
        self.bias_mu.data.uniform_(-bound, bound)
        sigma_weight = self.sigma0 / np.sqrt(self.in_features)
        sigma_bias = self.sigma0 / np.sqrt(self.out_features)
        self.weight_sigma.data.fill_(sigma_weight)
        self.bias_sigma.data.fill_(sigma_bias)

    def reset_noise(self) -> None:
        eps_in = self._scale_noise(self.in_features, device=self.weight_mu.device)
        eps_out = self._scale_noise(self.out_features, device=self.weight_mu.device)
        self.weight_epsilon.copy_(torch.outer(eps_out, eps_in))
        self.bias_epsilon.copy_(eps_out)

    @staticmethod
    def _scale_noise(size: int, *, device: torch.device) -> torch.Tensor:
        noise = torch.randn(size, device=device)
        return noise.sign().mul_(noise.abs().sqrt_())

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        if self.training:
            weight = self.weight_mu + self.weight_sigma * self.weight_epsilon
            bias = self.bias_mu + self.bias_sigma * self.bias_epsilon
        else:
            weight = self.weight_mu
            bias = self.bias_mu
        return F.linear(input, weight, bias)


class RainbowQNetwork(nn.Module):
    def __init__(
        self,
        input_dim: int,
        n_actions: int,
        hidden_dims=DEFAULT_HIDDEN_DIMS,
        *,
        atoms: int = 51,
        v_min: float = -50.0,
        v_max: float = 50.0,
        noisy: bool = True,
        sigma0: float = 0.4,
    ) -> None:
        super().__init__()
        self.n_actions = int(n_actions)
        self.atoms = int(atoms)
        self.v_min = float(v_min)
        self.v_max = float(v_max)
        self.noisy = bool(noisy)
        self.sigma0 = float(sigma0)
        self.hidden_layers = nn.ModuleList()
        self._noisy_layers: list[NoisyLinear] = []
        prev = input_dim
        for dim in hidden_dims:
            layer = self._make_linear(prev, int(dim))
            self.hidden_layers.append(layer)
            prev = int(dim)
        self.value_head = self._make_linear(prev, self.atoms)
        self.advantage_head = self._make_linear(prev, self.n_actions * self.atoms)
        support = torch.linspace(self.v_min, self.v_max, self.atoms)
        self.register_buffer("support", support)

    def _make_linear(self, in_dim: int, out_dim: int) -> nn.Module:
        if self.noisy:
            layer = NoisyLinear(in_dim, out_dim, sigma0=self.sigma0)
            self._noisy_layers.append(layer)
            return layer
        return nn.Linear(in_dim, out_dim)

    def reset_noise(self) -> None:
        if not self.noisy:
            return
        for layer in self._noisy_layers:
            layer.reset_noise()

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        x = obs
        for layer in self.hidden_layers:
            x = F.relu(layer(x))
        value = self.value_head(x).view(-1, 1, self.atoms)
        adv = self.advantage_head(x).view(-1, self.n_actions, self.atoms)
        adv = adv - adv.mean(dim=1, keepdim=True)
        return value + adv

    def q_values(self, obs: torch.Tensor) -> torch.Tensor:
        logits = self.forward(obs)
        probs = torch.softmax(logits, dim=-1)
        return torch.sum(probs * self.support, dim=-1)


In [72]:

class PPOActor(nn.Module):
    def __init__(self, obs_dim: int = OBS_DIM, hidden_dims=DEFAULT_PPO_DIMS, act_dim: int = ACT_DIM) -> None:
        super().__init__()
        layers = []
        prev = obs_dim
        for hid in hidden_dims:
            layers.extend([nn.Linear(prev, hid), nn.ReLU()])
            prev = hid
        self.body = nn.Sequential(*layers)
        self.mu_head = nn.Linear(prev, act_dim)
        self.log_std = nn.Parameter(torch.zeros(act_dim))

    def forward(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        features = self.body(obs)
        mu = self.mu_head(features)
        std = torch.exp(torch.clamp(self.log_std, -5.0, 2.0))
        return mu, std



In [73]:
class TD3Actor(nn.Module):
    def __init__(self, obs_dim: int = OBS_DIM, hidden_dims=DEFAULT_HIDDEN_DIMS, act_dim: int = ACT_DIM) -> None:
        super().__init__()
        layers = []
        prev = obs_dim
        for hid in hidden_dims:
            layers.extend([nn.Linear(prev, hid), nn.ReLU()])
            prev = hid
        layers.append(nn.Linear(prev, act_dim))
        layers.append(nn.Tanh())
        self.net = nn.Sequential(*layers)

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        return self.net(obs)


In [74]:

class GaussianPolicy(nn.Module):
    def __init__(self, obs_dim: int = OBS_DIM, hidden_dims=DEFAULT_SAC_DIMS, act_dim: int = ACT_DIM) -> None:
        super().__init__()
        layers = []
        prev = obs_dim
        for hid in hidden_dims:
            layers.extend([nn.Linear(prev, hid), nn.ReLU()])
            prev = hid
        layers.append(nn.Linear(prev, act_dim * 2))
        self.net = nn.Sequential(*layers)
        self.act_dim = int(act_dim)

    def forward(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        raw = self.net(obs)
        mu, log_std = torch.split(raw, self.act_dim, dim=-1)
        return mu, log_std



In [80]:
device = "cpu"
ckpt_path = Path("r_dqn_best_magic-sweep-1.pt")
if not ckpt_path.exists():
    print(f"Rainbow checkpoint missing: {ckpt_path.resolve()}")
else:
    q_net, action_set, meta = load_rainbow_actor(ckpt_path, device=device)
    obs = make_dummy_state(int(meta.get("obs_dim", OBS_DIM))).to(device)
    with torch.no_grad():
        q_vals = q_net.q_values(obs)
    best_idx = int(q_vals.argmax(dim=1))
    print(f"Loaded Rainbow Q-network from {ckpt_path}")
    print("Best discrete action index:", best_idx)
    print("Action (steer, throttle):", action_set[best_idx])


Loaded Rainbow Q-network from r_dqn_best_magic-sweep-1.pt
Best discrete action index: 1
Action (steer, throttle): [-0.15  0.8 ]


In [82]:
device = "cpu"
ckpt_path = Path("ppo_best_vague-sweep-1.pt")
if not ckpt_path.exists():
    print(f"PPO checkpoint missing: {ckpt_path.resolve()}")
else:
    actor = load_ppo_actor(ckpt_path, device=device)
    obs = make_dummy_state().to(device)
    with torch.no_grad():
        mu, std = actor(obs)
        action = torch.tanh(mu).cpu().numpy()[0]
    print(f"Loaded PPO checkpoint from {ckpt_path}")
    print("Mean action (tanh-squashed):", action)
    print("Std dev:", std.cpu().numpy())


Loaded PPO checkpoint from ppo_best_vague-sweep-1.pt
Mean action (tanh-squashed): [ 0.99759793 -0.9907341 ]
Std dev: [7.389056 4.357516]


In [84]:
device = "cpu"
ckpt_path = Path("td3_gaplock_young-sweep-2.pt")
if not ckpt_path.exists():
    print(f"TD3 checkpoint missing: {ckpt_path.resolve()}")
else:
    actor = load_td3_actor(ckpt_path, device=device)
    obs = make_dummy_state().to(device)
    with torch.no_grad():
        action = actor(obs).cpu().numpy()[0]
    print(f"Loaded TD3 checkpoint from {ckpt_path}")
    print("Action (steer, throttle):", action)


Loaded TD3 checkpoint from td3_gaplock_young-sweep-2.pt
Action (steer, throttle): [-0.9998891 -0.9999323]


In [86]:
device = "cpu"
ckpt_path = Path("sac_best_honest-sweep-1.pt")
if not ckpt_path.exists():
    print(f"SAC checkpoint missing: {ckpt_path.resolve()}")
else:
    policy = load_sac_actor(ckpt_path, device=device)
    obs = make_dummy_state().to(device)
    with torch.no_grad():
        mu, log_std = policy(obs)
        action = torch.tanh(mu).cpu().numpy()[0]
    print(f"Loaded SAC checkpoint from {ckpt_path}")
    print("Mean action (tanh-squashed):", action)
    print("Log std:", log_std.cpu().numpy())


Loaded SAC checkpoint from sac_best_honest-sweep-1.pt
Mean action (tanh-squashed): [-0.9682152  -0.99223673]
Log std: [[-0.69195795 -0.8922051 ]]
