<a href="https://colab.research.google.com/github/aslestia/ACS_2025/blob/main/ACS_Week06.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Week 6 Lab — **QR-DQN** (Focus) : Mean vs Quantile Decisions

This lab demonstrates **Quantile Regression DQN (QR-DQN)** on a **simple simulated environment**.
We compare decisions based on **mean** of predicted quantiles vs **selected quantile** (e.g., median or 75th).

**What you'll do:**
1. Generate a synthetic environment (choose: `claims` or `stocks`).
2. Implement QR-DQN (PyTorch).
3. Train the agent to predict the return distribution.
4. Compare decisions using **mean** vs **quantile** targets.
5. Visualize predicted quantiles per action.


## 1) Setup

In [None]:

import math, random, collections, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from dataclasses import dataclass
import matplotlib.pyplot as plt

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device


## 2) Synthetic Environment (claims or stocks)

In [None]:

@dataclass
class Config:
    mode: str = 'claims'   # 'claims' or 'stocks'
    window: int = 8        # state window length
    n_actions: int = 3     # e.g., retention levels or allocation choices
    episode_len: int = 64
    gamma: float = 0.99

cfg = Config()
print(cfg)



### Environment dynamics
- **claims mode**: sample heavy-tailed losses (Pareto/Lognormal mix). Reward is a simple underwriting-like return:
  \( r_t = \text{premium} - \text{retained\_loss} \).
  Actions set retention levels in \{0.3, 0.6, 0.9\}.  
- **stocks mode**: sample returns from a mean-reverting + stochastic process; actions represent discrete allocation \( w\in\{0.2,0.5,0.8\} \) with reward \( r_t = w \cdot R_t \).
State is a rolling window of last observations (loss or return).


In [None]:

class SimpleSimEnv:
    def __init__(self, cfg: Config):
        self.cfg = cfg
        self.t = 0
        self.window = cfg.window
        self.series = None
        self.reset()

    def _gen_series_claims(self, T):
        # mixture: with prob 0.5 Lognormal, else Pareto (heavy tail)
        ln = np.random.lognormal(mean=2.5, sigma=1.0, size=T)  # moderate body
        pa = (np.random.pareto(a=2.0, size=T)+1.0) * 50.0     # heavy tail
        u = np.random.rand(T)
        losses = np.where(u < 0.5, ln, pa)
        return losses  # positive values

    def _gen_series_stocks(self, T):
        # Simple mean-reverting + noise (not GBM, to keep it minimal)
        T0 = T + 100
        x = np.zeros(T0)
        rho = 0.9
        for t in range(1, T0):
            x[t] = rho*x[t-1] + 0.05*np.random.randn()
        ret = x[100:] + 0.2*np.random.randn(T)  # add noise
        return ret  # can be +/-

    def reset(self):
        T = 4096
        if self.cfg.mode == 'claims':
            self.series = self._gen_series_claims(T)
        else:
            self.series = self._gen_series_stocks(T)
        # start index ensuring we have a window
        self.t = self.window
        return self._get_state()

    def _get_state(self):
        window = self.series[self.t-self.window:self.t]
        # normalize for stability
        m = np.mean(window) + 1e-8
        s = np.std(window) + 1e-8
        return ((window - m) / s).astype(np.float32)

    def step(self, action:int):
        # action to retention/weight
        if self.cfg.n_actions == 3:
            levels = [0.3, 0.6, 0.9]
        else:
            levels = np.linspace(0.2, 0.8, self.cfg.n_actions)
        level = levels[action]

        x = self.series[self.t]
        if self.cfg.mode == 'claims':
            premium = 50.0
            retained_loss = level * x
            r = premium - retained_loss
        else:  # stocks
            r = level * x

        self.t += 1
        done = (self.t >= len(self.series)) or ((self.t % self.cfg.episode_len) == 0)
        next_state = self._get_state()
        return next_state, float(r), done, {}

env = SimpleSimEnv(cfg)
state = env.reset()
print("state shape:", state.shape)


## 3) Replay Buffer

In [None]:

Transition = collections.namedtuple('Transition', ('s','a','r','ns','d'))

class ReplayBuffer:
    def __init__(self, capacity=100_000):
        self.buf = collections.deque(maxlen=capacity)
    def push(self, *args):
        self.buf.append(Transition(*args))
    def sample(self, batch_size):
        batch = random.sample(self.buf, batch_size)
        s = torch.tensor(np.stack([t.s for t in batch]), dtype=torch.float32, device=device)
        a = torch.tensor([t.a for t in batch], dtype=torch.long, device=device)
        r = torch.tensor([t.r for t in batch], dtype=torch.float32, device=device)
        ns = torch.tensor(np.stack([t.ns for t in batch]), dtype=torch.float32, device=device)
        d = torch.tensor([t.d for t in batch], dtype=torch.float32, device=device)
        return s,a,r,ns,d
    def __len__(self):
        return len(self.buf)

buffer = ReplayBuffer()


## 4) QR-DQN Network & Quantile Loss

In [None]:

class QRQNetwork(nn.Module):
    def __init__(self, state_dim, n_actions, n_quantiles):
        super().__init__()
        hidden = 128
        self.n_actions = n_actions
        self.n_quantiles = n_quantiles
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
            nn.Linear(hidden, n_actions * n_quantiles)
        )
    def forward(self, x):
        batch = x.size(0)
        out = self.net(x).view(batch, self.n_actions, self.n_quantiles)
        return out

def huber_quantile_loss(pred, target, taus):
    # pred, target: [B, Nq], taus: [Nq]
    # compute pairwise td errors between each target and each pred quantile
    # We'll follow standard QR-DQN loss (pairwise differences)
    B, Nq = pred.shape
    with torch.no_grad():
        t = target.unsqueeze(1).repeat(1, Nq)  # [B, Nq]
    u = t - pred  # TD error
    # Huber
    kappa = 1.0
    abs_u = torch.abs(u)
    huber = torch.where(abs_u <= kappa, 0.5*u*u, kappa*(abs_u - 0.5*kappa))
    # quantile weight
    taus = taus.view(1, Nq).to(pred.device)
    loss = torch.abs(taus - (u.detach() < 0).float()) * huber / kappa
    return loss.mean()


## 5) Agent, Training Loop

In [None]:

class QRDQNAgent:
    def __init__(self, state_dim, n_actions, n_quantiles=51, gamma=0.99, lr=1e-3):
        self.n_actions = n_actions
        self.n_quantiles = n_quantiles
        self.gamma = gamma

        self.taus = torch.linspace(0.0 + 1/(2*n_quantiles), 1.0 - 1/(2*n_quantiles), n_quantiles, device=device)
        self.online = QRQNetwork(state_dim, n_actions, n_quantiles).to(device)
        self.target = QRQNetwork(state_dim, n_actions, n_quantiles).to(device)
        self.target.load_state_dict(self.online.state_dict())
        self.opt = optim.Adam(self.online.parameters(), lr=lr)

    def act(self, s, eps=0.05, mode='mean'):
        if random.random() < eps:
            return random.randrange(self.n_actions)
        s = torch.tensor(s, dtype=torch.float32, device=device).unsqueeze(0)
        with torch.no_grad():
            q = self.online(s)[0]  # [A, Nq]
            if mode == 'mean':
                vals = q.mean(dim=1)  # mean over quantiles
            else:
                # mode is a tau float, pick nearest quantile
                if isinstance(mode, float):
                    idx = int(round(mode * (self.n_quantiles-1)))
                else:
                    idx = self.n_quantiles//2
                vals = q[:, idx]
            a = int(torch.argmax(vals).item())
        return a

    def update(self, batch, double=True):
        s,a,r,ns,d = batch
        B = s.size(0)
        q_pred = self.online(s)                          # [B, A, Nq]
        q_pred_a = q_pred.gather(1, a.view(B,1,1).expand(-1,-1,self.n_quantiles)).squeeze(1)  # [B, Nq]

        with torch.no_grad():
            # action selection: mean of quantiles from online net (double DQN)
            q_next_online = self.online(ns).mean(dim=2)      # [B, A]
            a_star = torch.argmax(q_next_online, dim=1)      # [B]
            q_next_tgt = self.target(ns)                     # [B, A, Nq]
            q_next_star = q_next_tgt[torch.arange(B), a_star]# [B, Nq]
            y = r.unsqueeze(1) + (1.0 - d.unsqueeze(1)) * self.gamma * q_next_star  # [B, Nq]
            # reduce target across Nq by taking each column as separate target sample
            # we will compute pairwise loss by broadcasting each target column across pred quantiles

        # compute loss as mean over target columns
        loss = 0.0
        for j in range(self.n_quantiles):
            loss = loss + huber_quantile_loss(q_pred_a, y[:, j], self.taus)
        loss = loss / self.n_quantiles

        self.opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.online.parameters(), 10.0)
        self.opt.step()
        return float(loss.item())

    def sync_target(self):
        self.target.load_state_dict(self.online.state_dict())

# Instantiate
state_dim = env.window
agent = QRDQNAgent(state_dim, cfg.n_actions, n_quantiles=51, gamma=cfg.gamma, lr=1e-3)
agent


## 6) Training

In [None]:

def train(env, agent, buffer, steps=15000, warmup=1000, batch_size=64, sync_every=1000, eps_start=0.2, eps_end=0.02):
    s = env.reset()
    eps = eps_start
    losses = []
    rewards = []
    total = 0.0
    for t in range(1, steps+1):
        a = agent.act(s, eps=eps, mode='mean')
        ns, r, done, _ = env.step(a)
        buffer.push(s,a,r,ns, float(done))
        s = ns
        total += r

        if done:
            rewards.append(total)
            total = 0.0
            s = env.reset()

        if len(buffer) >= warmup:
            b = buffer.sample(batch_size)
            loss = agent.update(b)
            losses.append(loss)
            # linear epsilon decay
            eps = max(eps_end, eps - (eps_start - eps_end)/ (steps - warmup + 1))

        if t % sync_every == 0:
            agent.sync_target()

    return losses, rewards

losses, episodic_rewards = train(env, agent, buffer, steps=8000, warmup=800, batch_size=64, sync_every=1000)
len(losses), len(episodic_rewards)


## 7) Training Curves

In [None]:

plt.figure(figsize=(6.0,3.5))
if len(losses):
    L = np.array(losses)
    K = max(1, len(L)//200)
    ma = np.convolve(L, np.ones(K)/K, mode='valid')
    plt.plot(ma)
plt.title("Training Loss (smoothed)")
plt.xlabel("updates")
plt.ylabel("loss")
plt.show()

plt.figure(figsize=(6.0,3.5))
if len(episodic_rewards):
    R = np.array(episodic_rewards)
    K = max(1, len(R)//50)
    ma = np.convolve(R, np.ones(K)/K, mode='valid')
    plt.plot(ma)
plt.title("Episodic Return (smoothed)")
plt.xlabel("episodes")
plt.ylabel("return")
plt.show()


## 8) Evaluation: Mean vs Quantile Decisions

In [None]:

def evaluate(env, agent, mode='mean', episodes=10):
    totals = []
    for _ in range(episodes):
        s = env.reset()
        done = False
        total = 0.0
        while not done:
            a = agent.act(s, eps=0.0, mode=mode)
            s, r, done, _ = env.step(a)
            total += r
        totals.append(total)
    return np.mean(totals), np.std(totals)

mean_avg, mean_std = evaluate(env, agent, mode='mean', episodes=20)
med_avg, med_std  = evaluate(env, agent, mode=0.5,  episodes=20)   # median
q75_avg, q75_std  = evaluate(env, agent, mode=0.75, episodes=20)   # 75th quantile

print("Mean-based   : avg=%.3f  std=%.3f" % (mean_avg, mean_std))
print("Median-based : avg=%.3f  std=%.3f" % (med_avg,  med_std))
print("Q75-based    : avg=%.3f  std=%.3f" % (q75_avg,  q75_std))


## 9) Visualize Learned Quantiles per Action

In [None]:

# Take random batch of states and average predicted quantiles per action
S = []
env.reset()
for _ in range(256):
    S.append(env._get_state())
    env.t += 1
S = torch.tensor(np.stack(S), dtype=torch.float32, device=device)

with torch.no_grad():
    Q = agent.online(S)  # [B, A, Nq]
    q_mean = Q.mean(dim=0)  # [A, Nq]

taus = agent.taus.detach().cpu().numpy()
plt.figure(figsize=(7.0,4.0))
for a in range(cfg.n_actions):
    plt.plot(taus, q_mean[a].detach().cpu().numpy(), label=f"action {a}")
plt.title("Predicted Quantile Functions by Action (avg over states)")
plt.xlabel("tau")
plt.ylabel("quantile value")
plt.legend()
plt.show()
