In [None]:
import os
from os.path import join
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.distributions import Categorical

from rl_api.environments.toy_env import ToyVectorEnv
from rl_api.environments.wrappers import AutoResetVectorEnv
from rl_api.environments.types import BatchObs

from rl_api.agents.ppo_utils.configs import (
    DimensionConfig,
    PPOConfig,
    BufferConfig,
    EntropySchedulerConfig,
    TrainingConfig,
    EvalConfig,
    LoggingConfig,
    SavingConfig,
)
from rl_api.agents.ppo_utils.buffer import RolloutBufferVec
from rl_api.agents.ppo_utils.vectorized_agent import VectorizedPPOAgent
from rl_api.agents.ppo_utils.factories import build_ppo_agent

from rl_api.networks.networks import ActorNetwork, CriticNetwork, PPOPolicyNetwork


In [None]:
# fixed params
agent_dir_path = "./toy_ppo_agent"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# --------- simple encoder for (seq_len, 1) obs ---------

class MLPEncoder(nn.Module):
    def __init__(self, obs_shape, hidden_dim: int = 64):
        super().__init__()
        self.obs_shape = obs_shape
        in_dim = int(np.prod(obs_shape))

        self.net = nn.Sequential(
            nn.Flatten(),            # (B, *obs_shape) -> (B, in_dim)
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        self.output_dim = hidden_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, *obs_shape)
        return self.net(x)  # (B, hidden_dim)

In [None]:
# ---------- envs ----------
n_envs = 16
seq_len = 8
feature_dim = 1
obs_shape = (seq_len, feature_dim)
context_dim = 1
action_dim = 3

base_train_env = ToyVectorEnv(
    n_envs=n_envs,
    seq_len=seq_len,
    step_size=0.1,
    max_steps=50,
    x_limit=2.0,
    device=device,
)
train_env = AutoResetVectorEnv(base_train_env)

# For simplicity, use the same env for eval (could be separate)
eval_env = AutoResetVectorEnv(
    ToyVectorEnv(
        n_envs=n_envs,
        seq_len=seq_len,
        step_size=0.1,
        max_steps=50,
        x_limit=2.0,
        device=device,
    )
)


In [None]:
# ---------- configs ----------
dims = DimensionConfig(
    obs_shape=obs_shape,
    action_dim=action_dim,
    context_dim=context_dim,
)

ppo_cfg = PPOConfig(
    num_epochs=2,
    clip_eps=0.2,
    entropy_coef=0.01,
    clip_vf=None,
    vf_coef=0.5,
    vf_loss_clip=False,
    gamma=0.99,
    gae_lambda=0.95,
    target_kl=None,
    grad_clip=0.5,
)

buf_cfg = BufferConfig(
    buffer_size=1024,
    batch_size=64,
)

entropy_sched_cfg = EntropySchedulerConfig(
    use_scheduler=False
)

train_cfg = TrainingConfig(
    total_updates=150,  # small number just for the demo
)

eval_cfg = EvalConfig(
    eval_method="sample",
    n_steps=256
)
os.makedirs(agent_dir_path, exist_ok=True)

logging_cfg = LoggingConfig(
    current_update=0,
    log_interval=2,
    eval_interval=10,
    html_log_path=None,          # keep it simple for the example
    tensorboard_path=os.path.join(agent_dir_path, "tensorboard_logs"),
    verbose=True,
)

saving_cfg = SavingConfig(
    save_interval=25,
    save_agent_path=agent_dir_path,
)


In [None]:
# ---------- networks ----------
encoder = MLPEncoder(obs_shape=obs_shape, hidden_dim=32).to(device)

actor_net = ActorNetwork(
    
    obs_embed_dim=encoder.output_dim,
    context_dim=context_dim,
    action_dim=action_dim,
    hidden_units=[32, 16],
    dropout_rate=0.0,
).to(device)

critic_net = CriticNetwork(
    obs_embed_dim=encoder.output_dim,
    context_dim=context_dim,
    hidden_units=[32, 16],
    dropout_rate=0.0,
).to(device)

policy_net = PPOPolicyNetwork(
    encoder=encoder,
    action_head=actor_net,
    value_head=critic_net,
).to(device)

policy_optimizer = Adam(policy_net.parameters(), lr=3e-4)


In [None]:
# ---------- agent ----------
agent: VectorizedPPOAgent = build_ppo_agent(
    dims=dims,
    ppo_cfg=ppo_cfg,
    buf_cfg=buf_cfg,
    entropy_sched_cfg=entropy_sched_cfg,
    policy_net=policy_net,
    policy_optimizer=policy_optimizer,
    train_env=train_env,
    eval_env=eval_env,
    eval_cfg=eval_cfg,
    logging_cfg=logging_cfg,
    saving_cfg=saving_cfg,
    device=device,
)



In [None]:
agent.train(train_cfg=train_cfg)