In [1]:
# model_and_tokenizer.py
import torch
import torch.nn as nn

# adjust the import path to your project layout if needed
from src.models.modules.modeling_pact import PACTBase
# PACTBase wraps PACTTokenizer internally and builds the GPT backbone. :contentReference[oaicite:6]{index=6}

def build_pact_model(state_dim=38, action_dim=2, ctx_tokens=16,
                     n_embd=128, n_layer=4, n_head=8):
    """
    ctx_tokens = number of tokens the transformer sees (must be even).
    Because GPT block_size = 2 * seq_len, set seq_len = ctx_tokens // 2.  :contentReference[oaicite:7]{index=7}
    """
    assert ctx_tokens % 2 == 0, "ctx_tokens must be even"
    seq_len = ctx_tokens // 2

    gpt_config = dict(
        n_embd=n_embd, n_layer=n_layer, n_head=n_head,
        embd_pdrop=0.1, resid_pdrop=0.1, attn_pdrop=0.1,
        seq_len=seq_len
    )

    # Different tokenizers for "state" vs "action"
    # state -> VectorStateTokenizer (mlp_state)
    # action -> ActionTokenizer (mlp_action)
    # The action tokenizer handles both continuous and discrete; we use "continuous" here. :contentReference[oaicite:8]{index=8} :contentReference[oaicite:9]{index=9}
    input_config = {
        "state": {
            "tokenizer": "mlp_state",
            "input_type": "continuous",
            "tokenizer_kwargs": {"state_dim": state_dim, "hidden": [256, 256], "use_ln": True},
        },
        "action": {
            "tokenizer": "mlp_action",
            "input_type": "continuous",  # change to "discrete" if your actions are ids
            "tokenizer_kwargs": {"action_dim": action_dim, "hidden": [128, 128], "use_ln": True},
        },
    }

    model = PACTBase(gpt_config=gpt_config, input_config=input_config)
    return model


  rank_zero_deprecation(_deprecate_registry_message)


In [5]:
from pact_online_agent_mpc_qp_gif_v3 import run_episode_and_save_png

# Runs one episode with the new architecture if available (falls back to nominal-only),
# and saves a static PNG similar to sa_di_4obs_from_init.png
out = run_episode_and_save_png(
    path="traj_new.png",
    device="cpu",   # switch to "cuda" when your CUDA is OK
    steps=256,
    H=8,            # MPC-like horizon in the new runner
    area_size=6.0,
    draw_trail=True,
    linewidth=2.0,
    seed=0
)
print(out)




{'steps': 256, 'png': 'traj_new.png'}


In [2]:
x, y = obs['state'][:2] if available
gx, gy = obs['goal'][:2] if available


SyntaxError: expected 'else' after 'if' expression (1813692315.py, line 1)

In [2]:
# context.py
import torch

def make_context_window(states, actions, t, ctx_tokens=16, pad_mode="repeat_first"):
    """
    states:  (N, S) float32
    actions: (N, A) float32  (or (N,) long for discrete actions)
    t:       current time index (0-based), we will produce a window ending at t (inclusive for state)
    Returns:
      state_seq:  (1, T, S)
      action_seq: (1, T, A) or (1, T) for discrete
    """
    assert ctx_tokens % 2 == 0, "ctx_tokens must be even"
    T = ctx_tokens // 2

    S = states.shape[-1]
    A = actions.shape[-1] if actions.ndim == 2 else 1

    # indices we want
    s_start, s_end = t - T + 1, t + 1     # [t-T+1, t] inclusive -> python slice [s_start:s_end)
    a_start, a_end = t - T,     t         # [t-T, t-1] inclusive  -> python slice [a_start:a_end)

    def left_pad_take(x, start, end, feature_dim, is_action=False):
        # x: (N, D)
        N = x.shape[0]
        if start >= 0:
            out = x[start:end]
        else:
            need = -start
            if pad_mode == "zeros":
                pad = torch.zeros(need, feature_dim, dtype=x.dtype, device=x.device)
            else:  # repeat_first
                pad = x[0:1].expand(need, feature_dim).clone()
            out = torch.cat([pad, x[0:end]], dim=0)
        # if we run past the end (rare at episode end), right-pad similarly
        if out.shape[0] < (end - start):
            need = (end - start) - out.shape[0]
            if pad_mode == "zeros":
                pad = torch.zeros(need, feature_dim, dtype=x.dtype, device=x.device)
            else:
                pad = x[-1:].expand(need, feature_dim).clone()
            out = torch.cat([out, pad], dim=0)
        return out

    state_win  = left_pad_take(states,  s_start, s_end, S)
    if actions.ndim == 2:  # continuous
        action_win = left_pad_take(actions, a_start, a_end, A)
    else:                  # discrete ids
        # treat as (N,1) then squeeze back later
        action_win = left_pad_take(actions.view(-1, 1), a_start, a_end, 1).view(-1)

    # add batch dim
    state_seq  = state_win.unsqueeze(0)                 # (1, T, S)
    action_seq = action_win.unsqueeze(0)                # (1, T, A) or (1, T)
    return state_seq, action_seq


In [11]:
import numpy as np
data = np.load("pact_dataset.npz", allow_pickle=True)
data.files

['episode',
 'timestep',
 'state',
 'goal',
 'lidar',
 'full_obs',
 'nominal_action',
 'taken_action',
 'perturbed',
 'mode_excited',
 'excitation_direction',
 'reward',
 'cost',
 'safe_mask',
 'unsafe_mask',
 'collision_mask',
 'finish_mask',
 'next_state',
 'next_full_obs',
 'next_safe_mask',
 'metadata']

In [19]:
data['full_obs'].shape, data['taken_action'], data['nominal_action'], data['safe_mask']

((512, 38),
 array([[-7.4633515e-01,  2.6657838e-01],
        [-6.8383753e-01,  2.4425530e-01],
        [-6.1303717e-01,  1.6914645e-01],
        ...,
        [-2.5319010e-03, -4.0341003e-04],
        [-2.4128556e-03, -3.6737768e-04],
        [-3.3637848e-02,  1.4022903e-02]], shape=(512, 2), dtype=float32),
 array([[-7.4633515e-01,  2.6657838e-01],
        [-6.8383753e-01,  2.4425530e-01],
        [-6.2352717e-01,  2.2271338e-01],
        ...,
        [-2.5319010e-03, -4.0341003e-04],
        [-2.4128556e-03, -3.6737768e-04],
        [-2.3188367e-03, -3.4411764e-04]], shape=(512, 2), dtype=float32),
 array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True, False, False, False, False,
        False, False, False, False, F

In [4]:
# call_tokenizer_only.py
import torch
from src.models.modules.tokenizer_pact import PACTTokenizer  # same module PACTBase uses internally
# from model_and_tokenizer import build_pact_model

# Build the model to get consistent n_embd and tokenizers
pact = build_pact_model(state_dim=38, action_dim=2, ctx_tokens=16)
tok: PACTTokenizer = pact.tokenizer  # grab the tokenizer PACTBase created

# Suppose you already built a window:
# state_seq:  (B, T, 38), action_seq: (B, T, 2) or (B, T) discrete
# For this example we make dummy data:
B, T, S, A = 4, 8, 38, 2
state_seq  = torch.randn(B, T, S)
action_seq = torch.randn(B, T, A)

emb = tok({"state": state_seq, "action": action_seq})
print({k: v.shape for k,v in emb.items()})
# {'state': (B, T, n_embd), 'action': (B, T, n_embd)}


{'state': torch.Size([4, 8, 128]), 'action': torch.Size([4, 8, 128])}


In [69]:
# get_embeddings_from_npz.py
import numpy as np
import torch

# --- adjust paths to your repo if needed ---
from src.models.modules.modeling_pact import PACTBase          # backbone (tokenizer + GPT)  :contentReference[oaicite:1]{index=1}
from src.models.modules.tokenizer_pact import PACTTokenizer     # tokenizer API               :contentReference[oaicite:2]{index=2}

# ---------------------------
# 1) Build the PACT backbone
# ---------------------------
def build_pact_model(state_dim=38, action_dim=2, ctx_tokens=16,
                     n_embd=128, n_layer=4, n_head=8):
    assert ctx_tokens % 2 == 0, "ctx_tokens must be even"
    seq_len = ctx_tokens // 2  # 16 tokens -> 8 pairs; GPT block_size = 2*seq_len  :contentReference[oaicite:3]{index=3}

    gpt_config = dict(
        n_embd=n_embd, n_layer=n_layer, n_head=n_head,
        embd_pdrop=0.1, resid_pdrop=0.1, attn_pdrop=0.1,
        seq_len=seq_len
    )

    # Different tokenizers for state vs action (both continuous here)
    input_config = {
        "state": {
            "tokenizer": "mlp_state",     # VectorStateTokenizer under the hood
            "input_type": "continuous",
            "tokenizer_kwargs": {"state_dim": state_dim, "hidden": [256, 256], "use_ln": True},
        },
        "action": {
            "tokenizer": "mlp_action",    # ActionTokenizer under the hood
            "input_type": "continuous",   # set to "discrete" if your actions are ids
            "tokenizer_kwargs": {"action_dim": action_dim, "hidden": [128, 128], "use_ln": True},
        },
    }

    # PACTBase = {tokenizer, positional embeddings, GPT(minGPT)}  :contentReference[oaicite:4]{index=4}
    model = PACTBase(gpt_config=gpt_config, input_config=input_config)
    model.eval()
    return model

# --------------------------------------------------------
# 2) Context slicing to realize [a_{t-8},..., a_{t-1}, s_t]
# --------------------------------------------------------
def make_context_window(states, actions, t, ctx_tokens=16, pad_mode="repeat_first"):
    """
    states : (L, S)
    actions: (L, A)   (or (L,) for discrete)
    Build 8 pairs (T=ctx_tokens//2):
        states  -> [s_{t-7}, ..., s_t]      length T
        actions -> [a_{t-8}, ..., a_{t-1}]  length T
    """
    assert ctx_tokens % 2 == 0
    T = ctx_tokens // 2
    S = states.shape[-1]
    A = actions.shape[-1] if actions.ndim == 2 else 1

    s_start, s_end = t - T + 1, t + 1
    a_start, a_end = t - T,     t

    def left_pad_take(x, start, end, feat):
        # x: (L, feat); returns exactly (end-start, feat)
        if start >= 0:
            out = x[start:end]
        else:
            need = -start
            pad = np.zeros((need, feat), dtype=x.dtype) if pad_mode == "zeros" else np.repeat(x[0:1], need, axis=0)
            out = np.concatenate([pad, x[0:end]], axis=0)
        if out.shape[0] < (end - start):  # right-pad at episode tail if needed
            need = (end - start) - out.shape[0]
            pad = np.zeros((need, feat), dtype=x.dtype) if pad_mode == "zeros" else np.repeat(x[-1:], need, axis=0)
            out = np.concatenate([out, pad], axis=0)
        return out

    s_win = left_pad_take(states,  s_start, s_end, S)
    if actions.ndim == 2:
        a_win = left_pad_take(actions, a_start, a_end, A)
    else:
        a_win = left_pad_take(actions.reshape(-1,1), a_start, a_end, 1).reshape(-1)

    # add batch dim -> (1, T, ·)
    s_win = torch.as_tensor(s_win, dtype=torch.float32).unsqueeze(0)
    a_win = torch.as_tensor(a_win, dtype=torch.float32 if actions.ndim == 2 else torch.long).unsqueeze(0)
    return s_win, a_win

# ------------------------------------------------------------
# 3) Utilities to split interleaved transformer output streams
# ------------------------------------------------------------
def split_state_action_embeddings(out_embd):
    """
    out_embd: (B, 2*T, d) interleaved as [state_0, action_0, state_1, action_1, ...]
    returns:  state_out:(B,T,d), action_out:(B,T,d)
    """
    return out_embd[:, 0::2, :], out_embd[:, 1::2, :]

# ---------------------------------------------------------
# 4) Load npz, group by episode, and compute embeddings
# ---------------------------------------------------------
def get_embeddings_from_npz(
    npz_path,
    state_key="full_obs",                 # (N, 38)
    action_key="taken_action",            # or "nominal_action" (N, 2)
    ctx_tokens=16,
    n_embd=128
):
    data = np.load(npz_path, allow_pickle=True)

    # group into (E, L, ·) using `episode` id
    ep_ids = data["episode"].astype(int)          # (N,)
    episodes = np.unique(ep_ids)
    E = len(episodes)
    # derive per-episode length robustly
    per_counts = [np.sum(ep_ids == e) for e in episodes]
    assert len(set(per_counts)) == 1, "episodes must have equal length for this simple loader"
    L = per_counts[0]

    states_all  = data[state_key].astype(np.float32)     # (N, 38)
    actions_all = data[action_key].astype(np.float32)    # (N, 2)

    # reshape to (E, L, ·)
    S = states_all.shape[-1]
    A = actions_all.shape[-1]
    states  = np.stack([states_all [ep_ids == e] for e in episodes], axis=0)  # (E, L, 38)
    actions = np.stack([actions_all[ep_ids == e] for e in episodes], axis=0)  # (E, L, 2)

    # build model (state tokenizer != action tokenizer)
    pact = build_pact_model(state_dim=S, action_dim=A, ctx_tokens=ctx_tokens, n_embd=n_embd)

    # we’ll build windows for every (episode, t)
    T = ctx_tokens // 2
    batch_state = []
    batch_action = []
    map_index = []  # (episode_id, t) for each row in the batch
    for e in range(E):
        for t in range(L):
            s_win, a_win = make_context_window(states[e], actions[e], t, ctx_tokens=ctx_tokens)
            batch_state.append(s_win)
            batch_action.append(a_win)
            map_index.append((episodes[e], t))

    batch_state  = torch.cat(batch_state,  dim=0)  # (B, T, 38)  with B=E*L
    batch_action = torch.cat(batch_action, dim=0)  # (B, T,  2)

    # 4a) TOKENIZER-ONLY: raw token embeddings for state and action (pre-transformer)
    # PACTTokenizer.forward expects a dict with tensors shaped (B, T, ·),
    # reshapes to each tokenizer's batch_input_size, runs the module, reshapes back to (B,T,n_embd). :contentReference[oaicite:5]{index=5}
    tok = pact.tokenizer
    tok_res = tok({"state": batch_state, "action": batch_action})
    state_tok = tok_res["state"]   # (B, T, n_embd)
    action_tok = tok_res["action"] # (B, T, n_embd)

    # 4b) FULL TRANSFORMER: contextualized embeddings
    # PACTBase.forward expects {"state":(B,T,S), "action":(B,T,A)} and returns:
    #   out_embd:(B, 2*T, n_embd) interleaved [state_0, action_0, ...]
    #   state_tokens:(B, T, n_embd) (pre-transformer state tokens)                                     :contentReference[oaicite:6]{index=6}
    with torch.no_grad():
        out_embd, state_tokens_in = pact({"state": batch_state, "action": batch_action})
    state_out, action_out = split_state_action_embeddings(out_embd)  # (B,T,d) each

    return {
        "map_index": map_index,              # list of (episode_id, t)
        "state_tokens": state_tok,           # pre-transformer
        "action_tokens": action_tok,         # pre-transformer
        "state_ctx": state_out,              # post-transformer, contextualized
        "action_ctx": action_out,            # post-transformer, contextualized
        "last_state_ctx": state_out[:, -1],  # (B, d): embedding of s_t  (useful for policy head later)
        "last_action_ctx": action_out[:, -1] # (B, d): embedding of a_{t-1} (useful for critic head later)
    }

if __name__ == "__main__":
    res = get_embeddings_from_npz("pact_dataset.npz",
                                  state_key="full_obs",
                                  action_key="taken_action",   # or "nominal_action"
                                  ctx_tokens=16,
                                  n_embd=128)
    print("State tokens (pre)      :", tuple(res["state_tokens"].shape))   # (B, 8, d)
    print("Action tokens (pre)     :", tuple(res["action_tokens"].shape))  # (B, 8, d)
    print("State ctx (post)        :", tuple(res["state_ctx"].shape))      # (B, 8, d)
    print("Action ctx (post)       :", tuple(res["action_ctx"].shape))     # (B, 8, d)
    print("Last state ctx (for π)  :", tuple(res["last_state_ctx"].shape)) # (B, d)
    print("Last action ctx (critic):", tuple(res["last_action_ctx"].shape))# (B, d)


State tokens (pre)      : (512, 8, 128)
Action tokens (pre)     : (512, 8, 128)
State ctx (post)        : (512, 8, 128)
Action ctx (post)       : (512, 8, 128)
Last state ctx (for π)  : (512, 128)
Last action ctx (critic): (512, 128)


In [70]:
res

{'map_index': [(np.int64(0), 0),
  (np.int64(0), 1),
  (np.int64(0), 2),
  (np.int64(0), 3),
  (np.int64(0), 4),
  (np.int64(0), 5),
  (np.int64(0), 6),
  (np.int64(0), 7),
  (np.int64(0), 8),
  (np.int64(0), 9),
  (np.int64(0), 10),
  (np.int64(0), 11),
  (np.int64(0), 12),
  (np.int64(0), 13),
  (np.int64(0), 14),
  (np.int64(0), 15),
  (np.int64(0), 16),
  (np.int64(0), 17),
  (np.int64(0), 18),
  (np.int64(0), 19),
  (np.int64(0), 20),
  (np.int64(0), 21),
  (np.int64(0), 22),
  (np.int64(0), 23),
  (np.int64(0), 24),
  (np.int64(0), 25),
  (np.int64(0), 26),
  (np.int64(0), 27),
  (np.int64(0), 28),
  (np.int64(0), 29),
  (np.int64(0), 30),
  (np.int64(0), 31),
  (np.int64(0), 32),
  (np.int64(0), 33),
  (np.int64(0), 34),
  (np.int64(0), 35),
  (np.int64(0), 36),
  (np.int64(0), 37),
  (np.int64(0), 38),
  (np.int64(0), 39),
  (np.int64(0), 40),
  (np.int64(0), 41),
  (np.int64(0), 42),
  (np.int64(0), 43),
  (np.int64(0), 44),
  (np.int64(0), 45),
  (np.int64(0), 46),
  (np.int6

In [71]:
res["state_ctx"][128,1]

tensor([ 2.0480, -0.0701, -0.5216, -0.1725, -0.3298,  0.4476,  1.1527,  1.1839,
        -1.2147,  0.6124,  1.0262,  0.4195, -0.5923, -1.0895, -1.2064,  2.7057,
         0.2375,  0.3636,  0.1373,  0.7642, -0.2678,  0.7462, -0.8305,  1.2875,
        -0.3518,  0.4703, -0.9122, -0.4408, -0.6101, -0.2000, -1.0604, -0.4307,
        -0.8507,  1.5148,  1.0204,  2.6176, -0.1205,  1.2434, -1.2164, -2.6246,
         0.3012, -0.6954,  0.9459,  1.3253, -2.5194,  0.5416, -0.5279, -0.6676,
        -0.3461, -0.7930, -0.4879,  1.2265, -0.2988,  0.6953, -1.1791,  1.3406,
        -0.4869, -0.0754,  0.0365,  0.0827, -1.4684, -0.3380,  0.9199,  0.8276,
        -1.1835, -0.7424,  0.1268,  1.8562, -1.5769,  0.0653,  0.7459, -0.3673,
        -0.4569,  0.0767,  0.3209, -0.3831, -0.1016, -0.4221,  1.2457,  1.1714,
         0.6332,  0.0674,  0.1465, -1.3814,  1.5664,  0.4780,  0.1390,  2.5111,
        -0.2360, -0.1777, -1.1005,  0.1730,  0.4403,  0.1814,  0.9782,  0.0600,
         0.6889,  1.2567, -0.1481, -0.61

In [72]:
res["state_ctx"][500,1]

tensor([ 1.9642e+00,  2.8456e-02, -3.4041e-01, -2.1633e-01, -4.0707e-01,
         2.7337e-01,  1.1559e+00,  1.3840e+00, -1.2088e+00,  5.0117e-01,
         1.1859e+00,  3.5788e-01, -4.7555e-01, -9.6765e-01, -9.4968e-01,
         2.6461e+00,  1.4122e-01,  5.6851e-01,  8.8709e-04,  7.4258e-01,
        -3.3957e-01,  8.2526e-01, -8.1111e-01,  1.1132e+00, -3.8673e-01,
         5.2711e-01, -8.1755e-01, -5.0024e-01, -4.2218e-01, -6.1392e-02,
        -1.0610e+00, -3.6089e-01, -6.5406e-01,  1.6485e+00,  8.4232e-01,
         2.6350e+00,  8.0555e-02,  1.1212e+00, -1.1863e+00, -2.6115e+00,
         2.4867e-01, -8.6539e-01,  9.6221e-01,  1.5060e+00, -2.4428e+00,
         6.1161e-01, -3.3675e-01, -4.8783e-01, -4.1349e-01, -7.7358e-01,
        -3.5168e-01,  1.2386e+00, -3.3050e-01,  5.5655e-01, -1.3928e+00,
         1.4202e+00, -7.1153e-01,  6.7624e-02, -3.2127e-02,  1.8324e-01,
        -1.5678e+00, -2.6022e-01,  1.1370e+00,  5.4777e-01, -1.1423e+00,
        -8.8378e-01, -1.4942e-02,  1.6684e+00, -1.4

In [97]:
res["state_tokens"][0,1].sum()
b = 0
cnt = 0
for i in range(511):
    a = (res["action_tokens"][i,6].sum())
    if b!=a:
        cnt+=1
    b=a
print(cnt)


501


In [74]:
res["state_tokens"][249,1]

tensor([-0.0556, -0.0401, -0.1370,  0.0920, -0.0862, -0.0241,  0.0740,  0.0038,
         0.0044, -0.0126,  0.0894,  0.1270, -0.0315, -0.0830, -0.1333,  0.0984,
         0.1279, -0.1858, -0.0250,  0.2726, -0.0262, -0.0331,  0.0440,  0.0454,
        -0.0016,  0.0407,  0.0292, -0.0842, -0.1044, -0.1120, -0.1021, -0.0383,
        -0.1403, -0.0019, -0.0170,  0.0473, -0.0024,  0.0903, -0.0812, -0.0914,
        -0.0434,  0.0774, -0.0287,  0.0032,  0.0018,  0.0413, -0.0979, -0.1425,
        -0.0312, -0.1261,  0.0267,  0.1814,  0.0154,  0.0062, -0.0456, -0.0879,
        -0.0392,  0.0677,  0.0890,  0.0982,  0.0568,  0.0206,  0.0543, -0.0850,
        -0.1078,  0.0638,  0.0238, -0.0575,  0.0347,  0.0992, -0.0069, -0.0034,
        -0.0231, -0.0983,  0.0294,  0.0342, -0.0528, -0.0892, -0.0783, -0.0182,
         0.0324,  0.0060, -0.0452, -0.0771,  0.0749,  0.0011, -0.0149,  0.0794,
        -0.0754, -0.0120,  0.0409,  0.0773,  0.0367, -0.0271, -0.0200, -0.0353,
        -0.0628,  0.1960,  0.0926,  0.00

In [75]:
# pact.eval()  # turn off dropout
# with torch.no_grad():
#     out1, state_tok1 = pact({"state": batch_state, "action": batch_action})
#     out2, state_tok2 = pact({"state": batch_state, "action": batch_action})

# assert out1.shape == (batch_state.size(0), 2*(batch_state.size(1)), pact.gpt.config["n_embd"])
# assert state_tok1.shape == (batch_state.size(0), batch_state.size(1), pact.gpt.config["n_embd"])
# assert torch.all(torch.isfinite(out1)) and torch.all(torch.isfinite(state_tok1))
# # deterministic in eval mode (no dropout)
# assert torch.allclose(out1, out2) and torch.allclose(state_tok1, state_tok2)


In [76]:
# def split_sa(y):  # y: (B, 2T, d)
#     return y[:, 0::2, :], y[:, 1::2, :]

# out_embd, _ = pact({"state": batch_state[:2], "action": batch_action[:2]})
# state_ctx, action_ctx = split_sa(out_embd)
# assert state_ctx.shape[:2] == action_ctx.shape[:2] == (2, batch_state.size(1))


In [77]:
batch = {"state": batch_state.float(), "action": batch_action.float()}
out1, state_tok1 = pact(batch)


NameError: name 'batch_state' is not defined

In [99]:
from src.models.modules import a
a.CFG.device = "cpu"           # override the default
res = a.run_demo("pact_dataset.npz",
                 state_key="full_obs",
                 action_key="taken_action")


=== Demo outputs ===
state_seq: (256, 8, 38) action_seq: (256, 8, 2)
policy_out: (256, 2) critic_out: (256, 1)
state_ctx: (256, 8, 128) action_ctx: (256, 8, 128)
last_state_ctx: (256, 128) last_action_ctx: (256, 128)


In [100]:
res

{'policy_out': tensor([[0.0605, 0.0651],
         [0.0605, 0.0651],
         [0.0605, 0.0651],
         [0.0605, 0.0651],
         [0.0605, 0.0651],
         [0.0605, 0.0651],
         [0.0605, 0.0651],
         [0.0605, 0.0651],
         [0.0605, 0.0651],
         [0.0605, 0.0651],
         [0.0605, 0.0651],
         [0.0605, 0.0651],
         [0.0605, 0.0651],
         [0.0605, 0.0651],
         [0.0605, 0.0651],
         [0.0605, 0.0651],
         [0.0605, 0.0651],
         [0.0605, 0.0651],
         [0.0605, 0.0651],
         [0.0946, 0.0522],
         [0.0945, 0.0516],
         [0.0944, 0.0509],
         [0.0940, 0.0670],
         [0.0919, 0.0645],
         [0.0898, 0.0614],
         [0.1121, 0.0690],
         [0.1102, 0.0641],
         [0.1102, 0.0636],
         [0.1110, 0.0628],
         [0.0963, 0.0664],
         [0.0969, 0.0660],
         [0.0975, 0.0652],
         [0.0994, 0.0633],
         [0.0988, 0.0625],
         [0.0981, 0.0615],
         [0.1192, 0.0441],
         [0.11