<div style="magin-top:5%;margin-left:10%">
<h1 style="background-color:#a31f34;color:#eeeeee;margin-right:10%;border-radius:15px">
    Building a Production ML Pipeline For Gravitational Wave Detection
</h1>

<div style="color: #333333;font-size: 24pt">Alec Gunny$^{*,1}$, Ethan Marx$^1$, William Benoit$^2$</div>


<p style="font-size:14pt;padding-top:3%">
<sub><sup>1 - Massachussetts Institute of Technology</sup></sub>
</p>
<p style="font-size:14pt;margin-top:-1%">
<sub><sup>2 - University of Minnesota</sup></sub>
</p>
</div>

## Some implementation notes
- Not live, but everything you see was run in one fell swoop
- Show what matters, hide what doesn't
- Implementations meant to be illustrative, not necessarily optimal
- "What about..." Great question! Good tools let us answer good questions with more confidence

## Some implementation notes
Training will be handled by [`lightning`](https://lightning.ai/).
- Handling classes in slides is hard. You'll se a lot of

In [1]:
class BaseClass:
    def __init__(self, foo: str):
        self.foo = foo

In [2]:
class ExtendedClass(BaseClass):
    def repeat_foo(self, repeats):
        return self.foo * repeats

**NOT** a demo about how to use this DL framework - about a generic set of tools to augment *any* DL framework.
- Focus will be on the latter, will skim details of the former

## Gravitational wave astronomy

In [3]:
import h5py
import numpy as np

from utils import plotting

ifos = ["H1", "L1"]
sample_rate = 2048

## 3 minute crash course in gravitational wave data processing

In [4]:
! ls data

background.hdf5  cache.hdf5  signals.hdf5


`background.hdf5`

Real, open data observed by the Hanford (H1) and Livingston (L1) interferometers between April 1st and April 22nd 2019
- 1 week for training/validation, 2 weeks for test

In [5]:
with h5py.File("data/background.hdf5", "r") as f:
    dataset = f["train"]["1238175433-17136"]  # start timestamp of segment-duration
    background = {i: dataset[i][:sample_rate * 10] for i in ifos} # plot first 10s
t = np.arange(sample_rate * 10) / sample_rate
plotting.plot_timeseries(t, **background)

`signals.hdf5`

Simulated waveforms generated by gravitational wave events along with the parameters of those events
- Train and validation datasets contain raw waveform **polarizations**

In [6]:
with h5py.File("data/signals.hdf5") as f:
    dataset = f["train"]["polarizations"]
    signal = {i: dataset[i][440] for i in ["cross", "plus"]}
t = np.arange(4 * sample_rate) / sample_rate
plotting.plot_timeseries(t, legend_location="top_left", **signal)

- Inteferometers act as antennae that respond linearly to polarizations
- Function of relative locations/orientations of detectors and source

<img src="images/orientations.png" width=50% style="display: block;margin-left:auto;margin-right:auto"></img>

In [7]:
from pycbc.detector import Detector

sky_params = ["declination", "right_ascension", "polarization"]
with h5py.File("data/signals.hdf5", "r") as f:
    params = {i: f[f"train/parameters/{i}"][0] for i in sky_params}

responses = {}
for ifo in ifos:
    detector = Detector(ifo)
    fp, fc = detector.antenna_pattern(t_gps=1238175433, **params)
    response = ht = fp * signal["plus"] + fc * signal["cross"]  # call observed strain h(t)
    responses[ifo] = ht
plotting.plot_timeseries(t, legend_location="top_left", **responses)

- Projection from polarization $\rightarrow$ response introduces
    - phase shifts due to differences in arrival times
    - differences in relative amplitudes due to slight differences in polarization
- Test signals have been pre-projected so that we can analyze it as if it was a real event
- All signals have been rejection sampled to ensure their **signal-to-noise ratio** (SNR) is $\geq$ 4

## Back to detection
What examples can we give a neural network that will help it learn to detect presence of signal?
- "Loudest" part of signal is in last 0.5-1 seconds, near the **coalescence**
- Signals add simply to background noise, i.e. $h(t) = n(t) + s(t)$
- So why don't we:
    - Take short windows of background, say 1s
    - Add simulated/projected signals to ~50% of them
    - Train a binary classification network on this data

What would this data look like? What does the network "see"?

In [8]:
injected, uninjected = {}, {}
for ifo in ifos:
    # I'm actually going to grab 2 seconds for
    # reasons that will become clear momentarily
    bg = background[ifo][-2 * sample_rate:]
    uninjected[ifo] = bg
    injected[ifo] = bg.copy()
    injected[ifo][:sample_rate] += responses[ifo][-sample_rate:]
t = np.arange(2 * sample_rate) / sample_rate
plotting.plot_side_by_side(uninjected, injected, t, titles=["Before injection", "After injection"])

No major difference - but this shouldn't be suprising
- Background strain is $\mathcal{O}(10^{-19})$, waveforms are $\mathcal{O}(10^{-23})$
- Well we tried

Thank you!

- Most of background is low frequency content - 10-30Hz
- Most of signal is in 60-500Hz
- Emphasize signal by **whitening** the data: normalize frequency content by amplitude spectral density (ASD) of background

In [9]:
from gwpy.timeseries import TimeSeries

asd_length, fftlength = 8, 2
for ifo in ifos:
    asd_bg = background[ifo][-(asd_length + 1) * sample_rate:-sample_rate]
    asd = TimeSeries(asd_bg, sample_rate=sample_rate).asd(fftlength, method="median")
    for src in [injected, uninjected]:
        x = TimeSeries(src[ifo], sample_rate=sample_rate).whiten(asd=asd, fduration=1)
        x = x.crop(0.5, 1.5)  # edges are corrupted from filter settle-in
        src[ifo] = x.value

In [10]:
t = t[sample_rate // 2: -sample_rate // 2]
plotting.plot_side_by_side(uninjected, injected, t, titles=["Before injection", "After injection"])

So a quick review of how we'll generate the samples on which to train our neural network:

<img src="images/sample-flow.png" width=50% style="display: block;margin-left:auto;margin-right:auto"></img>

## Training the network
<img src="images/sample-flow.png" width=50% style="display: block;margin-left:auto;margin-right:auto"></img>

How do we do this in practice?
- Too much background to fit in memory at once
- Not sure if we can do it in real-time
- Start by generating fixed train and validation datasets up front, then fit on these

- `make_dataset` will take care of these steps using the traditional GW software stack
- Don't need to worry about the details, but clock the throughput

In [None]:
from utils.data import make_dataset

background_f = h5py.File("data/background.hdf5")
signal_f = h5py.File("data/signals.hdf5")
with background_f, signal_f:
    datasets = {}
    for split in ["train", "valid"]:
        datasets[split] = make_dataset(
            ifos,
            background_f[split],
            signal_f[split],
            kernel_length=1,
            fduration=1,
            psd_length=8,
            fftlength=2,
            sample_rate=sample_rate,
            highpass=32
        )

In [11]:
# with h5py.File("data/cache.hdf5", "w") as f:
#     for split, value in datasets.items():
#         f[split] = value
with h5py.File("data/cache.hdf5", "r") as f:
    datasets = {k: v[:] for k, v in f.items()}
ys = {k: np.zeros((len(v), 1)) for k, v in datasets.items()}
for split in ["train", "valid"]:
    ys[split][1::2] = 1
datasets = {k: (v, ys[k]) for k, v in datasets.items()}

What does our data look like now?

In [12]:
for split in ["train", "valid"]:
    X, y = datasets[split]
    num_signal = (y == 1).sum()
    num_background = (y == 0).sum()
    print("{} samples of shape {} in {} split, {} signal and {} background".format(
        len(X), X.shape[1:], split, num_signal, num_background
    ))

100000 samples of shape (2, 2048) in train split, 50000 signal and 50000 background
20000 samples of shape (2, 2048) in valid split, 10000 signal and 10000 background


Start by defining a simple `lightning` model that will train a 1D ResNet architecture on our dataset.

In [13]:
import torch
from lightning import pytorch as pl
from torchmetrics.classification import BinaryAUROC
from utils.nn import ResNet

class BbhDetectionBase(pl.LightningModule):
    def __init__(
        self,
        learning_rate: float = 0.001,
        batch_size: int = 512,
        max_fpr: float = 1e-2  # only measure ourselves on FPRs close to where we'll operate
    ) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.nn = ResNet(len(ifos), layers=[2, 3, 4, 2])
        self.metric = BinaryAUROC(max_fpr=max_fpr)

    def forward(self, X):
        return self.nn(X)

Define the training and validation steps

In [14]:
class BbhDetectionWithStep(BbhDetectionBase):
    def training_step(self, batch):
        X, y = batch
        y_hat = self(X)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(y_hat, y)
        self.log("train_loss", loss, on_step=True, prog_bar=True)
        return loss

    def validation_step(self, batch):
        X, y = batch
        y_hat = self(X)
        self.metric.update(y_hat, y)
        self.log("valid_auroc", self.metric, on_epoch=True, prog_bar=True)

Add an optimizer. Let's even throw in a scheduler so we can use larger batch sizes comfortably

In [15]:
class BbhDetectionWithOptimizers(BbhDetectionWithStep):
    def configure_optimizers(self):
        parameters = self.nn.parameters()
        optimizer = torch.optim.AdamW(parameters, self.hparams.learning_rate)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            self.hparams.learning_rate,
            pct_start=0.1,
            total_steps=self.trainer.estimated_stepping_batches
        )
        scheduler_config = dict(scheduler=scheduler, interval="step")
        return dict(optimizer=optimizer, lr_scheduler=scheduler_config)

    def configure_callbacks(self):
        chkpt = pl.callbacks.ModelCheckpoint(monitor="valid_auroc", mode="max")
        return [chkpt]

Finally define some simple data iterators. Still no magic.

In [16]:
class BbhDetectionModel(BbhDetectionWithOptimizers):
    def make_dataset(self, split):
        X, y = datasets[split]
        X, y = torch.Tensor(X), torch.Tensor(y)
        return torch.utils.data.TensorDataset(X, y)

    def train_dataloader(self):
        dataset = self.make_dataset("train")
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=self.hparams.batch_size,
            shuffle=True,
            pin_memory=True
        )

    def val_dataloader(self):
        dataset = self.make_dataset("valid")
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=self.hparams.batch_size * 4,
            shuffle=False,
            pin_memory=True
        )

And that's all it takes. Let's train on our static dataset and see how we perform

In [17]:
model = BbhDetectionModel(batch_size=1024, learning_rate=0.02)
trainer = pl.Trainer(
    max_epochs=20,
    precision="16-mixed",
    log_every_n_steps=5,
    logger=pl.loggers.CSVLogger("logs", name="vanilla-expt"),
    callbacks=[pl.callbacks.RichProgressBar()]
)
trainer.fit(model)

Output()

And now we can look at the loss curves for some reason

In [18]:
plotting.plot_run("vanilla-expt", 65)

## Mission accomplished?
- So it looks like we perform... well? Who knows, we'll address inference/evaluation later
- Consider all the ways we threw out data/priors/physics to make this work
    - Didn't use even close to all our background
    - Even worse when you consider we can _shift_ IFOs wrt one another
    - Only got to observe waveforms from one sky location/distance
    - Only got to observe waveforms inserted in one particular noise background
- We could just generate a _larger_ dataset, but just kicks the can

<img src="images/sample-flow.png" width=50% style="display: block;margin-left:auto;margin-right:auto"></img>
What if we did this _in real time_ during training?
- Take advantage of our data and physics to build more robust models
    - Our data generation throughput was ~40 samples/s
    - Our NN throughput was ~3500 samples/s
    - Even if we get a lot faster, existing tools insufficient for real-time use

## Enter `ml4gw`
Library of `torch` utilities for common GW tasks/transforms
- Align with existing APIs
- `pip` installable
- GPU accelerated, tensor-ized operations ensure efficient utilization
- Auto-differentiation means we can take gradients through ops - build physics into models

In [19]:
import ml4gw

## Enter `ml4gw`
Let's re-implement our sample-generation code using `ml4gw` dataloaders and transforms

Start by clearing out the GPU:

In [20]:
import gc

def flush():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

flush()

## Training with `ml4gw`

Start by defining a new model that will generate samples in real time.

In [21]:
from ml4gw import transforms

class Ml4gwDetectionBase(BbhDetectionModel):
    def __init__(
        self,
        ifos: list[str],
        kernel_length: float,
        fduration: float,
        psd_length: float,
        sample_rate: float,
        fftlength: float,
        chunk_length: float = 128,  # we'll talk about chunks in a second
        reads_per_chunk: int = 40,
        highpass: float = 32,
        **kwargs
    ) -> None:
        super().__init__(**kwargs)

        # real-time transformations defined at torch Modules
        self.spectral_density = transforms.SpectralDensity(
            sample_rate, fftlength, average="median", fast=True
        )
        self.whitener = transforms.Whiten(
            fduration, sample_rate, highpass=highpass
        )

Add some non-trainable parameters representing geometric information about our inteferometers that we'll need for projection

In [22]:
from ml4gw import gw

class Ml4gwDetectionWithIfoGeometry(Ml4gwDetectionBase):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        # get some geometry information about
        # the interferometers we're going to project to
        detector_tensors, vertices = gw.get_ifo_geometry(*ifos)
        self.register_buffer("detector_tensors", detector_tensors)
        self.register_buffer("detector_vertices", vertices)

We'll also need some distributions we can sample sky parameters from

In [23]:
from ml4gw import distributions

class Ml4gwDetectionWithSkyParamDistributions(Ml4gwDetectionWithIfoGeometry):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        # define some sky parameter distributions
        self.declination = distributions.Cosine()
        self.polarization = distributions.Uniform(0, torch.pi)
        self.phi = distributions.Uniform(-torch.pi, torch.pi)  # relative RAs of detector and source

        # rather than sample distances, we'll sample target SNRs.
        # This way we can ensure we train our network on
        # signals that are actually detectable. We'll use a distribution
        # that looks roughly like our sampled SNR distribution
        self.snr = distributions.PowerLaw(4, 100, 3)

Quickly define some quantities in terms of number of samples for use later

In [24]:
class Ml4gwDetectionWithSampleParams(Ml4gwDetectionWithSkyParamDistributions):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.kernel_size = int(self.hparams.kernel_length * self.hparams.sample_rate)
        fduration_size = int(self.hparams.fduration * self.hparams.sample_rate)
        self.window_size = self.kernel_size + fduration_size
        self.psd_size = int(self.hparams.psd_length * self.hparams.sample_rate)

Now we can sketch out what we want our real-time data augmentation function to look like

In [25]:
class Ml4gwDetectionWithAugmentation(Ml4gwDetectionWithSampleParams):
    @torch.no_grad()
    def augment(self, X: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        # break off "background" from target kernel and compute its PSD
        # (in double precision since our scale is so small)
        background, X = torch.split(X, [self.psd_size, self.window_size], dim=-1)
        psd = self.spectral_density(background.double())

        # sample at most batch_size signals from our bank and move them to our
        # current device. Keep a mask that indicates which rows to inject in
        batch_size = X.size(0)
        hc, hp, mask = self.sample_waveforms(batch_size)
        hc, hp, mask = hc.to(X), hp.to(X), mask.to(X.device)

        # sample sky parameters and project to responses, then
        # rescale the response according to a randomly sampled SNR
        responses = self.project_waveforms(hc, hp)
        responses = self.rescale_snrs(responses, psd[mask])

        # randomly slice out a window of the waveform, add it
        # to our background, then whiten everything
        responses = self.sample_kernels(responses)
        X[mask] += responses.float()
        X = self.whitener(X, psd)

        # create labels, marking 1s where we injected
        y = torch.zeros((batch_size, 1), device=X.device)
        y[mask] = 1
        return X, y

Make sure it gets called _after_ we've moved data onto the GPU

In [26]:
class Ml4gwDetectionWithGpuAugmentation(Ml4gwDetectionWithAugmentation):
    def on_after_batch_transfer(self, batch, _):
        # this is a parent method that lightning calls
        # between when the batch gets moved to GPU and
        # when it gets passed to the training_step.
        # Apply our augmentations here
        if self.trainer.training:
            batch = self.augment(batch)
        return batch

Now let's define these actual steps using our `ml4gw` modules

In [27]:
class Ml4gwDetectionWithSampleWaveforms(Ml4gwDetectionWithGpuAugmentation):
    def setup(self, stage):
        # lightning automatically calls this method before training starts,
        # we'll use it to load in all our signals up front, though we could
        # in principle sample these from disk for larger datasets
        with h5py.File("data/signals.hdf5", "r") as f:
            group = f["train"]["polarizations"]
            self.Hp = torch.Tensor(group["plus"][:])
            self.Hc = torch.Tensor(group["cross"][:])

    def sample_waveforms(self, batch_size: int) -> tuple[torch.Tensor, ...]:
        rvs = torch.rand(size=(batch_size,))
        mask = rvs > 0.5
        num_injections = mask.sum().item()

        idx = torch.randint(len(self.Hp), size=(num_injections,))
        hp = self.Hp[idx]
        hc = self.Hc[idx]
        return hc, hp, mask

In [28]:
class Ml4gwDetectionWithProjectWaveforms(Ml4gwDetectionWithSampleWaveforms):
    def project_waveforms(self, hc: torch.Tensor, hp: torch.Tensor) -> torch.Tensor:
        # sample sky parameters
        N = len(hc)
        declination = self.declination(N).to(hc)
        polarization = self.polarization(N).to(hc)
        phi = self.phi(N).to(hc)

        # project to interferometer response
        return gw.compute_observed_strain(
            declination,
            polarization,
            phi,
            detector_tensors=self.detector_tensors,
            detector_vertices=self.detector_vertices,
            sample_rate=self.hparams.sample_rate,
            cross=hc,
            plus=hp
        )

In [29]:
from ml4gw.utils.slicing import sample_kernels

class Ml4gwDetectionWithSnrScaling(Ml4gwDetectionWithProjectWaveforms):
    def rescale_snrs(self, responses: torch.Tensor, psd: torch.Tensor) -> torch.Tensor:
        # make sure everything has the same number of frequency bins
        num_freqs = int(responses.size(-1) // 2) + 1
        if psd.size(-1) != num_freqs:
            psd = torch.nn.functional.interpolate(psd, size=(num_freqs,), mode="linear")
        snrs = gw.compute_network_snr(
            responses.double(), psd, self.hparams.sample_rate, self.hparams.highpass
        )

        N = len(responses)
        target_snrs = self.snr(N).to(snrs.device)
        weights = target_snrs / snrs
        return responses * weights.view(-1, 1, 1)

    def sample_kernels(self, responses: torch.Tensor) -> torch.Tensor:
        # slice off random views of each waveformto inject in arbitrary positions
        responses = responses[:, :, -self.window_size:]

        # pad so that at least half the kernel always contains signals
        pad = [0, int(self.window_size // 2)]
        responses = torch.nn.functional.pad(responses, pad)
        return sample_kernels(responses, self.window_size, coincident=True)

### Dataloading
The last thing to do is to build dataloaders that read from our background file in real-time
- But including PSD background, our batches occupy nearly 300 MB of memory - nearly our disk read throughput of ~320 MB/s
- If our NN can handle 3.5 batches per second, we'll need to do more to be able to saturate it
- We'll leverage **chunked loading**: asynchronously load large buffers of random data into memory, then sample multiple _batches_ from those buffers or "chunks" between reads
- Trades off some entropy in our batches for better read throughput

In [30]:
from ml4gw.dataloading import ChunkedTimeSeriesDataset, Hdf5TimeSeriesDataset

class Ml4gwDetection(Ml4gwDetectionWithSnrScaling):
    def train_dataloader(self):
        batches_per_epoch = int((2 * len(self.Hp) - 1) // self.hparams.batch_size) + 1
        batches_per_chunk = int(batches_per_epoch // 10)
        chunks_per_epoch = int(batches_per_epoch // batches_per_chunk) + 1

        # Hdf5TimeSeries dataset samples batches from disk.
        # In this instance, we'll make our batches really large so that
        # we can treat them as chunks to sample training batches from
        dataset = Hdf5TimeSeriesDataset(
            "data/background.hdf5",
            channels=self.hparams.ifos,
            kernel_size=int(self.hparams.chunk_length * self.hparams.sample_rate),
            batch_size=self.hparams.reads_per_chunk,
            batches_per_epoch=chunks_per_epoch,
            coincident=False,
            path="train"
        )

        # multiprocess this so there's always a new chunk ready when we need it
        dataloader = torch.utils.data.DataLoader(
            dataset,
            num_workers=2,
            pin_memory=True,
            persistent_workers=True
        )

        # sample batches to pass to our NN from the chunks loaded from disk
        return ChunkedTimeSeriesDataset(
            dataloader,
            kernel_size=self.window_size + self.psd_size,
            batch_size=self.hparams.batch_size,
            batches_per_chunk=batches_per_chunk,
            coincident=False
        )

Now instantiate the model with all our preprocessing parameters from before.

For dataloading, each chunk will read 20 random 128 second segments from our data on disk, from which we'll sample ~10% of batches in epoch.

In [31]:
model = Ml4gwDetection(
    ifos,
    kernel_length=1,
    fduration=1,
    psd_length=8,
    sample_rate=sample_rate,
    fftlength=2,
    chunk_length=128,
    reads_per_chunk=20,
    highpass=32,
    learning_rate=0.005,
    batch_size=1024
)

Now let's fit this model, and see if we can do any better

In [32]:
logger = pl.loggers.CSVLogger("logs", name="ml4gw-expt")
trainer = pl.Trainer(
    max_epochs=20,
    precision="16-mixed",
    log_every_n_steps=5,
    logger=logger,
    callbacks=[pl.callbacks.RichProgressBar()]
)
trainer.fit(model)

Output()

In [33]:
from utils import plotting
plotting.plot_run("ml4gw-expt", 48)

In [None]:
from utils.infer import infer_on_timeslide

In [None]:
from utils import infer
from importlib import reload

infer = reload(infer)

psd_length = 64
inference_sampling_rate = 8

bgf = h5py.File("data/background.hdf5", "r")
fgf = h5py.File("data/signals.hdf5", "r")
model = model.to("cuda")
with bgf, fgf:
    bgf = bgf["test"]
    fgf = fgf["test"]
    segment = "1238725606-32598"
    shifts = [0, 1]

    bg = bgf[segment]
    fg = fgf[segment][str(shifts)]

    timestamps = fg["parameters/gps_time"][:]
    snrs = fg["parameters/snr"][:]
    num_rejected = fg.attrs["num_rejected"]
    t0 = bg[ifos[0]].attrs["x0"]

    bg, fg = infer.infer_on_timeslide(
        model,
        bgf[segment],
        fgf[segment][str(shifts)],
        ifos,
        shifts=shifts,
        kernel_length=1,
        inference_sampling_rate=inference_sampling_rate,
        fduration=1,
        psd_length=psd_length,
        batch_size=2048
    )

    timestamps -= t0
    mask = timestamps > psd_length
    mask &= timestamps < len(bg) / inference_sampling_rate
    idx = timestamps[mask] * 8
    snrs = snrs[mask]

with torch.no_grad():
    bg_pooled = torch.nn.functional.max_pool1d(
        bg.view(1, 1, -1), kernel_size=8 * 8
    )
    fg_pooled, indices = torch.nn.functional.max_pool1d(
        fg.view(1, 1, -1), kernel_size=8 * 8, return_indices=True
    )

    bg_pooled = bg_pooled.view(-1).cpu().numpy()
    fg_pooled = fg_pooled.view(-1).cpu().numpy()
    indices = indices.view(-1).cpu().numpy()

diffs = np.abs(indices[:, None] - idx)
argmin = diffs.argmin(axis=0)
events = fg_pooled[argmin]

In [None]:
diffs = np.abs(indices[:, None] - idx)
argmin = diffs.argmin(axis=0)
events = fg_pooled[argmin]

In [None]:
plt.scatter(events, snrs)
plt.yscale("log")