Note: run this notebook on a node with a GPU. See my ADASS demo for an environment that should be sufficient to run this and also has some plotting utils for making the output plot a bit prettier.

In [None]:
import time

import h5py
import numpy as np
from bilby.core.prior import Cosine, PriorDict, Uniform
from gwpy.timeseries import TimeSeriesDict
from pycbc.detector import Detector
from pycbc.filter.matchedfilter import sigmasq, make_frequency_series
from pycbc.types import FrequencySeries, TimeSeries
from pycbc.waveform import get_td_waveform
from tqdm import trange

# this assumes that you have data like the kind generated
# by the ADASS demo. Should be straightforward to adapt
# to the aframe data setup
ifos = ["H1", "L1"]
detectors = {i: Detector(i) for i in ifos}
t0 = 1238175433
kernel_length = 1.5
fduration = 1
sample_length = kernel_length + fduration

with h5py.File("data/signals.hdf5", "r") as f:
    hp, hc = [f["train"]["polarizations"][i][0].astype(np.float64) for i in ["plus", "cross"]]
with h5py.File("data/background.hdf5", "r") as f:
    tsd = TimeSeriesDict.read(f[f"train/{t0}-17136"], path=ifos, start=t0, end=t0 + 10.5)

sample_rate = tsd[ifos[0]].sample_rate
signal_length = hp.shape[-1] / sample_rate


class StopWatch:
    def __init__(self):
        self.intervals = {}

    def __call__(self, f):
        try:
            name = f.__name__
        except AttributeError:
            name = f.__class__.__name__
        self.intervals[name] = []

        def wrapper(*args, **kwargs):
            tick = time.time()
            out = f(*args, **kwargs)
            tock = time.time()
            self.intervals[name].append(tock - tick)
            return out
        return wrapper

Benchmark the pycbc implementation. If you want to do bilby instead, should be straightforward to adapt to the bilby code (which would frankly probably look simpler). Start by defining the functions we'll benchmark.

TODO: need to add injection and whitening step at the end.

In [None]:
skyloc_prior = PriorDict(dict(
    declination=Cosine(),
    right_ascension=Uniform(0, 2 * np.pi),
    polarization=Uniform(0, np.pi)
))

vanilla_timer = StopWatch()

def _compute_psd(x):
    xtilde = x.psd(fftlength=2, method="median", window="hann")
    xtilde = xtilde.interpolate(1 / signal_length)
    return FrequencySeries(xtilde.value, delta_f=1/signal_length)


@vanilla_timer
def compute_psds(x):
    return  {k: _compute_psd(v.crop(t0, t0 + 8)) for k, v in x.items()}


def _compute_ht(ifo, hp, hc, **skyloc):
    fp, fc = detectors[ifo].antenna_pattern(t_gps=t0, **skyloc)
    return TimeSeries(fp * hp + fc * hc, delta_t=1/2048)


@vanilla_timer
def compute_ht(hp, hc, **skyloc):
    return {i: _compute_ht(i, hp, hc, **skyloc) for i in ifos}


def _compute_snr(ht, psd, highpass=32):
    htilde = make_frequency_series(ht)
    return sigmasq(htilde, psd, low_frequency_cutoff=highpass)
    

@vanilla_timer
def compute_snr(ht, psds, highpass=32):
    return sum([_compute_snr(ht[i], psds[i], highpass) for i in ifos])**0.5


def make_sample(hp, hc):
    psds = compute_psds(tsd)
    skyloc = skyloc_prior.sample()
    ht = compute_ht(hp, hc, **skyloc)
    target_snr = np.random.uniform(7, 100)
    snr = compute_snr(ht, psds)
    ht = {k: v * target_snr / snr for k, v in ht.items()}

    # TODO: something like this but wrapped in a timer fn
    # ht = {k: v.crop(t0 + 8, t0 + 10.5) + ht[k].value[SOME CROP] for k in tsd.items()}
    # ht = {k: v.whiten(asd=psds[k]**0.5) for k, v in ht.items()}
    return ht

Now actually run this over 1000 iterations

In [None]:
for i in trange(1000):
    make_sample(hp, hc)

Define the same functions but in `ml4gw` land

In [None]:
import torch
from ml4gw import distributions, gw, transforms
from ml4gw.utils.slicing import sample_kernels


class Ml4gwAugmenter(torch.nn.Module):
    """
    Model with additional methods for performing our
    preprocessing augmentations in real-time on the GPU.
    Also loads training background in chunks from disk,
    then samples batches from chunks.

    Note that the training and validation steps themselves
    don't need to change at all: all we're doing is building
    better ways of getting data to _feed_ to the training
    step.
    """

    def __init__(
        self,
        ifos: list[str],
        kernel_length: float,
        fduration: float,
        psd_length: float,
        sample_rate: float,
        fftlength: float,
        highpass: float = 32
    ) -> None:
        super().__init__()

        # real-time transformations defined at torch Modules
        self.spectral_density = transforms.SpectralDensity(
            sample_rate, fftlength, average="median", fast=True
        )
        self.whitener = transforms.Whiten(fduration, sample_rate, highpass=highpass)

        # get some geometry information about
        # the interferometers we're going to project to
        detector_tensors, vertices = gw.get_ifo_geometry(*ifos)
        self.register_buffer("detector_tensors", detector_tensors)
        self.register_buffer("detector_vertices", vertices)

        # define some sky parameter distributions
        self.declination = distributions.Cosine()
        self.polarization = distributions.Uniform(0, torch.pi)
        self.phi = distributions.Uniform(-torch.pi, torch.pi)  # relative RAs of detector and source

        # rather than sample distances, we'll sample target SNRs.
        # This way we can ensure we train our network on
        # signals that are actually detectable. We'll use a distribution
        # that looks roughly like our sampled SNR distribution
        self.snr = distributions.PowerLaw(4, 100, 3)

        # up front let's define some properties in units of samples
        self.kernel_size = int(kernel_length * sample_rate)
        self.window_size = self.kernel_size + int(fduration * sample_rate)
        self.psd_size = int(psd_length * sample_rate)

    def project_waveforms(self, hc: torch.Tensor, hp: torch.Tensor) -> torch.Tensor:
        # sample sky parameters
        N = len(hc)
        declination = self.declination(N).to(hc)
        polarization = self.polarization(N).to(hc)
        phi = self.phi(N).to(hc)

        # project to interferometer response
        return gw.compute_observed_strain(
            declination,
            polarization,
            phi,
            detector_tensors=self.detector_tensors,
            detector_vertices=self.detector_vertices,
            sample_rate=self.hparams.sample_rate,
            cross=hc,
            plus=hp
        )

    def rescale_snrs(self, responses: torch.Tensor, psd: torch.Tensor) -> torch.Tensor:
        # make sure everything has the same number of frequency bins
        num_freqs = int(responses.size(-1) // 2) + 1
        if psd.size(-1) != num_freqs:
            psd = torch.nn.functional.interpolate(psd, size=(num_freqs,), mode="linear")
        snrs = gw.compute_network_snr(
            responses.double(), psd, self.hparams.sample_rate, self.hparams.highpass
        )

        N = len(responses)
        target_snrs = self.snr(N).to(snrs.device)
        weights = target_snrs / snrs
        return responses * weights.view(-1, 1, 1)

    def sample_kernels(self, responses: torch.Tensor) -> torch.Tensor:
        # slice off random views of each waveformto inject in arbitrary positions
        responses = responses[:, :, -self.window_size:]

        # pad so that at least half the kernel always contains signals
        pad = [0, int(self.window_size // 2)]
        responses = torch.nn.functional.pad(responses, pad)
        return sample_kernels(responses, self.window_size, coincident=True)

Build the module and move it to GPU

In [None]:
model = Ml4gwAugmenter(
    ifos,
    kernel_length,
    fduration,
    psd_length=8,
    sample_rate=sample_rate,
    highpas=32
).to("cuda")

Define a function that will run through the relevant benchmarking as a function of batch size.

In [None]:
# TODO: this also reflects the ADASS demo data structure. Again,
# should be trivial to transition to the aframe data structure
def get_polarization(f, polar, batch_size):
    return torch.Tensor(f["train/polarizations"][polar][:batch_size]).to("cuda")


@torch.no_grad()
def ml4gw_benchmark(batch_size, N=1000):
    X = np.stack([tsd[i].value for i in ifos])
    X = np.stack([X for _ in range(batch_size)])
    X = torch.Tensor(X).to("cuda")

    with h5py.File("data/signals.hdf5", "r") as f:
        hp = get_polarization(f, "plus", batch_size)
        hc = get_polarization(f, "cross", batch_size)

    ml4gw_timer = StopWatch()
    spec = ml4gw_timer(model.spectral_density)
    proj = ml4gw_timer(model.project_waveforms)
    resc = ml4gw_timer(model.rescale_snrs)
    for _ in trange(N):
        psds = spec(X)
        responses = proj(hc, hp)
        responses = resc(responses, psds)
        # TODO: add slice, add, run through whitener
    return ml4gw_timer

Do the benchmarking over the various batch sizes.

In [None]:
from collections import defaultdict
mean, var = defaultdict(dict), defaultdict(dict)
for batch_size in [32, 128, 512, 2048]:
    print(batch_size)
    timer = ml4gw_benchmark(batch_size)
    for k, v in timer.intervals.items():
        mean[k][batch_size] = np.mean(v)
        var[k][batch_size] = np.var(v)

Now plot everything in a logarithmic bar chart.

In [None]:
from bokeh.models import FactorRange
from bokeh.io import output_notebook, show
from bokeh.plotting import figure
from bokeh.palette import Bright8 as palette

output_notebook()

factors = ["CPU, no batching"] + [f"GPU, batch size {i}" for i in [32, 128, 512, 2048]]

# TODO: feel free to use some of my ADASS plotting utils
# to make this figure prettier
p = figure(
    x_range=FactorRange(*factors),
    background_fill_alpha=1.0,
    x_axis_label=r"\text{Implementation}",
    y_axis_label=r"\text{Average Execution Time [ms]}",
    y_axis_type="log"
)

maps = [
    ("SpectralDensity", "compute_psds", "Compute PSD"),
    ("project_waveforms", "compute_ht", "Project Waveforms"),
    ("rescale_snrs", "compute_snr", "Rescale SNR")
]
data = dict(impl=factors)
for i in maps:
    data[i[-1]] = []

x = []
bottom = []
top = []
labels = []
colors = []
for factor in factors:
    start = 0.1
    if factor.startswith("CPU"):
        for map, color in zip(maps, palette[-2::-1]):
            x.append(factor)
            bottom.append(start)
            t = np.mean(vanilla_timer.intervals[map[1]]) * 1000
            top.append(start + t)
            labels.append(map[2])
            start += t
            colors.append(color)
    else:
        for map, color in zip(maps, palette[-2::-1]):
            bs = int(factor.split()[-1])
            x.append(factor)
            bottom.append(start)
            t = mean[map[0]][bs] * 1000 / bs
            top.append(start + t)
            labels.append(map[2])
            start += t
            colors.append(color)
source = dict(
    x=x,
    top=top,
    bottom=bottom,
    label=labels,
    color=colors
)
p.vbar(
    x="x",
    top="top",
    bottom="bottom",
    width=0.9,
    legend_field="label",
    fill_color="color",
    fill_alpha=0.8,
    line_color="#333333",
    line_width=0.5,
    source=source
)
show(p)