<a href="https://colab.research.google.com/github/OneFineStarstuff/Pinn/blob/main/agi_self_reflect_cartpole_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
AGIAgent: environment-ready, memory-augmented, self-reflective RL on CartPole-v1

What this does
- Policy + value networks collect episodes in CartPole.
- Experiences are encoded and written into a simple differentiable memory.
- A lightweight "reasoner" attends over memory for queries.
- A planner wrapper selects actions from the policy (extensible to lookahead).
- After each batch, the agent "reflects" by replaying memory to refine the critic,
  then updates the policy with refined advantages.

Safe defaults
- Seeding, gradient clipping, entropy bonus, normalized advantages, checkpoints.

Run
  python agi_self_reflect_cartpole.py
"""

import os
import math
import time
import random
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

# Gym import with fallback
try:
    import gymnasium as gym
except ImportError:
    import gym


# -----------------------------
# Config and utilities
# -----------------------------

@dataclass
class Config:
    env_id: str = "CartPole-v1"
    max_episodes: int = 400
    episodes_per_batch: int = 8
    gamma: float = 0.99
    policy_lr: float = 3e-4
    critic_lr: float = 5e-4
    entropy_coef: float = 0.01
    value_coef: float = 0.5
    grad_clip: float = 0.5
    reflection_epochs_critic: int = 4
    reflection_epochs_policy: int = 2
    hidden_sizes: tuple = (128, 128)
    seed: int = 42
    device: str = "cpu"  # set to "cuda" if available
    log_interval: int = 10
    checkpoint_every: int = 100
    out_dir: str = "checkpoints_agi"

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def to_tensor(x, device):
    return torch.as_tensor(x, dtype=torch.float32, device=device)


# -----------------------------
# Models and encoders
# -----------------------------

class MLP(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_sizes=(128, 128), act=nn.Tanh):
        super().__init__()
        layers = []
        last = in_dim
        for h in hidden_sizes:
            layers += [nn.Linear(last, h), act()]
            last = h
        layers.append(nn.Linear(last, out_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

class PolicyNet(nn.Module):
    def __init__(self, obs_dim, n_actions, hidden_sizes=(128, 128)):
        super().__init__()
        self.backbone = MLP(obs_dim, n_actions, hidden_sizes=hidden_sizes, act=nn.Tanh)

    def forward(self, obs):
        return self.backbone(obs)  # logits

class ValueNet(nn.Module):
    def __init__(self, obs_dim, hidden_sizes=(128, 128)):
        super().__init__()
        self.backbone = MLP(obs_dim, 1, hidden_sizes=hidden_sizes, act=nn.Tanh)

    def forward(self, obs):
        return self.backbone(obs).squeeze(-1)

class Encoder(nn.Module):
    def __init__(self, in_dim, emb_dim=64, hidden_sizes=(64,)):
        super().__init__()
        self.proj = MLP(in_dim, emb_dim, hidden_sizes=hidden_sizes, act=nn.ReLU)

    def forward(self, x):
        return self.proj(x)


# -----------------------------
# Simple differentiable memory
# -----------------------------

class DifferentiableMemory(nn.Module):
    """
    Content-addressable memory:
    - write: encode experience dict -> embedding, append to memory bank
    - read: cosine attention over embeddings with a query embedding
    Memory embeddings are differentiable via the encoder; retrieval uses soft weights.
    """
    def __init__(self, encoder: Encoder, device="cpu"):
        super().__init__()
        self.encoder = encoder
        self.device = torch.device(device)
        self.clear()

    def clear(self):
        self.embeds = []      # list of tensors [E]
        self.payloads = []    # raw experiences

    @torch.no_grad()
    def write(self, experience: dict):
        # Expect experience has 'state' (tensor-like) and optional extras
        state = experience.get("state")
        if not torch.is_tensor(state):
            state = to_tensor(state, self.device)
        emb = self.encoder(state.unsqueeze(0)).squeeze(0).detach()
        self.embeds.append(emb)
        self.payloads.append(experience)

    def read(self, query_embedding: torch.Tensor, topk=3):
        if len(self.embeds) == 0:
            return [], torch.empty(0, device=self.device)
        E = torch.stack(self.embeds, dim=0)  # [N, D]
        q = query_embedding.unsqueeze(0)     # [1, D]
        # cosine similarity
        E_norm = torch.nn.functional.normalize(E, dim=-1)
        q_norm = torch.nn.functional.normalize(q, dim=-1)
        sims = (E_norm @ q_norm.t()).squeeze(-1)  # [N]
        vals, idx = torch.topk(sims, k=min(topk, E.shape[0]))
        results = [self.payloads[i] for i in idx.tolist()]
        return results, vals


# -----------------------------
# Reasoner and planner
# -----------------------------

class SimpleReasoner:
    """
    Aggregates top-k memory matches and returns a small report.
    Extendable to structured logic or graph query.
    """
    def __init__(self, memory: DifferentiableMemory):
        self.memory = memory

    def __call__(self, query_embedding, topk=3):
        items, weights = self.memory.read(query_embedding, topk=topk)
        summary = {
            "count": len(items),
            "weights": weights.detach().cpu().tolist(),
            "samples": [{k: (v if isinstance(v, (int, float, str)) else "tensor")}
                        for k, v in (items[0].items() if items else [])]
        }
        return summary

class PolicyPlanner:
    """
    Planner wrapper: selects actions via policy distribution, with temperature.
    You can replace this with MCTS/model-based lookahead later.
    """
    def __init__(self, policy: PolicyNet, temperature: float = 1.0, device="cpu"):
        self.policy = policy
        self.temperature = max(1e-3, float(temperature))
        self.device = torch.device(device)

    @torch.no_grad()
    def __call__(self, state_tensor: torch.Tensor):
        logits = self.policy(state_tensor.unsqueeze(0)) / self.temperature
        dist = torch.distributions.Categorical(logits=logits)
        action = dist.sample().item()
        return action


# -----------------------------
# Self-reflective RL core
# -----------------------------

class TrajectoryBuffer:
    def __init__(self, device):
        self.device = device
        self.clear()

    def clear(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.logps = []
        self.values = []
        self.episode_returns = []

    def store_step(self, state, action, reward, done, logp, value):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.dones.append(done)
        self.logps.append(logp)
        self.values.append(value)

    def finalize_episode(self, ep_return):
        self.episode_returns.append(ep_return)

    def as_tensors(self):
        states = torch.stack(self.states, dim=0).to(self.device)
        actions = torch.as_tensor(self.actions, dtype=torch.int64, device=self.device)
        rewards = torch.as_tensor(self.rewards, dtype=torch.float32, device=self.device)
        dones = torch.as_tensor(self.dones, dtype=torch.bool, device=self.device)
        logps = torch.stack(self.logps, dim=0).to(self.device)
        values = torch.as_tensor(self.values, dtype=torch.float32, device=self.device)
        return states, actions, rewards, dones, logps, values

class SelfReflectiveRL:
    def __init__(self, obs_dim, n_actions, cfg: Config):
        self.cfg = cfg
        self.device = torch.device(cfg.device)
        self.policy = PolicyNet(obs_dim, n_actions, cfg.hidden_sizes).to(self.device)
        self.critic = ValueNet(obs_dim, cfg.hidden_sizes).to(self.device)
        self.pi_opt = optim.Adam(self.policy.parameters(), lr=cfg.policy_lr)
        self.vf_opt = optim.Adam(self.critic.parameters(), lr=cfg.critic_lr)
        self.buffer = TrajectoryBuffer(self.device)

    @torch.no_grad()
    def act(self, obs):
        obs_t = to_tensor(obs, self.device).unsqueeze(0)
        logits = self.policy(obs_t)
        dist = torch.distributions.Categorical(logits=logits)
        action = dist.sample()
        logp = dist.log_prob(action).squeeze(0)
        value = self.critic(obs_t).squeeze(0)
        return int(action.item()), logp, float(value.item())

    def compute_returns_and_advantages(self, rewards, dones, values, gamma):
        returns = torch.zeros_like(rewards, device=self.device)
        G = 0.0
        for t in reversed(range(len(rewards))):
            if dones[t]:
                G = 0.0
            G = rewards[t] + gamma * G
            returns[t] = G
        advantages = returns - values.detach()
        advantages = (advantages - advantages.mean()) / (advantages.std(unbiased=False) + 1e-8)
        return returns, advantages

    def reflect_and_update(self):
        states, actions, rewards, dones, logps_old, values = self.buffer.as_tensors()
        returns, advantages = self.compute_returns_and_advantages(rewards, dones, values, self.cfg.gamma)

        # Critic reflection
        for _ in range(self.cfg.reflection_epochs_critic):
            v_pred = self.critic(states)
            v_loss = torch.nn.functional.mse_loss(v_pred, returns)
            self.vf_opt.zero_grad()
            v_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.cfg.grad_clip)
            self.vf_opt.step()

        # Recompute with refined critic
        with torch.no_grad():
            values_refined = self.critic(states)
        returns_refined, advantages_refined = self.compute_returns_and_advantages(
            rewards, dones, values_refined, self.cfg.gamma
        )

        # Policy reflection (+ small auxiliary value fit and entropy)
        for _ in range(self.cfg.reflection_epochs_policy):
            logits = self.policy(states)
            dist = torch.distributions.Categorical(logits=logits)
            logps = dist.log_prob(actions)
            entropy = dist.entropy().mean()

            policy_loss = -(logps * advantages_refined).mean()
            value_loss = torch.nn.functional.mse_loss(self.critic(states), returns_refined)
            loss = policy_loss + self.cfg.value_coef * value_loss - self.cfg.entropy_coef * entropy

            self.pi_opt.zero_grad()
            self.vf_opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.cfg.grad_clip)
            torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.cfg.grad_clip)
            self.pi_opt.step()
            self.vf_opt.step()

        stats = {
            "batch_ep_mean": float(np.mean(self.buffer.episode_returns)) if self.buffer.episode_returns else float("nan"),
            "batch_ep_count": len(self.buffer.episode_returns),
        }
        self.buffer.clear()
        return stats


# -----------------------------
# Optional: lightweight MAML stub
# -----------------------------

class MAMLStub:
    """
    Placeholder for meta-learner. Returns a small supervised loss on embeddings.
    You can swap this with your real MAML implementation without changing AGIAgent.
    """
    def __init__(self, encoder: Encoder, lr_inner=0.01):
        self.encoder = encoder
        self.lr_inner = lr_inner

    def __call__(self, task_batch: torch.Tensor):
        emb = self.encoder(task_batch)
        target = torch.zeros_like(emb)
        return ((emb - target) ** 2).mean()


# -----------------------------
# AGIAgent orchestration
# -----------------------------

class AGIAgent:
    def __init__(self, obs_dim, n_actions, cfg: Config):
        self.cfg = cfg
        self.device = torch.device(cfg.device)

        # Core RL
        self.rl = SelfReflectiveRL(obs_dim, n_actions, cfg)

        # Shared encoder for memory/reasoning (decoupled from policy/value)
        self.encoder = Encoder(in_dim=obs_dim, emb_dim=64, hidden_sizes=(64,)).to(self.device)

        # Memory + reasoner + planner
        self.memory = DifferentiableMemory(self.encoder, device=self.device)
        self.reasoner = SimpleReasoner(self.memory)
        self.planner = PolicyPlanner(self.rl.policy, temperature=1.0, device=self.device)

        # Optional meta-learner on encoded states (stub)
        self.maml = MAMLStub(self.encoder, lr_inner=0.01)

    def encode(self, x):
        x_t = x if torch.is_tensor(x) else to_tensor(x, self.device)
        return self.encoder(x_t)

    def store_experience(self, experience: dict):
        self.memory.write(experience)

    def reason_about(self, query_vec):
        if not torch.is_tensor(query_vec):
            query_vec = to_tensor(query_vec, self.device)
        if query_vec.ndim == 1:
            q_emb = self.encoder(query_vec)
        else:
            q_emb = self.encoder(query_vec).mean(dim=0)
        return self.reasoner(q_emb, topk=3)

    def plan_action(self, state_tensor: torch.Tensor):
        return self.planner(state_tensor)

    def learn_meta(self, task_batch: torch.Tensor):
        loss = self.maml(task_batch)
        loss.backward()
        return float(loss.item())

    def interact_and_learn(self, env, max_episodes=None):
        cfg = self.cfg
        max_episodes = max_episodes or cfg.max_episodes

        ep_count = 0
        rolling = []
        best_avg = -math.inf
        os.makedirs(cfg.out_dir, exist_ok=True)

        while ep_count < max_episodes:
            # Collect a batch of episodes
            batch_returns = []
            batch_start = ep_count

            while ep_count - batch_start < cfg.episodes_per_batch and ep_count < max_episodes:
                obs, info = env.reset(seed=cfg.seed + ep_count) if hasattr(env, "reset") and "gymnasium" in env.__class__.__module__ else (env.reset(), {})
                done = False
                ep_return = 0.0

                while not done:
                    state_t = to_tensor(obs, self.device)
                    # Plan (wrapper over policy)
                    action = self.plan_action(state_t)

                    # Step environment
                    step_out = env.step(action)
                    if len(step_out) == 5:
                        next_obs, reward, terminated, truncated, info = step_out
                        done = terminated or truncated
                    else:
                        next_obs, reward, done, info = step_out

                    # Also query RL policy for logp/value for learning
                    a_rl, logp, value = self.rl.act(obs)  # note: a_rl is sampled; we record action we actually took
                    # Store into RL buffer (with the action we executed via planner)
                    self.rl.buffer.store_step(
                        state=state_t,
                        action=action,
                        reward=reward,
                        done=done,
                        logp=logp,
                        value=value,
                    )

                    # Write to external memory
                    self.store_experience({
                        "state": state_t.detach().cpu(),
                        "action": int(action),
                        "reward": float(reward),
                        "done": bool(done),
                    })

                    obs = next_obs
                    ep_return += reward

                self.rl.buffer.finalize_episode(ep_return)
                batch_returns.append(ep_return)
                rolling.append(ep_return)
                if len(rolling) > 50:
                    rolling.pop(0)
                ep_count += 1

            # Reflection and updates over the collected batch
            reflect_stats = self.rl.reflect_and_update()

            # Logging
            avg_return = float(np.mean(rolling)) if rolling else 0.0
            last_batch_mean = float(np.mean(batch_returns)) if batch_returns else float("nan")
            if ep_count % cfg.log_interval == 0:
                print(f"[{ep_count:4d}] avg_return(50)={avg_return:7.2f}  last_batch_mean={last_batch_mean:7.2f}  batch_ep_count={reflect_stats['batch_ep_count']}")

            # Checkpointing
            if ep_count % cfg.checkpoint_every == 0:
                torch.save(self.rl.policy.state_dict(), os.path.join(cfg.out_dir, f"ep{ep_count}_policy.pt"))
                torch.save(self.rl.critic.state_dict(), os.path.join(cfg.out_dir, f"ep{ep_count}_critic.pt"))
                if avg_return > best_avg:
                    best_avg = avg_return
                    torch.save(self.rl.policy.state_dict(), os.path.join(cfg.out_dir, "best_policy.pt"))
                    torch.save(self.rl.critic.state_dict(), os.path.join(cfg.out_dir, "best_critic.pt"))

        print("Training complete.")


# -----------------------------
# Entrypoint
# -----------------------------

def main():
    cfg = Config()
    if torch.cuda.is_available():
        cfg.device = "cuda"
    set_seed(cfg.seed)

    env = gym.make(cfg.env_id)
    obs_dim = env.observation_space.shape[0]
    n_actions = env.action_space.n

    agent = AGIAgent(obs_dim, n_actions, cfg)

    # Optional: quick meta-learning warmup on random "tasks"
    task_batch = torch.randn(16, obs_dim, device=cfg.device)
    _ = agent.learn_meta(task_batch)

    # Train in the environment with self-reflection and memory
    agent.interact_and_learn(env, max_episodes=cfg.max_episodes)
    env.close()

if __name__ == "__main__":
    main()