# NextStat + PyTorch: Training a Classifier with Significance Loss

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nextstat/nextstat.io/blob/main/notebooks/01_pytorch_significance_loss.ipynb)

This notebook demonstrates **end-to-end differentiable training** of a neural network classifier where the loss function is the **discovery significance Z₀** — computed by NextStat's profiled likelihood engine.

Instead of training with cross-entropy and *then* running a statistical test, we **differentiate through the statistical test itself**.

## Pipeline
```
NN(features) → scores → SoftHistogram → SignificanceLoss(-Z₀) → backward → optimizer
```

### Requirements
- `nextstat` (Rust-powered statistical inference, built with `--features cuda` or `--features metal`)
- `torch` (PyTorch)
- GPU required: CUDA (NVIDIA) or Metal (Apple Silicon)

In [None]:
# Install dependencies (Colab)
!pip install -q nextstat torch numpy matplotlib

## 1. Build a HistFactory Workspace

We create a simple signal + background model with one systematic uncertainty.
In a real analysis this would come from your ntuple processing pipeline.

In [None]:
import json
import numpy as np

N_BINS = 10
edges = np.linspace(0.0, 1.0, N_BINS + 1)
centers = 0.5 * (edges[:-1] + edges[1:])
width = edges[1] - edges[0]

# Signal: Gaussian peak at 0.5
signal = 30.0 * np.exp(-0.5 * ((centers - 0.5) / 0.10) ** 2) * width
# Background: falling exponential
background = 150.0 * np.exp(-1.5 * centers) * width

workspace = {
    "channels": [{
        "name": "SR",
        "samples": [
            {
                "name": "signal",
                "data": signal.tolist(),
                "modifiers": [
                    {"name": "mu", "type": "normfactor", "data": None}
                ],
            },
            {
                "name": "background",
                "data": background.tolist(),
                "modifiers": [
                    {"name": "bkg_norm", "type": "normsys",
                     "data": {"hi": 1.10, "lo": 0.90}},
                ],
            },
        ],
    }],
    "observations": [{"name": "SR", "data": (signal + background).tolist()}],
    "measurements": [{
        "name": "meas",
        "config": {"poi": "mu", "parameters": []},
    }],
    "version": "1.0.0",
}

print(f"Signal bins:     {np.round(signal, 2)}")
print(f"Background bins: {np.round(background, 2)}")
print(f"S/B ratio:       {signal.sum() / background.sum():.2%}")

## 2. Load Model & Create Loss Function

In [None]:
import nextstat
from nextstat.torch import SignificanceLoss, SoftHistogram

model = nextstat.from_pyhf(workspace)

# SignificanceLoss returns -Z₀ by default (for SGD minimisation)
loss_fn = SignificanceLoss(model, "signal", device="auto")

print(f"Expected signal bins: {loss_fn.n_bins}")
print(f"Nuisance parameters:  {loss_fn.n_params}")

## 3. Define a Simple Classifier + SoftHistogram

In [None]:
import torch
import torch.nn as nn

class SimpleClassifier(nn.Module):
    """Toy 2-layer NN that maps features → [0, 1] score."""
    def __init__(self, input_dim=5, hidden=32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.net(x).squeeze(-1)

# Differentiable binning
soft_hist = SoftHistogram(
    bin_edges=torch.linspace(0.0, 1.0, N_BINS + 1),
    bandwidth=0.05,
    mode="kde",
)

print(f"SoftHistogram: {soft_hist.n_bins} bins, bandwidth={soft_hist._bw:.3f}")

## 4. Generate Toy Data

In [None]:
torch.manual_seed(42)
np.random.seed(42)

N_SIGNAL = 500
N_BACKGROUND = 2000
INPUT_DIM = 5

# Signal: clustered around feature-space center
x_sig = torch.randn(N_SIGNAL, INPUT_DIM) * 0.5 + 0.5
# Background: spread wider
x_bkg = torch.randn(N_BACKGROUND, INPUT_DIM) * 1.0

# Combine
x_all = torch.cat([x_sig, x_bkg], dim=0)
w_all = torch.ones(N_SIGNAL + N_BACKGROUND)  # uniform weights

print(f"Signal events:     {N_SIGNAL}")
print(f"Background events: {N_BACKGROUND}")
print(f"Feature dim:       {INPUT_DIM}")

## 5. Training Loop

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on: {device}")

classifier = SimpleClassifier(input_dim=INPUT_DIM).to(device)
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)

x_train = x_all.to(device)
w_train = w_all.to(device)

history = []

for epoch in range(100):
    optimizer.zero_grad()

    # Forward: NN → scores → soft histogram → -Z₀
    scores = classifier(x_train)
    histogram = soft_hist(scores, w_train)

    # Scale histogram to match expected yields
    hist_scaled = histogram * (signal.sum() / (histogram.sum().item() + 1e-10))

    loss = loss_fn(hist_scaled.double().to(device))

    # Backward
    loss.backward()
    optimizer.step()

    z0 = -loss.item()
    history.append(z0)

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1:3d}: Z₀ = {z0:.3f}σ")

print(f"\nFinal Z₀: {history[-1]:.3f}σ")

## 6. Visualise Training

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Training curve
axes[0].plot(history, color="#D4AF37", linewidth=1.5)
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Z₀ (σ)")
axes[0].set_title("Discovery Significance vs Epoch")
axes[0].grid(alpha=0.2)

# Final histogram
with torch.no_grad():
    scores_final = classifier(x_train)
    hist_final = soft_hist(scores_final, w_train).cpu().numpy()

axes[1].bar(centers, hist_final, width=width * 0.8, color="#D4AF37", alpha=0.8, label="NN output")
axes[1].step(centers, signal, where="mid", color="red", linewidth=1.5, label="True signal")
axes[1].set_xlabel("Classifier Score")
axes[1].set_ylabel("Events")
axes[1].set_title("Learned Histogram vs True Signal")
axes[1].legend()
axes[1].grid(alpha=0.2)

plt.tight_layout()
plt.show()

## 7. Systematic Impact (Feature Importance)

After training, check which systematics dominate:

In [None]:
from nextstat.interpret import rank_impact

ranking = rank_impact(model, top_n=10)
for r in ranking:
    print(f"  {r['rank']:2d}. {r['name']:20s}  impact={r['total_impact']:.4f}  pull={r['pull']:.3f}")

## 8. Log to W&B (Optional)

Uncomment to enable Weights & Biases logging:

In [None]:
# import wandb
# from nextstat.mlops import significance_metrics
#
# wandb.init(project="nextstat-colab")
# for i, z0 in enumerate(history):
#     wandb.log(significance_metrics(z0, prefix="train/"), step=i)
# wandb.finish()

---

## Summary

We trained a neural network where the loss is **profiled discovery significance** — computed by NextStat's Rust/CUDA engine. The gradient flows from the statistical test, through the differentiable histogram, into the network weights.

### Key APIs used:
- `SignificanceLoss(model, "signal")` — differentiable Z₀ loss
- `SoftHistogram(bin_edges)` — differentiable binning (KDE)
- `rank_impact(model)` — systematic impact ranking

### Next steps:
- [ML Training Guide](https://nextstat.io/docs/ml-training)
- [Optuna Tutorial](https://nextstat.io/docs/optuna) — optimise binning
- [RL Notebook](./02_gymnasium_rl_agent.ipynb) — RL agent for cut optimisation