# NextStat + Gymnasium: RL Agent Optimising Analysis Cuts

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nextstat/nextstat.io/blob/main/notebooks/02_gymnasium_rl_agent.ipynb)

This notebook trains a **Reinforcement Learning agent** (PPO via Stable-Baselines3) to optimise a signal histogram's yields, maximising discovery significance.

NextStat's `nextstat.gym` module wraps a HistFactory workspace as a **Gymnasium environment**:
- **Observation**: current signal histogram yields
- **Action**: multiplicative updates to each bin
- **Reward**: change in Z₀ (discovery significance)

The agent learns to reshape the signal histogram to maximise the statistical test.

In [None]:
# Install dependencies (Colab)
!pip install -q nextstat gymnasium stable-baselines3 numpy matplotlib

## 1. Build a Workspace

In [None]:
import json
import numpy as np

N_BINS = 8
edges = np.linspace(0.0, 1.0, N_BINS + 1)
centers = 0.5 * (edges[:-1] + edges[1:])
width = edges[1] - edges[0]

signal = 40.0 * np.exp(-0.5 * ((centers - 0.5) / 0.12) ** 2) * width
background = 200.0 * np.exp(-2.0 * centers) * width

workspace = {
    "channels": [{
        "name": "SR",
        "samples": [
            {
                "name": "signal",
                "data": signal.tolist(),
                "modifiers": [
                    {"name": "mu", "type": "normfactor", "data": None}
                ],
            },
            {
                "name": "background",
                "data": background.tolist(),
                "modifiers": [
                    {"name": "bkg_norm", "type": "normsys",
                     "data": {"hi": 1.08, "lo": 0.92}},
                ],
            },
        ],
    }],
    "observations": [{"name": "SR", "data": (signal + background).tolist()}],
    "measurements": [{
        "name": "meas",
        "config": {"poi": "mu", "parameters": []},
    }],
    "version": "1.0.0",
}

ws_json = json.dumps(workspace)
print(f"Signal:     {np.round(signal, 2)}")
print(f"Background: {np.round(background, 2)}")
print(f"S/sqrt(B):  {signal.sum() / background.sum()**0.5:.2f}")

## 2. Create the Gymnasium Environment

In [None]:
from nextstat.gym import make_histfactory_env

env = make_histfactory_env(
    ws_json,
    channel="SR",
    sample="signal",
    reward_metric="q0",       # maximise discovery test statistic
    max_steps=64,              # episode length
    action_scale=0.02,         # small perturbations per step
    action_mode="logmul",      # multiplicative updates in log-space
    init_noise=0.0,            # start from nominal
)

print(f"Observation space: {env.observation_space}")
print(f"Action space:      {env.action_space}")

## 3. Random Agent Baseline

First, let's see how a random agent performs:

In [None]:
np.random.seed(42)

obs, info = env.reset(seed=42)
total_reward = 0.0
rewards_random = []

for step in range(64):
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(action)
    total_reward += float(reward)
    rewards_random.append(total_reward)
    if terminated or truncated:
        break

print(f"Random agent — total reward: {total_reward:.3f}")
print(f"Random agent — final yields:  {np.round(obs, 2)}")

## 4. Train PPO Agent (Stable-Baselines3)

In [None]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

# Wrap for SB3
def make_env():
    return make_histfactory_env(
        ws_json,
        channel="SR",
        sample="signal",
        reward_metric="q0",
        max_steps=64,
        action_scale=0.02,
        action_mode="logmul",
        init_noise=0.01,  # slight noise for exploration
    )

vec_env = make_vec_env(make_env, n_envs=4)

agent = PPO(
    "MlpPolicy",
    vec_env,
    learning_rate=3e-4,
    n_steps=128,
    batch_size=64,
    n_epochs=10,
    verbose=1,
)

print("Training PPO agent...")
agent.learn(total_timesteps=20_000)
print("Done!")

## 5. Evaluate Trained Agent

In [None]:
eval_env = make_env()
obs, info = eval_env.reset(seed=123)
total_reward_ppo = 0.0
rewards_ppo = []
yields_history = [obs.copy()]

for step in range(64):
    action, _ = agent.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = eval_env.step(action)
    total_reward_ppo += float(reward)
    rewards_ppo.append(total_reward_ppo)
    yields_history.append(obs.copy())
    if terminated or truncated:
        break

print(f"PPO agent — total reward: {total_reward_ppo:.3f}")
print(f"Random agent — total reward: {total_reward:.3f}")
print(f"Improvement: {(total_reward_ppo - total_reward) / abs(total_reward) * 100:.1f}%")

## 6. Visualise Results

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Cumulative reward
axes[0].plot(rewards_random, label="Random", color="gray", alpha=0.7)
axes[0].plot(rewards_ppo, label="PPO", color="#D4AF37", linewidth=2)
axes[0].set_xlabel("Step")
axes[0].set_ylabel("Cumulative Reward")
axes[0].set_title("Reward: Random vs PPO")
axes[0].legend()
axes[0].grid(alpha=0.2)

# Yield evolution
yields_arr = np.array(yields_history)
for i in range(N_BINS):
    axes[1].plot(yields_arr[:, i], alpha=0.6, linewidth=0.8)
axes[1].set_xlabel("Step")
axes[1].set_ylabel("Yield")
axes[1].set_title("Bin Yields Over Time")
axes[1].grid(alpha=0.2)

# Final vs initial histogram
axes[2].bar(centers - width * 0.2, yields_history[0], width=width * 0.35,
            color="gray", alpha=0.6, label="Initial")
axes[2].bar(centers + width * 0.2, yields_history[-1], width=width * 0.35,
            color="#D4AF37", alpha=0.8, label="Optimised (PPO)")
axes[2].set_xlabel("Bin Center")
axes[2].set_ylabel("Signal Yield")
axes[2].set_title("Initial vs Optimised Signal")
axes[2].legend()
axes[2].grid(alpha=0.2)

plt.tight_layout()
plt.show()

## 7. Systematic Impact Before/After

Compare the ranking plot before and after the RL optimisation:

In [None]:
import nextstat
from nextstat.interpret import rank_impact

model = nextstat.from_pyhf(workspace)

print("Ranking (original model):")
for r in rank_impact(model, top_n=5):
    print(f"  {r['rank']}. {r['name']:20s} impact={r['total_impact']:.4f}")

# Note: to see the impact with optimised yields, you'd rebuild
# the workspace with the agent's final yields and re-run ranking.

---

## Summary

We trained a PPO agent to reshape signal histogram yields, maximising discovery significance q₀. The agent learned a policy that consistently outperforms random exploration.

### Key APIs used:
- `nextstat.gym.make_histfactory_env()` — Gymnasium environment wrapper
- `reward_metric="q0"` — reward = change in test statistic
- `action_mode="logmul"` — multiplicative actions in log-space
- `rank_impact()` — systematic impact analysis

### Next steps:
- Try `reward_metric="z0"` for significance-based reward
- Try `reward_metric="qmu"` with custom `mu_test` for exclusion limits
- Combine with the [PyTorch notebook](./01_pytorch_significance_loss.ipynb) for end-to-end NN + RL
- Scale up with [Optuna](https://nextstat.io/docs/optuna) for hyperparameter search over RL config