<a href="https://colab.research.google.com/github/NNehmer/stc-alberta/blob/main/STC_Alberta_Agent_V1_0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip -q install --upgrade torch

In [None]:
import torch, sys
print("Torch:", torch.__version__, "Python:", sys.version.split()[0])

Torch: 2.8.0+cu126 Python: 3.12.11


In [None]:
"""
STC–Alberta Agent — Colab-sichere Version (nur Definitionen)
- konsistente Einrückung (4 Spaces)
- keine Tabs
- kein automatischer Startblock
- imagine() akzeptiert Index- oder One-Hot-Aktionen
- Entropie-Bonus, EMA-Target, korrekter Kommutator, saubere Logs
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import deque
import random
from typing import Dict, Tuple, List

# ------------------------------------------------------------
# 1) STC Bausteine
# ------------------------------------------------------------
class SpectralProjector(nn.Module):
    """Intentionaler Subraum-Projektor Π_S (QR-orthonormalisiert)"""
    def __init__(self, latent_dim: int, intent_rank: int):
        super().__init__()
        self.latent_dim = latent_dim
        self.intent_rank = intent_rank
        self.basis = nn.Parameter(torch.randn(latent_dim, intent_rank) * 0.1)

    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        Q, _ = torch.linalg.qr(self.basis)   # [D, r]
        Pi_S = Q @ Q.T                       # [D, D]
        z_S = z @ Pi_S                       # Projektion
        return z_S, Pi_S

def coherence(z: torch.Tensor, z_S: torch.Tensor) -> torch.Tensor:
    """κ(ψ) = ||Π_S ψ||² / ||ψ||²  (formstabil, keepdim=True)"""
    norm_z   = torch.norm(z,   dim=-1, keepdim=True) + 1e-8
    norm_z_S = torch.norm(z_S, dim=-1, keepdim=True)
    return (norm_z_S / norm_z) ** 2

class ValueOperator(nn.Module):
    """Symmetrischer Wertoperator V, v(ψ)=⟨ψ|V|ψ⟩"""
    def __init__(self, latent_dim: int):
        super().__init__()
        self.W = nn.Parameter(torch.randn(latent_dim, latent_dim) * 0.01)

    def forward(self, z_S: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        V = 0.5 * (self.W + self.W.T)
        v = torch.einsum('bi,ij,bj->b', z_S, V, z_S)  # ⟨ψ|V|ψ⟩
        return v.unsqueeze(-1), V

# ------------------------------------------------------------
# 2) Agent
# ------------------------------------------------------------
class STCAlbertaAgent(nn.Module):
    def __init__(self, obs_dim: int, action_dim: int, latent_dim: int = 64, intent_rank: int = 16):
        super().__init__()
        self.obs_dim     = obs_dim
        self.action_dim  = action_dim
        self.latent_dim  = latent_dim
        self.intent_rank = intent_rank

        self.encoder = nn.Sequential(
            nn.Linear(obs_dim, 128), nn.LayerNorm(128), nn.ReLU(),
            nn.Linear(128, latent_dim)
        )
        self.projector  = SpectralProjector(latent_dim, intent_rank)
        self.value_op   = ValueOperator(latent_dim)
        self.value_target = ValueOperator(latent_dim)        # EMA-Target
        self.value_target.load_state_dict(self.value_op.state_dict())
        self.value_tau = 0.005

        self.policy = nn.Sequential(
            nn.Linear(latent_dim, 128), nn.ReLU(), nn.Linear(128, action_dim)
        )
        self.transition = nn.Sequential(
            nn.Linear(latent_dim + action_dim, 128), nn.ReLU(), nn.Linear(128, latent_dim)
        )
        self.reward_head = nn.Sequential(
            nn.Linear(latent_dim + action_dim, 64), nn.ReLU(), nn.Linear(64, 1)
        )

    @torch.no_grad()
    def update_value_target(self):
        for p, pt in zip(self.value_op.parameters(), self.value_target.parameters()):
            pt.data.mul_(1 - self.value_tau).add_(self.value_tau * p.data)

    def forward(self, obs: torch.Tensor) -> Dict[str, torch.Tensor]:
        z      = self.encoder(obs)
        z_S, P = self.projector(z)
        kappa  = coherence(z, z_S)
        value, V = self.value_op(z_S)
        logits = self.policy(z_S)
        return {
            'latent': z, 'latent_S': z_S, 'projector': P, 'value_matrix': V,
            'kappa': kappa, 'value': value, 'logits': logits
        }

    def imagine(self, z_S: torch.Tensor, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """actions: [B] (Indices) oder [B,A] (One-Hot)"""
        if actions.dim() == 1:
            a_oh = F.one_hot(actions.long(), num_classes=self.action_dim).float()
        elif actions.dim() == 2 and actions.size(-1) == self.action_dim:
            a_oh = actions.float()
        else:
            raise ValueError("actions must be [B] (indices) or [B,A] (one-hot)")
        za = torch.cat([z_S, a_oh], dim=-1)
        z_S_next = self.transition(za)
        reward   = self.reward_head(za)
        return z_S_next, reward

    def select_action(self, obs: torch.Tensor, epsilon: float = 0.0) -> int:
        with torch.no_grad():
            out = self.forward(obs.unsqueeze(0))
            if random.random() < epsilon:
                return random.randint(0, self.action_dim - 1)
            return out['logits'].argmax(dim=-1).item()

# ------------------------------------------------------------
# 3) Verluste
# ------------------------------------------------------------
def stc_loss(
    agent: STCAlbertaAgent,
    obs: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor,
    next_obs: torch.Tensor, dones: torch.Tensor,
    gamma: float = 0.99, lambda_kappa: float = 0.1,
    lambda_comm: float = 0.01, lambda_cons: float = 0.5
) -> Dict[str, torch.Tensor]:
    out      = agent(obs)
    out_next = agent(next_obs)

    z_S, P, V = out['latent_S'], out['projector'], out['value_matrix']
    kappa, value, logits = out['kappa'], out['value'], out['logits']

    dist  = torch.distributions.Categorical(logits=logits)
    logp  = dist.log_prob(actions)

    with torch.no_grad():
        v_next_target, _ = agent.value_target(out_next['latent_S'])
        td_target = rewards.unsqueeze(-1) + gamma * v_next_target * (1 - dones.unsqueeze(-1))

    td_err      = td_target - value
    policy_loss = -(logp * td_err.squeeze(-1).detach()).mean()
    value_loss  = F.mse_loss(value, td_target)
    entropy     = dist.entropy().mean()
    L_RL        = policy_loss + value_loss - 1e-3 * entropy

    L_rel  = -kappa.mean()
    comm   = V @ P - P @ V
    L_comm = (torch.norm(comm, p='fro') ** 2) / V.numel()

    z_S_next_pred, reward_pred = agent.imagine(z_S, actions)
    with torch.no_grad():
        v_next_pred_target, _ = agent.value_target(z_S_next_pred)
    q_pred   = reward_pred + gamma * v_next_pred_target * (1 - dones.unsqueeze(-1))
    q_target = rewards.unsqueeze(-1) + gamma * v_next_target * (1 - dones.unsqueeze(-1))
    L_cons   = F.mse_loss(q_pred, q_target)

    total = L_RL + lambda_kappa * L_rel + lambda_comm * L_comm + lambda_cons * L_cons
    return {
        'total_loss': total, 'L_RL': L_RL, 'L_rel': L_rel, 'L_comm': L_comm, 'L_cons': L_cons,
        'kappa': kappa.mean(), 'td_error': td_err.abs().mean(), 'policy_loss': policy_loss,
        'value_loss': value_loss, 'entropy': entropy
    }

# ------------------------------------------------------------
# 4) Einfache 2D-Nav-Umwelt (diskrete Aktionen)
# ------------------------------------------------------------
class SimpleControlEnv:
    """State: [x,y,goal_x,goal_y,vx,vy], Actions: up/down/left/right"""
    def __init__(self):
        self.state_dim  = 6
        self.action_dim = 4
        self.max_steps  = 200
        self.reset()

    def reset(self) -> np.ndarray:
        self.pos   = np.random.uniform(-1, 1, 2)
        self.vel   = np.zeros(2)
        self.goal  = np.random.uniform(-1, 1, 2)
        self.steps = 0
        return self._get_obs()

    def _get_obs(self) -> np.ndarray:
        return np.concatenate([self.pos, self.goal, self.vel])

    def step(self, action: int):
        acc_map = {0: np.array([0, 0.1]), 1: np.array([0, -0.1]),
                   2: np.array([-0.1, 0]), 3: np.array([0.1, 0])}
        acc  = acc_map[action]
        self.vel = np.clip(0.9 * self.vel + acc, -0.5, 0.5)
        self.pos = np.clip(self.pos + self.vel, -1,  1)
        dist = float(np.linalg.norm(self.pos - self.goal))
        reward = -dist
        self.steps += 1
        done = (dist < 0.1) or (self.steps >= self.max_steps)
        return self._get_obs(), reward, done, {'distance': dist}

# ------------------------------------------------------------
# 5) Replay Buffer
# ------------------------------------------------------------
class ReplayBuffer:
    def __init__(self, capacity: int = 10000):
        self.buffer = deque(maxlen=capacity)

    def push(self, obs, action, reward, next_obs, done):
        self.buffer.append((obs, action, reward, next_obs, done))

    def sample(self, batch_size: int) -> Dict[str, torch.Tensor]:
        batch = random.sample(self.buffer, batch_size)
        obs, actions, rewards, next_obs, dones = zip(*batch)
        return {
            'obs':      torch.FloatTensor(np.array(obs)),
            'actions':  torch.LongTensor(actions),
            'rewards':  torch.FloatTensor(rewards),
            'next_obs': torch.FloatTensor(np.array(next_obs)),
            'dones':    torch.FloatTensor(dones)
        }

    def __len__(self) -> int:
        return len(self.buffer)

# ------------------------------------------------------------
# 6) Training (keine Ausführung hier)
# ------------------------------------------------------------
def train_stc_alberta(
    num_episodes: int = 60,
    batch_size: int = 64,
    buffer_size: int = 10000,
    learning_rate: float = 3e-4,
    epsilon_start: float = 1.0,
    epsilon_end: float = 0.05,
    epsilon_decay: float = 0.995,
    print_every: int = 10
):
    env      = SimpleControlEnv()
    agent    = STCAlbertaAgent(obs_dim=6, action_dim=4, latent_dim=64, intent_rank=16)
    optimizer= torch.optim.Adam(agent.parameters(), lr=learning_rate)
    buffer   = ReplayBuffer(capacity=buffer_size)

    epsilon  = epsilon_start
    ep_rewards, ep_kappas, ep_dists = [], [], []

    # Header
    print("="*70)
    print("STC–Alberta Agent Training")
    print("="*70)
    print("Environment: SimpleControlEnv (2D)")
    print(f"State dim: {env.state_dim}, Action dim: {env.action_dim}")
    print("Agent latent dim: 64, Intent rank: 16")
    print("="*70)
    print()

    for episode in range(num_episodes):
        obs = env.reset()
        ep_reward = 0.0
        ep_k_list: List[float] = []
        done = False

        while not done:
            action = agent.select_action(torch.FloatTensor(obs), epsilon)
            next_obs, reward, done, info = env.step(action)
            buffer.push(obs, action, reward, next_obs, float(done))

            with torch.no_grad():
                out = agent(torch.FloatTensor(obs).unsqueeze(0))
                ep_k_list.append(out['kappa'].item())

            obs = next_obs
            ep_reward += reward

        ep_rewards.append(ep_reward)
        ep_kappas.append(float(np.mean(ep_k_list)))
        ep_dists.append(info['distance'])

        if len(buffer) >= batch_size:
            batch = buffer.sample(batch_size)
            optimizer.zero_grad()
            losses = stc_loss(agent, **batch)
            losses['total_loss'].backward()
            torch.nn.utils.clip_grad_norm_(agent.parameters(), 1.0)
            optimizer.step()
            agent.update_value_target()

        epsilon = max(epsilon_end, epsilon * epsilon_decay)

        if (episode + 1) % print_every == 0:
            avgR = float(np.mean(ep_rewards[-print_every:]))
            avgK = float(np.mean(ep_kappas[-print_every:]))
            avgD = float(np.mean(ep_dists[-print_every:]))
            msg  = (f"Episode {episode+1:4d} | Reward: {avgR:7.2f} | κ: {avgK:.3f} "
                    f"| Dist: {avgD:.3f} | ε: {epsilon:.3f}")
            if len(buffer) >= batch_size:
                msg += (f"\n             | L_RL: {losses['L_RL'].item():.4f} | L_rel: {losses['L_rel'].item():.4f} "
                        f"| L_comm: {losses['L_comm'].item():.6f} | L_cons: {losses['L_cons'].item():.4f} "
                        f"| H: {losses['entropy'].item():.3f}")
            print(msg)

    print()
    print("="*70)
    print("Training Complete")
    print("="*70)
    print(f"Final κ (last 30): {np.mean(ep_kappas[-30:]):.3f}")
    print(f"Final reward (last 30): {np.mean(ep_rewards[-30:]):.2f}")
    print(f"Final distance (last 30): {np.mean(ep_dists[-30:]):.3f}")

    return agent, ep_rewards, ep_kappas, ep_dists

# ------------------------------------------------------------
# 7) Ablation (A1: ohne Π_S)
# ------------------------------------------------------------
def run_ablation_no_projector(num_episodes: int = 60):
    print("\n"+"="*70)
    print("ABLATION A1: No Projector (Full Latent Space)")
    print("="*70)

    class NoProjAgent(STCAlbertaAgent):
        def forward(self, obs):
            z   = self.encoder(obs)
            z_S = z
            P   = torch.eye(self.latent_dim, device=z.device)
            k   = torch.ones(z.shape[0], 1, device=z.device)
            v, V = self.value_op(z_S)
            logits = self.policy(z_S)
            return {'latent': z, 'latent_S': z_S, 'projector': P,
                    'value_matrix': V, 'kappa': k, 'value': v, 'logits': logits}

    env = SimpleControlEnv()
    agent = NoProjAgent(obs_dim=6, action_dim=4, latent_dim=64, intent_rank=16)
    optimizer = torch.optim.Adam(agent.parameters(), lr=3e-4)
    buffer = ReplayBuffer()
    ep_rewards: List[float] = []

    for episode in range(num_episodes):
        obs = env.reset()
        ep_reward = 0.0
        done = False
        while not done:
            action = agent.select_action(torch.FloatTensor(obs), epsilon=0.1)
            next_obs, reward, done, _ = env.step(action)
            buffer.push(obs, action, reward, next_obs, float(done))
            obs = next_obs
            ep_reward += reward

        ep_rewards.append(ep_reward)

        if len(buffer) >= 64:
            batch = buffer.sample(64)
            optimizer.zero_grad()
            losses = stc_loss(agent, **batch, lambda_kappa=0.0)  # ohne κ-Reg.
            losses['total_loss'].backward()
            optimizer.step()
            agent.update_value_target()

    print(f"Final reward (A1, no Π_S, last 30): {np.mean(ep_rewards[-30:]):.2f}")
    return float(np.mean(ep_rewards[-30:]))


In [None]:
# Kurzer Lauf (Colab-freundlich); du kannst num_episodes später erhöhen.
agent, rewards, kappas, dists = train_stc_alberta(
    num_episodes=50,   # klein halten für schnellen Start
    batch_size=64,
    learning_rate=3e-4,
    print_every=10
)

ablation_reward = run_ablation_no_projector(num_episodes=50)

print("\n" + "="*70)
print("RESULTS SUMMARY")
print("="*70)
print(f"STC–Alberta (with Π_S): {np.mean(rewards[-30:]):.2f} (last 30 eps)")
print(f"Ablation A1 (no Π_S):   {ablation_reward:.2f} (last 30 eps)")
print(f"Improvement:            {np.mean(rewards[-30:]) - ablation_reward:.2f}")
print("="*70)


STC–Alberta Agent Training
Environment: SimpleControlEnv (2D)
State dim: 6, Action dim: 4
Agent latent dim: 64, Intent rank: 16

Episode   10 | Reward: -262.87 | κ: 0.356 | Dist: 0.820 | ε: 0.951
             | L_RL: -0.0334 | L_rel: -0.4398 | L_comm: 0.000020 | L_cons: 2.0184 | H: 1.375
Episode   20 | Reward: -255.37 | κ: 0.550 | Dist: 1.188 | ε: 0.905
             | L_RL: 1.5352 | L_rel: -0.6049 | L_comm: 0.000022 | L_cons: 1.6860 | H: 1.357
Episode   30 | Reward: -286.02 | κ: 0.618 | Dist: 1.526 | ε: 0.860
             | L_RL: 0.0220 | L_rel: -0.5992 | L_comm: 0.000023 | L_cons: 1.7142 | H: 1.353
Episode   40 | Reward: -239.62 | κ: 0.653 | Dist: 1.109 | ε: 0.818
             | L_RL: 0.7786 | L_rel: -0.6725 | L_comm: 0.000023 | L_cons: 0.8779 | H: 1.347
Episode   50 | Reward: -237.62 | κ: 0.695 | Dist: 1.132 | ε: 0.778
             | L_RL: 0.0917 | L_rel: -0.7121 | L_comm: 0.000024 | L_cons: 0.7995 | H: 1.342

Training Complete
Final κ (last 30): 0.655
Final reward (last 30): -254.42