In [None]:
import os
from os.path import join
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.distributions import Categorical

from rl_api.environments.toy_env import ToyVectorEnv
from rl_api.environments.wrappers import AutoResetVectorEnv
from rl_api.environments.types import BatchObs

from rl_api.agents.ppg_utils.configs import (
    DimensionConfig,
    PPOConfig,
    PPGConfig,
    BufferConfig,
    EntropySchedulerConfig,
    TrainingConfig,
    EvalConfig,
    LoggingConfig,
    SavingConfig,
)
from rl_api.agents.ppg_utils.buffer import RolloutBufferVec
from rl_api.agents.ppg_utils.vectorized_agent import VectorizedPPGAgent
from rl_api.agents.ppg_utils.factories import build_ppg_agent

from rl_api.networks.networks import ActorNetwork, CriticNetwork, PPGPolicyNetwork, PPGValueNetwork
from rl_api.networks.networks_factory import build_ppg_optimizers


In [None]:
# fixed params
agent_dir_path = "./toy_ppg_agent"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# --------- simple encoder for (seq_len, 1) obs ---------

class MLPEncoder(nn.Module):
    def __init__(self, obs_shape, hidden_dim: int = 64):
        super().__init__()
        self.obs_shape = obs_shape
        in_dim = int(np.prod(obs_shape))

        self.net = nn.Sequential(
            nn.Flatten(),            # (B, *obs_shape) -> (B, in_dim)
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        self.output_dim = hidden_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, *obs_shape)
        return self.net(x)  # (B, hidden_dim)

In [None]:
# ---------- envs ----------
n_envs = 16
seq_len = 8
feature_dim = 1
obs_shape = (seq_len, feature_dim)
context_dim = 1
action_dim = 3

base_train_env = ToyVectorEnv(
    n_envs=n_envs,
    seq_len=seq_len,
    step_size=0.1,
    max_steps=50,
    x_limit=2.0,
    device=device,
)
train_env = AutoResetVectorEnv(base_train_env)

# For simplicity, use the same env for eval (could be separate)
eval_env = AutoResetVectorEnv(
    ToyVectorEnv(
        n_envs=n_envs,
        seq_len=seq_len,
        step_size=0.1,
        max_steps=50,
        x_limit=2.0,
        device=device,
    )
)


In [None]:
# ---------- configs ----------
dims = DimensionConfig(
    obs_shape=obs_shape,
    action_dim=action_dim,
    context_dim=context_dim,
)

ppo_cfg = PPOConfig(
    clip_eps=0.1,
    entropy_coef=0.01,
    clip_vf=None,
    vf_coef=0.5,
    vf_loss_clip=False,
    gamma=0.99,
    gae_lambda=0.95,
    target_kl=None,
    grad_clip=0.5,
)

ppg_cfg = PPGConfig(
    n_pi=8,
    policy_epochs=1,
    critic_epochs=1,
    aux_epochs=3,
    beta_kl=0.01,
)

buf_cfg = BufferConfig(
    buffer_size=1024,
    ppo_batch_size=64,
    aux_batch_size=64,
)

entropy_sched_cfg = EntropySchedulerConfig(
    use_scheduler=False
)

train_cfg = TrainingConfig(
    total_updates=50,  # small number just for the demo
)

eval_cfg = EvalConfig(
    eval_method="sample",
    n_steps=256
)
os.makedirs(agent_dir_path, exist_ok=True)

logging_cfg = LoggingConfig(
    current_update=0,
    log_interval=2,
    eval_interval=10,
    html_log_path=os.path.join(agent_dir_path, "html_logs"),          
    tensorboard_path=os.path.join(agent_dir_path, "tensorboard_logs"),
    verbose=True,
)

saving_cfg = SavingConfig(
    save_interval=25,
    save_agent_path=agent_dir_path,
)


In [None]:
# ---------- networks ----------
policy_encoder = MLPEncoder(obs_shape=obs_shape, hidden_dim=32).to(device)
value_encoder = MLPEncoder(obs_shape=obs_shape, hidden_dim=32).to(device)

actor_net = ActorNetwork(
    obs_embed_dim=policy_encoder.output_dim,
    context_dim=context_dim,
    action_dim=action_dim,
    hidden_units=[32, 16], # must start with hidden_dim of the encoder
    dropout_rate=0.0,
).to(device)

aux_net = CriticNetwork(
    obs_embed_dim=policy_encoder.output_dim,
    context_dim=context_dim,
    hidden_units=[32, 16], # must start with hidden_dim of the encoder
    dropout_rate=0.0,
).to(device)

value_net = CriticNetwork(
    obs_embed_dim=value_encoder.output_dim,
    context_dim=context_dim,
    hidden_units=[32, 16], # must start with hidden_dim of the encoder
    dropout_rate=0.0,
).to(device)

policy_net = PPGPolicyNetwork(
    encoder=policy_encoder,
    action_head=actor_net,
    aux_value_head=aux_net,
).to(device)

value_net = PPGValueNetwork(
    encoder=value_encoder,
    value_head=value_net,
).to(device)


optimizers = build_ppg_optimizers(
    policy_network=policy_net,
    value_network=value_net,
    enc_lr=1e-4,
    actor_lr=1e-4,
    critic_lr=1e-4,
    weight_decay=0.0
)

In [None]:
# ---------- agent ----------
agent: VectorizedPPGAgent = build_ppg_agent(
    policy_net=policy_net,
    value_net=value_net,
    policy_optimizer=optimizers["policy_optimizer"],
    aux_optimizer=optimizers["aux_optimizer"],
    critic_optimizer=optimizers["critic_optimizer"],
    train_env=train_env,
    eval_env=eval_env,
    dims=dims,
    ppo_cfg=ppo_cfg,
    ppg_cfg=ppg_cfg,
    buf_cfg=buf_cfg,
    entropy_sched_cfg=entropy_sched_cfg,
    eval_cfg=eval_cfg,
    logging_cfg=logging_cfg,
    saving_cfg=saving_cfg,
    device=device
)



In [None]:
agent.train(train_cfg=train_cfg)