In [17]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import DataLoader, RandomSampler
from utils.jet_dataset import JetDataset
from wnae import WNAE
from model_config.model_registry import MODEL_REGISTRY
import os, random
from utils.plotting_helpers import ensure_dir, plot_epoch_1d, plot_epoch_2d
import itertools
import json

In [18]:
MODEL_NAME = "feat4_encoder32_deep_qcd"
model_config = MODEL_REGISTRY[MODEL_NAME]

DATA_PATH = json.load(open("data/dataset_config_small.json"))[model_config["process"]]["path"]
INPUT_DIM = model_config["input_dim"]
SAVEDIR = model_config["savedir"]+"_sinkhorn"
CHECKPOINT_PATH = f"{SAVEDIR}/wnae_checkpoint_{INPUT_DIM}.pth"
PLOT_DIR = f"{SAVEDIR}/plots/"
BATCH_SIZE = 4098
NUM_SAMPLES = 2 ** 16
LEARNING_RATE = 1e-3
N_EPOCHS = 100

#For plotting
PLOT_DISTRIBUTIONS = True
PLOT_EPOCHS  = [10,100,200]  # Final epoch is always added automatically
BINS         = np.linspace(-5.0, 5.0, 101)
N_1D_SAMPLES = 4   # how many random features to plot for non-final epochs
N_2D_SAMPLES = 4    # how many 2D scatter plots to print
RNG_SEED     = 0

WNAE_PARAMS = {
    "sampling": "pcd",
    "n_steps":20,
    "noise":0.05,
    "step_size":None,
    "temperature": 1.0,
    "bounds": (-4.,4.),
    "mh": False,
    "initial_distribution": "gaussian",
    "replay": True,
    "replay_ratio": 0.95,
    "distance":"sinkhorn"
}
DEVICE = torch.device("cuda")

In [19]:

def run_training(model, optimizer, loss_function, n_epochs, training_loader, validation_loader,checkpoint_path=None,save_every=20):

    start_epoch = 0
    training_losses, validation_losses = [], []
    batch_pos_energies, batch_neg_energies = [], []

    if checkpoint_path and os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        ckpt = torch.load(checkpoint_path, map_location=DEVICE)
        model.load_state_dict(ckpt["model_state_dict"])
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        start_epoch = ckpt.get("epoch", 0)
        training_losses = ckpt.get("training_losses", [])
        validation_losses = ckpt.get("validation_losses", [])
        batch_pos_energies = ckpt.get("batch_pos_energies", [])
        batch_neg_energies = ckpt.get("batch_neg_energies", [])
        if "buffer" in ckpt:
            print("Loading replay buffer from checkpoint")
            if model.buffer.max_samples!=len(ckpt["buffer"]):
                print(f'WARNING: stored buffer len ({len(ckpt["buffer"])}) different from declared buffer size {model.buffer.max_samples}')
                model.buffer.buffer = ckpt["buffer"][:model.buffer.max_samples]
            else:
                model.buffer.buffer = ckpt["buffer"]

    global PLOT_EPOCHS
    PLOT_EPOCHS = sorted(set(PLOT_EPOCHS + [start_epoch + n_epochs]))#Add the last epoch to the list for plotting
    for i_epoch in range(start_epoch, start_epoch + n_epochs):
        model.train()
        training_loss = 0
        n_batches = 0
        epoch_pos_energy = 0
        epoch_neg_energy = 0

        bar_format = f"Epoch {i_epoch+1}/{start_epoch + n_epochs}: " \
                     + "{l_bar}{bar:10}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]"

        for batch in tqdm(training_loader, bar_format=bar_format):
            x = batch[0].to(DEVICE, non_blocking=True)
            optimizer.zero_grad()

            if loss_function == "wnae":
                loss, train_dict = model.train_step(x)
            elif loss_function == "nae":
                loss, train_dict = model.train_step_nae(x)
            elif loss_function == "ae":
                loss, train_dict = model.train_step_ae(x, run_mcmc=True, mcmc_replay=True)

            loss.backward()
            optimizer.step()

            pos_e = train_dict.get("positive_energy", None)
            neg_e = train_dict.get("negative_energy", None)
            batch_pos_energies.append(pos_e)
            batch_neg_energies.append(neg_e)
            epoch_pos_energy = epoch_pos_energy + pos_e
            epoch_neg_energy = epoch_neg_energy + neg_e

            #print(f"E+: {train_dict['positive_energy']:.2f}, E-: {train_dict['negative_energy']:.2f}")

            training_loss += train_dict["loss"]
            n_batches += 1

        avg_pos_energy = epoch_pos_energy / n_batches
        avg_neg_energy = epoch_neg_energy / n_batches
        training_losses.append(training_loss / n_batches)

        # Validation
        model.eval()
        validation_loss = 0
        n_batches = 0
        #with torch.no_grad():
        for batch in validation_loader:
            x = batch[0].to(DEVICE, non_blocking=True)
    
            if loss_function == "wnae":
                val_dict = model.validation_step(x)
            elif loss_function == "nae":
                val_dict = model.validation_step_nae(x)
            elif loss_function == "ae":
                val_dict = model.validation_step_ae(x, run_mcmc=True)
            validation_loss += val_dict["loss"]
    
            if(n_batches==0 and PLOT_DISTRIBUTIONS==True and (i_epoch+1 in PLOT_EPOCHS)):
                #Plotting features, positive and negative samples, only for first batch
                mcmc = val_dict["mcmc_data"]["samples"][-1].detach().cpu().numpy()
                data = x.detach().cpu().numpy()
                ep_dir = ensure_dir(os.path.join(PLOT_DIR, f"epoch_{i_epoch+1}"))
    
                nfeat = data.shape[1]
                if ((i_epoch +1) == (start_epoch + n_epochs)):
                    features = range(nfeat)  #plot all features at final epoch
                else:
                    features = random.Random(RNG_SEED).sample(range(nfeat), N_1D_SAMPLES)
                
                pairs = random.Random(RNG_SEED).sample(list(itertools.combinations(range(nfeat), 2)),N_2D_SAMPLES)#N_2D_SAMPLES pairs of features to plot
                # 1D fixed-binning histograms
                plot_epoch_1d(data, mcmc, ep_dir, i_epoch+1, features, BINS)
                # a couple of 2D scatters for shape sanity
                plot_epoch_2d(data, mcmc, ep_dir, i_epoch+1, pairs)
            
            n_batches += 1

        validation_losses.append(validation_loss / n_batches)

        print(f"Epoch {i_epoch+1}/{start_epoch + n_epochs} | "
      f"Train Loss: {training_losses[-1]:.4f} | "
      f"Val Loss: {validation_losses[-1]:.4f} | "
      f"Avg E+: {avg_pos_energy:.2f} | Avg E-: {avg_neg_energy:.2f}")

        save_epoch = i_epoch%save_every==0 or (start_epoch + n_epochs -1 == i_epoch)
        if checkpoint_path and save_epoch:
            torch.save({
                "epoch": i_epoch + 1,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "training_losses": training_losses,
                "validation_losses": validation_losses,
                "batch_pos_energies": batch_pos_energies,
                "batch_neg_energies": batch_neg_energies,
                "buffer": model.buffer.buffer
            }, checkpoint_path)

    return training_losses, validation_losses

In [20]:

def plot_losses(training_losses, validation_losses, save_dir):
    epochs = list(range(len(training_losses)))
    ensure_dir(save_dir)
    save_path = os.path.join(save_dir, "training_loss_plot.png")
    plt.figure()
    plt.plot(epochs, training_losses, label="Training", color="red", linewidth=2)
    plt.plot(epochs, validation_losses, label="Validation", color="blue", linestyle="dashed", linewidth=2)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.yscale("log")
    plt.legend()
    plt.savefig(save_path)
    plt.close()

In [21]:
def main():
    dataset = JetDataset(DATA_PATH)
    ensure_dir(SAVEDIR)

    # Split
    indices = np.arange(len(dataset))
    np.random.seed(0)
    np.random.shuffle(indices)
    split = int(0.8 * len(indices))
    train_idx, val_idx = indices[:split], indices[split:]

    train_dataset = JetDataset(DATA_PATH, indices=train_idx, input_dim=INPUT_DIM)
    val_dataset = JetDataset(DATA_PATH, indices=val_idx, input_dim=INPUT_DIM)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=RandomSampler(train_dataset, replacement=True, num_samples=NUM_SAMPLES),pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, sampler=RandomSampler(val_dataset, replacement=True, num_samples=NUM_SAMPLES),pin_memory=True)

    model = WNAE(
        encoder=model_config["encoder"](),
        decoder=model_config["decoder"](),
        **WNAE_PARAMS
    ).to(DEVICE)

    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    training_losses, validation_losses = run_training(model=model,optimizer=optimizer,loss_function="wnae",n_epochs=N_EPOCHS,training_loader=train_loader,validation_loader=val_loader,checkpoint_path=CHECKPOINT_PATH)

    plot_losses(training_losses, validation_losses, PLOT_DIR)

In [None]:
main()

Epoch 1/100: 100%|██████████| 16/16 [00:16<00:00]


Epoch 1/100 | Train Loss: 0.1087 | Val Loss: 0.1261 | Avg E+: 1.02 | Avg E-: 1.14


Epoch 2/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 2/100 | Train Loss: 0.1522 | Val Loss: 0.1547 | Avg E+: 1.02 | Avg E-: 1.23


Epoch 3/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 3/100 | Train Loss: 0.1399 | Val Loss: 0.1280 | Avg E+: 1.01 | Avg E-: 1.25


Epoch 4/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 4/100 | Train Loss: 0.1380 | Val Loss: 0.1437 | Avg E+: 1.03 | Avg E-: 1.27


Epoch 5/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 5/100 | Train Loss: 0.1520 | Val Loss: 0.1475 | Avg E+: 1.81 | Avg E-: 2.12


Epoch 6/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 6/100 | Train Loss: 0.1644 | Val Loss: 0.1599 | Avg E+: 2.33 | Avg E-: 2.50


Epoch 7/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 7/100 | Train Loss: 0.1670 | Val Loss: 0.1585 | Avg E+: 1.52 | Avg E-: 1.72


Epoch 8/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 8/100 | Train Loss: 0.1588 | Val Loss: 0.1512 | Avg E+: 1.31 | Avg E-: 1.54


Epoch 9/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 9/100 | Train Loss: 0.1527 | Val Loss: 0.1561 | Avg E+: 1.50 | Avg E-: 1.71


Epoch 10/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 10/100 | Train Loss: 0.1621 | Val Loss: 0.1664 | Avg E+: 1.39 | Avg E-: 1.63


Epoch 11/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 11/100 | Train Loss: 0.1454 | Val Loss: 0.1359 | Avg E+: 1.99 | Avg E-: 2.26


Epoch 12/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 12/100 | Train Loss: 0.1429 | Val Loss: 0.1278 | Avg E+: 2.18 | Avg E-: 2.32


Epoch 13/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 13/100 | Train Loss: 0.1426 | Val Loss: 0.1431 | Avg E+: 1.08 | Avg E-: 1.31


Epoch 14/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 14/100 | Train Loss: 0.1499 | Val Loss: 0.1361 | Avg E+: 1.23 | Avg E-: 1.53


Epoch 15/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 15/100 | Train Loss: 0.1503 | Val Loss: 0.1442 | Avg E+: 1.95 | Avg E-: 2.23


Epoch 16/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 16/100 | Train Loss: 0.1325 | Val Loss: 0.1416 | Avg E+: 1.34 | Avg E-: 1.57


Epoch 17/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 17/100 | Train Loss: 0.1447 | Val Loss: 0.1462 | Avg E+: 1.39 | Avg E-: 1.63


Epoch 18/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 18/100 | Train Loss: 0.1407 | Val Loss: 0.1484 | Avg E+: 1.28 | Avg E-: 1.60


Epoch 19/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 19/100 | Train Loss: 0.1541 | Val Loss: 0.1498 | Avg E+: 1.89 | Avg E-: 2.15


Epoch 20/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 20/100 | Train Loss: 0.1506 | Val Loss: 0.1540 | Avg E+: 1.42 | Avg E-: 1.67


Epoch 21/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 21/100 | Train Loss: 0.1614 | Val Loss: 0.1382 | Avg E+: 1.49 | Avg E-: 1.81


Epoch 22/100: 100%|██████████| 16/16 [00:16<00:00]


Epoch 22/100 | Train Loss: 0.1340 | Val Loss: 0.1286 | Avg E+: 2.11 | Avg E-: 2.29


Epoch 23/100: 100%|██████████| 16/16 [00:16<00:00]


Epoch 23/100 | Train Loss: 0.1367 | Val Loss: 0.1259 | Avg E+: 1.50 | Avg E-: 1.82


Epoch 24/100: 100%|██████████| 16/16 [00:16<00:00]


Epoch 24/100 | Train Loss: 0.1351 | Val Loss: 0.1271 | Avg E+: 2.01 | Avg E-: 2.24


Epoch 25/100: 100%|██████████| 16/16 [00:16<00:00]


Epoch 25/100 | Train Loss: 0.1382 | Val Loss: 0.1398 | Avg E+: 1.17 | Avg E-: 1.42


Epoch 26/100: 100%|██████████| 16/16 [00:16<00:00]


Epoch 26/100 | Train Loss: 0.1378 | Val Loss: 0.1303 | Avg E+: 1.64 | Avg E-: 1.89


Epoch 27/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 27/100 | Train Loss: 0.1327 | Val Loss: 0.1161 | Avg E+: 2.62 | Avg E-: 2.75


Epoch 28/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 28/100 | Train Loss: 0.1307 | Val Loss: 0.1261 | Avg E+: 1.19 | Avg E-: 1.41


Epoch 29/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 29/100 | Train Loss: 0.1285 | Val Loss: 0.1315 | Avg E+: 1.71 | Avg E-: 1.99


Epoch 30/100: 100%|██████████| 16/16 [00:16<00:00]


Epoch 30/100 | Train Loss: 0.1234 | Val Loss: 0.1210 | Avg E+: 3.92 | Avg E-: 4.05


Epoch 31/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 31/100 | Train Loss: 0.1201 | Val Loss: 0.1099 | Avg E+: 1.89 | Avg E-: 2.03


Epoch 32/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 32/100 | Train Loss: 0.1187 | Val Loss: 0.1253 | Avg E+: 1.72 | Avg E-: 1.95


Epoch 33/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 33/100 | Train Loss: 0.1322 | Val Loss: 0.1323 | Avg E+: 2.17 | Avg E-: 2.31


Epoch 34/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 34/100 | Train Loss: 0.1434 | Val Loss: 0.1479 | Avg E+: 1.62 | Avg E-: 1.91


Epoch 35/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 35/100 | Train Loss: 0.1352 | Val Loss: 0.1233 | Avg E+: 4.77 | Avg E-: 4.98


Epoch 36/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 36/100 | Train Loss: 0.1223 | Val Loss: 0.1184 | Avg E+: 3.41 | Avg E-: 3.41


Epoch 37/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 37/100 | Train Loss: 0.1290 | Val Loss: 0.1370 | Avg E+: 1.88 | Avg E-: 2.05


Epoch 38/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 38/100 | Train Loss: 0.1325 | Val Loss: 0.1323 | Avg E+: 2.11 | Avg E-: 2.35


Epoch 39/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 39/100 | Train Loss: 0.1341 | Val Loss: 0.1431 | Avg E+: 3.05 | Avg E-: 2.92


Epoch 40/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 40/100 | Train Loss: 0.1504 | Val Loss: 0.1346 | Avg E+: 2.28 | Avg E-: 2.18


Epoch 41/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 41/100 | Train Loss: 0.1366 | Val Loss: 0.1445 | Avg E+: 1.15 | Avg E-: 1.39


Epoch 42/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 42/100 | Train Loss: 0.1436 | Val Loss: 0.1424 | Avg E+: 1.23 | Avg E-: 1.53


Epoch 43/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 43/100 | Train Loss: 0.1543 | Val Loss: 0.1573 | Avg E+: 1.39 | Avg E-: 1.61


Epoch 44/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 44/100 | Train Loss: 0.1511 | Val Loss: 0.1293 | Avg E+: 1.13 | Avg E-: 1.33


Epoch 45/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 45/100 | Train Loss: 0.1322 | Val Loss: 0.1374 | Avg E+: 1.16 | Avg E-: 1.40


Epoch 46/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 46/100 | Train Loss: 0.1766 | Val Loss: 0.1883 | Avg E+: 2.70 | Avg E-: 2.74


Epoch 47/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 47/100 | Train Loss: 0.1798 | Val Loss: 0.1778 | Avg E+: 1.27 | Avg E-: 1.44


Epoch 48/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 48/100 | Train Loss: 0.1764 | Val Loss: 0.1743 | Avg E+: 1.47 | Avg E-: 1.65


Epoch 49/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 49/100 | Train Loss: 0.1590 | Val Loss: 0.1518 | Avg E+: 2.80 | Avg E-: 2.98


Epoch 50/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 50/100 | Train Loss: 0.1355 | Val Loss: 0.1181 | Avg E+: 2.58 | Avg E-: 2.75


Epoch 51/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 51/100 | Train Loss: 0.1330 | Val Loss: 0.1197 | Avg E+: 1.50 | Avg E-: 1.79


Epoch 52/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 52/100 | Train Loss: 0.1158 | Val Loss: 0.1268 | Avg E+: 1.84 | Avg E-: 2.17


Epoch 53/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 53/100 | Train Loss: 0.1396 | Val Loss: 0.1215 | Avg E+: 3.68 | Avg E-: 4.04


Epoch 54/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 54/100 | Train Loss: 0.1099 | Val Loss: 0.1094 | Avg E+: 3.36 | Avg E-: 3.54


Epoch 55/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 55/100 | Train Loss: 0.1170 | Val Loss: 0.1168 | Avg E+: 2.37 | Avg E-: 2.58


Epoch 56/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 56/100 | Train Loss: 0.1353 | Val Loss: 0.1346 | Avg E+: 4.25 | Avg E-: 4.46


Epoch 57/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 57/100 | Train Loss: 0.1426 | Val Loss: 0.1523 | Avg E+: 3.85 | Avg E-: 3.92


Epoch 58/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 58/100 | Train Loss: 0.1596 | Val Loss: 0.1553 | Avg E+: 2.33 | Avg E-: 2.58


Epoch 59/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 59/100 | Train Loss: 0.1503 | Val Loss: 0.1367 | Avg E+: 3.05 | Avg E-: 3.26


Epoch 60/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 60/100 | Train Loss: 0.1422 | Val Loss: 0.1366 | Avg E+: 1.19 | Avg E-: 1.50


Epoch 61/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 61/100 | Train Loss: 0.1355 | Val Loss: 0.1324 | Avg E+: 1.36 | Avg E-: 1.69


Epoch 62/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 62/100 | Train Loss: 0.1369 | Val Loss: 0.1342 | Avg E+: 3.07 | Avg E-: 3.27


Epoch 63/100: 100%|██████████| 16/16 [00:15<00:00]


Epoch 63/100 | Train Loss: 0.1393 | Val Loss: 0.1205 | Avg E+: 2.68 | Avg E-: 2.68


Epoch 64/100:  25%|██▌       | 4/16 [00:03<00:11]

In [1]:
import os
import json
import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from torch.utils.data import DataLoader, SequentialSampler
from matplotlib.transforms import Bbox
from utils.jet_dataset import JetDataset
from wnae import WNAE
from model_config.model_registry import MODEL_REGISTRY

import os
import matplotlib.pyplot as plt
import numpy as np

In [9]:

def compute_mse(dataloader):
    mses = []
    for batch in dataloader:
        x = batch[0].to(DEVICE, non_blocking=True)
        recon_x = model.decoder(model.encoder(x))
        per_sample_mse = torch.mean((x - recon_x) ** 2, dim=1)
        mses.extend(per_sample_mse.detach().cpu().numpy())
    return np.array(mses)

def plot_checkpoint_energies(checkpoint, plot_dir="plots"):
    """
    Plot positive and negative energies per batch from a loaded checkpoint.

    Args:
        checkpoint: dictionary returned from torch.load(checkpoint_path)
        plot_dir: directory to save plots
    """
    os.makedirs(plot_dir, exist_ok=True)
    
    pos_energies = checkpoint.get("batch_pos_energies", None)
    neg_energies = checkpoint.get("batch_neg_energies", None)
    
    if pos_energies is None or neg_energies is None:
        print("[WARNING] Positive/Negative energies not found in checkpoint.")
        print("Checkpoint keys:", checkpoint.keys())
        return
    
    pos_energies = np.array(pos_energies)
    neg_energies = np.array(neg_energies)
    
    # Plot energies per batch
    plt.figure(figsize=(8,5))
    plt.plot(pos_energies, label="Positive Energy")
    plt.plot(neg_energies, label="Negative Energy")
    plt.xlabel("Batch number")
    plt.ylabel("Energy")
    plt.legend(frameon=False)
    #plt.grid(True)
    plt.yscale('log')
    plt.tight_layout()
    plot_path = os.path.join(plot_dir, "energies_per_batch.png")
    plt.savefig(plot_path)
    plt.savefig(plot_path.replace(".png",".pdf"))
    plt.close()
    
    print(f"[INFO] Energy plot saved to: {plot_path}")

def load_dataset(file_path, key="Jets", max_jets=10000, pt_cut=None):
    tmp_ds = JetDataset(file_path, input_dim=INPUT_DIM, key=key, pt_cut=pt_cut)
    # Sample random indices from the already cut dataset
    if len(tmp_ds) > max_jets:
        sampled = np.random.choice(tmp_ds.indices, size=max_jets, replace=False)
        tmp_ds.indices = sampled
    return tmp_ds

def plot_eff_vs_pt(bkg_mses, sig_mses_dict, bkg_dataset, signal_loaders, wp=0.1, savedir="plots"):
    """
    Plot efficiency vs jet pT for a fixed working point defined by a background mistag rate.
    
    Args:
        bkg_mses (np.ndarray): background MSE scores
        sig_mses_dict (dict): {name: np.ndarray} of signal MSE scores
        bkg_dataset (JetDataset): background dataset (provides pT)
        signal_loaders (dict): {name: DataLoader} for signals (to get dataset pT)
        wp (float): background mistag working point (e.g. 0.1 for 10%)
        savedir (str): where to save plot
    """
    # threshold from background: WP corresponds to (1 - wp) quantile
    threshold = np.percentile(bkg_mses, 100 * (1 - wp))

    bins_pt = np.linspace(150, 800, 50)  # adjust as needed
    bin_centers = 0.5 * (bins_pt[:-1] + bins_pt[1:])

    # background efficiency per pT bin
    bkg_pts = bkg_dataset.get_pt()
    bkg_eff_pt = []
    for i in range(len(bins_pt) - 1):
        mask = (bkg_pts >= bins_pt[i]) & (bkg_pts < bins_pt[i+1])
        if np.sum(mask) > 0:
            eff = np.mean(bkg_mses[mask] > threshold)
        else:
            eff = np.nan
        bkg_eff_pt.append(eff)

    # signal efficiencies per pT bin
    sig_eff_pt_dict = {}
    for name, sig_mses in sig_mses_dict.items():
        sig_pts = signal_loaders[name].dataset.get_pt()
        sig_eff = []
        for i in range(len(bins_pt) - 1):
            mask = (sig_pts >= bins_pt[i]) & (sig_pts < bins_pt[i+1])
            if np.sum(mask) > 0:
                eff = np.mean(sig_mses[mask] > threshold)
            else:
                eff = np.nan
            sig_eff.append(eff)
        sig_eff_pt_dict[name] = sig_eff

    # plot
    fig, ax = plt.subplots(figsize=(6, 4))
    ax.plot(bin_centers, bkg_eff_pt, label=f"{BKG_NAME} mistag (WP={wp*100:.0f}%)", linestyle="--")
    for name, eff in sig_eff_pt_dict.items():
        ax.plot(bin_centers, eff, label=f"{name}")
    ax.set_xlabel("Jet $p_T$ [GeV]")
    ax.set_ylabel("Efficiency")
    ax.set_ylim(0, 1.3)
    ax.legend(ncol=2)
    ax.grid(alpha=0.3)
    plt.tight_layout()
    savefig = f"{savedir}/plots/eff_vs_pt_wp_{wp}.png"
    plt.savefig(savefig, dpi=200)
    plt.close()
    print(f"Saved {savefig}")

def plot_sample_vs_reconstruction(model, bkg_loader, savedir, device=torch.device("cpu")):
    """
    Plot one MCMC jet and one validation jet with original and reconstructed features.

    Args:
        model: trained WNAE model (should be in eval mode)
        bkg_loader: DataLoader for background/validation jets
        savedir: directory to save the plot
    """
    # Get one validation jet
    val_batch = next(iter(bkg_loader))
    val_jet = val_batch[0][0:1].to(device)  # first jet in batch

    val_energy, val_z, val_reco = model._WNAE__energy_with_samples(val_jet)

    # Get one MCMC jet
    if len(model.buffer.buffer) == 0:
        raise ValueError("MCMC buffer is empty!")
    mcmc_jet = model.buffer.buffer[0].unsqueeze(0).to(device)  # first mcmc jet in batch
    mcmc_energy, mcmc_z, mcmc_reco = model._WNAE__energy_with_samples(mcmc_jet)

    n_features = val_jet.shape[1]
    features = range(n_features)

    plt.figure(figsize=(10,5))

    plt.plot(features, val_jet[0].cpu().numpy(), 'o-', color='C0', label='Val jet')
    plt.plot(features, val_reco[0].detach().cpu().numpy(), 's--', color='C0', label='Val jet reco.')

    plt.plot(features, mcmc_jet[0].cpu().numpy(), 'o-', color='C1', label='MCMC jet')
    plt.plot(features, mcmc_reco[0].detach().cpu().numpy(), 's--', color='C1', label='MCMC jet reco.')

    plt.xlabel("Feature index")
    plt.ylabel("Feature value")

    plt.text(0.95, 0.95, f"Val energy: {val_energy.item():.1f}", transform=plt.gca().transAxes,
            horizontalalignment='right', verticalalignment='top', color='C0')
    plt.text(0.95, 0.90, f"MCMC energy: {mcmc_energy.item():.1f}", transform=plt.gca().transAxes,
            horizontalalignment='right', verticalalignment='top', color='C1')

    plt.legend()
    plt.tight_layout()
    save_path = f"{savedir}/sample_reconstruction.png"
    plt.savefig(save_path)
    plt.close()
    plt.cla()
    plt.clf()
    print(f"Saved {save_path}")

    #Latent plot
    plt.figure(figsize=(10,5))

    n_latent = val_z.shape[1]
    latent_idx = range(n_latent)

    plt.plot(latent_idx, val_z[0].detach().cpu().numpy(), 'o-', color='C0', label='Val jet z')
    plt.plot(latent_idx, mcmc_z[0].detach().cpu().numpy(), 'o-', color='C1', label='MCMC jet z')

    plt.xlabel("Latent dimension index")
    plt.ylabel("Latent value")

    plt.legend()
    plt.tight_layout()
    save_path_z = f"{savedir}/sample_latent.png"
    plt.savefig(save_path_z)
    plt.close()
    print(f"Saved {save_path_z}")

def plot_energy_distributions(model, bkg_loader, n_samples=10000, savedir="plots", device=torch.device("cpu")):
    """
    Plot distributions of positive (data) and negative (MCMC) reconstruction energies
    on separate subplots.
    """
    model.eval()
    E_pos_list = []
    E_neg_list = []

    # Collect positive (data) energies in batches
    count = 0
    with torch.no_grad():
        for batch in bkg_loader:
            jets = batch[0].to(device)
            energies, _, _ = model._WNAE__energy_with_samples(jets)  # batch processing
            E_pos_list.extend(energies.cpu().numpy())
            count += len(jets)
            if count >= n_samples:
                E_pos_list = E_pos_list[:n_samples]
                break

    # Collect negative (MCMC) energies
    if len(model.buffer.buffer) == 0:
        raise ValueError("MCMC buffer is empty!")
    mcmc_jets = torch.stack(model.buffer.buffer[:min(n_samples, len(model.buffer.buffer))]).to(device)
    with torch.no_grad():
        energies, _, _ = model._WNAE__energy_with_samples(mcmc_jets)
        E_neg_list = energies.cpu().numpy()

    # Compute 99th percentile for x-limits
    x_pos_max = np.percentile(E_pos_list, 99)
    x_neg_max = np.percentile(E_neg_list, 99)

    # Plot distributions in two subplots
    fig, axs = plt.subplots(1, 2, figsize=(12,5), sharey=True)
    axs[0].hist(E_pos_list, bins=np.linspace(0, x_pos_max, 50), histtype='step', color='C0')
    axs[0].set_title("E+ (data)")
    axs[0].set_xlabel("Reconstruction energy")
    axs[0].set_ylabel("Counts")
    axs[0].set_xlim(0, x_pos_max)

    axs[1].hist(E_neg_list, bins=np.linspace(0, x_neg_max, 50), histtype='step', color='C1')
    axs[1].set_title("E- (MCMC)")
    axs[1].set_xlabel("Reconstruction energy")
    axs[1].set_xlim(0, x_neg_max)

    plt.tight_layout()
    os.makedirs(savedir, exist_ok=True)
    save_path = f"{savedir}/energy_distributions.png"
    plt.savefig(save_path)
    plt.close()
    print(f"Saved {save_path}")


In [12]:
MAX_JETS = 20000
PT_CUT = None
BKG_NAME = model_config["process"]
CONFIG_PATH = "data/dataset_config_small.json"

os.makedirs(f"{SAVEDIR}/plots", exist_ok=True)

with open(CONFIG_PATH, "r") as f:
    config = json.load(f)

bkg_path = config[BKG_NAME]["path"]
bkg_dataset = load_dataset(bkg_path, max_jets=MAX_JETS,pt_cut=PT_CUT)
bkg_loader = DataLoader(bkg_dataset, batch_size=BATCH_SIZE, sampler=SequentialSampler(bkg_dataset))

signal_loaders = {}
for name, sample in config.items():
    if name==BKG_NAME:
        continue
    sig_dataset = load_dataset(sample["path"], max_jets=MAX_JETS,pt_cut=PT_CUT)
    signal_loaders[name] = DataLoader(sig_dataset, batch_size=BATCH_SIZE, sampler=SequentialSampler(sig_dataset))

model = WNAE(encoder=model_config["encoder"](),decoder=model_config["decoder"](),**WNAE_PARAMS)
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=DEVICE)["model_state_dict"])
model.to(DEVICE)
model.eval()
if "buffer" in checkpoint:
    print("Loading replay buffer from checkpoint")
    if model.buffer.max_samples!=len(checkpoint["buffer"]):
        print(f'WARNING: stored buffer len ({len(checkpoint["buffer"])}) different from declared buffer size {model.buffer.max_samples}')
        model.buffer.buffer = checkpoint["buffer"][:model.buffer.max_samples]
    else:
        model.buffer.buffer = checkpoint["buffer"]
print(f"Device is {DEVICE}")
plot_energy_distributions(model, bkg_loader, savedir=f"{SAVEDIR}/plots", device=DEVICE)
plot_sample_vs_reconstruction(model, bkg_loader, savedir=f"{SAVEDIR}/plots", device=DEVICE)
plot_checkpoint_energies(checkpoint, plot_dir=f"{SAVEDIR}/plots/")

print("[INFO] Computing background mse...")
bkg_mses = compute_mse(bkg_loader)

sig_mses_dict = {}
for name, loader in signal_loaders.items():
    print(f"[INFO] Computing mse for signal: {name}")
    sig_mses_dict[name] = compute_mse(loader)
# --- Combined figure: mse, ROC, mass, pt ---
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# --- 1) Mse distributions ---
ax_mse = axes[0, 0]

all_mses = np.concatenate([bkg_mses] + list(sig_mses_dict.values()))
_, x_max = np.percentile(all_mses, [0, 99.])


bins_mse = np.linspace(0, x_max, 101)
ax_mse.hist(bkg_mses, bins=bins_mse, histtype='step', label=BKG_NAME, density=True)
for name, mses in sig_mses_dict.items():
    ax_mse.hist(mses, bins=bins_mse, histtype='step', label=name, density=True)
ax_mse.set_xlabel("Reconstruction MSE")
ax_mse.set_ylabel("Density")
ax_mse.set_xlim([0, x_max])
ax_mse.legend()

# --- 2) ROC curves ---
ax_roc = axes[0, 1]
all_labels = np.zeros_like(bkg_mses)
for name, sig_mses in sig_mses_dict.items():
    labels = np.concatenate([all_labels, np.ones_like(sig_mses)])
    scores = np.concatenate([bkg_mses, sig_mses])
    fpr, tpr, _ = roc_curve(labels, scores)
    roc_auc = auc(fpr, tpr)
    ax_roc.plot(fpr, tpr, label=f"{name} (AUC = {roc_auc:.3f})")
ax_roc.plot([0, 1], [0, 1], color="navy", lw=1, linestyle="--")
ax_roc.set_xlabel("Background mistag rate")
ax_roc.set_ylabel("Signal efficiency")
ax_roc.legend(loc="lower right")
ax_roc.grid(True, alpha=0.3)

# --- 3) Jet mass distributions ---
ax_mass = axes[1, 0]
bins_mass = np.linspace(0, 200, 101)
ax_mass.hist(bkg_dataset.get_mass(), bins=bins_mass, histtype='step', density=True, label=BKG_NAME)
for name, loader in signal_loaders.items():
    sig_ds = loader.dataset
    ax_mass.hist(sig_ds.get_mass(), bins=bins_mass, histtype='step', density=True, label=name)
ax_mass.set_xlabel("Jet mass [GeV]")
ax_mass.set_ylabel("Density")
ax_mass.legend()

# --- 4) Jet pt distributions ---
ax_pt = axes[1, 1]
bins_pt = np.linspace(150, 800, 65)
ax_pt.hist(bkg_dataset.get_pt(), bins=bins_pt, histtype='step', density=True, label=BKG_NAME)
for name, loader in signal_loaders.items():
    sig_ds = loader.dataset
    ax_pt.hist(sig_ds.get_pt(), bins=bins_pt, histtype='step', density=True, label=name)
ax_pt.set_xlabel("Jet $p_T$ [GeV]")
ax_pt.set_ylabel("Density")
ax_pt.set_ylim(1e-4, 3*1e-2)
ax_pt.set_yscale("log")
ax_pt.legend()

plt.tight_layout()
savefig = f"{SAVEDIR}/plots/summary.png"
plt.savefig(savefig, dpi=200)
plt.close()
print(f"Saved {savefig}")


# --- Save each subplot individually ---
#Thank you SO: https://stackoverflow.com/questions/4325733/save-a-subplot-in-matplotlib
individual_plots = {
    "mse": ax_mse,
    "roc": ax_roc,
    "mass": ax_mass,
    "pt": ax_pt,
}

expand_left_frac = 0.12 
expand_right_frac = 0.05
expand_bottom_frac = 0.11
expand_top_frac = 0.01


for name, ax in individual_plots.items():
    fig = ax.figure
    extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
    # Slightly expand to avoid clipping labels, legends, ticks
    width = extent.width
    height = extent.height
    

    new_extent = Bbox.from_bounds(
        extent.x0 - width * expand_left_frac,
        extent.y0 - height * expand_bottom_frac,
        width + width * (expand_left_frac + expand_right_frac),
        height + height * (expand_bottom_frac + expand_top_frac)
    )
    # Save
    savefig = f"{SAVEDIR}/plots/{name}.png"
    fig.savefig(savefig, dpi=200, bbox_inches=new_extent)
    print(f"Saved {savefig}")


plot_eff_vs_pt(bkg_mses, sig_mses_dict, bkg_dataset, signal_loaders, wp=0.1, savedir=SAVEDIR)
print("[INFO] Evaluation complete.")


Loading replay buffer from checkpoint
Device is cuda
Saved models/feat16_encoder32_deep_qcd_sinkhorn/plots/energy_distributions.png
Saved models/feat16_encoder32_deep_qcd_sinkhorn/plots/sample_reconstruction.png
Saved models/feat16_encoder32_deep_qcd_sinkhorn/plots/sample_latent.png
[INFO] Energy plot saved to: models/feat16_encoder32_deep_qcd_sinkhorn/plots/energies_per_batch.png
[INFO] Computing background mse...
[INFO] Computing mse for signal: Top_bqq
Saved models/feat16_encoder32_deep_qcd_sinkhorn/plots/summary.png
Saved models/feat16_encoder32_deep_qcd_sinkhorn/plots/mse.png
Saved models/feat16_encoder32_deep_qcd_sinkhorn/plots/roc.png
Saved models/feat16_encoder32_deep_qcd_sinkhorn/plots/mass.png
Saved models/feat16_encoder32_deep_qcd_sinkhorn/plots/pt.png
Saved models/feat16_encoder32_deep_qcd_sinkhorn/plots/eff_vs_pt_wp_0.1.png
[INFO] Evaluation complete.


<Figure size 640x480 with 0 Axes>