In [1]:
import os
from os.path import join
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.distributions import Categorical

from rl_api.environments.toy_env import ToyVectorEnv
from rl_api.environments.wrappers import AutoResetVectorEnv
from rl_api.environments.types import BatchObs

from rl_api.agents.ppo_utils.configs import (
    DimensionConfig,
    PPOConfig,
    BufferConfig,
    EntropySchedulerConfig,
    TrainingConfig,
    EvalConfig,
    LoggingConfig,
    SavingConfig,
)
from rl_api.agents.ppo_utils.buffer import RolloutBufferVec
from rl_api.agents.ppo_utils.vectorized_agent import VectorizedPPOAgent
from rl_api.agents.ppo_utils.factories import build_ppo_agent

from rl_api.networks.networks import ActorNetwork, CriticNetwork, PPOPolicyNetwork


2025-12-14 14:57:35.285548: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765717055.378889    2009 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765717055.404869    2009 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-12-14 14:57:35.608907: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
# fixed params
agent_dir_path = "./toy_ppo_agent"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# --------- simple encoder for (seq_len, 1) obs ---------

class MLPEncoder(nn.Module):
    def __init__(self, obs_shape, hidden_dim: int = 64):
        super().__init__()
        self.obs_shape = obs_shape
        in_dim = int(np.prod(obs_shape))

        self.net = nn.Sequential(
            nn.Flatten(),            # (B, *obs_shape) -> (B, in_dim)
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        self.output_dim = hidden_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, *obs_shape)
        return self.net(x)  # (B, hidden_dim)

In [4]:
# ---------- envs ----------
n_envs = 16
seq_len = 8
feature_dim = 1
obs_shape = (seq_len, feature_dim)
context_dim = 1
action_dim = 3

base_train_env = ToyVectorEnv(
    n_envs=n_envs,
    seq_len=seq_len,
    step_size=0.1,
    max_steps=50,
    x_limit=2.0,
    device=device,
)
train_env = AutoResetVectorEnv(base_train_env)

# For simplicity, use the same env for eval (could be separate)
eval_env = AutoResetVectorEnv(
    ToyVectorEnv(
        n_envs=n_envs,
        seq_len=seq_len,
        step_size=0.1,
        max_steps=50,
        x_limit=2.0,
        device=device,
    )
)


In [None]:
# ---------- configs ----------
dims = DimensionConfig(
    obs_shape=obs_shape,
    action_dim=action_dim,
    context_dim=context_dim,
)

ppo_cfg = PPOConfig(
    num_epochs=2,
    clip_eps=0.2,
    entropy_coef=0.01,
    clip_vf=None,
    vf_coef=0.5,
    vf_loss_clip=False,
    gamma=0.99,
    gae_lambda=0.95,
    target_kl=None,
    grad_clip=0.5,
)

buf_cfg = BufferConfig(
    buffer_size=1024,
    batch_size=64,
)

entropy_sched_cfg = EntropySchedulerConfig(
    use_scheduler=False
)

train_cfg = TrainingConfig(
    total_updates=150,  # small number just for the demo
)

eval_cfg = EvalConfig(
    eval_method="sample",
    n_steps=256
)
os.makedirs(agent_dir_path, exist_ok=True)

logging_cfg = LoggingConfig(
    current_update=0,
    log_interval=2,
    eval_interval=10,
    html_log_path=None,          # keep it simple for the example
    tensorboard_path=os.path.join(agent_dir_path, "tensorboard_logs"),
    verbose=True,
)

saving_cfg = SavingConfig(
    save_interval=25,
    save_agent_path=agent_dir_path,
)


In [6]:
# ---------- networks ----------
encoder = MLPEncoder(obs_shape=obs_shape, hidden_dim=32).to(device)

actor_net = ActorNetwork(
    
    obs_embed_dim=encoder.output_dim,
    context_dim=context_dim,
    action_dim=action_dim,
    hidden_units=[32, 16],
    dropout_rate=0.0,
).to(device)

critic_net = CriticNetwork(
    obs_embed_dim=encoder.output_dim,
    context_dim=context_dim,
    hidden_units=[32, 16],
    dropout_rate=0.0,
).to(device)

policy_net = PPOPolicyNetwork(
    encoder=encoder,
    action_head=actor_net,
    value_head=critic_net,
).to(device)

policy_optimizer = Adam(policy_net.parameters(), lr=3e-4)


In [7]:
# ---------- agent ----------
agent: VectorizedPPOAgent = build_ppo_agent(
    dims=dims,
    ppo_cfg=ppo_cfg,
    buf_cfg=buf_cfg,
    entropy_sched_cfg=entropy_sched_cfg,
    policy_net=policy_net,
    policy_optimizer=policy_optimizer,
    train_env=train_env,
    eval_env=eval_env,
    eval_cfg=eval_cfg,
    logging_cfg=logging_cfg,
    saving_cfg=saving_cfg,
    device=device,
)



In [8]:
agent.train(train_cfg=train_cfg)


[Upd   2 | step: 2048]
| pi_loss: -0.0020  v_loss: +0.4564  ent: 1.082  ent_coef: 0.0100  kl: 0.0011  batches_used: 100.000%  expl_var: 0.10  clip: 0.00%
| adv_mean: -6.4228  adv_std: 4.9774  rets_mean: -6.2440  rets_std: 4.9890
| reward_sum: -1265.587  reward_pre_step: -0.61796
[Grads] enc: 0.299  act: 0.125  critic: 0.343

[Upd   4 | step: 4096]
| pi_loss: -0.0030  v_loss: +0.3751  ent: 1.078  ent_coef: 0.0100  kl: 0.0003  batches_used: 100.000%  expl_var: 0.27  clip: 0.00%
| adv_mean: -4.6439  adv_std: 4.0470  rets_mean: -4.4959  rets_std: 4.1836
| reward_sum: -950.296  reward_pre_step: -0.46401
[Grads] enc: 0.279  act: 0.155  critic: 0.335

[Upd   6 | step: 6144]
| pi_loss: -0.0050  v_loss: +0.2994  ent: 1.061  ent_coef: 0.0100  kl: 0.0028  batches_used: 100.000%  expl_var: 0.41  clip: 0.17%
| adv_mean: -4.1427  adv_std: 3.4636  rets_mean: -4.0416  rets_std: 3.7431
| reward_sum: -846.979  reward_pre_step: -0.41356
[Grads] enc: 0.307  act: 0.154  critic: 0.313

[Upd   8 | step: 819

KeyboardInterrupt: 