# Detecting anomalies from HEP data with Generative Adversarial Networks

(Someone smarter than me introduce the physics/objectives)

We'll train a generative adversarial network (GAN) to build a generator network that can simulate samples from the training distribution, while simultaneously building a discriminator network that learns to distinguish real training samples from simulated ones. This requires the discriminator to build a robust representation of the training distribution (provided you've trained a sufficiently good generator). Then when we want to detect new anomalous events, we ditch the generator and just run new samples through the discriminator. Samples that are unlikely under the training distribution (and which are therefore likely to be anomalies) will then be ranked more highly by the discriminator, and we can use its output as a detection statistic.

A good chunk of the GAN training code (and its sometimes obnoxious comments) comes from [this official Keras example](https://keras.io/guides/writing_a_training_loop_from_scratch/).

Begin by doing our imports and defining some hyperparameters.

In [1]:
import subprocess
from pathlib import Path

import h5py
import numpy as np
import tensorflow as tf
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from bokeh.plotting import figure
from bokeh.io import output_notebook, show
from bokeh.palettes import Dark2_8 as palette

output_notebook()


# Basic params
BATCH_SIZE = 8192
LATENT_DIM = 32  # dimension of the latent vector the generator samples
VALID_FRAC = 0.25  # fraction of background/signal datasets to use for validation
SEED = 101588  # random seed
NUM_TRAIN_SAMPLES = 5000000  # you know what it is

# Optimization params
GEN_UPDATES = 8  # number of gradient updates to perform to the
                 # generator in between updates to the discriminator
D_LR = 0.0003  # learning rate for the discriminator
G_LR = 0.0008  # learning rate for the generator
FPR_THRESH = 1e-5  # false positive rate to use during validation

# paths and what not
BACKGROUND_FNAME = Path("background.h5")
SIGNAL_FNAMES = {
    "A-4_leptons": "https://zenodo.org/record/7152590/files/Ato4l_lepFilter_13TeV_filtered.h5?download=1",
    "leptoquarks-b_tau": "https://zenodo.org/record/7152599/files/leptoquark_LOWMASS_lepFilter_13TeV_filtered.h5?download=1",
    "h_0-tau_tau": "https://zenodo.org/record/7152614/files/hToTauTau_13TeV_PU20_filtered.h5?download=1",
    "h_plus-tau_nu": "https://zenodo.org/record/7152617/files/hChToTauNu_13TeV_PU20_filtered.h5?download=1"
}

2023-07-13 18:04:27.197515: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Next we'll download the challenge datasets from Zenodo, caching them for later use. This will take 1-2GB worth of disk space.

In [2]:
def download_dataset(fname, url):
    subprocess.run(f"wget -O {fname} {url}", shell=True)
    
if not BACKGROUND_FNAME.exists():
    download_dataset(
        str(BACKGROUND_FNAME),
        "https://zenodo.org/record/5046428/files/background_for_training.h5?download=1"
    )

for signal, url in SIGNAL_FNAMES.items():
    fname = signal + ".h5"
    if not Path(fname).exists():
        download_dataset(fname, url)

Start by loading in the background data and seeing how many events we have and look at what our features represent.

In [3]:
def load_dataset(dataset, N):
    X = dataset[:N]

    # each dataset has 4 columns along its last dimension,
    # the last of which represents a mask indicating whether
    # that particle recorded any values for this sample.
    # Break out that colum separately to use as a separate
    # feature vector to the network.
    X, y = np.split(X, [3], axis=-1)
    return X, y[:, :, 0]


with h5py.File(BACKGROUND_FNAME, "r") as f:
    print("Features:", ", ".join([i.decode() for i in f["Particles_Names"][:]]))
    print("Event types:", ", ".join([i.decode() for i in f["Particles_Classes"][:]]))
    print("Total background events:", len(f["Particles"]))
    print(f"Loading {NUM_TRAIN_SAMPLES} events")
    X, masks = load_dataset(f["Particles"], NUM_TRAIN_SAMPLES)

_, num_events, num_features = X.shape
FEATURE_DIM = num_events * num_features
X = X.reshape(-1, FEATURE_DIM)

Features: Pt, Eta, Phi, Class
Event types: MET_class_1, Four_Ele_class_2, Four_Mu_class_3, Ten_Jet_class_4
Total background events: 13451915
Loading 5000000 events


In [4]:
(
    train_bg_events,
    valid_bg_events,
    train_bg_masks,
    valid_bg_masks
) = train_test_split(X, masks, test_size=VALID_FRAC, random_state=SEED)

In [5]:
scaler = StandardScaler()
scaler.fit(train_bg_events)

def preprocess(X, mask):
    X = scaler.transform(X)
    X = X.reshape(-1, num_events, num_features)
    X[mask == 0] *= 0
    X = X.reshape(-1, FEATURE_DIM)
    return X

train_bg_events = preprocess(train_bg_events, train_bg_masks)
valid_bg_events = preprocess(valid_bg_events, valid_bg_masks)

In [6]:
valid_events = [valid_bg_events]
valid_masks = [valid_bg_masks]
valid_y = [np.zeros((len(valid_bg_events), 1))]

signals = sorted(SIGNAL_FNAMES)
for i, signal in enumerate(signals):
    with h5py.File(signal + ".h5", "r") as f:
        dataset = f["Particles"]
        n = int(0.25 * len(dataset))
        print(f"Loading {n} events from signal {signal}")
        events, masks = load_dataset(dataset, n)

    events = events.reshape(-1, FEATURE_DIM)
    events = preprocess(events, masks)
    valid_events.append(events)
    valid_masks.append(masks)

    classes = np.ones((n, 1)) * (i + 1)
    valid_y.append(classes)

valid_y = np.concatenate(valid_y)
valid_events = np.concatenate(valid_events)
valid_masks = np.concatenate(valid_masks)

idx = np.random.permutation(len(valid_y))
valid_y = valid_y[idx]
valid_events = valid_events[idx]
valid_masks = valid_masks[idx]

valid_X = (valid_events, valid_masks)

Loading 13992 events from signal A-4_leptons
Loading 172820 events from signal h_0-tau_tau
Loading 190068 events from signal h_plus-tau_nu
Loading 85136 events from signal leptoquarks-b_tau


In [7]:
generator_input = tf.keras.Input((LATENT_DIM,))
x = tf.keras.layers.Dense(128, activation="relu")(generator_input)
x = tf.keras.layers.Dense(256, activation="relu")(x)
generator_output = tf.keras.layers.Dense(FEATURE_DIM, activation="linear")(x)
generator_mask = tf.keras.layers.Dense(num_events, activation="sigmoid")(x)
generator = tf.keras.Model(
    inputs=generator_input,
    outputs=[generator_output, generator_mask]
)

discriminator_input = tf.keras.Input((FEATURE_DIM,))
disc_x = tf.keras.layers.Dense(256, activation="relu")(discriminator_input)
disc_x = tf.keras.layers.Dense(128, activation="relu")(disc_x)
disc_x = tf.keras.layers.Dense(64, activation="relu")(disc_x)

discriminator_mask = tf.keras.Input((num_events,))
disc_mask = tf.keras.layers.Dense(256, activation="relu")(discriminator_mask)
disc_mask = tf.keras.layers.Dense(128, activation="relu")(disc_mask)
disc_mask = tf.keras.layers.Dense(64, activation="relu")(disc_mask)

disc_x = tf.keras.layers.Concatenate()([disc_x, disc_mask])
disc_x = tf.keras.layers.Dense(256, activation="relu")(disc_x)
disc_x = tf.keras.layers.Dense(512, activation="relu")(disc_x)
disc_x = tf.keras.layers.Dense(1, activation="linear")(disc_x)
discriminator = tf.keras.Model(
    inputs=[discriminator_input, discriminator_mask],
    outputs=disc_x
)

2023-07-13 18:05:13.804099: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1639] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 14605 MB memory:  -> device: 0, name: Tesla V100-SXM2-16GB, pci bus id: 0000:8a:00.0, compute capability: 7.0


Can't seem to get this custom metric to work, so using a custom training loop below. But someone smarter than I should figure this out so we can just use `Model.fit`

In [8]:
class TPR(tf.keras.metrics.Metric):
    def __init__(self, k, **kwargs):
        self.k = k
        super().__init__(**kwargs)

    def reset_state(self):
        # TODO: maybe these states need to be registered as variables?
        # Also not sure what happens with `reset_state` if we don't pass
        # the metric to `Model.compile`
        self.background_preds = tf.convert_to_tensor(())
        self.signal_preds = tf.convert_to_tensor(())

    def update_state(self, y_true, y_pred, sample_weight=None):
        background = y_pred[y_true == 0]
        self.background_preds = tf.concat([self.background_preds, background], axis=0)

        signal = y_pred[y_true == 1]
        self.signal_preds = tf.concat([self.signal_preds, signal], axis=0)

    def result(self):
        k = tf.shape(self.background_preds)[0] - self.k
        threshold = tf.sort(self.background_preds)[k]
        mask = self.signal_preds > threshold
        mask = tf.cast(mask, tf.int64)
        tpr = tf.math.reduce_mean(mask)
        return tpr

threshold_k = int(FPR_THRESH * len(valid_bg_events))

In [9]:
@tf.function
def fudge_mask(mask, noisy=True):
    mask = tf.where(mask == 0, -10., 10.)
    if noisy:
        eps = tf.random.normal(shape=(BATCH_SIZE, num_events))
        eps = eps[:tf.shape(mask)[0]]
        mask = mask + eps
    return tf.sigmoid(mask)


class GAN(tf.keras.Model):
    def __init__(self, discriminator, generator, latent_dim, gen_updates):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.gen_updates = gen_updates

        self.d_loss_tracker = tf.keras.metrics.Mean(name="d_loss")
        self.g_loss_tracker = tf.keras.metrics.Mean(name="g_loss")
        self.tpr = tf.keras.metrics.SensitivityAtSpecificity(1 - FPR_THRESH)

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def step_generator(self, batch_size):
        # Sample random points in the latent space
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Assemble labels that say "all real images"
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            predictions = discriminator(generator(random_latent_vectors))
            g_loss = loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, generator.trainable_weights)
        g_optimizer.apply_gradients(zip(grads, generator.trainable_weights))
        return g_loss

    def train_step(self, X):
        real_events, real_mask = X
        batch_size = tf.shape(real_events)[0]

        # train the generator for multiple steps
        # in between a single step of the discriminator
        g_loss = 0
        for i in range(self.gen_updates):
            g_loss += self.step_generator(batch_size)
        g_loss /= self.gen_updates

        # Sample random points in the latent space
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Decode them to fake events
        generated_events, generated_mask = generator(random_latent_vectors)

        # Combine them with real events
        combined_events = tf.concat([generated_events, real_events], axis=0)
        combined_masks = tf.concat([generated_mask, real_mask], axis=0)

        # Assemble labels discriminating real from fake events
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )

        # Add random noise to the labels - important trick!
        labels += 0.05 * tf.random.uniform((2 * batch_size, 1))

        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions = discriminator([combined_events, combined_masks])
            d_loss = loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, discriminator.trainable_weights)
        d_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights))

        self.d_loss_tracker.update_state(d_loss)
        self.g_loss_tracker.update_state(g_loss)
        return {
            "d_loss": self.d_loss_tracker.result(),
            "g_loss": self.g_loss_tracker.result(),
        }

    def test_step(self, data):
        # Unpack the data
        x, y = data
        events, masks = x
        masks = fudge_mask(masks, noisy=False)

        # Compute predictions
        y_pred = self.discriminator((events, masks), training=False)

        # Updates the metrics tracking the loss
        self.tpr.update_state(y, y_pred)
        return {"tpr": self.tpr.result()}

In [10]:
d_optimizer = tf.keras.optimizers.Adam(learning_rate=D_LR)
g_optimizer = tf.keras.optimizers.Adam(learning_rate=G_LR)
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=False)

gan = GAN(discriminator, generator, LATENT_DIM, GEN_UPDATES)
gan.compile(d_optimizer, g_optimizer, loss_fn)

In [11]:
events = tf.data.Dataset.from_tensor_slices(train_bg_events.astype("float32"))
masks = tf.data.Dataset.from_tensor_slices(train_bg_masks.astype("float32"))

dataset = tf.data.Dataset.zip((events, masks))
dataset = dataset.shuffle(buffer_size=1024).batch(BATCH_SIZE)
dataset = dataset.map(lambda e, m: (e, fudge_mask(m)))

losses, tprs = [], []
best_score = 0
for epoch in range(100):
    epoch_losses = [0, 0]
    with tqdm(dataset, desc=f"Epoch {epoch + 1}") as pbar:
        for step, x in enumerate(pbar):
            losses = gan.train_on_batch(*x)
            epoch_losses = [i + j for i, j in zip(losses, epoch_losses)]
            pbar.set_postfix(
                discriminator_loss=epoch_losses[0] / (step + 1) ,
                generator_loss=epoch_losses[1] / (step + 1)
            )
        
    epoch_losses = [i / len(dataset) for i in epoch_losses]
    print(
        "Discriminator Loss: {:0.3e}, Generator Loss {:0.3e}".format(
            *epoch_losses
        )
    )
    losses.append(epoch_losses)

    masks = fudge_mask(valid_masks, noisy=False)
    y_pred = discriminator((valid_events, masks), training=False).numpy()
    thresh = np.sort(y_pred[valid_y == 0])[-threshold_k]
    tpr = (y_pred[valid_y  > 0] >= thresh).mean()
    print(f"Valid TPR: {tpr:0.3e}")
    tprs.append(tpr)

    per_signal = []
    for i, signal in enumerate(signals):
        signal_tpr = (y_pred[valid_y == (i + 1)] >= thresh).mean() * 100
        per_signal.append(f"{signal}: {signal_tpr:0.3f}%")
    print("\t".join(per_signal))
        

    if tpr > best_score:
        best_score = tpr
        print("Achieved new best score! Saving weights")
        discriminator.save_weights("checkpoints/discriminator")
        generator.save_weights("checkpoints/generator")

Epoch 1:   0%|                                                                                               | 0/458 [00:00<?, ?it/s]2023-07-13 18:05:29.612507: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7f50b0d87640 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-07-13 18:05:29.612543: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Tesla V100-SXM2-16GB, Compute Capability 7.0
2023-07-13 18:05:29.620432: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:255] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2023-07-13 18:05:32.998260: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:432] Loaded cuDNN version 8600
2023-07-13 18:05:33.122702: I ./tensorflow/compiler/jit/device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
Epoch 1: 100%|███████████████████████████████

Discriminator Loss: 2.008e+00, Generator Loss 4.859e-01
Valid TPR: 1.214e-03
A-4_leptons: 3.009%	h_0-tau_tau: 0.035%	h_plus-tau_nu: 0.026%	leptoquarks-b_tau: 0.034%
Achieved new best score! Saving weights


Epoch 2: 100%|█████████████████████████████████████| 458/458 [00:21<00:00, 21.81it/s, discriminator_loss=0.703, generator_loss=0.784]


Discriminator Loss: 7.035e-01, Generator Loss 7.844e-01
Valid TPR: 7.143e-04
A-4_leptons: 1.780%	h_0-tau_tau: 0.019%	h_plus-tau_nu: 0.016%	leptoquarks-b_tau: 0.021%


Epoch 3: 100%|█████████████████████████████████████| 458/458 [00:20<00:00, 22.53it/s, discriminator_loss=-.0778, generator_loss=14.2]


Discriminator Loss: -7.779e-02, Generator Loss 1.423e+01
Valid TPR: 4.697e-04
A-4_leptons: 1.215%	h_0-tau_tau: 0.012%	h_plus-tau_nu: 0.008%	leptoquarks-b_tau: 0.014%


Epoch 4: 100%|██████████████████████████████████████| 458/458 [00:21<00:00, 21.29it/s, discriminator_loss=0.204, generator_loss=9.53]


Discriminator Loss: 2.043e-01, Generator Loss 9.530e+00
Valid TPR: 7.705e-04
A-4_leptons: 1.973%	h_0-tau_tau: 0.018%	h_plus-tau_nu: 0.017%	leptoquarks-b_tau: 0.019%


Epoch 5: 100%|█████████████████████████████████████| 458/458 [00:20<00:00, 22.03it/s, discriminator_loss=0.711, generator_loss=0.744]


Discriminator Loss: 7.109e-01, Generator Loss 7.436e-01
Valid TPR: 9.350e-04
A-4_leptons: 2.358%	h_0-tau_tau: 0.024%	h_plus-tau_nu: 0.020%	leptoquarks-b_tau: 0.026%


Epoch 6: 100%|█████████████████████████████████████| 458/458 [00:20<00:00, 22.13it/s, discriminator_loss=0.706, generator_loss=0.737]


Discriminator Loss: 7.057e-01, Generator Loss 7.372e-01
Valid TPR: 8.961e-04
A-4_leptons: 2.280%	h_0-tau_tau: 0.023%	h_plus-tau_nu: 0.017%	leptoquarks-b_tau: 0.027%


Epoch 7: 100%|██████████████████████████████████████| 458/458 [00:20<00:00, 22.15it/s, discriminator_loss=0.685, generator_loss=1.16]


Discriminator Loss: 6.852e-01, Generator Loss 1.156e+00
Valid TPR: 9.805e-04
A-4_leptons: 2.473%	h_0-tau_tau: 0.025%	h_plus-tau_nu: 0.021%	leptoquarks-b_tau: 0.028%


Epoch 8: 100%|██████████████████████████████████████| 458/458 [00:22<00:00, 20.66it/s, discriminator_loss=0.71, generator_loss=0.768]


Discriminator Loss: 7.099e-01, Generator Loss 7.682e-01
Valid TPR: 1.151e-03
A-4_leptons: 2.923%	h_0-tau_tau: 0.029%	h_plus-tau_nu: 0.024%	leptoquarks-b_tau: 0.033%


Epoch 9: 100%|█████████████████████████████████████| 458/458 [00:20<00:00, 22.10it/s, discriminator_loss=0.703, generator_loss=0.737]


Discriminator Loss: 7.029e-01, Generator Loss 7.371e-01
Valid TPR: 1.208e-03
A-4_leptons: 3.080%	h_0-tau_tau: 0.029%	h_plus-tau_nu: 0.025%	leptoquarks-b_tau: 0.034%


Epoch 10: 100%|████████████████████████████████████| 458/458 [00:20<00:00, 22.58it/s, discriminator_loss=0.707, generator_loss=0.774]


Discriminator Loss: 7.067e-01, Generator Loss 7.744e-01
Valid TPR: 1.056e-03
A-4_leptons: 2.680%	h_0-tau_tau: 0.025%	h_plus-tau_nu: 0.023%	leptoquarks-b_tau: 0.031%


Epoch 11: 100%|████████████████████████████████████| 458/458 [00:21<00:00, 21.75it/s, discriminator_loss=0.701, generator_loss=0.742]


Discriminator Loss: 7.012e-01, Generator Loss 7.416e-01
Valid TPR: 1.128e-03
A-4_leptons: 2.823%	h_0-tau_tau: 0.030%	h_plus-tau_nu: 0.025%	leptoquarks-b_tau: 0.032%


Epoch 12: 100%|████████████████████████████████████| 458/458 [00:20<00:00, 22.66it/s, discriminator_loss=0.687, generator_loss=0.799]


Discriminator Loss: 6.871e-01, Generator Loss 7.985e-01
Valid TPR: 1.193e-03
A-4_leptons: 2.995%	h_0-tau_tau: 0.032%	h_plus-tau_nu: 0.026%	leptoquarks-b_tau: 0.032%


Epoch 13: 100%|█████████████████████████████████████| 458/458 [00:21<00:00, 20.89it/s, discriminator_loss=0.577, generator_loss=1.35]


Discriminator Loss: 5.771e-01, Generator Loss 1.346e+00
Valid TPR: 7.467e-04
A-4_leptons: 1.787%	h_0-tau_tau: 0.020%	h_plus-tau_nu: 0.021%	leptoquarks-b_tau: 0.023%


Epoch 14: 100%|██████████████████████████████████████| 458/458 [00:20<00:00, 22.48it/s, discriminator_loss=0.46, generator_loss=1.82]


Discriminator Loss: 4.596e-01, Generator Loss 1.815e+00
Valid TPR: 7.900e-04
A-4_leptons: 1.937%	h_0-tau_tau: 0.020%	h_plus-tau_nu: 0.022%	leptoquarks-b_tau: 0.022%


Epoch 15: 100%|█████████████████████████████████████| 458/458 [00:20<00:00, 22.84it/s, discriminator_loss=0.266, generator_loss=2.11]


Discriminator Loss: 2.665e-01, Generator Loss 2.108e+00
Valid TPR: 7.034e-04
A-4_leptons: 1.708%	h_0-tau_tau: 0.019%	h_plus-tau_nu: 0.019%	leptoquarks-b_tau: 0.019%


Epoch 16: 100%|█████████████████████████████████████| 458/458 [00:20<00:00, 22.77it/s, discriminator_loss=0.188, generator_loss=8.69]


Discriminator Loss: 1.879e-01, Generator Loss 8.688e+00
Valid TPR: 7.446e-04
A-4_leptons: 1.830%	h_0-tau_tau: 0.017%	h_plus-tau_nu: 0.018%	leptoquarks-b_tau: 0.028%


Epoch 17: 100%|█████████████████████████████████████| 458/458 [00:20<00:00, 22.52it/s, discriminator_loss=-.131, generator_loss=15.2]


Discriminator Loss: -1.308e-01, Generator Loss 1.525e+01
Valid TPR: 6.472e-04
A-4_leptons: 1.644%	h_0-tau_tau: 0.012%	h_plus-tau_nu: 0.014%	leptoquarks-b_tau: 0.026%


Epoch 18: 100%|██████████████████████████████████████| 458/458 [00:21<00:00, 20.92it/s, discriminator_loss=0.44, generator_loss=5.88]


Discriminator Loss: 4.398e-01, Generator Loss 5.880e+00
Valid TPR: 7.705e-04
A-4_leptons: 1.865%	h_0-tau_tau: 0.021%	h_plus-tau_nu: 0.020%	leptoquarks-b_tau: 0.025%


Epoch 19: 100%|████████████████████████████████████| 458/458 [00:22<00:00, 20.76it/s, discriminator_loss=0.688, generator_loss=0.948]


Discriminator Loss: 6.885e-01, Generator Loss 9.477e-01
Valid TPR: 8.030e-04
A-4_leptons: 1.930%	h_0-tau_tau: 0.021%	h_plus-tau_nu: 0.023%	leptoquarks-b_tau: 0.025%


Epoch 20: 100%|█████████████████████████████████████| 458/458 [00:20<00:00, 22.29it/s, discriminator_loss=0.645, generator_loss=1.48]


Discriminator Loss: 6.451e-01, Generator Loss 1.478e+00
Valid TPR: 7.814e-04
A-4_leptons: 1.872%	h_0-tau_tau: 0.023%	h_plus-tau_nu: 0.022%	leptoquarks-b_tau: 0.021%


Epoch 21:   5%|██                                     | 24/458 [00:02<00:43, 10.08it/s, discriminator_loss=0.65, generator_loss=1.29]


KeyboardInterrupt: 

Now load in the best version of our discriminator and visualize some of its predictions

In [12]:
discriminator.load_weights("checkpoints/discriminator")

masks = fudge_mask(valid_masks, noisy=False)
y_pred = discriminator((valid_events, masks), training=False).numpy()

In [13]:
p = figure(
    height=300,
    width=700,
    y_axis_type="log",
    x_axis_label="Discriminator Score",
    y_axis_label="Fraction in bin"
)
p.toolbar_location = None

# only look at predictions above the FPR threshold
bg_preds = y_pred[valid_y == 0]
thresh = np.percentile(bg_preds, 100 * (1 - FPR_THRESH))
bg_preds = bg_preds[bg_preds > thresh]
bins = np.linspace(thresh, y_pred.max(), 41)

bg, _ = np.histogram(bg_preds, bins=bins)
bg = bg / (valid_y == 0).sum()

def plot_hist(values, label, color):
    p.vbar(
        x=centers,
        top=np.clip(values, 1e-7, 1),
        bottom=1e-7,
        width=width,
        fill_color=color,
        fill_alpha=0.3,
        line_color="#111111",
        line_width=1.5,
        legend_label=label
    )
    

centers = (bins[:-1] + bins[1:]) / 2
width = bins[1] - bins[0]
plot_hist(bg, "Background", palette[0])

for i, signal in enumerate(signals):
    mask = valid_y == (i + 1)
    fg_preds = y_pred[mask]
    fg_preds = fg_preds[fg_preds > thresh]
    hist, _ = np.histogram(fg_preds, bins)
    hist = hist / mask.sum()
    plot_hist(hist, signal, palette[i + 1])

p.legend.click_policy = "hide"
show(p)