<a href="https://colab.research.google.com/github/OneFineStarstuff/Pinn/blob/main/training_loop_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# provenance: Copilot x Kyaw, 2025-08-23
# purpose: Memory-conditioned actor with pluggable fusion (concat, FiLM, gated)
# deps: torch>=2.0

from dataclasses import dataclass
from typing import Optional, Tuple, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F

# ---------- utils ----------

def mlp(in_dim, hidden, out_dim, act=nn.Tanh):
    layers = []
    dims = [in_dim] + hidden + [out_dim]
    for i in range(len(dims)-2):
        layers += [nn.Linear(dims[i], dims[i+1]), act()]
    layers += [nn.Linear(dims[-2], dims[-1])]
    return nn.Sequential(*layers)

# ---------- interfaces ----------

class MemoryModule(nn.Module):
    def __init__(self, d_model: int, d_mem: int, capacity: int = 1024, k: int = 8):
        super().__init__()
        self.key_proj = nn.Linear(d_model, d_mem, bias=False)
        self.val_proj = nn.Linear(d_model, d_mem, bias=False)
        self.register_buffer("keys", torch.zeros(capacity, d_mem))
        self.register_buffer("vals", torch.zeros(capacity, d_mem))
        self.capacity = capacity
        self.k = k
        self.ptr = 0
        self.full = False

    @torch.no_grad()
    def write(self, state_feat: torch.Tensor):
        K = F.normalize(self.key_proj(state_feat), dim=-1)  # (B, d_mem)
        V = self.val_proj(state_feat)                        # (B, d_mem)
        B = K.size(0)
        idx = torch.arange(B, device=K.device)
        pos = (self.ptr + idx) % self.capacity
        self.keys[pos] = K
        self.vals[pos] = V
        self.ptr = int((self.ptr + B) % self.capacity)
        if self.ptr == 0: self.full = True

    def retrieve(self, query_feat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        if (self.ptr == 0) and (not self.full):
            B = query_feat.size(0)
            return torch.zeros(B, self.vals.size(-1), device=query_feat.device), torch.zeros(B, self.k, device=query_feat.device)
        Kq = F.normalize(self.key_proj(query_feat), dim=-1)       # (B, d_mem)
        K = self.keys if self.full else self.keys[:self.ptr]      # (N, d_mem)
        V = self.vals if self.full else self.vals[:self.ptr]      # (N, d_mem)
        if K.numel() == 0:
            B = query_feat.size(0)
            return torch.zeros(B, V.size(-1), device=query_feat.device), torch.zeros(B, self.k, device=query_feat.device)
        sims = Kq @ K.T                                           # (B, N)
        topk = min(self.k, sims.size(-1))
        scores, idx = sims.topk(topk, dim=-1)                     # (B, k)
        attn = F.softmax(scores, dim=-1)                          # (B, k)
        picked = V[idx]                                           # (B, k, d_mem)
        ctx = torch.einsum('bk,bkd->bd', attn, picked)            # (B, d_mem)
        return ctx, attn

class Reasoner(nn.Module):
    def __init__(self, d_model: int, d_reason: int):
        super().__init__()
        self.net = mlp(d_model, [d_model], d_reason, act=nn.ReLU)

    def forward(self, state_feat: torch.Tensor, extra: Optional[Dict]=None) -> torch.Tensor:
        return self.net(state_feat)

# ---------- fusion ----------

class MemoryConditioner(nn.Module):
    def __init__(self, d_model: int, d_mem: int, d_reason: int, fuse: str = "film", dropout: float = 0.0):
        super().__init__()
        self.fuse = fuse
        self.dropout = nn.Dropout(dropout)
        if fuse == "concat":
            self.proj = nn.Linear(d_model + d_mem + d_reason, d_model)
        elif fuse == "film":
            self.gamma = mlp(d_mem + d_reason, [max(64, d_model//2)], d_model, act=nn.ReLU)
            self.beta  = mlp(d_mem + d_reason, [max(64, d_model//2)], d_model, act=nn.ReLU)
        elif fuse == "gated":
            self.gate = mlp(d_model + d_mem + d_reason, [max(64, d_model//2)], d_model, act=nn.ReLU)
        else:
            raise ValueError(f"unknown fuse: {fuse}")

    def forward(self, state_feat: torch.Tensor, mem_ctx: torch.Tensor, reason_ctx: torch.Tensor) -> torch.Tensor:
        if self.fuse == "concat":
            x = torch.cat([state_feat, mem_ctx, reason_ctx], dim=-1)
            return self.dropout(torch.tanh(self.proj(x)))
        elif self.fuse == "film":
            cond = torch.cat([mem_ctx, reason_ctx], dim=-1)
            gamma = self.gamma(cond)
            beta  = self.beta(cond)
            return self.dropout((1 + gamma) * state_feat + beta)
        else:  # gated
            x = torch.cat([state_feat, mem_ctx, reason_ctx], dim=-1)
            gate = torch.sigmoid(self.gate(x))
            return self.dropout(gate * state_feat + (1 - gate) * x[..., :state_feat.size(-1)])

# ---------- actor-critic ----------

class ActorCritic(nn.Module):
    def __init__(self, obs_dim: int, n_act: int, d_model: int = 128, d_mem: int = 128, d_reason: int = 64, fuse: str = "film",
                 use_memory: bool = True, use_reasoner: bool = True):
        super().__init__()
        self.encoder = mlp(obs_dim, [256, 256], d_model, act=nn.ReLU)
        self.use_memory = use_memory
        self.use_reasoner = use_reasoner
        self.memory = MemoryModule(d_model, d_mem)
        self.reasoner = Reasoner(d_model, d_reason)
        self.conditioner = MemoryConditioner(d_model, d_mem if use_memory else 0, d_reason if use_reasoner else 0, fuse=fuse)
        self.policy = mlp(d_model, [128], n_act, act=nn.ReLU)
        self.value  = mlp(d_model, [128], 1, act=nn.ReLU)

    def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
        B = obs.size(0)
        feat = self.encoder(obs)                                  # (B, d_model)
        mem_ctx = torch.zeros(B, self.memory.vals.size(-1), device=obs.device)
        reason_ctx = torch.zeros(B, self.reasoner.net[-1].out_features, device=obs.device)
        if self.use_memory:
            mem_ctx, _ = self.memory.retrieve(feat)
        if self.use_reasoner:
            reason_ctx = self.reasoner(feat)
        cond = self.conditioner(feat, mem_ctx, reason_ctx)
        logits = self.policy(cond)
        value  = self.value(cond)
        aux = {"state_feat": feat, "mem_ctx": mem_ctx, "reason_ctx": reason_ctx}
        return logits, value, aux

    @torch.no_grad()
    def write_memory(self, feat: torch.Tensor):
        if self.use_memory:
            self.memory.write(feat)

# ---------- action selection ----------

def select_action(net: ActorCritic, obs: torch.Tensor, epsilon: float = 0.0):
    logits, value, aux = net(obs.unsqueeze(0))
    if torch.rand(()) < epsilon:
        act = torch.randint(0, logits.size(-1), (1,), device=logits.device)
    else:
        probs = F.softmax(logits, dim=-1)
        act = torch.distributions.Categorical(probs=probs).sample()
    net.write_memory(aux["state_feat"])  # log after acting
    return act.item(), value.squeeze(0), aux

# training_loop.py (snippet)
# assumes: env, optimizer, advantage_estimator are defined

def rollout(env, net: ActorCritic, steps: int, gamma=0.99, lam=0.95, epsilon=0.0):
    obs = torch.tensor(env.reset(), dtype=torch.float32)
    traj = {k: [] for k in ["obs","act","logp","val","rew","done","mem_ctx","reason_ctx"]}
    for t in range(steps):
        logits, value, aux = net(obs.unsqueeze(0))
        probs = F.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs=probs)
        act = dist.sample()
        logp = dist.log_prob(act)
        next_obs, rew, done, _ = env.step(act.item())
        net.write_memory(aux["state_feat"])
        # store
        traj["obs"].append(obs)
        traj["act"].append(act.squeeze(0))
        traj["logp"].append(logp.squeeze(0))
        traj["val"].append(value.squeeze(0))
        traj["rew"].append(torch.tensor(rew, dtype=torch.float32))
        traj["done"].append(torch.tensor(done, dtype=torch.float32))
        traj["mem_ctx"].append(aux["mem_ctx"].detach().squeeze(0))
        traj["reason_ctx"].append(aux["reason_ctx"].detach().squeeze(0))
        obs = torch.tensor(next_obs, dtype=torch.float32)
        if done:
            obs = torch.tensor(env.reset(), dtype=torch.float32)
            # optional: flush memory between episodes if desired
            # net.memory.ptr = 0; net.memory.full = False
    # stack
    for k in traj: traj[k] = torch.stack(traj[k])
    returns, adv = advantage_estimator(traj["rew"], traj["val"], traj["done"], gamma, lam)
    return traj, returns, adv

def update(traj, returns, adv, net: ActorCritic, clip_ratio=0.2, vf_coef=0.5, ent_coef=0.01):
    B = traj["obs"].size(0)
    logits, values, _ = net(traj["obs"])
    dist = torch.distributions.Categorical(logits=logits)
    logp = dist.log_prob(traj["act"])
    ratio = torch.exp(logp - traj["logp"])
    clip_adv = torch.clamp(ratio, 1-clip_ratio, 1+clip_ratio) * adv
    policy_loss = -(torch.min(ratio * adv, clip_adv)).mean()
    value_loss = F.mse_loss(values.squeeze(-1), returns)
    entropy = dist.entropy().mean()
    loss = policy_loss + vf_coef * value_loss - ent_coef * entropy
    return loss, {"policy_loss": policy_loss.item(), "value_loss": value_loss.item(), "entropy": entropy.item()}