In [2]:
import time
from dataclasses import dataclass
from collections import deque
from typing import Tuple

import numpy as np
import gymnasium as gym

import torch
import torch.nn as nn
import torch.nn.functional as F


# -----------------------------
# Config
# -----------------------------
@dataclass
class Config:
    env_id: str = "Pendulum-v1"
    seed: int = 0

    # Discretize action in [-2, 2] into n_bins
    n_bins: int = 11          # まず 11 がおすすめ（粗いが学習しやすい）
    action_low: float = -2.0
    action_high: float = 2.0

    # Vectorized A2C (安定しやすい)
    n_envs: int = 8
    rollout_len: int = 256

    # Training length
    total_updates: int = 3000

    gamma: float = 0.99
    gae_lambda: float = 0.95

    lr: float = 3e-4
    value_coef: float = 0.5
    entropy_coef: float = 1e-3   # 離散だと探索が死にやすいので 1e-3 を推奨
    max_grad_norm: float = 1.0

    device: str = "cuda"  # "cuda" or "cpu"
    log_every: int = 20


def set_seed(seed: int):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# -----------------------------
# Actor-Critic Network (Discrete)
# -----------------------------
class DiscreteActorCritic(nn.Module):
    def __init__(self, obs_dim: int, n_actions: int):
        super().__init__()
        hidden = 256

        self.trunk = nn.Sequential(
            nn.Linear(obs_dim, hidden),
            nn.Tanh(),
            nn.Linear(hidden, hidden),
            nn.Tanh(),
        )
        self.pi_head = nn.Linear(hidden, n_actions)  # logits
        self.v_head = nn.Linear(hidden, 1)           # state value

        self._init_weights()

    def _init_weights(self):
        # よくある RL の初期化（必須ではないが安定しやすい）
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=1.0)
                nn.init.constant_(m.bias, 0.0)
        nn.init.orthogonal_(self.pi_head.weight, gain=0.01)
        nn.init.orthogonal_(self.v_head.weight, gain=1.0)

    def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        obs: [B, obs_dim]
        returns:
          logits: [B, n_actions]
          v: [B]
        """
        h = self.trunk(obs)
        logits = self.pi_head(h)
        v = self.v_head(h).squeeze(-1)
        return logits, v

    @torch.no_grad()
    def act(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        obs: [B, obs_dim]
        returns:
          a_idx: [B] (int64)
          logp: [B]
          v: [B]
        """
        logits, v = self.forward(obs)
        dist = torch.distributions.Categorical(logits=logits)
        a_idx = dist.sample()
        logp = dist.log_prob(a_idx)
        return a_idx, logp, v

    def evaluate(self, obs: torch.Tensor, a_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        obs: [B, obs_dim]
        a_idx: [B] (int64)
        returns:
          logp: [B]
          entropy: [B]
          v: [B]
        """
        logits, v = self.forward(obs)
        dist = torch.distributions.Categorical(logits=logits)
        logp = dist.log_prob(a_idx)
        entropy = dist.entropy()
        return logp, entropy, v


# -----------------------------
# Rollout Buffer (vectorized)
# -----------------------------
class RolloutBuffer:
    def __init__(self, T: int, n_envs: int, obs_dim: int, device: torch.device):
        self.T = T
        self.n_envs = n_envs
        self.device = device

        self.obs = torch.zeros((T, n_envs, obs_dim), dtype=torch.float32, device=device)
        self.actions = torch.zeros((T, n_envs), dtype=torch.int64, device=device)
        self.rewards = torch.zeros((T, n_envs), dtype=torch.float32, device=device)
        self.dones = torch.zeros((T, n_envs), dtype=torch.float32, device=device)
        self.logps = torch.zeros((T, n_envs), dtype=torch.float32, device=device)
        self.values = torch.zeros((T, n_envs), dtype=torch.float32, device=device)

        self.advantages = torch.zeros((T, n_envs), dtype=torch.float32, device=device)
        self.returns = torch.zeros((T, n_envs), dtype=torch.float32, device=device)

        self.ptr = 0

    def add(self, obs, action, reward, done, logp, value):
        t = self.ptr
        self.obs[t] = obs
        self.actions[t] = action
        self.rewards[t] = reward
        self.dones[t] = done
        self.logps[t] = logp
        self.values[t] = value
        self.ptr += 1

    def compute_gae(self, last_value: torch.Tensor, gamma: float, lam: float):
        """
        last_value: V(s_T) for each env, shape [n_envs]
        """
        adv = torch.zeros((self.n_envs,), dtype=torch.float32, device=self.device)
        for t in reversed(range(self.T)):
            nonterminal = 1.0 - self.dones[t]
            next_value = last_value if t == self.T - 1 else self.values[t + 1]
            delta = self.rewards[t] + gamma * nonterminal * next_value - self.values[t]
            adv = delta + gamma * lam * nonterminal * adv
            self.advantages[t] = adv

        self.returns = self.advantages + self.values

        # Advantage normalize over full batch (T*n_envs)
        flat_adv = self.advantages.view(-1)
        self.advantages = (self.advantages - flat_adv.mean()) / (flat_adv.std(unbiased=False) + 1e-8)

    def flatten(self):
        B = self.T * self.n_envs
        obs = self.obs.view(B, -1)
        actions = self.actions.view(B)
        logps = self.logps.view(B)
        values = self.values.view(B)
        advantages = self.advantages.view(B)
        returns = self.returns.view(B)
        return obs, actions, logps, values, advantages, returns

    def reset(self):
        self.ptr = 0


# -----------------------------
# Vector env factory
# -----------------------------
def make_env(env_id: str, seed: int, idx: int):
    def thunk():
        env = gym.make(env_id)
        env.reset(seed=seed + idx)
        env.action_space.seed(seed + idx)
        env.observation_space.seed(seed + idx)
        return env
    return thunk


# -----------------------------
# Train
# -----------------------------
def train(cfg: Config):
    set_seed(cfg.seed)

    device = torch.device(cfg.device if (cfg.device == "cpu" or torch.cuda.is_available()) else "cpu")
    envs = gym.vector.SyncVectorEnv([make_env(cfg.env_id, cfg.seed, i) for i in range(cfg.n_envs)])

    obs_np, info = envs.reset()
    obs = torch.as_tensor(obs_np, dtype=torch.float32, device=device)

    obs_dim = obs.shape[1]
    n_actions = cfg.n_bins

    # Discrete bins -> continuous action value (Pendulum expects shape (1,))
    action_values = np.linspace(cfg.action_low, cfg.action_high, cfg.n_bins, dtype=np.float32)  # [n_bins]

    net = DiscreteActorCritic(obs_dim, n_actions).to(device)
    optim = torch.optim.Adam(net.parameters(), lr=cfg.lr)

    buf = RolloutBuffer(cfg.rollout_len, cfg.n_envs, obs_dim, device)

    ep_returns = np.zeros(cfg.n_envs, dtype=np.float32)
    recent_ep_returns = deque(maxlen=50)

    t0 = time.time()

    for update in range(1, cfg.total_updates + 1):
        buf.reset()

        # -------- rollout collection --------
        for t in range(cfg.rollout_len):
            with torch.no_grad():
                a_idx, logp, v = net.act(obs)  # a_idx: [n_envs]

            # map discrete index -> continuous action, shape (n_envs, 1)
            a_np = action_values[a_idx.cpu().numpy()].reshape(cfg.n_envs, 1)

            next_obs_np, rewards_np, terminated, truncated, infos = envs.step(a_np)
            done_np = np.logical_or(terminated, truncated).astype(np.float32)

            rewards = torch.as_tensor(rewards_np, dtype=torch.float32, device=device)
            dones = torch.as_tensor(done_np, dtype=torch.float32, device=device)
            next_obs = torch.as_tensor(next_obs_np, dtype=torch.float32, device=device)

            buf.add(
                obs=obs,
                action=a_idx,
                reward=rewards,
                done=dones,
                logp=logp,
                value=v
            )

            # episode accounting
            ep_returns += rewards_np.astype(np.float32)
            for i in range(cfg.n_envs):
                if done_np[i] > 0.5:
                    recent_ep_returns.append(float(ep_returns[i]))
                    ep_returns[i] = 0.0

            obs = next_obs

        # bootstrap last value
        with torch.no_grad():
            _, last_v = net.forward(obs)  # [n_envs]
        buf.compute_gae(last_v, cfg.gamma, cfg.gae_lambda)

        # -------- update (A2C: single epoch, full batch) --------
        b_obs, b_actions, b_logps_old, b_values_old, b_adv, b_ret = buf.flatten()

        logp_new, entropy, v_new = net.evaluate(b_obs, b_actions)

        # Policy loss: -E[A log pi(a|s)]
        policy_loss = -(b_adv * logp_new).mean()

        # Value loss: E[(V - R)^2]
        value_loss = F.mse_loss(v_new, b_ret)

        # Entropy bonus (maximize entropy => minimize -entropy)
        entropy_loss = -entropy.mean()

        loss = policy_loss + cfg.value_coef * value_loss + cfg.entropy_coef * entropy_loss

        optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), cfg.max_grad_norm)
        optim.step()

        # -------- logging --------
        if (update % cfg.log_every) == 0:
            elapsed = time.time() - t0
            avg50 = np.mean(recent_ep_returns) if len(recent_ep_returns) > 0 else float("nan")
            # detach()してからitem()を取ると警告を避けられる
            print(
                f"[upd {update:4d}/{cfg.total_updates}] "
                f"loss={loss.detach().item(): .3f} "
                f"pi={policy_loss.detach().item(): .3f} "
                f"v={value_loss.detach().item(): .3f} "
                f"ent={entropy.detach().mean().item(): .3f} "
                f"avg_return(50ep)={avg50: .1f} "
                f"elapsed={elapsed: .1f}s device={device} bins={cfg.n_bins}"
            )

    envs.close()
    return net


if __name__ == "__main__":
    cfg = Config(
        n_bins=11,          # 11→21に増やすと精度は上がるが学習は難しくなりがち
        n_envs=8,
        rollout_len=256,
        total_updates=3000,
        lr=3e-4,
        entropy_coef=1e-3,  # 離散で探索維持に効く
        device="cuda",
        seed=0,
        log_every=20,
    )
    train(cfg)


[upd   20/3000] loss= 4645.607 pi=-0.000 v= 9291.220 ent= 2.396 avg_return(50ep)=-1199.3 elapsed= 22.6s device=cuda bins=11
[upd   40/3000] loss= 3958.690 pi= 0.001 v= 7917.383 ent= 2.394 avg_return(50ep)=-1232.9 elapsed= 40.9s device=cuda bins=11
[upd   60/3000] loss= 3772.094 pi= 0.001 v= 7544.191 ent= 2.394 avg_return(50ep)=-1155.2 elapsed= 66.9s device=cuda bins=11
[upd   80/3000] loss= 4447.222 pi=-0.003 v= 8894.455 ent= 2.392 avg_return(50ep)=-1249.5 elapsed= 84.1s device=cuda bins=11
[upd  100/3000] loss= 4836.534 pi= 0.012 v= 9673.049 ent= 2.381 avg_return(50ep)=-1166.6 elapsed= 100.1s device=cuda bins=11
[upd  120/3000] loss= 4386.787 pi= 0.014 v= 8773.551 ent= 2.367 avg_return(50ep)=-1208.5 elapsed= 119.4s device=cuda bins=11
[upd  140/3000] loss= 4111.366 pi= 0.013 v= 8222.711 ent= 2.353 avg_return(50ep)=-1186.6 elapsed= 138.2s device=cuda bins=11
[upd  160/3000] loss= 4015.539 pi= 0.034 v= 8031.014 ent= 2.336 avg_return(50ep)=-1151.5 elapsed= 151.9s device=cuda bins=11
[upd

KeyboardInterrupt: 