In [1]:
import subprocess
from pathlib import Path

import h5py
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import tensorflow as tf

BATCH_SIZE = 8192
LATENT_DIM = 32
VALID_FRAC = 0.25
SEED = 101588
NUM_TRAIN_SAMPLES = 5000000

GEN_UPDATES = 8
D_LR = 0.0003
G_LR = 0.0008
FPR_THRESH = 1e-5

BACKGROUND_FNAME = Path("background.h5")
SIGNAL_FNAMES = {
    "A-4_leptons": "https://zenodo.org/record/7152590/files/Ato4l_lepFilter_13TeV_filtered.h5?download=1",
    "leptoquarks-b_tau": "https://zenodo.org/record/7152599/files/leptoquark_LOWMASS_lepFilter_13TeV_filtered.h5?download=1",
    "h_0-tau_tau": "https://zenodo.org/record/7152614/files/hToTauTau_13TeV_PU20_filtered.h5?download=1",
    "h_plus-tau_nu": "https://zenodo.org/record/7152617/files/hChToTauNu_13TeV_PU20_filtered.h5?download=1"
}

2023-07-13 15:50:20.079502: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
def download_dataset(fname, url):
    subprocess.run(f"wget -O {fname} {url}", shell=True)
    
if not BACKGROUND_FNAME.exists():
    download_dataset(
        str(BACKGROUND_FNAME),
        "https://zenodo.org/record/5046428/files/background_for_training.h5?download=1"
    )

for signal, url in SIGNAL_FNAMES.items():
    fname = signal + ".h5"
    if not Path(fname).exists():
        download_dataset(fname, url)

In [3]:
def load_dataset(dataset, N):
    X = dataset[:N]
    X, y = np.split(X, [3], axis=-1)
    return X, y[:, :, 0]


with h5py.File(BACKGROUND_FNAME, "r") as f:
    print("Features:", ", ".join([i.decode() for i in f["Particles_Names"][:]]))
    print("Event types:", ", ".join([i.decode() for i in f["Particles_Classes"][:]]))
    print("Total background events:", len(f["Particles"]))
    print(f"Loading {NUM_TRAIN_SAMPLES} events")
    X, masks = load_dataset(f["Particles"], NUM_TRAIN_SAMPLES)

_, num_events, num_features = X.shape
FEATURE_DIM = num_events * num_features
X = X.reshape(-1, FEATURE_DIM)

Features: Pt, Eta, Phi, Class
Event types: MET_class_1, Four_Ele_class_2, Four_Mu_class_3, Ten_Jet_class_4
Total background events: 13451915
Loading 5000000 events


In [4]:
(
    train_bg_events,
    valid_bg_events,
    train_bg_masks,
    valid_bg_masks
) = train_test_split(X, masks, test_size=VALID_FRAC, random_state=SEED)

In [5]:
scaler = StandardScaler()
scaler.fit(train_bg_events)

def preprocess(X, mask):
    X = scaler.transform(X)
    X = X.reshape(-1, num_events, num_features)
    X[mask == 0] *= 0
    X = X.reshape(-1, FEATURE_DIM)
    return X

train_bg_events = preprocess(train_bg_events, train_bg_masks)
valid_bg_events = preprocess(valid_bg_events, valid_bg_masks)

In [6]:
valid_signal_events = []
valid_signal_masks = []

for signal in SIGNAL_FNAMES:
    with h5py.File(signal + ".h5", "r") as f:
        dataset = f["Particles"]
        n = int(0.25 * len(dataset))
        print(f"Loading {n} events from signal {signal}")
        events, masks = load_dataset(dataset, n)
    events = events.reshape(-1, FEATURE_DIM)
    events = preprocess(events, masks)
    valid_signal_events.append(events)
    valid_signal_masks.append(masks)

valid_signal_events = np.concatenate(valid_signal_events)
valid_signal_masks = np.concatenate(valid_signal_masks)

valid_y = np.concatenate([
    np.zeros((len(valid_bg_events), 1)),
    np.ones((len(valid_signal_events), 1))
])
valid_events = np.concatenate([valid_bg_events, valid_signal_events])
valid_masks = np.concatenate([valid_bg_masks, valid_signal_masks])

idx = np.random.permutation(len(valid_y))
valid_y = valid_y[idx]
valid_events = valid_events[idx]
valid_masks = valid_masks[idx]
valid_X = (valid_events, valid_masks)

Loading 13992 events from signal A-4_leptons
Loading 85136 events from signal leptoquarks-b_tau
Loading 172820 events from signal h_0-tau_tau
Loading 190068 events from signal h_plus-tau_nu


In [7]:
generator_input = tf.keras.Input((LATENT_DIM,))
x = tf.keras.layers.Dense(128, activation="relu")(generator_input)
x = tf.keras.layers.Dense(256, activation="relu")(x)
generator_output = tf.keras.layers.Dense(FEATURE_DIM, activation="linear")(x)
generator_mask = tf.keras.layers.Dense(num_events, activation="sigmoid")(x)
generator = tf.keras.Model(
    inputs=generator_input,
    outputs=[generator_output, generator_mask]
)

discriminator_input = tf.keras.Input((FEATURE_DIM,))
disc_x = tf.keras.layers.Dense(256, activation="relu")(discriminator_input)
disc_x = tf.keras.layers.Dense(128, activation="relu")(disc_x)
disc_x = tf.keras.layers.Dense(64, activation="relu")(disc_x)

discriminator_mask = tf.keras.Input((num_events,))
disc_mask = tf.keras.layers.Dense(256, activation="relu")(discriminator_mask)
disc_mask = tf.keras.layers.Dense(128, activation="relu")(disc_mask)
disc_mask = tf.keras.layers.Dense(64, activation="relu")(disc_mask)

disc_x = tf.keras.layers.Concatenate()([disc_x, disc_mask])
disc_x = tf.keras.layers.Dense(256, activation="relu")(disc_x)
disc_x = tf.keras.layers.Dense(512, activation="relu")(disc_x)
disc_x = tf.keras.layers.Dense(1, activation="linear")(disc_x)
discriminator = tf.keras.Model(
    inputs=[discriminator_input, discriminator_mask],
    outputs=disc_x
)

2023-07-13 15:50:55.559566: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1960] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


Can't seem to get this custom metric to work, so using a custom training loop below. But someone smarter than I should figure this out so we can just use `Model.fit`

In [8]:
class TPR(tf.keras.metrics.Metric):
    def __init__(self, k, **kwargs):
        self.k = k
        super().__init__(**kwargs)

    def reset_state(self):
        # TODO: maybe these states need to be registered as variables?
        # Also not sure what happens with `reset_state` if we don't pass
        # the metric to `Model.compile`
        self.background_preds = tf.convert_to_tensor(())
        self.signal_preds = tf.convert_to_tensor(())

    def update_state(self, y_true, y_pred, sample_weight=None):
        background = y_pred[y_true == 0]
        self.background_preds = tf.concat([self.background_preds, background], axis=0)

        signal = y_pred[y_true == 1]
        self.signal_preds = tf.concat([self.signal_preds, signal], axis=0)

    def result(self):
        k = tf.shape(self.background_preds)[0] - self.k
        threshold = tf.sort(self.background_preds)[k]
        mask = self.signal_preds > threshold
        mask = tf.cast(mask, tf.int64)
        tpr = tf.math.reduce_mean(mask)
        return tpr

threshold_k = int(FPR_THRESH * len(valid_bg_events))

In [9]:
class Callback(tf.keras.callbacks.Callback):
    def on_epoch_start(self, epoch, logs):
        self.model.tpr.reset_state()


@tf.function
def fudge_mask(mask, noisy=True):
    mask = tf.where(mask == 0, -10., 10.)
    if noisy:
        mask = mask + tf.random.normal(shape=mask.shape)
    return tf.sigmoid(mask)


class GAN(tf.keras.Model):
    def __init__(self, discriminator, generator, latent_dim, gen_updates):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.gen_updates = gen_updates

        self.d_loss_tracker = tf.keras.metrics.Mean(name="d_loss")
        self.g_loss_tracker = tf.keras.metrics.Mean(name="g_loss")
        self.tpr = tf.keras.metrics.SensitivityAtSpecificity(1 - FPR_THRESH)

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def step_generator(self, batch_size):
        # Sample random points in the latent space
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Assemble labels that say "all real images"
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            predictions = discriminator(generator(random_latent_vectors))
            g_loss = loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, generator.trainable_weights)
        g_optimizer.apply_gradients(zip(grads, generator.trainable_weights))
        return g_loss

    def train_step(self, X):
        real_events, real_mask = X
        batch_size = tf.shape(real_events)[0]

        # train the generator for multiple steps
        # in between a single step of the discriminator
        g_loss = 0
        for i in range(self.gen_updates):
            g_loss += self.step_generator(batch_size)
        g_loss /= self.gen_updates

        # Sample random points in the latent space
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Decode them to fake events
        generated_events, generated_mask = generator(random_latent_vectors)

        # Combine them with real events
        combined_events = tf.concat([generated_events, real_events], axis=0)
        combined_masks = tf.concat([generated_mask, real_mask], axis=0)

        # Assemble labels discriminating real from fake events
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )

        # Add random noise to the labels - important trick!
        labels += 0.05 * tf.random.uniform((2 * batch_size, 1))

        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions = discriminator([combined_events, combined_masks])
            d_loss = loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, discriminator.trainable_weights)
        d_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights))

        self.d_loss_tracker.update_state(d_loss)
        self.g_loss_tracker.update_state(g_loss)
        return {
            "d_loss": self.d_loss_tracker.result(),
            "g_loss": self.g_loss_tracker.result(),
        }

    def test_step(self, data):
        # Unpack the data
        x, y = data
        events, masks = x
        masks = fudge_mask(masks, noisy=False)

        # Compute predictions
        y_pred = self.discriminator((events, masks), training=False)

        # Updates the metrics tracking the loss
        self.tpr.update_state(y, y_pred)
        return {"tpr": self.tpr.result()}

In [10]:
d_optimizer = tf.keras.optimizers.Adam(learning_rate=D_LR)
g_optimizer = tf.keras.optimizers.Adam(learning_rate=G_LR)
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=False)

gan = GAN(discriminator, generator, LATENT_DIM, GEN_UPDATES)
gan.compile(d_optimizer, g_optimizer, loss_fn)

In [11]:
from tqdm import tqdm
from contextlib import contextmanager


@contextmanager
def progbar(dataset, epoch):
    with tqdm(dataset, desc=f"Epoch {epoch + 1}") as pbar:
        yield pbar

In [None]:
from tqdm import tqdm

events = tf.data.Dataset.from_tensor_slices(train_bg_events.astype("float32"))
masks = tf.data.Dataset.from_tensor_slices(train_bg_masks.astype("float32"))
masks = masks.map(fudge_mask)

dataset = tf.data.Dataset.zip((events, masks))
dataset = dataset.shuffle(buffer_size=1024).batch(BATCH_SIZE)

losses, tprs = [], []
for epoch in range(100):
    epoch_losses = [0, 0]
    with tqdm(dataset, desc=f"Epoch {epoch + 1}") as pbar:
        for step, x in enumerate(pbar):
            losses = gan.train_on_batch(*x)
            epoch_losses = [i + j for i, j in zip(losses, epoch_losses)]
            pbar.set_postfix(
                discriminator_loss=epoch_losses[0] / (step + 1) ,
                generator_loss=epoch_losses[1] / (step + 1)
            )
        
    epoch_losses = [i / len(dataset) for i in epoch_losses]
    print(
        "Discriminator Loss: {:0.3e}, Generator Loss {:0.3e}".format(
            *epoch_losses
        )
    )
    losses.append(epoch_losses)

    masks = fudge_mask(valid_masks, noisy=False)
    y_pred = discriminator((valid_events, masks), training=False).numpy()
    thresh = np.sort(y_pred[valid_y == 0])[-threshold_k]
    tpr = (y_pred[valid_y == 1] >= thresh).mean()
    print(f"Valid TPR: {tpr:0.3e}")
    tprs.append(tpr)

Epoch 1: 100%|█████████████████████████████████████| 458/458 [02:39<00:00,  2.88it/s, discriminator_loss=0.722, generator_loss=0.712]


Discriminator Loss: 7.218e-01, Generator Loss 7.116e-01:
Valid TPR: 1.273e-03


Epoch 2:   3%|▉                                     | 12/458 [00:04<02:26,  3.05it/s, discriminator_loss=0.717, generator_loss=0.699]