# Code directly from [`ml4gw`](https://github.com/ML4GW/ml4gw) Tutorial (https://github.com/ML4GW/ml4gw)

In [1]:
import torch
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
# Desired duration of time-domain waveform
waveform_duration = 8
# Sample rate of all the data we'll be using today
sample_rate = 2048

# Define minimum, maximum, and reference frequencies
f_min = 10
f_max = 1024
f_ref = 20

nyquist = sample_rate / 2
num_samples = int(waveform_duration * sample_rate)
num_freqs = num_samples // 2 + 1

# Create an array of frequency values at which to generate our waveform
# At the moment, only frequency-domain approximants have been implemented
frequencies = torch.linspace(0, nyquist, num_freqs).to(device)
freq_mask = (frequencies >= f_min) * (frequencies < f_max).to(device)

In [3]:
from ml4gw.distributions import PowerLaw, Sine, Cosine, DeltaFunction
from torch.distributions import Uniform

# On CPU, keep the number of waveforms around 100. On GPU, you can go higher
num_waveforms = 1000

# Create a dictionary of parameter distributions
param_dict = {
    "chirp_mass": PowerLaw(((1/4)**(3/5) * 20), ((1/4)**(3/5) * 800), -2.35),
    "mass_ratio": DeltaFunction(1),
    "chi1": DeltaFunction(0),
    "chi2": DeltaFunction(0),
    "distance": PowerLaw(100, 1000, 2),
    "phic": DeltaFunction(0),
    "inclination": Sine(),
}

# param_dict = {
#     "chirp_mass": PowerLaw(10, 100, -2.35),
#     "mass_ratio": Uniform(0.125, 0.999),
#     "chi1": Uniform(-0.999, 0.999),
#     "chi2": Uniform(-0.999, 0.999),
#     "distance": PowerLaw(100, 1000, 2),
#     "phic": DeltaFunction(0),
#     "inclination": Sine(),
# }

# And then sample from each of those distributions
params = {
    k: v.sample((num_waveforms,)).to(device) for k, v in param_dict.items()
}

In [4]:
from ml4gw.waveforms import IMRPhenomD

approximant = IMRPhenomD().to(device)

# Calling the approximant with the frequency array, reference frequency, and waveform parameters
# returns the cross and plus polarizations
hc_f, hp_f = approximant(f=frequencies[freq_mask], f_ref=f_ref, **params)
print(f"Cross polarization frequency shape: {hc_f.shape}, Plus polarization frequency shape: {hp_f.shape}")

Cross polarization frequency shape: torch.Size([1000, 8112]), Plus polarization frequency shape: torch.Size([1000, 8112])


In [5]:
# create spectrum of frequencies, initially filled with zeros,
# with a delta_f such that after we fft to time domain the duration
# of the waveform will be waveform_duration
shape = (hc_f.shape[0], num_freqs)
hc_spectrum = torch.zeros(shape, dtype=hc_f.dtype, device=device)
hp_spectrum = torch.zeros(shape, dtype=hc_f.dtype, device=device)

# fill the spectrum with the
# hc and hp values at the specified frequencies
hc_spectrum[:, freq_mask] = hc_f
hp_spectrum[:, freq_mask] = hp_f

# now, irfft and scale the waveforms by sample_rate
hc, hp = torch.fft.irfft(hc_spectrum), torch.fft.irfft(hp_spectrum)
hc *= sample_rate
hp *= sample_rate

# The coalescence point is placed at the right edge, so shift it to
# give some room for ringdown
ringdown_duration = 0.5
ringdown_size = int(ringdown_duration * sample_rate)
hc = torch.roll(hc, -ringdown_size, dims=-1)
hp = torch.roll(hp, -ringdown_size, dims=-1)
print(f"Cross polarization strain shape: {hc.shape}, Plus polarization strain shape: {hp.shape}")

Cross polarization strain shape: torch.Size([1000, 16384]), Plus polarization strain shape: torch.Size([1000, 16384])


In [6]:
from ml4gw.gw import get_ifo_geometry, compute_observed_strain

# Define probability distributions for sky location and polarization angle
dec = Cosine()
psi = Uniform(0, torch.pi)
phi = Uniform(-torch.pi, torch.pi)

# The interferometer geometry for V1 and K1 are also in ml4gw
ifos = ["H1", "L1"]
tensors, vertices = get_ifo_geometry(*ifos)

# Pass the detector geometry, along with the polarizations and sky parameters,
# to get the observed strain
waveforms = compute_observed_strain(
    dec=dec.sample((num_waveforms,)).to(device),
    psi=psi.sample((num_waveforms,)).to(device),
    phi=phi.sample((num_waveforms,)).to(device),
    detector_tensors=tensors.to(device),
    detector_vertices=vertices.to(device),
    sample_rate=sample_rate,
    cross=hc,
    plus=hp,
)
print(f"Waveform shape: {waveforms.shape}")

Waveform shape: torch.Size([1000, 2, 16384])


In [7]:
from ml4gw.transforms import SpectralDensity
import h5py

fftlength = 2
spectral_density = SpectralDensity(
    sample_rate=sample_rate,
    fftlength=fftlength,
    overlap=None,
    average="median",
).to(device)

# This is H1 and L1 background data from the O3 Observation run
background_file = "/data/p_dsi/ligo/chattec-dgx01/chattec/LIGO/ligo_data/ml4gw_data_test/background-1240658942-9110.hdf5"
with h5py.File(background_file, "r") as f:
    background = [torch.Tensor(f[ifo][:]) for ifo in ifos]
    background = torch.stack(background).to(device)

# Note cast to double
psd = spectral_density(background.double())
print(f"Power Spectral Density shape: {psd.shape}")

Power Spectral Density shape: torch.Size([2, 2049])


In [8]:
from ml4gw.gw import compute_ifo_snr, compute_network_snr

# Note need to interpolate
if psd.shape[-1] != num_freqs:
    # Adding dummy dimensions for consistency
    while psd.ndim < 3:
        psd = psd[None]
    psd = torch.nn.functional.interpolate(
        psd, size=(num_freqs,), mode="linear"
    )

In [9]:
from ml4gw.gw import reweight_snrs

target_snrs = PowerLaw(12, 100, -3).sample((num_waveforms,)).to(device)
# Each waveform will be scaled by the ratio of its target SNR to its current SNR
waveforms = reweight_snrs(
    responses=waveforms,
    target_snrs=target_snrs,
    psd=psd,
    sample_rate=sample_rate,
    highpass=f_min,
)

In [10]:
from ml4gw.dataloading import Hdf5TimeSeriesDataset
from pathlib import Path

# Length of data used to estimate PSD
psd_length = 16
psd_size = int(psd_length * sample_rate)

# Length of filter. A segment of length fduration / 2
# will be cropped from either side after whitening
fduration = 2

# Length of window of data we'll feed to our network
kernel_length = 1.5
kernel_size = int(1.5 * sample_rate)

# Total length of data to sample
window_length = psd_length + fduration + kernel_length

fnames = list(Path("/data/p_dsi/ligo/chattec-dgx01/chattec/LIGO/ligo_data/ml4gw_data_test").iterdir())
dataloader = Hdf5TimeSeriesDataset(
    fnames=fnames,
    channels=ifos,
    kernel_size=int(window_length * sample_rate),
    batch_size=2
    * num_waveforms,  # Grab twice as many background samples as we have waveforms
    batches_per_epoch=1,
    coincident=False,
)

background_samples = [x for x in dataloader][0].to(device)
print(f"Background samples shape: {background_samples.shape}")

Background samples shape: torch.Size([2000, 2, 39936])


In [11]:
from ml4gw.transforms import Whiten

whiten = Whiten(
    fduration=fduration, sample_rate=sample_rate, highpass=f_min
).to(device)

# Create PSDs using the first psd_length seconds of each sample
# with the SpectralDensity module we defined earlier
psd = spectral_density(background_samples[..., :psd_size].double())
print(f"PSD shape: {psd.shape}")

# Take everything after the first psd_length as our input kernel
kernel = background_samples[..., psd_size:]
# And whiten using our PSDs
whitened_kernel = whiten(kernel, psd)
print(f"Kernel shape: {kernel.shape}")
print(f"Whitened kernel shape: {whitened_kernel.shape}")

PSD shape: torch.Size([2000, 2, 2049])
Kernel shape: torch.Size([2000, 2, 7168])
Whitened kernel shape: torch.Size([2000, 2, 3072])


In [12]:
pad = int(fduration / 2 * sample_rate)
injected = kernel.detach().clone()
# Inject waveforms into every other background sample
injected[::2, :, pad:-pad] += waveforms[..., -kernel_size:]
# And whiten with the same PSDs as before
whitened_injected = whiten(injected, psd)

In [13]:
y = torch.zeros(len(injected))
y[::2] = 1
with h5py.File("validation_dataset.hdf5", "w") as f:
    f.create_dataset("X", data=whitened_injected.cpu())
    f.create_dataset("y", data=y)