In [1]:
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Optional

import h5py
import numpy as np
import torch
from bilby.core.prior import Constraint, Cosine, PriorDict, Uniform
from bilby.gw.conversion import convert_to_lal_binary_black_hole_parameters
from bilby.gw.prior import UniformSourceFrame
from bilby.gw.source import lal_binary_black_hole
from bilby.gw.waveform_generator import WaveformGenerator
from bokeh.io import output_notebook, show
from bokeh.palettes import Dark2_8 as palette
from bokeh.plotting import figure
from gwpy.signal.filter_design import fir_from_transfer
from gwpy.timeseries import TimeSeries
from rich.progress import track

from ml4gw.dataloading import InMemoryDataset
from ml4gw.distributions import Cosine as CosineSampler
from ml4gw.distributions import LogNormal as LogNormalSampler
from ml4gw.distributions import Uniform as UniformSampler
from ml4gw.transforms import RandomWaveformInjection

output_notebook()

In [2]:
# Data parameters
START = 1262607622
DURATION = 12288
SAMPLE_RATE = 2048
KERNEL_LENGTH = 2
HIGHPASS = 20

# Injection parameters
WAVEFORM_DURATION = 8
NUM_WAVEFORMS = 20000
REFERENCE_FREQUENCY = 50
MINIMUM_FREQUENCY = 20
INJECTION_FRACTION = 0.5
MEAN_SNR = 15
STD_SNR = 15
MIN_SNR = 1

# Optimization parameters
VALID_FRAC = 0.25

In [3]:
background = []
for ifo in "HL":
    ts = TimeSeries.fetch_open_data(
        f"{ifo}1", start=START, end=START + DURATION
    )
    ts = ts.resample(SAMPLE_RATE)
    background.append(ts.value)
background = np.stack(background)

In [4]:
train_length = int((1 - VALID_FRAC) * SAMPLE_RATE * DURATION)
train_background, valid_background = np.split(
    background, [train_length], axis=-1
)

In [5]:
prior_dict = PriorDict(
    dict(
        mass_1=Uniform(
            name="mass_1", minimum=5, maximum=100, unit=r"$M_{\odot}$"
        ),
        mass_2=Uniform(
            name="mass_2", minimum=5, maximum=100, unit=r"$M_{\odot}$"
        ),
        mass_ratio=Constraint(name="mass_ratio", minimum=0.2, maximum=5.0),
        luminosity_distance=UniformSourceFrame(
            name="luminosity_distance", minimum=100, maximum=3000, unit="Mpc"
        ),
        dec=Cosine(name="dec"),
        ra=Uniform(
            name="ra", minimum=0, maximum=2 * np.pi, boundary="periodic"
        ),
        theta_jn=0,
        psi=0,
        phase=0,
        a_1=0,
        a_2=0,
        tilt_1=0,
        tilt_2=0,
        phi_12=0,
        phi_jl=0,
    )
)

In [6]:
waveform_generator = WaveformGenerator(
    duration=WAVEFORM_DURATION,
    sampling_frequency=SAMPLE_RATE,
    frequency_domain_source_model=lal_binary_black_hole,
    parameter_conversion=convert_to_lal_binary_black_hole_parameters,
    waveform_arguments={
        "waveform_approximant": "IMRPhenomPv2",
        "reference_frequency": REFERENCE_FREQUENCY,
        "minimum_frequency": MINIMUM_FREQUENCY,
    },
)


def generate_waveform(i):
    row = {k: v[i] for k, v in params.items()}
    polarizations = waveform_generator.time_domain_strain(row)
    polarization_names = sorted(polarizations.keys())
    polarizations = np.stack([polarizations[p] for p in polarization_names])

    # center so that coalescence time is middle sample
    dt = WAVEFORM_DURATION / 2
    polarizations = np.roll(polarizations, int(dt * SAMPLE_RATE), axis=-1)
    return polarizations

14:17 bilby INFO    : Waveform generator initiated with
  frequency_domain_source_model: bilby.gw.source.lal_binary_black_hole
  time_domain_source_model: None
  parameter_conversion: bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters


The process of actually generating these waveforms can be pretty time consuming, so we'll create a cache file for skipping it on repeated runs. We'll also use multiple threads in case the cache file doesn't exist so that we're not waiting around too long.

In [7]:
WAVEFORMS_FILE = "waveforms.h5"
if os.path.exists(WAVEFORMS_FILE):
    print("Using local cache file")
    with h5py.File(WAVEFORMS_FILE, "r") as f:
        polarizations = [f[p][:][:, None] for p in ["cross", "plus"]]
        waveforms = np.concatenate(polarizations, axis=1)
        params = {k: v[:] for k, v in f["params"].items()}
else:
    waveforms = np.zeros(
        (NUM_WAVEFORMS, 2, int(SAMPLE_RATE * WAVEFORM_DURATION))
    )
    params = prior_dict.sample(NUM_WAVEFORMS)
    with ThreadPoolExecutor(4) as pool:
        it = enumerate(pool.map(generate_waveform, range(NUM_WAVEFORMS)))
        for i, polarizations in track(
            it, "Generating waveforms", total=NUM_WAVEFORMS
        ):
            waveforms[i] = polarizations

    with h5py.File(WAVEFORMS_FILE, "w") as f:
        f["cross"] = waveforms[:, 0]
        f["plus"] = waveforms[:, 1]
        params_group = f.create_group("params")
        for p, values in params.items():
            params_group[p] = values

Output()

KeyboardInterrupt: 

In [None]:
t = np.arange(0, WAVEFORM_DURATION, 1 / SAMPLE_RATE) - WAVEFORM_DURATION / 2
p = figure(
    width=750,
    height=300,
    x_axis_label="Time from coalescence [s]",
    y_axis_label="Gravitational wave strain [unitless]",
    tools="",
)
for i in range(2):
    p.line(
        t,
        waveforms[1, i],
        line_color=palette[i],
        line_alpha=0.8,
        line_width=1.5,
        legend_label=["cross", "plus"][i],
    )
p.legend.click_policy = "hide"
show(p)

In [None]:
num_train = int((1 - VALID_FRAC) * NUM_WAVEFORMS)
train_waveforms, valid_waveforms = np.split(waveforms, [num_train], axis=0)

In [None]:
train_loader = InMemoryDataset(
    train_background,
    kernel_size=int(KERNEL_LENGTH * SAMPLE_RATE),
    batch_size=256,
    coincident=False,
    shuffle=True,
    batches_per_epoch=100,
)
valid_loader = InMemoryDataset(
    valid_background,
    kernel_size=int(KERNEL_LENGTH * SAMPLE_RATE),
    batch_size=1024,
    coincident=True,
    shuffle=False,
)

In [None]:
injector = RandomWaveformInjection(
    sample_rate=SAMPLE_RATE,
    ifos=["H1", "L1"],
    dec=CosineSampler(),
    psi=UniformSampler(0, np.pi),
    phi=UniformSampler(-np.pi, np.pi),
    snr=LogNormalSampler(MEAN_SNR, STD_SNR, MIN_SNR),
    highpass=HIGHPASS,
    prob=INJECTION_FRACTION,
    trigger_offset=0.5,
    plus=train_waveforms[:, 1],
    cross=train_waveforms[:, 0],
)
injector.fit(H1=train_background[0], L1=train_background[1])

In [None]:
waveform, sampled_params = injector.sample(1)
dec, psi, phi, snr = sampled_params[0]

In [None]:
class WhiteningTransform(torch.nn.Module):
    def __init__(
        self,
        num_ifos: int,
        sample_rate: float,
        kernel_length: float,
        fftlength: float = 2,
        highpass: Optional[float] = None,
        fduration: Optional[float] = None,
    ) -> None:
        """Torch module for performing whitening. The first and last
        (fduration / 2) seconds of data are corrupted by the whitening
        and will be cropped. Thus, the output length
        that is ultimately passed to the network will be
        (kernel_length - fduration)
        """

        super().__init__()
        self.num_ifos = num_ifos
        self.sample_rate = sample_rate
        self.kernel_length = kernel_length
        self.fftlength = fftlength

        self.df = 1 / kernel_length
        self.ncorner = int(highpass / self.df) if highpass else 0
        self.fduration = fduration or kernel_length / 2

        # number of samples of corrupted data
        # due to settling in of whitening filter
        self.crop_samples = int((self.fduration / 2) * self.sample_rate)
        self.ntaps = int(self.fduration * self.sample_rate)
        self.pad = (self.ntaps - 1) // 2
        self.kernel_size = int(kernel_length * sample_rate)

        # initialize the parameter with 0s, then fill it out later
        tdf = torch.zeros((num_ifos, 1, self.ntaps - 1))
        self.register_buffer("time_domain_filter", tdf)

        window = torch.hann_window(self.ntaps)
        self.register_buffer("window", window, persistent=False)

    def fit(self, **ifos: np.ndarray) -> None:
        """
        Build a whitening time domain filter
        """
        if len(ifos) != self.num_ifos:
            raise ValueError(
                "Expected to fit whitening transform on {} backgrounds, "
                "but was passed {}".format(self.num_ifos, len(ifos))
            )

        tdfs = []
        for x in ifos.values():
            ts = TimeSeries(x, dt=1 / self.sample_rate)
            asd = ts.asd(
                fftlength=self.fftlength, window="hann", method="median"
            )
            asd = asd.interpolate(self.df).value
            if (asd == 0).any():
                raise ValueError("Found 0 values in background asd")

            tdf = fir_from_transfer(
                1 / asd,
                ntaps=self.ntaps,
                window="hann",
                ncorner=self.ncorner,
            )
            tdfs.append(tdf)

        tdf = torch.tensor(np.stack(tdfs)[:, None, :-1], dtype=torch.float64)
        self.time_domain_filter.copy_(tdf)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        # do a constant detrend along the time axis,
        X = X - X.mean(axis=-1, keepdims=True)
        X[:, :, : self.pad] *= self.window[: self.pad]
        X[:, :, -self.pad :] *= self.window[-self.pad :]

        nfft = min(8 * self.time_domain_filter.size(-1), self.kernel_size)
        if nfft >= self.kernel_size / 2:
            conv = torch.nn.functional.conv1d(
                X,
                self.time_domain_filter,
                groups=self.num_ifos,
                padding=int(self.pad),
            )

            # crop the beginning and ending fduration / 2
            conv = conv[:, :, self.crop_samples : -self.crop_samples]
        else:
            raise NotImplementedError(
                "An optimal torch implementation of whitening for short "
                "fdurations is not complete. Use a larger fduration "
            )
        # scale by sqrt(2 / sample_rate) for some inscrutable
        # signal processing reason beyond my understanding
        return conv * (2 / self.sample_rate) ** 0.5

In [None]:
preprocessor = WhiteningTransform(
    num_ifos=2,
    sample_rate=SAMPLE_RATE,
    kernel_length=KERNEL_LENGTH,
    highpass=HIGHPASS,
)
preprocessor.fit(H1=train_background[0], L1=train_background[1])
preprocessor.to("cuda")