# VALOR Demo

This notebook walks through training of VALOR on simple MuJoCo environments. It also supports rendering the contexts at the end. Run all the cells in order after installing all of the requirements inside of requirements.txt.

### Imports & Set-up

In [None]:
import os
import numpy as np

import torch
import torch.nn.functional as F
from torch import optim

import matplotlib.pyplot as plt

import gymnasium as gym

from envs import make_vec_env, ContextConcatWrapper
from models import ActorCritic, TrajectoryDecoder
from ppo import PPOConfig, compute_gae, ppo_update

def strip_ctx(x, K):
    return x[..., :-K]

device = "cuda" if torch.cuda.is_available() else "cpu"
device

### Hyperparameters

In [None]:
# ---- training hyperparams ----
ENV_ID = "HalfCheetah-v5"   # 'Hopper-v5', 'Ant-v5'
NUM_ENVS = 16
SEED = 1
K = 4
EP_LEN = 256
TOTAL_ITERS = 300
DECODER_STEPS = 20
ALPHA_INT = 3.0
ENTROPY_COEF = 0.001

# plotting options
PLOT_STATE_IDXS = (0, 1)    # which two state dims (without context) to visualize in 2D

### Training loop 

In [None]:
def run_training(
    env_id=ENV_ID, 
    num_envs=NUM_ENVS, 
    seed=SEED, 
    K=K, 
    ep_len=EP_LEN,
    total_iters=TOTAL_ITERS, 
    decoder_steps=DECODER_STEPS,
    alpha_int=ALPHA_INT, 
    entropy_coef=ENTROPY_COEF, 
    device=device
):
    torch.manual_seed(seed); np.random.seed(seed)

    # env setup
    vec = make_vec_env(env_id, num_envs, seed, K=K, ep_len=ep_len)
    obs_full, infos = vec.reset(seed=seed)
    obs_dim_full = obs_full.shape[-1]
    obs_dim_wo = obs_dim_full - K
    act_dim = vec.single_action_space.shape[0]

    # network setup
    net = ActorCritic(obs_dim_full, act_dim).to(device)
    dec = TrajectoryDecoder(obs_dim_wo_ctx=obs_dim_wo, K=K).to(device)

    opt = optim.Adam(net.parameters(), lr=3e-4)
    dec_opt = optim.Adam(dec.parameters(), lr=1e-3)

    cfg = PPOConfig(obs_dim=obs_dim_full, act_dim=act_dim, device=device)
    cfg.entropy_coef = entropy_coef

    # buffers
    T, N = ep_len, num_envs
    obs_buf = np.zeros((T, N, obs_dim_full), np.float32)
    act_buf = np.zeros((T, N, act_dim), np.float32)
    logp_buf = np.zeros((T, N), np.float32)
    rew_buf = np.zeros((T, N), np.float32)
    val_buf = np.zeros((T, N), np.float32)
    done_buf = np.zeros((T, N), np.float32)

    # context ids for this batch (from initial reset)
    ctx_ids_for_batch = np.asarray(infos["context_id"], dtype=np.int64)

    # histories for logging
    hist_iters, hist_dec_loss, hist_acc, hist_avg_int = [], [], [], []
    last_logits = None
    last_ctx_ids = None

    for it in range(1, total_iters+1):
        
        # collect rollout
        for t in range(T):
            with torch.no_grad():
                o = torch.as_tensor(obs_full, device=device)
                a, logp, _ = net.sample_action(o)
                v = net.value(o)

            next_obs, r, term, trunc, infos = vec.step(a.cpu().numpy())
            done = np.logical_or(term, trunc)

            obs_buf[t] = obs_full
            act_buf[t] = a.cpu().numpy()
            logp_buf[t] = logp.cpu().numpy()
            rew_buf[t] = r
            val_buf[t] = v.cpu().numpy()
            done_buf[t] = done.astype(np.float32)

            obs_full = next_obs

        # decoder batch
        with_ctx = torch.as_tensor(obs_buf.transpose(1,0,2))
        wo_ctx = strip_ctx(with_ctx, K).to(device)
        ctx_ids = torch.as_tensor(ctx_ids_for_batch, dtype=torch.long, device=device)

        # train decoder
        dec.train()
        last_dec_loss = 0.0; last_acc = 0.0
        for _ in range(decoder_steps):
            logits = dec(wo_ctx)
            loss = F.cross_entropy(logits, ctx_ids)
            dec_opt.zero_grad(set_to_none=True)
            loss.backward()
            dec_opt.step()
            with torch.no_grad():
                acc = (logits.argmax(-1) == ctx_ids).float().mean().item()
                last_dec_loss = loss.item(); last_acc = acc

        # intrinsic reward
        with torch.no_grad():
            dec.eval()
            logits = dec(wo_ctx)
            logp_c = F.log_softmax(logits, dim=-1).gather(1, ctx_ids.unsqueeze(1)).squeeze(1)
            intr_ep = logp_c
            dense = (intr_ep / T).unsqueeze(1).expand(N, T)
            dense = (dense - dense.mean()) / (dense.std() + 1e-8)
            r_int = dense.transpose(1,0).contiguous().to(dtype=torch.float32)

        # PPO update
        rewards_total = torch.as_tensor(rew_buf, device=device) + alpha_int * r_int.to(device)
        adv, ret = compute_gae(
            rewards_total.to(device),
            torch.as_tensor(val_buf).to(device),
            torch.as_tensor(done_buf).to(device),
            cfg.gamma, cfg.gae_lambda
        )

        obs_flat = torch.as_tensor(obs_buf).to(device).reshape(T*N, -1)
        act_flat = torch.as_tensor(act_buf).to(device).reshape(T*N, -1)
        logp_flat = torch.as_tensor(logp_buf).to(device).reshape(T*N)
        adv_flat = adv.reshape(T*N)
        ret_flat = ret.reshape(T*N)

        ppo_update(cfg, net, (obs_flat, act_flat, logp_flat, adv_flat, ret_flat), opt)

        avg_int_r = rewards_total.mean().item()
        
        # log training metrics
        hist_iters.append(it)
        hist_dec_loss.append(last_dec_loss)
        hist_acc.append(last_acc)
        hist_avg_int.append(avg_int_r)

        # log dec confusion metrics
        last_logits = logits.detach().cpu().numpy()
        last_ctx_ids = ctx_ids.detach().cpu().numpy()

        # reset for next batch
        obs_full, infos = vec.reset()
        ctx_ids_for_batch = np.asarray(infos["context_id"], dtype=np.int64)
        
        print(f"[iter {it}/{total_iters}] dec_loss={last_dec_loss:.4f} | acc={last_acc:.3f} | avg_int_r={avg_int_r:.5f}")

    history = {
        "iters": np.array(hist_iters),
        "dec_loss": np.array(hist_dec_loss),
        "acc": np.array(hist_acc),
        "avg_int": np.array(hist_avg_int),
        "last_logits": last_logits,
        "last_ctx_ids": last_ctx_ids,
        "obs_dim_wo": obs_dim_wo,
    }
    return net, dec, history


In [None]:
net, dec, history = run_training()
print("Finished")

In [None]:
# SAVE MODEL WEIGHTS
os.makedirs("weights", exist_ok=True)

# Save state_dicts (recommended for portability)
torch.save(net.state_dict(), f"weights/policy.pt")
torch.save(dec.state_dict(), f"weights/decoder.pt")

### Decoder Loss and Accuracy graphs
Loss should trend downwards and accuuracy should trend upwards.

In [None]:
# Loss
plt.figure()
plt.plot(history["iters"], history["dec_loss"])
plt.xlabel("Iteration"); plt.ylabel("CE Loss")
plt.title("Decoder Loss")
plt.show()

# Accuracy
plt.figure()
plt.plot(history["iters"], history["acc"])
plt.xlabel("Iteration"); plt.ylabel("Accuracy")
plt.title("Decoder Accuracy")
plt.show()

### Final Decoder confusion matrix
Rows = true context, Cols = predicted

Should see a diagonal matrix.

In [None]:
logits = history["last_logits"]   # [N,K]
ctx_ids = history["last_ctx_ids"] # [N]
K_current = logits.shape[1]

pred = logits.argmax(axis=1)
cm = np.zeros((K_current, K_current), dtype=int)
for t, p in zip(ctx_ids, pred):
    cm[t, p] += 1

plt.figure()
plt.imshow(cm, aspect="auto")
plt.colorbar()
plt.xlabel("Predicted context")
plt.ylabel("True context")
plt.title("Decoder Confusion Matrix (last batch)")
plt.show()

cm

## 7) Visualize trajectories for different contexts

We roll out the trained policy in single-env mode with a **fixed** context and plot the 2D projection of two chosen state dims (just the first 2 by default). 

We should be able to see some visual clustering for most contexts. If there is no apparent clustering it's likely because it's happening in a different state dimension.

In [None]:
def rollout_under_context(net, env_id, ctx_id, K, steps=EP_LEN):
    env = gym.make(env_id)
    env = ContextConcatWrapper(env, K=K, context_id=ctx_id)  # fixed context
    obs, info = env.reset(seed=SEED)
    obs_wo = []
    for t in range(steps):
        with torch.no_grad():
            o = torch.as_tensor(obs, device=device).unsqueeze(0)
            a, _, _ = net.sample_action(o)
        obs, r, term, trunc, info = env.step(a.squeeze(0).cpu().numpy())
        obs_wo.append(obs[:-K].copy())  # strip context here
        if term or trunc:
            break
    return np.array(obs_wo)  # [T,D_wo]

# choose dims for 2D projection
i, j = PLOT_STATE_IDXS

for ctx in range(K):
    traj = rollout_under_context(net, ENV_ID, ctx, K, steps=EP_LEN)
    plt.figure()
    plt.scatter(traj[:, i], traj[:, j], s=4)
    plt.xlabel(f"state[{i}]"); plt.ylabel(f"state[{j}]")
    plt.title(f"Trajectory projection for context {ctx}")
    plt.show()

## 8) Render Contexts
Short rollouuts for each fixed context are rendered using the trained policy.

In [None]:
import imageio
from PIL import Image, ImageDraw
from IPython.display import Image as IPyImage, display

CTX_COLORS = [
    (235, 64, 52),   # red-ish
    (52, 140, 235),  # blue-ish
    (52, 199, 89),   # green-ish
    (255, 159, 10),  # orange
    (155, 89, 182),  # purple
    (46, 204, 113),  # emerald
    (241, 196, 15),  # sunflower
    (26, 188, 156),  # teal
]

def make_rgb_env(env_id, K, ctx_id, seed=1):
    env = gym.make(env_id, render_mode="rgb_array")
    env = ContextConcatWrapper(env, K=K, context_id=ctx_id)
    obs, info = env.reset(seed=seed)
    return env, obs

def _frame_with_border_and_text(frame, ctx_id, step, K, border_px=6):
    h, w, _ = frame.shape
    color = CTX_COLORS[ctx_id % len(CTX_COLORS)]
    
    canvas = np.zeros((h + 2*border_px, w + 2*border_px, 3), dtype=np.uint8)
    canvas[:] = color
    canvas[border_px:border_px+h, border_px:border_px+w] = frame

    # add text
    pil_img = Image.fromarray(canvas)
    draw = ImageDraw.Draw(pil_img)
    
    txt = f"ctx={ctx_id} / K={K} | step={step}"
    try:
        tw = draw.textlength(txt)
    except Exception:
        tw = 8 * len(txt)
        
    th = 12
    draw.rectangle([8, 8, 8 + int(tw) + 8, 8 + th + 8], fill=(0,0,0,160))
    draw.text((12, 12), txt, fill=(255,255,255))
    return np.array(pil_img)

def rollout_gif(net, env_id, K, ctx_id, steps=300, fps=30, seed=1, out_dir="videos"):
    os.makedirs(out_dir, exist_ok=True)
    env, obs = make_rgb_env(env_id, K, ctx_id, seed=seed)
    frames = []
    
    for t in range(steps):
        with torch.no_grad():
            o = torch.as_tensor(obs, device=device).unsqueeze(0)
            a, _, _ = net.sample_action(o)
            
        obs, r, term, trunc, info = env.step(a.squeeze(0).cpu().numpy())
        
        frame = env.render()
        if frame is None:
            break
        
        frame = _frame_with_border_and_text(frame, ctx_id, t, K)
        frames.append(frame)
        
        if term or trunc:
            break
        
    env.close()
    if not frames:
        return None
    
    path = os.path.join(out_dir, f"{env_id}_ctx{ctx_id}.gif")
    imageio.mimsave(path, frames, fps=fps, loop=0)
    return path

def show_context_videos(net, env_id, K, steps=300, fps=30, seed=1):
    paths = []

    for ctx in range(K):
        p = rollout_gif(net, env_id, K, ctx, steps=steps, fps=fps, seed=seed)
        paths.append(p)
        
    for p in paths:
        if p is not None:
            display(IPyImage(filename=p))
        else:
            print("No frames for one of the contexts (render returned None).")

You should see gifs for each context appear in videos/

In [None]:
show_context_videos(net, ENV_ID, K, steps=128, fps=10, seed=SEED) # This might take a couuple minutes