<div style="magin-top:5%;margin-left:10%">
<h1 style="background-color:#a31f34;color:#eeeeee;margin-right:10%;border-radius:15px">
    Building a Production ML Pipeline For Gravitational Wave Detection
</h1>

<div style="color: #333333;font-size: 24pt">Alec Gunny$^{*,1}$, Ethan Marx$^1$, William Benoit$^2$</div>


<p style="font-size:14pt;padding-top:3%">
<sub><sup>1 - Massachussetts Institute of Technology</sup></sub>
</p>
<p style="font-size:14pt;margin-top:-1%">
<sub><sup>2 - University of Minnesota</sup></sub>
</p>
</div>

## Gravitational wave astronomy

In [1]:
import h5py
import numpy as np

from utils import plotting

ifos = ["H1", "L1"]
sample_rate = 2048

## Gravitational wave data

In [2]:
! ls data

background.hdf5  cache.hdf5  signals.hdf5


`background.hdf5`

Real, open data observed by the Hanford (H1) and Livingston (L1) interferometers between April 1st and April 22nd 2019
- 1 week for training/validation, 2 weeks for test

In [3]:
with h5py.File("data/background.hdf5", "r") as f:
    dataset = f["train"]["1238175433-17136"]  # start timestamp of segment-duration
    data= {i: dataset[i][:sample_rate * 10] for i in ifos} # plot first 10s
t = np.arange(sample_rate * 10) / sample_rate
plotting.plot_timeseries(t, **data)

`signals.hdf5`

Simulated waveforms generated by gravitational wave events along with the parameters of those events
- Train and validation datasets contain raw waveform **polarizations**

In [4]:
with h5py.File("data/signals.hdf5") as f:
    dataset = f["train"]["polarizations"]
    data = {i: dataset[i][0] for i in ["cross", "plus"]}
t = np.arange(4 * sample_rate) / sample_rate
plotting.plot_timeseries(t, legend_location="top_left", **data)

- Inteferometers act as antennae that respond linearly to polarizations
- Function of relative locations/orientations of detectors and source

<img src="images/orientations.png" width=50% style="display: block;margin-left:auto;margin-right:auto"></img>

In [5]:
from pycbc.detector import Detector

sky_params = ["declination", "right_ascension", "polarization"]
with h5py.File("data/signals.hdf5", "r") as f:
    params = {i: f[f"train/parameters/{i}"][0] for i in sky_params}

responses = {}
for ifo in ifos:
    detector = Detector(ifo)
    fp, fc = detector.antenna_pattern(t_gps=1238175433, **params)
    response = ht = fp * data["plus"] + fc * data["cross"]  # call observed strain h(t)
    responses[ifo] = ht
plotting.plot_timeseries(t, legend_location="top_left", **responses)

- Projection from polarization $\rightarrow$ response introduces
    - phase shifts due to differences in arrival times
    - differences in relative amplitudes due to slight differences in polarization
- Test signals have been pre-projected so that we can analyze it as if it was a real event
- All signals have been rejection sampled to ensure their **signal-to-noise ratio** (SNR) is $\geq$ 4

## Some implementation notes
- Not live, but everything you see was run in one fell swoop
- Show what matters, hide what doesn't
- Implementations meant to be illustrative, not necessarily optimal
- "What about..." Great question! Good tools let us answer good questions with more confidence

- Lots of low frequency wiggles
- Normalize the data using a technique called *whitening* to remove these

In [None]:
ts = TimeSeriesDict.read(fnames[0], path=["H1", "L1"])
psd = {k: v.psd(fftlength=2, window="hann", method="median") for k, v in ts.items()}
psd = {k: v.crop(1, sample_rate / 2) for k, v in psd.items()}

plot_utils.plot_spectral(**psd)

- Whitening builds a filter whose frequency response is the inverse of the target PSD
- Output contains all frequencies roughly equally

In [None]:
timeseries = {k: v.whiten(asd=psd[k]**0.5, fduration=1) for k, v in timeseries.items()}
plot_utils.plot_timeseries(**timeseries)

- Noise looks "noisier", but some funky artifacts near the edges
- Due filter settle-in: crop these off

In [None]:
timeseries = {k: v.crop(start + 0.5, start + 1.5) for k, v in timeseries.items()}
plot_utils.plot_timeseries(**timeseries)

How does their frequency content look?

In [None]:
kernel_psd = {k: v.fft().abs() for k, v in timeseries.items()}
plot_utils.plot_spectral(**kernel_psd)

Simplest way to start:
- Generate a bunch of examples of background - label them 0
- Take half of these examples and inject simulated waveforms - label them 1
- Fit like your life depends on it

In [None]:
from importlib import reload
from utils import data

data = reload(data)

background_f = h5py.File("data/background.hdf5")
signal_f = h5py.File("data/signals.hdf5")
with background_f, signal_f:
    datasets = {}
    for split in ["train", "valid"]:
        datasets[split] = data.make_dataset(
            ifos,
            background_f[split],
            signal_f[split],
            kernel_length=1,
            fduration=1,
            psd_length=8,
            fftlength=2,
            sample_rate=sample_rate,
            highpass=32
        )

In [None]:
# with h5py.File("data/cache.hdf5", "w") as f:
#     for split, value in datasets.items():
#         f[split] = value
with h5py.File("data/cache.hdf5", "r") as f:
    datasets = {k: v[:] for k, v in f.items()}

In [None]:
import torch
from lightning import pytorch as pl
from torchmetrics.classification import BinaryAUROC
from utils.nn import ResNet

class BbhDetectionBase(pl.LightningModule):
    def __init__(
        self,
        learning_rate: float = 0.001,
        batch_size: int = 512,
        max_fpr: float = 1e-2
    ) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.nn = ResNet(len(ifos), layers=[2, 3, 4, 2])
        self.metric = BinaryAUROC(max_fpr=max_fpr)

    def forward(self, X):
        return self.nn(X)

    def training_step(self, batch):
        X, y = batch
        y_hat = self(X)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(y_hat, y)
        self.log("train_loss", loss, on_step=True, prog_bar=True)
        return loss

    def validation_step(self, batch):
        X, y = batch
        y_hat = self(X)
        self.metric.update(y_hat, y)
        self.log("valid_auroc", self.metric, on_epoch=True, prog_bar=True)

In [None]:
class BbhDetectionWithDataLoaders(BbhDetectionBase):
    def make_dataset(self, split):
        X = torch.Tensor(datasets[split])
        y = torch.zeros((len(X), 1))
        y[1::2] = 1
        return torch.utils.data.TensorDataset(X, y)

    def train_dataloader(self):
        dataset = self.make_dataset("train")
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=self.hparams.batch_size,
            shuffle=True,
            pin_memory=True
        )

    def val_dataloader(self):
        dataset = self.make_dataset("valid")
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=self.hparams.batch_size * 4,
            shuffle=False,
            pin_memory=True
        )

In [None]:
class BbhDetectionModel(BbhDetectionWithDataLoaders):
    def configure_optimizers(self):
        parameters = self.nn.parameters()
        optimizer = torch.optim.AdamW(parameters, self.hparams.learning_rate)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            self.hparams.learning_rate,
            pct_start=0.1,
            total_steps=self.trainer.estimated_stepping_batches
        )
        scheduler_config = dict(scheduler=scheduler, interval="step")
        return dict(optimizer=optimizer, lr_scheduler=scheduler_config)

    def configure_callbacks(self):
        chkpt = pl.callbacks.ModelCheckpoint(monitor="valid_auroc", mode="max")
        return [chkpt]

In [None]:
model = BbhDetectionModel(batch_size=1024, learning_rate=0.02)
logger = pl.loggers.CSVLogger("logs", name="vanilla-expt")

trainer = pl.Trainer(
    max_epochs=20,
    precision="16-mixed",
    log_every_n_steps=5,
    logger=logger,
    callbacks=[pl.callbacks.RichProgressBar()]
)
trainer.fit(model)

In [None]:
from importlib import reload
from utils import plotting
plotting = reload(plotting)

plotting.plot_run("vanilla-expt", 61)

In [None]:
import gc
import torch

def flush():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

flush()

In [None]:
from math import ceil
from ml4gw import gw
from ml4gw.dataloading import ChunkedTimeSeries, Hdf5TimeSeriesDataset
from ml4gw.distributions import Cosine, PowerLaw, Uniform
from ml4gw.transforms import SpectralDensity, Whiten
from ml4gw.utils.slicing import sample_kernels


class Ml4gwDetectionBase(BbhDetectionModel):
    def __init__(
        self,
        ifos: list[str],
        kernel_length: float,
        fduration: float,
        psd_length: float,
        sample_rate: float,
        fftlength: float,
        chunk_length: float = 128,
        reads_per_chunk: int = 40,
        highpass: float = 32,
        **kwargs
    ) -> None:
        super().__init__(**kwargs)

        # some torch modules we'll use to perform
        # real-time transformations
        self.spectral_density = SpectralDensity(
            sample_rate,
            fftlength,
            average="median",
            fast=True
        )
        self.whitener = Whiten(
            fduration,
            sample_rate,
            highpass=highpass
        )

        # get some geometry information about
        # the interferometers we're going to project to
        detector_tensors, vertices = gw.get_ifo_geometry(*ifos)
        self.register_buffer("detector_tensors", detector_tensors)
        self.register_buffer("detector_vertices", vertices)

        # define some sky parameter distributions
        self.declination = Cosine()
        self.polarization = Uniform(0, torch.pi)
        self.phi = Uniform(-torch.pi, torch.pi)  # relative RAs of detector and source
        self.snr_distribution = PowerLaw(4, 100, 3)

        self.ifos = ifos
        self.sample_rate = sample_rate
        self.window_size = int((kernel_length + fduration) * sample_rate)
        self.load_size = self.window_size + int(psd_length * sample_rate)
        self.chunk_size = int(chunk_length * sample_rate)
        self.highpass = highpass

    def setup(self, stage):
        with h5py.File("data/signals.hdf5", "r") as f:
            group = f["train"]["polarizations"]
            self.Hp = torch.Tensor(group["plus"][:])
            self.Hc = torch.Tensor(group["cross"][:])

    def train_dataloader(self):
        batches_per_epoch = (2 * len(self.Hp)) // self.hparams.batch_size
        batches_per_chunk = int(batches_per_epoch // 10)
        chunks_per_epoch = int(ceil(batches_per_epoch / batches_per_chunk))

        dataset = Hdf5TimeSeriesDataset(
            "data/background.hdf5",
            channels=self.ifos,
            kernel_size=self.chunk_size,
            batch_size=self.hparams.reads_per_chunk,
            batches_per_epoch=chunks_per_epoch,
            coincident=False,
            path="train"
        )
        dataloader = torch.utils.data.DataLoader(
            dataset,
            num_workers=2,
            pin_memory=True,
            persistent_workers=True
        )
        return ChunkedTimeSeriesDataset(
            dataloader,
            kernel_size=self.load_size,
            batch_size=self.hparams.batch_size,
            batches_per_chunk=batches_per_chunk,
            coincident=False
        )

    def on_after_batch_transfer(self, batch, _):
        if self.trainer.training:
            batch = self.augment(batch)
        return batch

In [None]:
class Ml4gwDetectionWithWaveformSampling(Ml4gwDetectionBase):
    def sample_waveforms(self, batch_size: int):
        rvs = torch.rand(size=(batch_size,))
        mask = rvs > 0.5
        num_injections = mask.sum().item()

        idx = torch.randint(len(self.Hp), size=(num_injections,))
        hp = self.Hp[idx]
        hc = self.Hc[idx]
        return hc, hp, mask

In [None]:
class Ml4gwDetectionWithWaveformProjection(Ml4gwDetectionWithWaveformSampling):
    def project_waveforms(self, hc, hp):
        N = len(hc)

        # sample sky parameters
        declination = self.declination(N).to(hc)
        polarization = self.polarization(N).to(hc)
        phi = self.phi(N).to(hc)

        # project to interferometer response
        return gw.compute_observed_strain(
            declination,
            polarization,
            phi,
            detector_tensors=self.detector_tensors,
            detector_vertices=self.detector_vertices,
            sample_rate=self.sample_rate,
            cross=hc,
            plus=hp
        )

In [None]:
class Ml4gwDetectionWithSnrRescaling(Ml4gwDetectionWithWaveformProjection):
    def rescale_snrs(self, responses, psd):
        num_freqs = int(responses.size(-1) // 2) + 1
        if psd.size(-1) != num_freqs:
            psd = torch.nn.functional.interpolate(
                psd, size=(num_freqs,), mode="linear"
            )
        snrs = gw.compute_network_snr(
            responses.double(), psd, self.sample_rate, self.highpass
        )

        N = len(responses)
        target_snrs = self.snr_distribution(N).to(snrs.device)
        weights = target_snrs / snrs
        responses *= weights.view(-1, 1, 1)
        return responses

In [None]:
class Ml4gwDetectionWithKernelSampling(Ml4gwDetectionWithSnrRescaling):
    def sample_kernels(self, responses):
        # slice off random views of each waveform
        # to inject in arbitrary positions
        responses = responses[:, :, -self.window_size:]

        # pad so that at least half the
        # kernel always contains signals
        pad = [0, int(self.window_size // 2)]
        responses = torch.nn.functional.pad(responses, pad)
        return sample_kernels(responses, self.window_size, coincident=True)

In [None]:
class Ml4gwDetection(Ml4gwDetectionWithKernelSampling):
    @torch.no_grad()
    def augment(self, X):
        # break off "background" from target kernel
        # and compute its PSD (in double precision
        # since our scale is so small)
        split = [X.size(-1) - self.window_size, self.window_size]
        background, X = torch.split(X, split, dim=-1)
        psd = self.spectral_density(background.double())

        hc, hp, mask = self.sample_waveforms(X.size(0))
        hc, hp, mask = hc.to(X), hp.to(X), mask.to(X.device)

        responses = self.project_waveforms(hc, hp)
        responses = self.rescale_snrs(responses, psd[mask])
        responses = self.sample_kernels(responses)

        # add at the appropriate locations
        X[mask] += responses.float()
        X = self.whitener(X, psd)

        # create labels
        y = torch.zeros((X.size(0), 1), device=X.device)
        y[mask] = 1
        return X, y

In [None]:
model = Ml4gwDetection(
    ifos,
    kernel_length=1,
    fduration=1,
    psd_length=8,
    sample_rate=sample_rate,
    fftlength=2,
    highpass=32,
    learning_rate=0.005,
    batch_size=512
)

In [None]:
logger = pl.loggers.CSVLogger("logs", name="ml4gw-expt")
trainer = pl.Trainer(
    max_epochs=20,
    precision="16-mixed",
    log_every_n_steps=5,
    logger=logger,
    callbacks=[pl.callbacks.RichProgressBar()]
)
trainer.fit(model)

In [None]:
from utils import plotting
plotting.plot_run("ml4gw-expt", 48)

In [None]:
from utils.infer import infer_on_timeslide

In [None]:
from utils import infer
from importlib import reload

infer = reload(infer)

psd_length = 64
inference_sampling_rate = 8

bgf = h5py.File("data/background.hdf5", "r")
fgf = h5py.File("data/signals.hdf5", "r")
model = model.to("cuda")
with bgf, fgf:
    bgf = bgf["test"]
    fgf = fgf["test"]
    segment = "1238725606-32598"
    shifts = [0, 1]

    bg = bgf[segment]
    fg = fgf[segment][str(shifts)]

    timestamps = fg["parameters/gps_time"][:]
    snrs = fg["parameters/snr"][:]
    num_rejected = fg.attrs["num_rejected"]
    t0 = bg[ifos[0]].attrs["x0"]

    bg, fg = infer.infer_on_timeslide(
        model,
        bgf[segment],
        fgf[segment][str(shifts)],
        ifos,
        shifts=shifts,
        kernel_length=1,
        inference_sampling_rate=inference_sampling_rate,
        fduration=1,
        psd_length=psd_length,
        batch_size=2048
    )

    timestamps -= t0
    mask = timestamps > psd_length
    mask &= timestamps < len(bg) / inference_sampling_rate
    idx = timestamps[mask] * 8
    snrs = snrs[mask]

with torch.no_grad():
    bg_pooled = torch.nn.functional.max_pool1d(
        bg.view(1, 1, -1), kernel_size=8 * 8
    )
    fg_pooled, indices = torch.nn.functional.max_pool1d(
        fg.view(1, 1, -1), kernel_size=8 * 8, return_indices=True
    )

    bg_pooled = bg_pooled.view(-1).cpu().numpy()
    fg_pooled = fg_pooled.view(-1).cpu().numpy()
    indices = indices.view(-1).cpu().numpy()

diffs = np.abs(indices[:, None] - idx)
argmin = diffs.argmin(axis=0)
events = fg_pooled[argmin]

In [None]:
diffs = np.abs(indices[:, None] - idx)
argmin = diffs.argmin(axis=0)
events = fg_pooled[argmin]

In [None]:
plt.scatter(events, snrs)
plt.yscale("log")