In [8]:
import torch
import torch.nn as nn
from pathlib import Path

INPUT_DIM = 116   # 108 lidar + attacker(3+1) + target(3+1) = 116
HID = 256
OUT_DIM = 2       # your action dim

class ActorMLP(nn.Module):
    def __init__(self, in_dim=INPUT_DIM, hid=HID, out_dim=OUT_DIM):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hid),
            nn.ReLU(),
            nn.Linear(hid, hid),
            nn.ReLU(),
            nn.Linear(hid, out_dim),)
    def forward(self, x):
        return self.net(x)

def load_actor_flex(ckpt_path, model: nn.Module, device="cpu"):

    ckpt = torch.load(ckpt_path, map_location=device)

    # Case 1: Full module
    if isinstance(ckpt, nn.Module):
        model.load_state_dict(ckpt.state_dict())
        return model

    # Case 2: Plain state_dict (looks like mapping of tensor names to tensors)
    if isinstance(ckpt, dict) and all(isinstance(k, str) for k in ckpt.keys()) \
       and any(k.startswith("net.") or "weight" in k for k in ckpt.keys()):
        model.load_state_dict(ckpt)
        return model

    # Case 3: Container dict (TD3 style)
    # Try common keys
    for key in ["actor", "policy", "model", "actor_state_dict", "policy_state_dict"]:
        if key in ckpt and isinstance(ckpt[key], dict):
            try:
                model.load_state_dict(ckpt[key], strict=False)
                return model
            except Exception:
                pass

    # If we get here, we don't know the shape â€” print keys to help debug
    raise RuntimeError(
        f"Unrecognized checkpoint structure. Top-level keys: {list(ckpt.keys())}"
    )

def make_state(
    lidar108, attacker_xyz_theta_coll, target_xyz_theta_coll
):
    """
    Pack observation into a (1, 116) tensor.
    - lidar108: length-108 iterable
    - attacker_xyz_theta_coll: (x, y, theta, collision_flag)
    - target_xyz_theta_coll: (x, y, theta, collision_flag)
    """
    import numpy as np
    arr = np.concatenate([
        np.asarray(lidar108, dtype=np.float32).reshape(108),
        np.asarray(attacker_xyz_theta_coll, dtype=np.float32).reshape(4),
        np.asarray(target_xyz_theta_coll, dtype=np.float32).reshape(4),
    ]).reshape(1, -1)  # (1,116)
    return torch.from_numpy(arr)

if __name__ == "__main__":
    device = "cpu"  # or "cuda" if available and desired

    # Point to your files (adjust paths as needed)
    td3_path = Path("td3_gaplock.pt")          # e.g., "/mnt/data/td3_gaplock.pt"
    ppo_path = Path("ppo_gaplock.pt")          # optional, if you want to test PPO file too

    # Build model skeleton (ReLU MLP 116-256-256-2, per Aaron's confirmation)
    actor = ActorMLP().to(device)
    actor.eval()

    # --- Load TD3 actor ---
    actor = load_actor_flex(td3_path, actor, device=device)
    actor.eval()

    # Dummy example inputs (replace with real lidar + VICON)
    lidar108 = [0.1]*108          # put your 108 downsampled beams here
    attacker = [0.0, 0.0, 0.0, 0] # x,y,theta,collision_flag(0/1)
    target   = [1.0, 0.3, 0.1, 0] # x,y,theta,collision_flag(0/1)

    state = make_state(lidar108, attacker, target).to(device)  # (1,116)

    with torch.no_grad():
        action = actor(state)      # (1,2)
    print("TD3 action:", action.cpu().numpy())

    # --- (Optional) Try PPO file with same head if its actor shares this shape ---
    if ppo_path.exists():
        ppo_actor = ActorMLP().to(device)
        try:
            ppo_actor = load_actor_flex(ppo_path, ppo_actor, device=device)
            ppo_actor.eval()
            with torch.no_grad():
                ppo_action = ppo_actor(state)
            print("PPO action:", ppo_action.cpu().numpy())
        except Exception as e:
            print("PPO load failed:", e)


TD3 action: [[-0.03003868  0.00363896]]
PPO action: [[-0.00662935 -0.04710494]]
