Conditional Value-at-Risk (CVaR), also known as Expected Shortfall, is a risk management measure used to quantify the tail risk of a portfolio or a decision process. Unlike Value-at-Risk (VaR), which provides an upper bound for the potential loss at a given confidence level, CVaR estimates the average loss assuming that the loss exceeds the VaR threshold. This makes CVaR particularly useful in assessing extreme risk scenarios, which are often neglected in traditional risk measures. In the context of reinforcement learning, CVaR can be incorporated into the objective function to develop risk-sensitive algorithms that take into account not just the expected rewards but also the potential for high losses.

In our implementation, CVaR is employed as a risk-averse modification to traditional Q-learning, adjusting the value function by incorporating the worst-case scenarios. This allows the model to better handle environments where the uncertainty or volatility in rewards is high.

In [None]:
# ==================== QR-DQN + CVaR (risk-aware) ====================

import os, math, csv, time, sys, subprocess, json
from pathlib import Path
import numpy as np
import torch, torch.nn as nn
from torch.utils.data import IterableDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter

try:
    from tqdm.auto import tqdm
except Exception:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "tqdm"])
    from tqdm.auto import tqdm

# ---------------- CONFIG ----------------
# Point to your large dataset run folder (created earlier)
RUN_DIR          = sorted((Path("/kaggle/input/breakout-offline-minatar-big/breakout_offline_minatar_big")).glob("run_*"))[-1]
VARIANT          = "atari"      # "atari" (84x84x4) or "native" (tiny MinAtar grids)
BATCH            = 128
LR               = 1e-4
N_QUANT          = 51           # number of quantile atoms (e.g., 21/51)
GAMMA            = 0.99
KAPPA            = 1.0          # Huber kappa for quantile regression
ALPHA            = 0.10         # <-- CVaR risk level (0 < α ≤ 1). Lower = more risk-averse.
EPOCHS           = 12
STEPS_PER_EPOCH  = 300          # adjust as you like; dataset is large enough
TARGET_UPDATE    = 600          # hard copy every N iters (set 0 to rely on EMA)
EMA_TAU          = 5e-3         # <-- EMA target update per step; set 0 to disable EMA
SEED             = 123
USE_AMP          = True         # AMP/mixed precision for speed
SAVE_DIR         = Path("/kaggle/working/qr_dqn_logs") / "qrdqn_cvar_logs"
PLOTS_DIR        = SAVE_DIR / "plots"
# ----------------------------------------

os.makedirs(SAVE_DIR, exist_ok=True); os.makedirs(PLOTS_DIR, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpus = torch.cuda.device_count() if device.type == "cuda" else 0
torch.manual_seed(SEED); np.random.seed(SEED)
torch.backends.cudnn.benchmark = True  # fixed-shape speedup

# --------------- Discover shards + split (80/20 by shards; robust to 1 shard) ---------------
shards = sorted((RUN_DIR).glob(f"*_{VARIANT}_shard_*.npz"))
assert len(shards) > 0, f"No {VARIANT} shards found in {RUN_DIR}"

if len(shards) >= 2:
    n_train_shards = max(1, min(int(0.8 * len(shards)), len(shards) - 1))
    train_shards = shards[:n_train_shards]
    eval_shards  = shards[n_train_shards:]
    single_shard = False
else:
    train_shards = shards
    eval_shards  = shards
    single_shard = True

print(f"Shards: {len(shards)} | variant={VARIANT} | Train shards: {len(train_shards)} | Eval shards: {len(eval_shards)}")

# --------------- Fast IterableDataset (loads each shard once/epoch) ---------------
class ShardBatcher(IterableDataset):
    """
    - If multiple shards: train/eval are split by shards (80/20 above).
    - If a single shard: we split indices (train ~90%, eval ~10%) inside this class.
    - Yields tensors: obs/next_obs float CHW in [0,1], act long, rew float, done float.
    """
    def __init__(self, shard_paths, batch_size, seed=123, eval_mode=False, split_frac=0.9):
        super().__init__()
        self.shards = list(map(str, shard_paths))
        self.bs = batch_size
        self.eval_mode = eval_mode
        self.rng = np.random.default_rng(seed)
        self.split_frac = split_frac
        d0 = np.load(self.shards[0]); H, W, C = d0["obs"].shape[1:]
        self.in_ch = int(C); self.n_actions = int(d0["act"].max()) + 1

    def _yield_batches_from_arrays(self, d):
        N = d["act"].shape[0]; idx = np.arange(N)
        if single_shard:
            split = max(1, int(self.split_frac * N))
            idx = idx[split:] if self.eval_mode else idx[:split]
        if not self.eval_mode:
            self.rng.shuffle(idx)
        for s in range(0, idx.size, self.bs):
            bi = idx[s:s+self.bs]
            o  = torch.from_numpy(d["obs"][bi]).permute(0,3,1,2).float()/255.0
            no = torch.from_numpy(d["next_obs"][bi]).permute(0,3,1,2).float()/255.0
            a  = torch.from_numpy(d["act"][bi]).long()
            r  = torch.from_numpy(d["rew"][bi]).float()
            dn = torch.from_numpy(d["done"][bi]).float()
            yield o, a, r, no, dn

    def __iter__(self):
        order = np.arange(len(self.shards))
        if not self.eval_mode:
            self.rng.shuffle(order)
        for si in order:
            d = np.load(self.shards[si])
            yield from self._yield_batches_from_arrays(d)

train_ds = ShardBatcher(train_shards, BATCH, SEED, eval_mode=False)
eval_ds  = ShardBatcher(eval_shards,  BATCH, SEED, eval_mode=True)

IN_CH, N_ACT = train_ds.in_ch, train_ds.n_actions
train_loader = DataLoader(train_ds, batch_size=None, num_workers=0)
eval_loader  = DataLoader(eval_ds,  batch_size=None, num_workers=0)
print(f"Channels: {IN_CH} | Actions: {N_ACT} | Device: {device} | GPUs: {n_gpus}")

# --------------- Model + CVaR helpers ---------------
class QRDQN(nn.Module):
    def __init__(self, in_ch, n_actions, n_quant=51):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, 32, 8, 4), nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 4, 2),    nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, 1),    nn.ReLU(inplace=True),
        )
        self.gap = nn.AdaptiveAvgPool2d((7,7))
        self.fc  = nn.Sequential(
            nn.Linear(64*7*7, 512), nn.ReLU(inplace=True),
            nn.Linear(512, n_actions*n_quant),
        )
        self.n_actions = n_actions
        self.n_quant   = n_quant
    def forward(self, x):
        z = self.conv(x)
        z = self.gap(z).reshape(z.size(0), -1)
        z = self.fc(z).reshape(-1, self.n_actions, self.n_quant)
        return z  # (B, A, Nq), atoms ordered by tau midpoints

def quantile_huber_loss(pred_tau, target_tau, taus, kappa=1.0):
    # pred_tau, target_tau: (B, Nq)
    u = target_tau.unsqueeze(1) - pred_tau.unsqueeze(2)         # (B,Nq,Nq)
    abs_u = torch.abs(u)
    huber = torch.where(abs_u <= kappa, 0.5*u.pow(2), kappa*(abs_u - 0.5*kappa))
    tau = taus.view(1, -1, 1)                                   # (1,Nq,1)
    weight = torch.abs(tau - (u.detach() < 0).float())
    return (weight * huber).mean()

def cvar_from_quantiles(qvals, alpha):
    """
    qvals: (B, A, Nq) quantile values (assumed ordered by tau asc).
    alpha in (0,1]; CVaR_α = mean of the lowest α-fraction of atoms.
    returns: (B, A) CVaR scores.
    """
    B, A, Nq = qvals.shape
    k = max(1, int(math.ceil(alpha * Nq)))
    # Take mean over the first k atoms (lower tail)
    return qvals[:, :, :k].mean(dim=2)

# Instantiate nets/opt
qnet_base   = QRDQN(IN_CH, N_ACT, N_QUANT).to(device)
target_base = QRDQN(IN_CH, N_ACT, N_QUANT).to(device)
target_base.load_state_dict(qnet_base.state_dict())

# Optional multi-GPU
qnet   = nn.DataParallel(qnet_base)   if n_gpus > 1 else qnet_base
target = nn.DataParallel(target_base) if n_gpus > 1 else target_base

opt    = torch.optim.Adam(qnet.parameters(), lr=LR)
from torch.amp import autocast, GradScaler
scaler = GradScaler('cuda', enabled=(USE_AMP and device.type == "cuda"))

taus = torch.linspace(0, 1, N_QUANT + 1, device=device)
taus = (taus[:-1] + taus[1:]) / 2.0  # midpoints, ascending

# --------------- Logging (train only; eval after) ---------------
writer  = SummaryWriter(log_dir=str(SAVE_DIR / "tb"))
csv_log = open(SAVE_DIR / "training_log.csv", "w", newline="")
csv_wr  = csv.writer(csv_log)
csv_wr.writerow(["epoch","iter","loss_train","q_mean","q_std","grad_norm","lr","time_sec"])

# --------------- Training (CVaR action selection) ---------------
global_iter = 0
t0 = time.time()

def ema_target_update(ema_tau):
    with torch.no_grad():
        for p_t, p in zip(target_base.parameters(), qnet_base.parameters()):
            p_t.data.mul_(1.0 - ema_tau).add_(ema_tau * p.data)

for epoch in range(1, EPOCHS+1):
    qnet.train()
    running = 0.0
    pbar = tqdm(total=STEPS_PER_EPOCH, desc=f"Epoch {epoch}", leave=False, dynamic_ncols=True, smoothing=0.2)
    i = 0
    for (o,a,r,no,dn) in train_loader:
        i += 1
        o,no = o.to(device, non_blocking=True), no.to(device, non_blocking=True)
        a,r,dn = a.to(device, non_blocking=True), r.to(device, non_blocking=True), dn.to(device, non_blocking=True)

        with autocast('cuda', enabled=(USE_AMP and device.type=="cuda")):
            q = qnet(o)  # (B, A, Nq)
            # CVaR-based target action on next states (Double-DQN)
            q_next_online = qnet(no)                    # (B, A, Nq)
            cvar_scores   = cvar_from_quantiles(q_next_online, ALPHA)  # (B, A)
            next_a        = cvar_scores.argmax(dim=1)   # (B,)

            # Gather online predicted quantiles for taken actions (current states)
            q_sel = q.gather(1, a.view(-1,1,1).expand(-1,1,N_QUANT)).squeeze(1)  # (B, Nq)

            # Target network quantiles for the selected next action
            tq_all = target(no)  # (B, A, Nq)
            tq     = tq_all.gather(1, next_a.view(-1,1,1).expand(-1,1,N_QUANT)).squeeze(1)  # (B, Nq)

            # Bellman target distribution
            tgt = r.view(-1,1) + GAMMA * (1.0 - dn.view(-1,1)) * tq

            # Distributional QR Huber loss
            loss = quantile_huber_loss(q_sel, tgt, taus, KAPPA)

        opt.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        grad_norm = float(nn.utils.clip_grad_norm_(qnet.parameters(), 10.0))
        scaler.step(opt)
        scaler.update()

        # Target updates: EMA each step; optional hard copy
        if EMA_TAU > 0.0:
            ema_target_update(EMA_TAU)
        elif TARGET_UPDATE > 0 and (global_iter % TARGET_UPDATE == 0):
            target_base.load_state_dict(qnet_base.state_dict())

        # Logging
        running += float(loss.item())
        global_iter += 1
        if (i % 25 == 0) or (i == STEPS_PER_EPOCH):
            lr = float(opt.param_groups[0]["lr"])
            q_std = float(q_sel.std().item())
            q_mean = float(q_sel.mean().item())
            elapsed = time.time() - t0
            avg_tr = running / max(1, i)

            writer.add_scalar("loss/train", avg_tr, global_iter)
            writer.add_scalar("q/pred_mean_train", q_mean, global_iter)
            writer.add_scalar("grad/norm", grad_norm, global_iter)
            writer.add_scalar("opt/lr", lr, global_iter)
            writer.add_scalar("cvar/alpha", ALPHA, global_iter)

            csv_wr.writerow([epoch, global_iter, avg_tr, q_mean, q_std, grad_norm, lr, elapsed])
            csv_log.flush()

        pbar.set_postfix_str(f"loss={running/max(1,i):.4g}")
        pbar.update(1)
        if i >= STEPS_PER_EPOCH:
            break
    pbar.close()

    # end-epoch snapshot
    torch.save({
        "model": qnet_base.state_dict(),
        "target": target_base.state_dict(),
        "config": dict(n_quant=N_QUANT, gamma=GAMMA, alpha=ALPHA, ema_tau=EMA_TAU,
                       in_ch=IN_CH, n_actions=N_ACT),
    }, SAVE_DIR / f"qrdqn_cvar_epoch_{epoch}.pt")
    print(f"[Epoch {epoch}] avg_train_loss={running/max(1,i):.6f}")

# final save
torch.save({
    "model": qnet_base.state_dict(),
    "target": target_base.state_dict(),
    "config": dict(n_quant=N_QUANT, gamma=GAMMA, alpha=ALPHA, ema_tau=EMA_TAU,
                   in_ch=IN_CH, n_actions=N_ACT),
}, SAVE_DIR / "qrdqn_cvar_last.pt")

writer.close()
csv_log.close()
print("Training logs saved to:", SAVE_DIR)

# --------------- Single evaluation after training (held-out only) ---------------
@torch.no_grad()
def full_eval():
    qnet.eval(); target.eval()
    total_loss=0.0; iters=0; q_means=[]; tq_means=[]
    agree_total=0; count_total=0
    ds_hist = Counter(); pi_hist = Counter()

    # Build a fresh eval loader to ensure held-out portion only
    eval_loader2 = DataLoader(eval_ds, batch_size=None, num_workers=0)
    pbar = tqdm(total=sum(math.ceil(np.load(sp)["obs"].shape[0] * (0.1 if single_shard else 1.0) / BATCH) for sp in eval_shards),
                desc="Eval", leave=True, dynamic_ncols=True)

    for (o,a,r,no,dn) in eval_loader2:
        o,no = o.to(device), no.to(device)
        a,r,dn = a.to(device), r.to(device), dn.to(device)

        with autocast('cuda', enabled=(USE_AMP and device.type=="cuda")):
            q = qnet(o)  # (B, A, Nq)
            # policy (for agreement/hist): CVaR on current states
            cvar_now = cvar_from_quantiles(q, ALPHA)  # (B, A)
            pi_a = cvar_now.argmax(1)

            # gather predicted quantiles for logged actions
            q_sel = q.gather(1, a.view(-1,1,1).expand(-1,1,N_QUANT)).squeeze(1)

            # Double-DQN target based on CVaR on next states
            q_next = qnet(no)
            next_a = cvar_from_quantiles(q_next, ALPHA).argmax(1)
            tq_all = target(no)
            tq = tq_all.gather(1, next_a.view(-1,1,1).expand(-1,1,N_QUANT)).squeeze(1)
            tgt = r.view(-1,1) + GAMMA * (1.0 - dn.view(-1,1)) * tq
            loss = quantile_huber_loss(q_sel, tgt, taus, KAPPA)

        total_loss += float(loss.item()); iters += 1
        q_means.append(float(q_sel.mean().item())); tq_means.append(float(tq.mean().item()))
        agree_total += int((pi_a == a).sum().item()); count_total += int(a.numel())
        for aa in a.detach().cpu().numpy(): ds_hist[int(aa)] += 1
        for pa in pi_a.detach().cpu().numpy():    pi_hist[int(pa)] += 1

        pbar.set_postfix_str(f"loss={total_loss/max(1,iters):.4g}")
        pbar.update(1)
    pbar.close()

    return dict(
        eval_loss = total_loss / max(1,iters),
        q_mean    = float(np.mean(q_means)) if q_means else 0.0,
        tq_mean   = float(np.mean(tq_means)) if tq_means else 0.0,
        num_batches = iters,
        policy_dataset_agreement = (agree_total / max(1,count_total)),
        ds_hist = dict(ds_hist), pi_hist = dict(pi_hist),
    )

eval_stats = full_eval()
print("[Final Eval]", {k:v for k,v in eval_stats.items() if k not in ("ds_hist","pi_hist")})

# Save eval CSV + histograms
with open(SAVE_DIR / "final_eval_cvar.csv", "w", newline="") as f:
    w = csv.writer(f)
    w.writerow(["eval_loss","q_mean","tq_mean","policy_dataset_agreement","num_batches","alpha","checkpoint"])
    w.writerow([eval_stats["eval_loss"], eval_stats["q_mean"], eval_stats["tq_mean"],
                eval_stats["policy_dataset_agreement"], eval_stats["num_batches"], ALPHA,
                str(SAVE_DIR / "qrdqn_cvar_last.pt")])

with open(SAVE_DIR / "eval_action_histograms_cvar.json", "w") as f:
    json.dump({"dataset": eval_stats["ds_hist"], "policy": eval_stats["pi_hist"]}, f, indent=2)

# Quick plots for report
df = pd.read_csv(SAVE_DIR / "training_log.csv")
plt.figure(); df.plot(x="iter", y=["loss_train"]); plt.title("QR-DQN CVaR: Train Loss"); plt.xlabel("iter"); plt.ylabel("loss")
plt.savefig(PLOTS_DIR/"loss_train_cvar.png", bbox_inches="tight"); plt.close()
plt.figure(); df.plot(x="iter", y=["q_mean"]); plt.title("QR-DQN CVaR: Train Q Mean"); plt.xlabel("iter"); plt.ylabel("Q")
plt.savefig(PLOTS_DIR/"q_train_cvar.png", bbox_inches="tight"); plt.close()
print("Artifacts:", SAVE_DIR)


# **CVaR eval**

In [4]:
# =================== EVAL-ONLY with PROGRESS (held-out data) — CVaR QR-DQN ===================
# - Matches the baseline eval structure but uses CVaR policy (alpha from checkpoint config)


import os, csv, json, math, time
from pathlib import Path
import numpy as np
import torch, torch.nn as nn
from torch.utils.data import IterableDataset, DataLoader
import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter

try:
    from tqdm.auto import tqdm
except Exception:
    import sys, subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "tqdm"])
    from tqdm.auto import tqdm

# ---------------- CONFIG (match training) ----------------
RUN_DIR   = sorted((Path("/kaggle/input/breakout-offline-minatar-big/breakout_offline_minatar_big")).glob("run_*"))[-1]
VARIANT   = "atari"        # "atari" or "native"
BATCH     = 128
SEED      = 123
CKPT_PATH = Path("/kaggle/working/qr_dqn_logs/qrdqn_cvar_logs/qrdqn_cvar_epoch_5.pt")  
OUT_DIR   = Path("/kaggle/working/qr_dqn_logs")                                    
USE_AMP   = True
EVAL_MAX_BATCHES = None    
# ---------------------------------------------------------

OUT_DIR.mkdir(parents=True, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpus = torch.cuda.device_count() if device.type == "cuda" else 0
torch.backends.cudnn.benchmark = True

# ---------------- Discover eval shards ----------------
shards = sorted((RUN_DIR).glob(f"*_{VARIANT}_shard_*.npz"))
assert shards, f"No {VARIANT} shards found in {RUN_DIR}"
multi = len(shards) >= 2
if multi:
    n_train_shards = max(1, min(int(0.8 * len(shards)), len(shards) - 1))
    eval_shards    = shards[n_train_shards:]
else:
    eval_shards    = shards

# ---------------- Dataset (eval-only) ----------------
class EvalBatcher(IterableDataset):
    def __init__(self, shard_paths, batch_size, seed=123, single_shard=not multi):
        super().__init__()
        self.shards = list(map(str, shard_paths))
        self.bs = batch_size
        self.single_shard = single_shard
        self.rng = np.random.default_rng(seed)
        d0 = np.load(self.shards[0])
        H,W,C = d0["obs"].shape[1:]
        self.in_ch = int(C)
        self.n_actions = int(d0["act"].max()) + 1

    def __iter__(self):
        for sp in self.shards:
            d = np.load(sp)
            if self.single_shard:
                N = d["obs"].shape[0]
                split = max(1, int(0.9 * N))          # last 10% = eval
                idx = np.arange(split, N)
            else:
                idx = np.arange(d["obs"].shape[0])    # whole shard = eval
            for s in range(0, idx.size, self.bs):
                bi = idx[s:s+self.bs]
                o  = torch.from_numpy(d["obs"][bi]).permute(0,3,1,2).float()/255.0
                no = torch.from_numpy(d["next_obs"][bi]).permute(0,3,1,2).float()/255.0
                a  = torch.from_numpy(d["act"][bi]).long()
                r  = torch.from_numpy(d["rew"][bi]).float()
                dn = torch.from_numpy(d["done"][bi]).float()
                yield o, a, r, no, dn

eval_ds = EvalBatcher(eval_shards, BATCH)
eval_loader = DataLoader(eval_ds, batch_size=None, num_workers=0)
IN_CH, N_ACT = eval_ds.in_ch, eval_ds.n_actions

def estimate_batches(shard_paths, batch, single_shard=not multi):
    total = 0
    for sp in shard_paths:
        d = np.load(sp)
        if single_shard:
            N = d["obs"].shape[0]; split = max(1, int(0.9*N))
            n = N - split
        else:
            n = d["obs"].shape[0]
        total += math.ceil(n / batch)
    return total

expected_batches = estimate_batches(eval_shards, BATCH)
if EVAL_MAX_BATCHES is not None:
    expected_batches = min(expected_batches, EVAL_MAX_BATCHES)

# ---------------- Model (same as training arch) ----------------
class QRDQN(nn.Module):
    def __init__(self, in_ch, n_actions, n_quant=51):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, 32, 8, 4), nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 4, 2),    nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, 1),    nn.ReLU(inplace=True),
        )
        self.gap = nn.AdaptiveAvgPool2d((7,7))
        self.fc = nn.Sequential(
            nn.Linear(64*7*7, 512), nn.ReLU(inplace=True),
            nn.Linear(512, n_actions*51),
        )
        self.n_actions = n_actions
        self.n_quant   = 51
    def forward(self, x):
        z = self.conv(x); z = self.gap(z).reshape(z.size(0), -1)
        z = self.fc(z).reshape(-1, self.n_actions, self.n_quant)
        return z

def quantile_huber_loss(pred_tau, target_tau, taus, kappa=1.0):
    u = target_tau.unsqueeze(1) - pred_tau.unsqueeze(2)
    abs_u = torch.abs(u)
    huber = torch.where(abs_u <= kappa, 0.5*u.pow(2), kappa*(abs_u - 0.5*kappa))
    tau = taus.view(1,-1,1)
    weight = torch.abs(tau - (u.detach() < 0).float())
    return (weight*huber).mean()

def cvar_from_quantiles(qvals, alpha):
    """qvals: (B,A,Nq) -> CVaR_α (mean of lowest α-fraction of atoms)."""
    B, A, Nq = qvals.shape
    k = max(1, int(math.ceil(alpha * Nq)))
    return qvals[:, :, :k].mean(dim=2)

# ----- Load CVaR checkpoint -----
ckpt = torch.load(CKPT_PATH, map_location=device)
N_QUANT = ckpt["config"].get("n_quant", 51)
GAMMA   = ckpt["config"].get("gamma", 0.99)
ALPHA   = ckpt["config"].get("alpha", 0.10)   # CVaR level used at train time
KAPPA   = 1.0

qnet_base   = QRDQN(IN_CH, N_ACT, N_QUANT).to(device)
target_base = QRDQN(IN_CH, N_ACT, N_QUANT).to(device)
qnet_base.load_state_dict(ckpt["model"])
target_base.load_state_dict(ckpt["target"])

# DataParallel if multiple GPUs
qnet   = nn.DataParallel(qnet_base)   if n_gpus > 1 else qnet_base
target = nn.DataParallel(target_base) if n_gpus > 1 else target_base
qnet.eval(); target.eval()

taus = torch.linspace(0, 1, N_QUANT + 1, device=device); taus = (taus[:-1] + taus[1:]) / 2.0

# ---------------- Eval loop with progress (CVaR policy + CVaR Double-DQN targets) ----------------
@torch.no_grad()
def run_eval():
    total_loss=0.0; iters=0; seen=0
    q_means=[]; tq_means=[]
    agree_total=0; count_total=0
    ds_hist = Counter(); pi_hist = Counter()

    start = time.time()
    pbar = tqdm(total=expected_batches, desc="Eval (CVaR)", leave=True, dynamic_ncols=True)

    for (o,a,r,no,dn) in eval_loader:
        if (EVAL_MAX_BATCHES is not None) and (iters >= EVAL_MAX_BATCHES):
            break

        o,no = o.to(device, non_blocking=True), no.to(device, non_blocking=True)
        a,r,dn = a.to(device, non_blocking=True), r.to(device, non_blocking=True), dn.to(device, non_blocking=True)

        with torch.amp.autocast('cuda', enabled=(USE_AMP and device.type=="cuda")):
            q = qnet(o)  # (B,A,Nq)

            # Policy for agreement/hist: CVaR on CURRENT states
            pi_scores = cvar_from_quantiles(q, ALPHA)  
            pi_a = pi_scores.argmax(1)                

            # Online quantiles for the logged action
            q_sel = q.gather(1, a.view(-1,1,1).expand(-1,1,N_QUANT)).squeeze(1)  

            # CVaR Double-DQN target:
            q_next_online = qnet(no)                             
            next_a = cvar_from_quantiles(q_next_online, ALPHA).argmax(1) 
            tq_all = target(no)                                  
            tq = tq_all.gather(1, next_a.view(-1,1,1).expand(-1,1,N_QUANT)).squeeze(1)  

            tgt = r.view(-1,1) + GAMMA*(1.0 - dn.view(-1,1))*tq
            loss = quantile_huber_loss(q_sel, tgt, taus, KAPPA)

        bsz = a.numel()
        seen += bsz
        total_loss += float(loss.item()); iters += 1
        q_means.append(float(q_sel.mean().item())); tq_means.append(float(tq.mean().item()))

        # agreement & histograms
        agree_total += int((pi_a == a).sum().item())
        count_total += int(a.numel())
        for aa in a.detach().cpu().numpy(): ds_hist[int(aa)] += 1
        for pa in pi_a.detach().cpu().numpy(): pi_hist[int(pa)] += 1

        # progress
        elapsed = time.time() - start
        pbar.set_postfix_str(f"loss={total_loss/max(1,iters):.4g} | {seen/max(1,elapsed):.1f} samp/s")
        pbar.update(1)

    pbar.close()
    return dict(
        eval_loss = total_loss / max(1,iters),
        q_mean    = float(np.mean(q_means)) if q_means else 0.0,
        tq_mean   = float(np.mean(tq_means)) if tq_means else 0.0,
        num_batches = iters,
        samples_seen = seen,
        policy_dataset_agreement = (agree_total / max(1,count_total)),
        ds_hist = dict(ds_hist), pi_hist = dict(pi_hist),
        eval_source = "held_out_from_training_run",
        device = str(device), gpus = n_gpus, alpha = ALPHA
    )

stats = run_eval()
print("\n[Eval CVaR report]", {k: v for k,v in stats.items() if k not in ("ds_hist","pi_hist")})

out_csv = OUT_DIR / "final_eval_report_cvar.csv"
with open(out_csv, "w", newline="") as f:
    w = csv.writer(f)
    w.writerow(["eval_source","device","gpus","alpha","eval_loss","q_mean","tq_mean",
                "policy_dataset_agreement","num_batches","samples_seen","checkpoint"])
    w.writerow([stats["eval_source"], stats["device"], stats["gpus"], stats["alpha"], stats["eval_loss"],
                stats["q_mean"], stats["tq_mean"], stats["policy_dataset_agreement"],
                stats["num_batches"], stats["samples_seen"], str(CKPT_PATH)])
print("Saved CSV:", out_csv)

# Histograms 
with open(OUT_DIR / "eval_action_histograms_cvar.json", "w") as f:
    json.dump({"dataset": stats["ds_hist"], "policy": stats["pi_hist"], "alpha": stats["alpha"]}, f, indent=2)
print("Saved hist:", OUT_DIR / "eval_action_histograms_cvar.json")

ds_total = sum(stats["ds_hist"].values()) or 1
pi_total = sum(stats["pi_hist"].values()) or 1
actions = sorted(set(list(stats["ds_hist"].keys()) + list(stats["pi_hist"].keys())))
df = pd.DataFrame({
    "action": actions,
    "dataset_freq": [stats["ds_hist"].get(a,0)/ds_total for a in actions],
    "policy_freq":  [stats["pi_hist"].get(a,0)/pi_total for a in actions],
})
ax = df.plot(kind="bar", x="action", y=["dataset_freq","policy_freq"], figsize=(7,4), rot=0,
             title=f"Action frequencies (dataset vs CVaR policy, α={stats['alpha']})")
plt.tight_layout(); plt.savefig(OUT_DIR / "eval_support_cvar.png"); plt.close()
print("Saved plot:", OUT_DIR / "eval_support_cvar.png")


Eval (CVaR):   0%|          | 0/132 [00:00<?, ?it/s]


[Eval CVaR report] {'eval_loss': 0.001983130349827027, 'q_mean': 0.2605646306818182, 'tq_mean': 0.27534346147017047, 'num_batches': 132, 'samples_seen': 16800, 'policy_dataset_agreement': 0.1686904761904762, 'eval_source': 'held_out_from_training_run', 'device': 'cuda', 'gpus': 1, 'alpha': 0.1}
Saved CSV: /kaggle/working/qr_dqn_logs/final_eval_report_cvar.csv
Saved hist: /kaggle/working/qr_dqn_logs/eval_action_histograms_cvar.json
Saved plot: /kaggle/working/qr_dqn_logs/eval_support_cvar.png
