<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/unified_ai_agent_policy_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# unified_ai/agent/policy.py
from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F

@dataclass
class PolicyConfig:
    obs_dim: int
    hidden_dim: int
    actions: int
    lr: float
    gamma: float
    epsilon: float
    epsilon_min: float
    epsilon_decay: float

class QNetwork(nn.Module):
    def __init__(self, obs_dim: int, hidden: int, actions: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
            nn.Linear(hidden, actions)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

class AgentPolicy:
    def __init__(self, cfg: PolicyConfig, device: torch.device):
        self.cfg = cfg
        self.device = device
        self.q = QNetwork(cfg.obs_dim, cfg.hidden_dim, cfg.actions).to(device)
        self.opt = torch.optim.Adam(self.q.parameters(), lr=cfg.lr)
        self.epsilon = cfg.epsilon

    def select_action(self, obs: torch.Tensor) -> int:
        if torch.rand(()) < self.epsilon:
            return int(torch.randint(0, 3, (1,)).item() - 1)  # map {0,1,2} -> {-1,0,1}
        with torch.no_grad():
            q = self.q(obs.unsqueeze(0))
            a_idx = int(q.argmax(dim=-1).item())
            return a_idx - 1

    def update(self, s: torch.Tensor, a: int, r: float, s2: torch.Tensor, done: bool):
        a_idx = a + 1  # map {-1,0,1} -> {0,1,2}
        q = self.q(s.unsqueeze(0))
        q_sa = q[0, a_idx]
        with torch.no_grad():
            target = torch.tensor(r, device=self.device)
            if not done:
                q2 = self.q(s2.unsqueeze(0)).max(dim=-1).values[0]
                target = target + self.cfg.gamma * q2
        loss = F.smooth_l1_loss(q_sa, target)
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()
        return float(loss.item())

    def decay_epsilon(self):
        self.epsilon = max(self.cfg.epsilon_min, self.epsilon * self.cfg.epsilon_decay)

    def boost_epsilon(self, amount: float):
        self.epsilon = min(0.9, self.epsilon + amount)