In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from dm_control import suite
from collections import deque
import random
import time
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ---------------------- Utilities ----------------------
class RunningMeanStd:
    """Online running mean/std for observation normalization."""
    def __init__(self, shape, eps=1e-4, clip=10.0):
        self.mean = np.zeros(shape, dtype=np.float64)
        self.var = np.ones(shape, dtype=np.float64)
        self.count = eps
        self.clip = clip

    def update(self, x: np.ndarray):
        # x is (N, dim) or (dim,)
        x = np.asarray(x)
        if x.ndim == 1:
            x = x[None, :]
        batch_mean = np.mean(x, axis=0)
        batch_var = np.var(x, axis=0)
        batch_count = x.shape[0]
        self._update_from_moments(batch_mean, batch_var, batch_count)

    def _update_from_moments(self, batch_mean, batch_var, batch_count):
        # Welford-like update
        delta = batch_mean - self.mean
        tot_count = self.count + batch_count

        new_mean = self.mean + delta * batch_count / tot_count
        m_a = self.var * (self.count)
        m_b = batch_var * (batch_count)
        M2 = m_a + m_b + (delta ** 2) * self.count * batch_count / tot_count
        new_var = M2 / (tot_count)

        self.mean = new_mean
        self.var = new_var
        self.count = tot_count

    def normalize(self, x: np.ndarray):
        return np.clip((x - self.mean) / (np.sqrt(self.var) + 1e-8), -self.clip, self.clip)

def flatten_obs(obs):
    """Flatten observation dict to single array."""
    if isinstance(obs, dict):
        parts = [np.asarray(v).ravel() for v in obs.values()]
        return np.concatenate(parts).astype(np.float32)
    else:
        return np.asarray(obs).astype(np.float32).ravel()

# ---------------------- Replay Buffer ----------------------
class ReplayBuffer:
    """Efficient replay buffer with numpy storage."""
    def __init__(self, size=200000):
        self.max_size = size
        self.ptr = 0
        self.size = 0
        
        # Pre-allocate arrays (filled on first push)
        self.states = None
        self.actions = None
        self.rewards = None
        self.next_states = None
        self.dones = None

    def push(self, s, a, r, s2, d):
        """Store transition."""
        # Initialize arrays on first push
        if self.states is None:
            self.states = np.zeros((self.max_size, len(s)), dtype=np.float32)
            self.actions = np.zeros((self.max_size, len(a)), dtype=np.float32)
            self.rewards = np.zeros(self.max_size, dtype=np.float32)
            self.next_states = np.zeros((self.max_size, len(s2)), dtype=np.float32)
            self.dones = np.zeros(self.max_size, dtype=np.float32)
        
        self.states[self.ptr] = s
        self.actions[self.ptr] = a
        self.rewards[self.ptr] = r
        self.next_states[self.ptr] = s2
        self.dones[self.ptr] = d
        
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample(self, batch_size):
        """Sample batch and convert to GPU tensors once."""
        idx = np.random.randint(0, self.size, size=batch_size)
        
        return (
            torch.from_numpy(self.states[idx]).to(device),
            torch.from_numpy(self.actions[idx]).to(device),
            torch.from_numpy(self.rewards[idx]).to(device).unsqueeze(-1),
            torch.from_numpy(self.next_states[idx]).to(device),
            torch.from_numpy(self.dones[idx]).to(device).unsqueeze(-1)
        )

    def __len__(self):
        return self.size

# ---------------------- Networks ----------------------
def orthogonal_init(m):
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
        nn.init.constant_(m.bias, 0.0)

class Actor(nn.Module):
    """Gaussian policy network."""
    def __init__(self, obs_dim, act_dim, hidden_dim=256, log_std_min=-5, log_std_max=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim), 
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), 
            nn.ReLU(),
        )
        self.mu = nn.Linear(hidden_dim, act_dim)
        self.log_std = nn.Linear(hidden_dim, act_dim)
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max
        self.apply(orthogonal_init)

    def forward(self, obs):
        h = self.net(obs)
        mu = self.mu(h)
        log_std = torch.tanh(self.log_std(h))  # bound to (-1,1)
        # scale to range
        log_std = self.log_std_min + 0.5 * (log_std + 1.0) * (self.log_std_max - self.log_std_min)
        std = torch.exp(log_std)
        return mu, std, log_std

    def sample(self, obs):
        mu, std, log_std = self.forward(obs)
        eps = torch.randn_like(mu)
        a = mu + eps * std
        var = std ** 2
        log_prob = -0.5 * (((a - mu) ** 2) / var + 2 * log_std + math.log(2 * math.pi))
        log_prob = log_prob.sum(-1, keepdim=True)
        return a, log_prob

    def log_prob(self, obs, a):
        mu, std, log_std = self.forward(obs)
        var = std ** 2
        log_prob = -0.5 * (((a - mu) ** 2) / var + 2 * log_std + math.log(2 * math.pi))
        return log_prob.sum(-1, keepdim=True)

    def kl(self, obs, other_mu, other_std):
        mu, std, _ = self.forward(obs)
        var = std ** 2
        other_var = other_std ** 2
        kl_div = (torch.log(std / other_std) + (other_var + (other_mu - mu) ** 2) / (2 * var) - 0.5)
        return kl_div.sum(-1, keepdim=True)

class Critic(nn.Module):
    """Q-function network."""
    def __init__(self, obs_dim, act_dim, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim + act_dim, hidden_dim), 
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), 
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        self.apply(orthogonal_init)
    
    def forward(self, s, a):
        return self.net(torch.cat([s, a], dim=-1))

# ---------------------- PMPO Trainer (stabilized) ----------------------
class PMPO:
    """Preference Matching Policy Optimization with twin critics and stability fixes."""
    def __init__(self, obs_dim, act_dim, 
                 actor_lr=1e-5, critic_lr=1e-4,
                 gamma=0.99, tau=0.995,
                 alpha=0.5, beta=1.5,
                 grad_clip=1.0,
                 target_update_freq=250,
                 critic_warmup_steps=5000,
                 actor_update_every=2,
                 q_clip=1e6):
        
        # Hyperparameters
        self.gamma = gamma
        self.tau = tau  # used only for soft updates if desired
        self.alpha = alpha
        self.beta = beta
        self.grad_clip = grad_clip
        self.target_update_freq = target_update_freq
        self.critic_warmup_steps = critic_warmup_steps
        self.actor_update_every = actor_update_every
        self.q_clip = q_clip

        # Networks: actor + reference actor
        self.actor = Actor(obs_dim, act_dim).to(device)
        self.ref_actor = Actor(obs_dim, act_dim).to(device)
        self.ref_actor.load_state_dict(self.actor.state_dict())
        self.ref_actor.eval()

        # Twin critics (Q1, Q2) and twin targets
        self.critic1 = Critic(obs_dim, act_dim).to(device)
        self.critic2 = Critic(obs_dim, act_dim).to(device)
        self.target_critic1 = Critic(obs_dim, act_dim).to(device)
        self.target_critic2 = Critic(obs_dim, act_dim).to(device)
        self.target_critic1.load_state_dict(self.critic1.state_dict())
        self.target_critic2.load_state_dict(self.critic2.state_dict())

        # Optimizers
        self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic1_opt = torch.optim.Adam(self.critic1.parameters(), lr=critic_lr)
        self.critic2_opt = torch.optim.Adam(self.critic2.parameters(), lr=critic_lr)

        # Counters
        self._total_steps = 0
        self._critic_updates = 0

    def update_critic(self, batch):
        """Update twin critics with min-target to avoid overestimation."""
        s, a, r, s2, d = batch

        # clamp rewards moderately (optional) to keep scale manageable
        r = torch.clamp(r, -1e3, 1e3)

        with torch.no_grad():
            # next actions from current actor
            a2, _ = self.actor.sample(s2)
            q1_next = self.target_critic1(s2, a2)
            q2_next = self.target_critic2(s2, a2)
            q_next = torch.min(q1_next, q2_next)

            y = r + self.gamma * (1.0 - d) * q_next
            y = torch.clamp(y, -self.q_clip, self.q_clip)

        # Critic 1
        q1 = self.critic1(s, a)
        loss1 = F.mse_loss(q1, y)
        self.critic1_opt.zero_grad()
        loss1.backward()
        torch.nn.utils.clip_grad_norm_(self.critic1.parameters(), self.grad_clip)
        self.critic1_opt.step()

        # Critic 2
        q2 = self.critic2(s, a)
        loss2 = F.mse_loss(q2, y)
        self.critic2_opt.zero_grad()
        loss2.backward()
        torch.nn.utils.clip_grad_norm_(self.critic2.parameters(), self.grad_clip)
        self.critic2_opt.step()

        # Hard target sync periodically (more stable than immediate polyak if tuned)
        self._critic_updates += 1
        if (self._critic_updates % self.target_update_freq) == 0:
            self.target_critic1.load_state_dict(self.critic1.state_dict())
            self.target_critic2.load_state_dict(self.critic2.state_dict())

        return (loss1.item(), loss2.item())

    def update_actor(self, s, K=16):
        """Update policy with preference-based objective (PMPO)."""
        B = s.size(0)

        # Sample K actions per state from reference policy (stay on device)
        s_rep = s[:, None, :].expand(B, K, -1).reshape(B * K, -1)
        
        with torch.no_grad():
            mu_ref, std_ref, _ = self.ref_actor.forward(s_rep)
            eps = torch.randn_like(mu_ref)
            acts = mu_ref + eps * std_ref
            qs = self.critic1(s_rep, acts).reshape(B, K)  # could use avg of critics, using critic1 for ranking is fine

        # Rank actions
        idx = torch.argsort(qs, dim=1)  # ascending
        top2_idx = idx[:, -2:]  # highest Q-values
        bot2_idx = idx[:, :2]   # lowest Q-values

        # Advanced indexing to gather actions/states (no Python loops)
        batch_idx = torch.arange(B, device=device)[:, None].expand(B, 2)
        top_flat = (batch_idx * K + top2_idx).reshape(-1)
        bot_flat = (batch_idx * K + bot2_idx).reshape(-1)

        pref_s = s.repeat_interleave(2, dim=0)
        pref_a = acts[top_flat]
        rej_s = s.repeat_interleave(2, dim=0)
        rej_a = acts[bot_flat]

        # Compute PMPO terms
        logp_pref = self.actor.log_prob(pref_s, pref_a).mean()
        logp_rej = self.actor.log_prob(rej_s, rej_a).mean()

        with torch.no_grad():
            mu_r, std_r, _ = self.ref_actor.forward(s)
        kl = self.actor.kl(s, mu_r, std_r).mean()

        # PMPO objective: maximize J -> minimize -J
        J = self.alpha * logp_pref - (1.0 - self.alpha) * logp_rej - self.beta * kl
        loss = -J

        self.actor_opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.grad_clip)
        self.actor_opt.step()

        return loss.item(), kl.item()

    def update_reference(self):
        """Update reference policy to current policy."""
        self.ref_actor.load_state_dict(self.actor.state_dict())

    @torch.no_grad()
    def act(self, obs, deterministic=False):
        """Select action (for evaluation)."""
        obs_t = torch.from_numpy(obs).float().to(device).unsqueeze(0)
        if deterministic:
            mu, _, _ = self.actor.forward(obs_t)
            return mu.cpu().numpy()[0]
        else:
            a, _ = self.actor.sample(obs_t)
            return a.cpu().numpy()[0]

# ---------------------- Training Loop ----------------------
def train(domain="cheetah", task="run", 
          total_steps=500000,
          batch_size=256,
          start_steps=1000,
          ref_update_freq=1000,
          eval_freq=5000,
          eval_episodes=5,
          seed=0):
    """
    Main training loop with stabilization changes.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    env = suite.load(domain, task)
    ts = env.reset()
    s = flatten_obs(ts.observation)

    # obs normalization
    obs_rms = RunningMeanStd(len(s))

    obs_dim = s.size
    act_dim = env.action_spec().shape[0]
    
    print(f"Environment: {domain}-{task}")
    print(f"Obs dim: {obs_dim}, Act dim: {act_dim}")
    
    # Initialize
    rb = ReplayBuffer(size=200000)
    algo = PMPO(obs_dim, act_dim,
                actor_lr=1e-5,    # smaller actor lr
                critic_lr=1e-4,   # smaller critic lr
                gamma=0.99,
                tau=0.995,
                alpha=0.5,
                beta=1.5,
                grad_clip=1.0,
                target_update_freq=250,
                critic_warmup_steps=5000,
                actor_update_every=2,
                q_clip=1e5)

    # Action bounds
    act_low = env.action_spec().minimum
    act_high = env.action_spec().maximum
    
    # Metrics
    episode_reward = 0.0
    episode_step = 0
    episode_count = 0
    start_time = time.time()
    last_eval = 0

    for t in range(1, total_steps + 1):
        # Select action
        if len(rb) < start_steps:
            a = np.random.uniform(act_low, act_high, size=act_dim).astype(np.float32)
        else:
            # normalize obs
            obs_norm = obs_rms.normalize(s)
            a = algo.act(obs_norm, deterministic=False)
            a = np.clip(a, act_low, act_high)

        # Environment step
        ts = env.step(a)
        s2 = flatten_obs(ts.observation)
        k = ts.reward
        if k is None:
            k = 0.0
        r = float(k)
        d = float(ts.last())
        
        episode_reward += r
        episode_step += 1
        
        # Store transition (store raw obs, normalize when sampling/training)
        rb.push(s, a, r, s2, d)

        # Update running obs stats (do it on CPU numpy)
        obs_rms.update(s2)

        # Reset handling
        if d:
            episode_count += 1
            if episode_count % 10 == 0:
                elapsed = time.time() - start_time
                print(f"Step {t:6d} | Episode {episode_count:4d} | "
                      f"Reward: {episode_reward:8.2f} | Length: {episode_step:4d} | Time: {elapsed:.1f}s")
            
            ts = env.reset()
            s = flatten_obs(ts.observation)
            episode_reward = 0.0
            episode_step = 0
        else:
            s = s2

        # Training updates
        if len(rb) >= start_steps:
            batch = rb.sample(batch_size)

            # BEFORE feeding to networks, normalize states and next_states in-tensor
            bs, ba, br, bs2, bd = batch
            # Convert to numpy to normalize with obs_rms then back (we avoid CPU-GPU ping-pong by normalizing on CPU once per batch)
            bs_np = bs.cpu().numpy()
            bs2_np = bs2.cpu().numpy()
            bs_norm = torch.from_numpy(np.clip((bs_np - obs_rms.mean) / (np.sqrt(obs_rms.var) + 1e-8), -obs_rms.clip, obs_rms.clip)).float().to(device)
            bs2_norm = torch.from_numpy(np.clip((bs2_np - obs_rms.mean) / (np.sqrt(obs_rms.var) + 1e-8), -obs_rms.clip, obs_rms.clip)).float().to(device)

            # Replace states in batch with normalized tensors for training
            train_batch = (bs_norm, ba, br, bs2_norm, bd)

            # Update critics (twin)
            loss1, loss2 = algo.update_critic(train_batch)
            critic_loss = (loss1 + loss2) / 2.0

            # Actor update schedule & warmup
            if (t > algo.critic_warmup_steps) and (t % algo.actor_update_every == 0):
                # pass normalized states
                actor_loss, kl = algo.update_actor(bs_norm, K=16)
            else:
                actor_loss, kl = float('nan'), float('nan')

            # Periodic status print
            if t % 1000 == 0:
                elapsed = time.time() - start_time
                print(f"[{t:6d}] Critic Loss1: {loss1:.4e} | Critic Loss2: {loss2:.4e} | "
                      f"Actor Loss: {actor_loss:.4f} | KL: {kl:.4f} | Time: {elapsed:.1f}s")

        # Update reference policy
        if t % ref_update_freq == 0 and t > 0:
            algo.update_reference()
            print(f"[{t:6d}] Updated reference policy")

        # Evaluation
        if t % eval_freq == 0 and t > 0 and (t - last_eval) >= eval_freq:
            eval_reward = evaluate(env, algo, eval_episodes, act_low, act_high, obs_rms)
            print(f"[{t:6d}] ===== EVAL: {eval_reward:.2f} =====")
            last_eval = t

    print("\nTraining finished!")
    return algo

def evaluate(env, algo, num_episodes, act_low, act_high, obs_rms):
    """Evaluate policy deterministically (uses obs normalization)."""
    total_reward = 0.0
    
    for _ in range(num_episodes):
        ts = env.reset()
        s = flatten_obs(ts.observation)
        ep_reward = 0.0
        
        while not ts.last():
            s_norm = obs_rms.normalize(s)
            a = algo.act(s_norm, deterministic=True)
            a = np.clip(a, act_low, act_high)
            ts = env.step(a)
            s = flatten_obs(ts.observation)
            r = ts.reward
            if r is None:
                r = 0.0
            ep_reward += r
        
        total_reward += ep_reward
    
    return total_reward / num_episodes

if __name__ == "__main__":
    trained_agent = train(
        domain="cartpole",
        task="swingup",
        total_steps=500000,
        batch_size=256,
        start_steps=1000,
        ref_update_freq=1000,
        eval_freq=5000,
        eval_episodes=5,
        seed=0
    )


/home/naman/.venv/lib/python3.10/site-packages/glfw/__init__.py:917: GLFWError: (65550) b'X11: The DISPLAY environment variable is missing'


Using device: cuda
Environment: cartpole-swingup
Obs dim: 5, Act dim: 1
[  1000] Critic Loss1: 9.4578e-03 | Critic Loss2: 1.5406e-01 | Actor Loss: nan | KL: nan | Time: 4.5s
[  1000] Updated reference policy
[  2000] Critic Loss1: 4.0722e-02 | Critic Loss2: 4.0519e-02 | Actor Loss: nan | KL: nan | Time: 43.4s
[  2000] Updated reference policy
[  3000] Critic Loss1: 2.3887e-02 | Critic Loss2: 2.3661e-02 | Actor Loss: nan | KL: nan | Time: 81.9s
[  3000] Updated reference policy
[  4000] Critic Loss1: 1.8952e-02 | Critic Loss2: 1.8933e-02 | Actor Loss: nan | KL: nan | Time: 120.4s
[  4000] Updated reference policy
[  5000] Critic Loss1: 1.2607e-02 | Critic Loss2: 1.2739e-02 | Actor Loss: nan | KL: nan | Time: 158.9s
[  5000] Updated reference policy
[  5000] ===== EVAL: 12.25 =====
[  6000] Critic Loss1: 1.0848e-02 | Critic Loss2: 1.1305e-02 | Actor Loss: -0.0910 | KL: 0.9486 | Time: 237.8s
[  6000] Updated reference policy
[  7000] Critic Loss1: 1.4104e-02 | Critic Loss2: 1.4927e-02 | A