<a href="https://colab.research.google.com/github/OneFineStarstuff/Pinn/blob/main/self_reflective_cartpole_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
Self-Reflective Advantage Actor-Critic on CartPole-v1

Features:
- Collects episodes into a memory buffer, then "reflects" via multi-epoch replay:
  1) Critic-only epochs to fit V to returns (stabilize targets).
  2) Policy epochs using advantages with entropy regularization.
- Reproducibility: seeds, deterministic torch flags (where possible).
- Safe defaults: grad clipping, advantage normalization, separate optimizers.
- Checkpointing: periodic save for policy and critic.
- Minimal dependencies: torch, gymnasium (or gym), numpy.

Run:
  python self_reflective_cartpole.py

Notes:
- If you have 'gym' instead of 'gymnasium', the script will fall back automatically.
"""

import os
import math
import time
import random
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

# Gym import with fallback
try:
    import gymnasium as gym
except ImportError:
    import gym


# -----------------------------
# Config and utilities
# -----------------------------

@dataclass
class Config:
    env_id: str = "CartPole-v1"
    max_episodes: int = 500
    episodes_per_batch: int = 8
    gamma: float = 0.99
    policy_lr: float = 3e-4
    critic_lr: float = 5e-4
    entropy_coef: float = 0.01
    value_coef: float = 0.5
    grad_clip: float = 0.5
    reflection_epochs_critic: int = 4
    reflection_epochs_policy: int = 2
    hidden_sizes: tuple = (128, 128)
    seed: int = 42
    device: str = "cpu"  # "cuda" if available and desired
    log_interval: int = 10
    checkpoint_every: int = 100
    out_dir: str = "checkpoints_self_reflect"


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # Some determinism flags (trade-offs for speed/compat)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def to_tensor(x, device):
    return torch.as_tensor(x, dtype=torch.float32, device=device)


# -----------------------------
# Models
# -----------------------------

class MLP(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_sizes=(128, 128), act=nn.Tanh):
        super().__init__()
        layers = []
        last = in_dim
        for h in hidden_sizes:
            layers += [nn.Linear(last, h), act()]
            last = h
        layers.append(nn.Linear(last, out_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


class PolicyNet(nn.Module):
    def __init__(self, obs_dim, n_actions, hidden_sizes=(128, 128)):
        super().__init__()
        self.backbone = MLP(obs_dim, n_actions, hidden_sizes=hidden_sizes, act=nn.Tanh)

    def forward(self, obs):
        logits = self.backbone(obs)
        return logits  # for Categorical


class ValueNet(nn.Module):
    def __init__(self, obs_dim, hidden_sizes=(128, 128)):
        super().__init__()
        self.backbone = MLP(obs_dim, 1, hidden_sizes=hidden_sizes, act=nn.Tanh)

    def forward(self, obs):
        return self.backbone(obs).squeeze(-1)


# -----------------------------
# Memory buffer
# -----------------------------

class TrajectoryBuffer:
    def __init__(self, device):
        self.device = device
        self.clear()

    def clear(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.logps = []
        self.values = []
        self.episode_returns = []  # per-episode sum for logging

    def store_step(self, state, action, reward, done, logp, value):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.dones.append(done)
        self.logps.append(logp)
        self.values.append(value)

    def finalize_episode(self, ep_return):
        self.episode_returns.append(ep_return)

    def as_tensors(self):
        states = torch.stack([to_tensor(s, self.device) for s in self.states], dim=0)
        actions = torch.as_tensor(self.actions, dtype=torch.int64, device=self.device)
        rewards = torch.as_tensor(self.rewards, dtype=torch.float32, device=self.device)
        dones = torch.as_tensor(self.dones, dtype=torch.bool, device=self.device)
        logps = torch.stack(self.logps, dim=0).to(self.device)
        values = torch.as_tensor(self.values, dtype=torch.float32, device=self.device)
        return states, actions, rewards, dones, logps, values


# -----------------------------
# Self-reflective agent
# -----------------------------

class SelfReflectiveAgent:
    def __init__(self, obs_dim, n_actions, cfg: Config):
        self.cfg = cfg
        self.device = torch.device(cfg.device)
        self.policy = PolicyNet(obs_dim, n_actions, cfg.hidden_sizes).to(self.device)
        self.critic = ValueNet(obs_dim, cfg.hidden_sizes).to(self.device)

        self.pi_opt = optim.Adam(self.policy.parameters(), lr=cfg.policy_lr)
        self.vf_opt = optim.Adam(self.critic.parameters(), lr=cfg.critic_lr)

        self.buffer = TrajectoryBuffer(self.device)
        self.n_actions = n_actions

    @torch.no_grad()
    def act(self, obs):
        obs_t = to_tensor(obs, self.device).unsqueeze(0)
        logits = self.policy(obs_t)
        dist = torch.distributions.Categorical(logits=logits)
        action = dist.sample()
        logp = dist.log_prob(action).squeeze(0)
        value = self.critic(obs_t).squeeze(0)
        return int(action.item()), logp, float(value.item())

    def compute_returns_and_advantages(self, rewards, dones, values, gamma):
        # Episode-wise discounted returns since CartPole is episodic
        returns = torch.zeros_like(rewards, device=self.device)
        G = 0.0
        for t in reversed(range(len(rewards))):
            if dones[t]:
                G = 0.0
            G = rewards[t] + gamma * G
            returns[t] = G

        advantages = returns - values.detach()
        # Normalize advantages for stability
        adv_mean, adv_std = advantages.mean(), advantages.std(unbiased=False) + 1e-8
        advantages = (advantages - adv_mean) / adv_std
        return returns, advantages

    def reflect_and_update(self):
        # Prepare batch tensors
        states, actions, rewards, dones, logps_old, values = self.buffer.as_tensors()
        returns, advantages = self.compute_returns_and_advantages(
            rewards, dones, values, self.cfg.gamma
        )

        # Critic reflection: multiple epochs fitting V to returns
        for _ in range(self.cfg.reflection_epochs_critic):
            v_pred = self.critic(states)
            v_loss = torch.nn.functional.mse_loss(v_pred, returns)
            self.vf_opt.zero_grad()
            v_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.cfg.grad_clip)
            self.vf_opt.step()

        # After critic stabilizes, recompute values and advantages (optional second look)
        with torch.no_grad():
            values_refined = self.critic(states)
        returns_refined, advantages_refined = self.compute_returns_and_advantages(
            rewards, dones, values_refined, self.cfg.gamma
        )

        # Policy reflection: multiple epochs using refined advantages
        for _ in range(self.cfg.reflection_epochs_policy):
            logits = self.policy(states)
            dist = torch.distributions.Categorical(logits=logits)
            logps = dist.log_prob(actions)
            entropy = dist.entropy().mean()

            policy_loss = -(logps * advantages_refined).mean()
            value_loss = torch.nn.functional.mse_loss(self.critic(states), returns_refined)
            loss = policy_loss + self.cfg.value_coef * value_loss - self.cfg.entropy_coef * entropy

            self.pi_opt.zero_grad()
            self.vf_opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.cfg.grad_clip)
            torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.cfg.grad_clip)
            self.pi_opt.step()
            self.vf_opt.step()

        # Clear memory after reflection
        self.buffer.clear()

    def save(self, path_prefix):
        os.makedirs(os.path.dirname(path_prefix), exist_ok=True)
        torch.save(self.policy.state_dict(), f"{path_prefix}_policy.pt")
        torch.save(self.critic.state_dict(), f"{path_prefix}_critic.pt")

    def load(self, path_prefix):
        self.policy.load_state_dict(torch.load(f"{path_prefix}_policy.pt", map_location=self.device))
        self.critic.load_state_dict(torch.load(f"{path_prefix}_critic.pt", map_location=self.device))


# -----------------------------
# Training loop
# -----------------------------

def train(cfg: Config):
    set_seed(cfg.seed)
    env = gym.make(cfg.env_id)
    obs_dim = env.observation_space.shape[0]
    n_actions = env.action_space.n

    agent = SelfReflectiveAgent(obs_dim, n_actions, cfg)

    ep_count = 0
    best_avg = -math.inf
    rolling = []

    while ep_count < cfg.max_episodes:
        # Collect a batch of episodes
        batch_start = ep_count
        while ep_count - batch_start < cfg.episodes_per_batch and ep_count < cfg.max_episodes:
            obs, info = env.reset(seed=cfg.seed + ep_count) if hasattr(env, "reset") and "gymnasium" in env.__class__.__module__ else (env.reset(), {})
            done = False
            ep_return = 0.0

            while not done:
                action, logp, value = agent.act(obs)
                step_out = env.step(action)
                if len(step_out) == 5:
                    next_obs, reward, terminated, truncated, info = step_out
                    done = terminated or truncated
                else:
                    next_obs, reward, done, info = step_out
                agent.buffer.store_step(
                    state=to_tensor(obs, agent.device),
                    action=action,
                    reward=reward,
                    done=done,
                    logp=logp,
                    value=value,
                )
                obs = next_obs
                ep_return += reward

            agent.buffer.finalize_episode(ep_return)
            rolling.append(ep_return)
            if len(rolling) > 50:
                rolling.pop(0)
            ep_count += 1

        # Reflection and updates over the collected batch
        agent.reflect_and_update()

        # Logging
        avg_return = float(np.mean(rolling)) if rolling else 0.0
        if ep_count % cfg.log_interval == 0:
            print(f"[{ep_count:4d}] avg_return(50)={avg_return:7.2f}  last_batch_mean={np.mean(agent.buffer.episode_returns) if agent.buffer.episode_returns else float('nan')}")

        # Checkpointing
        if ep_count % cfg.checkpoint_every == 0:
            agent.save(os.path.join(cfg.out_dir, f"ep{ep_count}"))
            # Track best by rolling average
            if avg_return > best_avg:
                best_avg = avg_return
                agent.save(os.path.join(cfg.out_dir, "best"))

    env.close()
    print("Training complete.")
    return agent


if __name__ == "__main__":
    cfg = Config()
    # Auto-select CUDA if available and desired
    if torch.cuda.is_available():
        cfg.device = "cuda"
    os.makedirs(cfg.out_dir, exist_ok=True)
    trained_agent = train(cfg)