In [None]:
"""
Acknowledgements:
1. We referenced the skeleton code of Assignment 7, the original SAC paper (Haarnoja et. al., 2018, https://arxiv.org/abs/1812.05905), 
and the PyTorch implementation of Discor (https://github.com/toshikwa/discor.pytorch) to implement the SAC algorithm.
2. We referenced the original BRO paper (Nauman et. al., 2024, https://arxiv.org/abs/2405.16158), its jax implementation 
(https://github.com/naumix/BiggerRegularizedOptimistic), and its PyTorch implementation (https://github.com/naumix/BiggerRegularizedOtimistic_Torch)
to build upon our SAC implementation and implement the BRO algorithm. 
We also note that the authors' PyTorch implementation of BRO is not coherent with the original paper or its jax implementation, 
but is rather an over-simplified version of the jax implementation.
3. We did NOT directly use the aforementioned open-sourced implementations of SAC and BRO for our implementation,
although they provided us with insights and hints for debugging purposes. 
We implemented the neural network classes and the update function ourself.
"""

In [None]:
pip install --upgrade "gymnasium[mujoco]" torch torchvision torchaudio tqdm matplotlib "gymnasium[other]"

In [None]:
import os, math, random, numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.distributions import Normal
import gymnasium as gym
from tqdm.auto import trange
import matplotlib.pyplot as plt
from pathlib import Path

# -----------------  EXPERIMENT  SWITCH-BOARD  ------------------
LARGE_SCALE   = False
USE_CDQ       = True
REPLAY_RATIO  = 1
DUAL_ACTORS   = False
USE_QUANTILE  = False
USE_WD        = False
HARD_RESET    = False
# ---------------------------------------------------------------

# ------------ reproducibility ------------
SEED = 42
torch.manual_seed(SEED); np.random.seed(SEED); random.seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('is_available:', torch.cuda.is_available())
print('PyTorch built for CUDA', torch.version.cuda)
if torch.cuda.is_available():
    print('runtime driver version', torch.cuda.get_device_properties(0).major,
          '.', torch.cuda.get_device_properties(0).minor, sep='')
    print('device:', torch.cuda.get_device_name(0))

# ------------ environment ---------------
ENV_ID = "Ant-v5"
env    = gym.make(ENV_ID); env.reset(seed=SEED)
OBS_DIM = env.observation_space.shape[0]
ACT_DIM = env.action_space.shape[0]
ACT_MAX = float(env.action_space.high[0])

In [None]:
CRITIC_WID, CRITIC_DEPTH = 512, 3
ACTOR_WID, ACTOR_DEPTH = 256, 2
CRITIC_MLP_WID           = 256
ACTOR_MLP_WID           = 256
N_QUANT           = 100 if USE_QUANTILE else 1
BATCH             = 128
UTD               = REPLAY_RATIO
LR                = 3e-4
TAU               = 0.005
GAMMA             = 0.99
LOG_STD_MIN, LOG_STD_MAX = -5, 2
TARGET_ENT        = -ACT_DIM/2
WD_COEFF          = 1e-4 if USE_WD else 0.0
GRAD_CLIP         = 10.0
RESET_EVERY       = 50_000
TOTAL_STEPS       = 500_000
EVAL_EVERY        = 5_000
START_RANDOM      = 2_500
EVAL_EPISODES     = 50
KAPPA             = 1.0
ALPHA_INIT        = 1.0

def mlp(in_dim, out_dim, hidden=128):
    return nn.Sequential(nn.Linear(in_dim, hidden), nn.ReLU(),
                         nn.Linear(hidden, hidden), nn.ReLU(),
                         nn.Linear(hidden, out_dim))

# residual block used by BRO
class Residual(nn.Module):
    def __init__(self, width: int):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(width, width, bias=True),
            nn.LayerNorm(width),
            nn.ReLU(inplace=True),
            nn.Linear(width, width, bias=True),
            nn.LayerNorm(width)
        )
    def forward(self, x):
        return x + self.block(x)

# large BRO critic with quantile head
class BroCriticLarge(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.inp = nn.Sequential(
            nn.Linear(obs_dim + act_dim, CRITIC_WID),
            nn.LayerNorm(CRITIC_WID),
            nn.ReLU(inplace=True)
        )
        self.body = nn.Sequential(*[Residual(CRITIC_WID) for _ in range(CRITIC_DEPTH - 1)])
        self.out  = nn.Linear(CRITIC_WID, N_QUANT)
        taus = (torch.arange(N_QUANT)+0.5)/N_QUANT
        self.register_buffer("taus", taus.view(1,N_QUANT))
    def forward(self,s,a):
        h = self.inp(torch.cat([s, a], dim=-1))
        h = self.body(h)
        return self.out(h)
    def optimistic(self, s, a, p: float = 0.84):
        q = self.forward(s, a)
        if q.shape[-1] == 1:
            return q
        idx = int(p * q.shape[-1])
        return q[:, idx:idx + 1]
    def mean(self, s, a):
        q = self.forward(s, a)
        return q.mean(-1, keepdim=True)

# small BRO critic with quantile head
class CriticSmall(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.net = mlp(obs_dim+act_dim, N_QUANT, CRITIC_MLP_WID)
        self.register_buffer("taus", torch.ones(1,N_QUANT)/N_QUANT)
    def forward(self,s,a):
        return self.net(torch.cat([s,a],-1))
    def optimistic(self, s, a, p: float = 0.84):
        q = self.forward(s, a)
        if q.shape[-1] == 1:
            return q
        idx = int(p * q.shape[-1])
        return q[:, idx:idx + 1]
    def mean(self, s, a):
        q = self.forward(s, a)
        return q.mean(-1, keepdim=True)

Critic = BroCriticLarge if LARGE_SCALE else CriticSmall

# actor
class BroActorLarge(nn.Module):
    """BroNet style stochastic policy head."""
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.trunk = nn.Sequential(
            nn.Linear(obs_dim, ACTOR_WID),
            nn.LayerNorm(ACTOR_WID),
            nn.ReLU(inplace=True),
            *[Residual(ACTOR_WID) for _ in range(ACTOR_DEPTH-1)]
        )
        self.mu_head   = nn.Linear(ACTOR_WID, act_dim)
        self.log_head  = nn.Linear(ACTOR_WID, act_dim)

    def forward(self, s, det=False, logp=True):
        h  = self.trunk(s)
        mu = self.mu_head(h)
        log_std = torch.clamp(self.log_head(h), LOG_STD_MIN, LOG_STD_MAX)
        std = log_std.exp()
        dist = Normal(mu, std)
        x = mu if det else dist.rsample()
        y = torch.tanh(x)
        lp = None
        if logp:
            lp = dist.log_prob(x).sum(-1,keepdim=True)
            lp -= (2*(math.log(2)-x-F.softplus(-2*x))).sum(-1,keepdim=True)
        return y, lp

class ActorSmall(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.net = mlp(obs_dim, act_dim*2, ACTOR_MLP_WID)
    def forward(self, s, det=False, logp=True):
        mu, log_std = self.net(s).chunk(2,-1)
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        std = log_std.exp()
        dist = Normal(mu,std)
        x = mu if det else dist.rsample()
        y = torch.tanh(x)
        lp = None
        if logp:
            lp = dist.log_prob(x).sum(-1,keepdim=True)
            lp -= (2*(math.log(2)-x-F.softplus(-2*x))).sum(-1,keepdim=True)
        return y, lp

Actor = BroActorLarge if LARGE_SCALE else ActorSmall

# QR‑Huber loss
def qhuber(pred,target,taus):
    if N_QUANT==1:
        return F.mse_loss(pred, target)
    u = target.unsqueeze(2) - pred.unsqueeze(1)
    hub = torch.where(u.abs()<=KAPPA, 0.5*u**2, KAPPA*(u.abs()-0.5*KAPPA))
    loss = (taus.unsqueeze(0) - (u<0).float()).abs()*hub/KAPPA
    return loss.mean()

In [None]:
from collections import deque, namedtuple
Transition = namedtuple("T", "s a r s2 d")

class Replay:
    def __init__(self, cap=int(2e6)):
        self.buf = deque(maxlen=cap)
    def add(self,*args): self.buf.append(Transition(*args))
    def sample(self, n):
        b = random.sample(self.buf, n)
        b = Transition(*zip(*b))
        to = lambda x: torch.as_tensor(x, dtype=torch.float32, device=device)
        return (to(np.stack(b.s)),
                to(np.stack(b.a)),
                to(b.r).unsqueeze(1),
                to(np.stack(b.s2)),
                to(b.d).unsqueeze(1))
    def __len__(self): return len(self.buf)

class BRO(nn.Module):
    def __init__(self):
        super().__init__()
        # ----- actors -----
        if DUAL_ACTORS:
            self.pi_p = Actor(OBS_DIM, ACT_DIM).to(device)
            self.pi_o = Actor(OBS_DIM, ACT_DIM).to(device)
        else:
            self.pi = Actor(OBS_DIM, ACT_DIM).to(device)

        # ----- critics ----
        self.q1 = Critic(OBS_DIM, ACT_DIM).to(device)
        self.q2 = Critic(OBS_DIM, ACT_DIM).to(device)
        self.t1 = Critic(OBS_DIM, ACT_DIM).to(device); self.t1.load_state_dict(self.q1.state_dict())
        self.t2 = Critic(OBS_DIM, ACT_DIM).to(device); self.t2.load_state_dict(self.q2.state_dict())
        for p in (*self.t1.parameters(), *self.t2.parameters()): p.requires_grad=False

        # ----- optimizers -----
        opt_cls = torch.optim.Adam
        self.opt_q = opt_cls([*self.q1.parameters(), *self.q2.parameters()], lr=LR, weight_decay=WD_COEFF)
        if DUAL_ACTORS:
            self.opt_p = opt_cls(self.pi_p.parameters(), lr=LR, weight_decay=WD_COEFF)
            self.opt_o = opt_cls(self.pi_o.parameters(), lr=LR, weight_decay=WD_COEFF)
        else:
            self.opt_pi = opt_cls(self.pi.parameters(), lr=LR, weight_decay=WD_COEFF)

        self.log_alpha = torch.tensor(math.log(ALPHA_INIT), requires_grad=True, device=device)
        self.opt_alpha = opt_cls([self.log_alpha], lr=LR)

        self.act_max = ACT_MAX
        self.step_ctr= 0

    # ---------- act ----------
    @torch.no_grad()
    def act(self,s,eval=False):
        if DUAL_ACTORS:
            # for evaluation, use the pessimistic actor
            if eval:
                a,_ = self.pi_p(torch.as_tensor(s,dtype=torch.float32,device=device).unsqueeze(0),
                                det=eval, logp=False)
            # for optimistic interaction with environment during training
            else:
                a,_ = self.pi_o(torch.as_tensor(s,dtype=torch.float32,device=device).unsqueeze(0),
                                det=eval, logp=False)
        else:
            a,_ = self.pi(torch.as_tensor(s,dtype=torch.float32,device=device).unsqueeze(0),
                          det=eval, logp=False)
        return (a*self.act_max).squeeze(0).cpu().numpy()

    # ---------- update ----------
    def update(self, batch):
        s,a,r,s2,d = batch
        alpha = self.log_alpha.exp()

        # ---- critic target ----
        with torch.no_grad():
            if DUAL_ACTORS:
                # Calculate critic target value using pessimistic actor
                a2, lp2 = self.pi_p(s2)
            else:
                a2, lp2 = self.pi(s2)
            q1_t = self.t1.optimistic(s2,a2)
            q2_t = self.t2.optimistic(s2,a2)
            q_t  = torch.min(q1_t,q2_t) if USE_CDQ else 0.5*(q1_t+q2_t)
            tgt  = r + GAMMA*(1-d)*(q_t - alpha*lp2)

        # ---- critic update ----
        q1, q2 = self.q1(s,a), self.q2(s,a)
        loss_q = qhuber(q1,tgt,self.q1.taus) + qhuber(q2,tgt,self.q2.taus)
        self.opt_q.zero_grad(); loss_q.backward()
        nn.utils.clip_grad_norm_(self.q1.parameters(), GRAD_CLIP)
        nn.utils.clip_grad_norm_(self.q2.parameters(), GRAD_CLIP)
        self.opt_q.step()

        # ---- actor(s) update ----
        if DUAL_ACTORS:
            # pessimistic
            a_p , lp_p = self.pi_p(s)
            q1_p, q2_p = self.q1.mean(s,a_p), self.q2.mean(s,a_p)
            q_p = torch.min(q1_p,q2_p) if USE_CDQ else 0.5*(q1_p+q2_p)
            loss_p = (alpha*lp_p - q_p).mean()
            self.opt_p.zero_grad(); loss_p.backward(retain_graph=True); self.opt_p.step()

            # optimistic
            a_o , lp_o = self.pi_o(s)
            q1_o, q2_o = self.q1.optimistic(s,a_o), self.q2.optimistic(s,a_o)
            q_o = torch.min(q1_o,q2_o) if USE_CDQ else 0.5*(q1_o+q2_o)
            loss_o = (alpha*lp_o - q_o).mean()
            self.opt_o.zero_grad(); loss_o.backward(); self.opt_o.step()

            ent = -lp_p.mean()
        else:
            a_pi, lp = self.pi(s)
            q1_pi, q2_pi = self.q1.optimistic(s,a_pi), self.q2.optimistic(s,a_pi)
            q_pi = torch.min(q1_pi,q2_pi) if USE_CDQ else 0.5*(q1_pi+q2_pi)
            loss_pi = (alpha*lp - q_pi).mean()
            self.opt_pi.zero_grad(); loss_pi.backward(); self.opt_pi.step()

            ent = -lp.mean()

        # ---- temperature ----
        alpha_loss = (self.log_alpha.exp() * (ent - TARGET_ENT).detach()).mean()
        self.opt_alpha.zero_grad(); alpha_loss.backward(); self.opt_alpha.step()

        # ---- target polyak ----
        with torch.no_grad():
            for p,pt in zip(self.q1.parameters(), self.t1.parameters()):
                pt.data.mul_(1-TAU).add_(TAU*p.data)
            for p,pt in zip(self.q2.parameters(), self.t2.parameters()):
                pt.data.mul_(1-TAU).add_(TAU*p.data)

        # ---- periodic hard reset ----
        self.step_ctr += 1
        if HARD_RESET and self.step_ctr % RESET_EVERY == 0:
            self.t1.load_state_dict(self.q1.state_dict())
            self.t2.load_state_dict(self.q2.state_dict())

In [None]:
agent  = BRO()
replay = Replay()
o, _   = env.reset(seed=SEED)
returns, steps, alpha = [], [], []

for t in trange(TOTAL_STEPS, ncols=80):
    a = env.action_space.sample() if t < START_RANDOM else agent.act(o)
    o2, r, term, trunc, _ = env.step(a)
    replay.add(o, a, r, o2, term or trunc)
    o = o2 if not (term or trunc) else env.reset()[0]

    if len(replay) >= START_RANDOM:
        for _ in range(UTD):
            agent.update(replay.sample(BATCH))

    # ---------- quick eval ----------
    if (t+1) % EVAL_EVERY == 0:
        with torch.no_grad():
            alpha.append(agent.log_alpha.exp().item())
            eval_env = gym.make(ENV_ID)
            R = 0.0
            for _ in range(EVAL_EPISODES):
                s,_ = eval_env.reset(seed=SEED)
                ep_r = 0.0
                for _ in range(eval_env.spec.max_episode_steps):
                    s,r,done,trunc,_ = eval_env.step(agent.act(s, eval=True))
                    ep_r += r
                    if done or trunc: break
                R += ep_r
            eval_env.close()
        mean_R = R / EVAL_EPISODES
        returns.append(mean_R); steps.append(t+1)
        print(f"step {t+1}: mean eval return {mean_R:.2f}, alpha {alpha[-1]:.2f}")


In [None]:
import matplotlib.pyplot as plt

fig, ax1 = plt.subplots(figsize=(6,4))

# reward curve
line1 = ax1.plot(
    steps, returns,
    marker='o', lw=2, label=r"mean return (50‑episode eval)"
)
ax1.set_xlabel(r"environment steps")
ax1.set_ylabel(r"mean episodic return", color=line1[0].get_color())
ax1.tick_params(axis='y', labelcolor=line1[0].get_color())
ax1.grid(alpha=.3)

# temperature curve
ax2 = ax1.twinx()
line2 = ax2.plot(
    steps, alpha,
    marker='s', ls='--', color='tab:orange', label=r"$\alpha$ (entropy temperature)"
)

# tighten alpha‑axis limits so its variations are visible
lo, hi = min(alpha), max(alpha)
margin = 0.1 * (hi - lo if hi > lo else 1)
ax2.set_ylim(lo - margin, hi + margin)

ax2.set_ylabel(r"entropy temperature  $\alpha$", color=line2[0].get_color())
ax2.tick_params(axis='y', labelcolor=line2[0].get_color())

# legend and title
lines = line1 + line2
labels = [l.get_label() for l in lines]
ax1.legend(lines, labels, loc="right")
plt.title(f"BRO (Fast) on {ENV_ID}")
plt.tight_layout()
plt.show()

# save returns as a csv file at the current folder
path = Path.cwd() / f"returns_{ENV_ID}.csv"
np.savetxt(path, np.array(returns), delimiter=",", header="mean episodic return", comments="")
# print the path to the saved file
print(f"Returns saved to {path}")



In [None]:
from gymnasium.wrappers import RecordVideo
env = gym.make(ENV_ID, render_mode="rgb_array")      # no “human” window
vid_env = RecordVideo(env, video_folder="videos",
                      episode_trigger=lambda ep: ep == 0,
                      disable_logger=True)
s,_ = vid_env.reset(seed=SEED)
done = trunc = False
while not (done or trunc):
    s,_,done,trunc,_ = vid_env.step(agent.act(s, eval=True))
vid_env.close()

from IPython.display import Video, HTML
Video(sorted(Path("videos").glob("*.mp4"))[0].as_posix(), embed=True)
