In [1]:
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Optional

import h5py
import numpy as np
import torch
from bilby.core.prior import Constraint, Cosine, PriorDict, Uniform
from bilby.gw.conversion import convert_to_lal_binary_black_hole_parameters
from bilby.gw.prior import UniformSourceFrame
from bilby.gw.source import lal_binary_black_hole
from bilby.gw.waveform_generator import WaveformGenerator
from bokeh.io import output_notebook, show
from bokeh.layouts import gridplot
from bokeh.models import LinearAxis, LogAxis, Range1d
from bokeh.palettes import Dark2_8 as palette
from bokeh.plotting import figure
from gwpy.timeseries import TimeSeries
from rich.progress import track

from ml4gw.dataloading import InMemoryDataset
from ml4gw.distributions import Cosine as CosineSampler
from ml4gw.distributions import LogNormal as LogNormalSampler
from ml4gw.distributions import Uniform as UniformSampler
from ml4gw.gw import compute_network_snr
from ml4gw.transforms import RandomWaveformInjection, Whitening
from ml4gw.utils.slicing import slice_kernels
from utils.autoencoder import Autoencoder, ShiftedPearsonLoss
from utils.tracker import MultiThresholdAUROC, Run

output_notebook()

In [2]:
# Paths and what not
BASE_DIR = Path.home() / "bbhnet" / "results"
DATA_DIR = Path.home() / "bbhnet" / "data"
RUN_NAME = "notebook-run"

# Data parameters
START = 1262653854
DURATION = 2048
SAMPLE_RATE = 2048
KERNEL_LENGTH = 2
HIGHPASS = 20

# Injection parameters
WAVEFORM_DURATION = 8
NUM_WAVEFORMS = 50000
REFERENCE_FREQUENCY = 50
MINIMUM_FREQUENCY = 20
INJECTION_FRACTION = 0.5
MEAN_SNR = 15
STD_SNR = 15
MIN_SNR = 1

# Optimization parameters
VALID_FRAC = 0.25
LEARNING_RATE = 4e-3
BATCH_SIZE = 256
PATIENCE = 50
MAX_EPOCHS = 100

In [3]:
background = []
for ifo in "HL":
    ts = TimeSeries.fetch_open_data(
        f"{ifo}1", start=START, end=START + DURATION
    )
    ts = ts.resample(SAMPLE_RATE)
    background.append(ts.value)
background = np.stack(background)

In [4]:
train_length = int((1 - VALID_FRAC) * SAMPLE_RATE * DURATION)
train_background, valid_background = np.split(
    background, [train_length], axis=-1
)

In [5]:
priors = dict(
    mass_1=Uniform(name="mass_1", minimum=5, maximum=100, unit=r"$M_{\odot}$"),
    mass_2=Uniform(name="mass_2", minimum=5, maximum=100, unit=r"$M_{\odot}$"),
    mass_ratio=Constraint(name="mass_ratio", minimum=0.2, maximum=5.0),
    luminosity_distance=UniformSourceFrame(
        name="luminosity_distance", minimum=100, maximum=3000, unit="Mpc"
    ),
    dec=Cosine(name="dec"),
    ra=Uniform(name="ra", minimum=0, maximum=2 * np.pi, boundary="periodic"),
    theta_jn=0,
    psi=0,
    phase=0,
    a_1=0,
    a_2=0,
    tilt_1=0,
    tilt_2=0,
    phi_12=0,
    phi_jl=0,
)
prior_dict = PriorDict(priors)

In [6]:
waveform_generator = WaveformGenerator(
    duration=WAVEFORM_DURATION,
    sampling_frequency=SAMPLE_RATE,
    frequency_domain_source_model=lal_binary_black_hole,
    parameter_conversion=convert_to_lal_binary_black_hole_parameters,
    waveform_arguments={
        "waveform_approximant": "IMRPhenomPv2",
        "reference_frequency": REFERENCE_FREQUENCY,
        "minimum_frequency": MINIMUM_FREQUENCY,
    },
)


def generate_waveform(i):
    row = {k: v[i] for k, v in params.items()}
    polarizations = waveform_generator.time_domain_strain(row)
    polarization_names = sorted(polarizations.keys())
    polarizations = np.stack([polarizations[p] for p in polarization_names])

    # center so that coalescence time is middle sample
    dt = WAVEFORM_DURATION / 2
    polarizations = np.roll(polarizations, int(dt * SAMPLE_RATE), axis=-1)
    return polarizations

09:51 bilby INFO    : Waveform generator initiated with
  frequency_domain_source_model: bilby.gw.source.lal_binary_black_hole
  time_domain_source_model: None
  parameter_conversion: bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters


The process of actually generating these waveforms can be pretty time consuming, so we'll create a cache file for skipping it on repeated runs. We'll also use multiple threads in case the cache file doesn't exist so that we're not waiting around too long.

In [7]:
WAVEFORMS_FILE = DATA_DIR / "waveforms.h5"
if WAVEFORMS_FILE.exists():
    print("Using local cache file")
    with h5py.File(WAVEFORMS_FILE, "r") as f:
        polarizations = [f[p][:][:, None] for p in ["cross", "plus"]]
        waveforms = np.concatenate(polarizations, axis=1)
        params = {k: v[:] for k, v in f["params"].items()}
else:
    waveforms = np.zeros(
        (NUM_WAVEFORMS, 2, int(SAMPLE_RATE * WAVEFORM_DURATION))
    )
    params = prior_dict.sample(NUM_WAVEFORMS)
    with ThreadPoolExecutor(4) as pool:
        it = pool.map(generate_waveform, range(NUM_WAVEFORMS))
        it = track(it, "Generating waveforms", total=NUM_WAVEFORMS)
        for i, polarizations in enumerate(it):
            waveforms[i] = polarizations

    with h5py.File(WAVEFORMS_FILE, "w") as f:
        f["cross"] = waveforms[:, 0]
        f["plus"] = waveforms[:, 1]
        params_group = f.create_group("params")
        for p, values in params.items():
            params_group[p] = values

Using local cache file


In [8]:
t = np.arange(0, WAVEFORM_DURATION, 1 / SAMPLE_RATE) - WAVEFORM_DURATION / 2
p = figure(
    width=750,
    height=300,
    x_axis_label="Time from coalescence [s]",
    y_axis_label="Gravitational wave strain [unitless]",
    tools="",
)
for i in range(2):
    p.line(
        t,
        waveforms[1, i],
        line_color=palette[i],
        line_alpha=0.8,
        line_width=1.5,
        legend_label=["cross", "plus"][i],
    )
p.legend.click_policy = "hide"
show(p)

In [9]:
num_train = int((1 - VALID_FRAC) * NUM_WAVEFORMS)
train_waveforms, valid_waveforms = np.split(waveforms, [num_train], axis=0)
train_params = {k: v[:num_train] for k, v in params.items()}
valid_params = {k: v[num_train:] for k, v in params.items()}

In [10]:
kernel_size = int(KERNEL_LENGTH * SAMPLE_RATE)
batches_per_epoch = (num_train - 1) // BATCH_SIZE + 1

train_loader = InMemoryDataset(
    train_background,
    kernel_size=kernel_size,
    batch_size=BATCH_SIZE,
    coincident=False,
    shuffle=True,
    batches_per_epoch=batches_per_epoch,
)

# set up the stride of our validation
# dataloader so that by injecting on
# every batch we go through all of our
# validation waveforms
num_valid_waveforms = int(VALID_FRAC * NUM_WAVEFORMS)
waveforms_per_second = num_valid_waveforms / VALID_FRAC / DURATION
waveforms_per_sample = waveforms_per_second / SAMPLE_RATE
valid_stride = int(1 / waveforms_per_sample)
valid_batch_size = 4 * BATCH_SIZE

valid_loader = InMemoryDataset(
    valid_background,
    kernel_size=kernel_size,
    stride=valid_stride,
    batch_size=valid_batch_size,
    coincident=True,
    shuffle=False,
)

In [11]:
injector = RandomWaveformInjection(
    sample_rate=SAMPLE_RATE,
    ifos=["H1", "L1"],
    dec=CosineSampler(),
    psi=UniformSampler(0, np.pi),
    phi=UniformSampler(-np.pi, np.pi),
    snr=LogNormalSampler(MEAN_SNR, STD_SNR, MIN_SNR),
    highpass=HIGHPASS,
    prob=1.0,
    trigger_offset=-0.6,
    cross=train_waveforms[:, 0],
    plus=train_waveforms[:, 1],
)
injector.fit(*train_background, sample_rate=SAMPLE_RATE, fftlength=2)
injector = injector.to("cuda")

valid_injector = RandomWaveformInjection(
    sample_rate=SAMPLE_RATE,
    ifos=["H1", "L1"],
    dec=valid_params["dec"],
    psi=valid_params["psi"],
    phi=valid_params["ra"],
    snr=None,
    highpass=HIGHPASS,
    cross=valid_waveforms[:, 0],
    plus=valid_waveforms[:, 1],
)
valid_injector = valid_injector.to("cuda")

In [12]:
waveform, sampled_params = injector.sample(1, "cuda")
dec, psi, phi, snr = sampled_params[0]

In [13]:
preprocessor = Whitening(num_channels=2, sample_rate=SAMPLE_RATE, fduration=1)
preprocessor.fit(
    KERNEL_LENGTH, *train_background, fftlength=2, highpass=HIGHPASS
)
preprocessor = preprocessor.to("cuda")

In [17]:
# initialize a neural network architecture
# note that kernel_size here refers to the
# size of the time dimension of the weights
# of the convolutional layers in the network
nn = Autoencoder(
    2,
    int(KERNEL_LENGTH * SAMPLE_RATE // 2),
    latent_dim=128,
    layers=[8, 32, 128, 256, 1024],
    skip_connections=[1, 3],
    kernel_size=7,
    independent=True,
)
nn.to("cuda")

# combine preprocessing and nn into a single model
# and report number of trainable parameters for reference
model = torch.nn.Sequential(preprocessor, nn)
num_params = sum([x.numel() for x in model.parameters()])
print(f"Number of parameters: {num_params}")

# build a loss function to make predictions
# independent of phase factor
light_travel_time = 0.005
max_shift = int(SAMPLE_RATE * light_travel_time)
loss_fn = ShiftedPearsonLoss(max_shift=max_shift, alpha=10)
loss_fn.to("cuda")

# initialize an optimizer and some validation
# metrics to keep track of performance
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=1e-2,
    epochs=MAX_EPOCHS,
    steps_per_epoch=batches_per_epoch,
    anneal_strategy="cos",
)
thresholds = [0.01, 0.1, 1]
auroc = MultiThresholdAUROC(thresholds).to("cuda")

# combine all these things into a Run object
# that can keep track of the progression of training
run = Run(model, BASE_DIR / RUN_NAME, auroc)

Number of parameters: 8411522


In [18]:
for epoch in run.run(MAX_EPOCHS):
    model.train()
    for X in epoch.track(train_loader, "train_loss", "Training"):
        optimizer.zero_grad(set_to_none=True)
        X = X.to("cuda")

        # do the forward step in parts so that we compute
        # the correlation with the whitened input
        X, idx, _ = injector(X)
        X = preprocessor(X)
        y_hat = nn(X)

        # compute loss and perform gradient update
        loss = loss_fn(y_hat, X)
        loss.backward()
        optimizer.step()
        scheduler.step()

        epoch.update(loss.item())

    model.eval()
    with torch.no_grad():
        waveform_idx = 0
        signal_preds, background_preds = [], []
        for X in epoch.track(valid_loader, "valid_loss", "Validating"):
            X = X.to("cuda")
            X_aug = preprocessor(X)
            y_hat = nn(X_aug)

            # only compute predictions on non-injected data,
            # since we don't have a training loss on this data
            # to compare to, and we're primarily interested in
            # using this to compare our sensitivity on injections
            y_hat = torch.flip(y_hat, dims=(-1,))
            corr = loss_fn.compute_corr(y_hat, X_aug)
            preds = loss_fn.score_max(corr)
            background_preds.append(preds)

            # this time sample waveforms explicitly
            # and deterministically
            stop = min(waveform_idx + len(X), num_valid_waveforms)
            idx = torch.arange(waveform_idx, stop)
            waveforms, _ = valid_injector.sample(idx, device="cuda")
            X = X[: len(waveforms)]

            # do inference with waveforms near the
            # start, middle, and end of kernels
            buffer = int(0.55 * SAMPLE_RATE)
            middle = waveforms.size(-1) // 2 - buffer
            waveform_loss, waveform_pred = 0, 0
            for i in range(5):
                # injected a different view of the
                # waveform at eaach iteration
                offset = int(0.9 * i * SAMPLE_RATE / 4)
                start = middle - offset
                stop = start + kernel_size
                X_aug = X + waveforms[:, :, start:stop]

                # predict on injected waveform
                X_aug = preprocessor(X_aug)
                y_hat = nn(X_aug)

                # now compute loss and predictions
                y_hat = torch.flip(y_hat, dims=(-1,))
                corr = loss_fn.compute_corr(y_hat, X_aug)
                preds = loss_fn.score_max(corr)
                loss = -loss_fn.get_weighted_corr(corr).mean()

                # add to our running totals
                waveform_pred += preds
                waveform_loss += loss

            waveform_loss /= 5
            waveform_pred /= 5
            signal_preds.append(waveform_pred)

            epoch.update(waveform_loss.item())
            waveform_idx += len(X)

        # compile all our predictions into arrays
        # and compute various AUCs using them
        signal_preds = torch.cat(signal_preds)
        background_preds = torch.cat(background_preds)
        aurocs = auroc(signal_preds, background_preds).cpu().numpy()
        epoch.metrics["aurocs"] = {k: v for k, v in zip(thresholds, aurocs)}

run.save()

Output()

KeyboardInterrupt: 

# MAIN PART OF NOTEBOOK ENDS HERE
> **Below is a lot of plotting nonsense that needs to be cleaned up, and I can't guarantee that all of it works. Such is the downside of working in notebooks...**

In [None]:
model.load_state_dict(torch.load(run.best_weights_path))
t = np.arange(0, 1, 1 / SAMPLE_RATE) - 0.5
nrows = 4
ncols = 4

X = next(iter(valid_loader))[: nrows * ncols].to("cuda")
idx = torch.arange(nrows * ncols)
waveforms, _ = valid_injector.sample(idx, device="cuda")

start = waveforms.size(-1) // 2 - kernel_size // 2
stop = start + kernel_size
waveforms = waveforms[:, :, start:stop]
X += waveforms
with torch.no_grad():
    X = model._modules["0"](X)
    y = model._modules["1"](X)
    y = torch.flip(y, dims=(1,))
    y /= y.max(axis=-1, keepdims=True).values
    y *= loss_fn.window

    pad = (loss_fn.max_shift, loss_fn.max_shift)
    num_windows = 2 * loss_fn.max_shift + 1
    batch_size = X.size(0)

    predictions = torch.nn.functional.pad(y, pad)
    predictions = predictions.unsqueeze(2)
    predictions = torch.nn.functional.unfold(predictions, (1, num_windows))
    predictions = predictions.reshape(
        batch_size, loss_fn.num_ifos, num_windows, -1
    )
    # num_windows x batch_size, num_ifos x time
    predictions = predictions.transpose(0, 2).transpose(1, 2)

    # num_windows x batch_size, num_ifos
    targets = X
    predictions = predictions - predictions.mean(-1, keepdims=True)
    targets = targets - targets.mean(-1, keepdims=True)
    corr = (predictions * targets).sum(axis=-1)
    norm = (targets**2).sum(-1) * (predictions**2).sum(-1)
    corr /= norm**0.5

    _, indices = corr.max(axis=0)
    mask = torch.arange(num_windows).view(-1, 1, 1).to(indices.device)
    mask = 1 - 2 * torch.exp(-((mask - indices) ** 2) / 10)
    scores = corr * mask

X = X.cpu().numpy()
y = y.cpu().numpy()
waveforms = waveforms.cpu().numpy()
waveforms = waveforms[:, :, kernel_size // 4 : -kernel_size // 4]
corr = corr.cpu().numpy()
scores = scores.cpu().numpy()
det_stats = ((corr**2).sum(-1) ** 0.5).max(0)

mass1s = valid_params["mass_1"][: nrows * ncols]
mass2s = valid_params["mass_2"][: nrows * ncols]
dists = valid_params["luminosity_distance"][: nrows * ncols]

rows = []
for i in range(nrows):
    cols = []
    for j in range(ncols):
        k = i * ncols + j
        pred, mass1, mass2, dist = y[k], mass1s[k], mass2s[k], dists[k]
        p = figure(
            title=f"{mass1:0.1f}/{mass2:0.1f}/{dist:0.1f}",
            x_range=(-0.5, 0.5),
            y_range=(waveforms.min(), waveforms.max()),
        )
        p.extra_y_ranges = {"pred": Range1d(y.min(), y.max())}
        p.add_layout(LinearAxis(y_range_name="pred"), "right")
        for ifo in range(2):
            p.line(
                t,
                waveforms[k, ifo],
                line_color=palette[ifo],
                line_width=0.5,
                line_alpha=0.6,
            )
            p.line(
                t,
                y[k, ifo],
                line_color=palette[ifo + 2],
                line_width=0.5,
                line_alpha=0.6,
                line_dash="dotted",
                y_range_name="pred",
            )
        cols.append(p)
    rows.append(cols)
grid = gridplot(rows, width=225, height=125)
show(grid)

In [None]:
p = figure(height=300, width=700)
x = np.arange(len(corr))
i = 3
for j in range(2):
    p.line(x, corr[:, i, j], line_color=palette[j])
show(p)

In [None]:
k = 1
width = 256
start = SAMPLE_RATE // 2 - width // 2
stop = start + width
signal = waveforms
slc = slice(start, stop)
shifts = np.arange(len(corr)) - len(corr) // 2
max_idx = np.argmax(corr[:, k], axis=0)
shift = shifts[max_idx]

sig_diff = 2 * (signal[k].max() - signal[k].min())
y_diff = 2 * (y[k].max() - y[k].min())

mass1, mass2, dist = mass1s[k], mass2s[k], dists[k]
label = f"m1={mass1:0.1f}/m2={mass2:0.1f}/dist={dist:0.1f}Mpc"
p = figure(
    title=f"Injected waveform and prediction, {label}",
    x_range=(t[slc].min(), t[slc].max()),
    y_range=(signal[k, :, slc].min() - sig_diff, signal[k, :, slc].max()),
    height=300,
    width=700,
    tools="",
    x_axis_label="Time from coalescence [s]",
    y_axis_label="Strain",
)
p.extra_y_ranges = {"pred": Range1d(y.min() - y_diff, y.max())}
p.add_layout(
    LinearAxis(
        y_range_name="pred", axis_label="Predicted Strain [normalized]"
    ),
    "right",
)
for ifo in range(2):
    p.line(
        t[slc],
        signal[k, ifo, start - shift[ifo] : stop - shift[ifo]]
        - ifo * sig_diff,
        line_color=palette[ifo],
        line_width=1.5,
        line_alpha=0.6,
    )
    p.line(
        t[slc],
        y[k, ifo, slc] - ifo * y_diff,
        line_color=palette[ifo + 2],
        line_width=1.5,
        line_alpha=0.6,
        line_dash="dotted",
        y_range_name="pred",
    )
show(p)

In [None]:
model.load_state_dict(torch.load(run.best_weights_path))
t = np.arange(0, 1, 1 / SAMPLE_RATE) - 0.5
nrows = 4
ncols = 4

X = next(iter(valid_loader))[: nrows * ncols].to("cuda")
with torch.no_grad():
    X = model._modules["0"](X)
    y = model._modules["1"](X)
    y = torch.flip(y, dims=(1,))
    y /= y.max(axis=-1, keepdims=True).values
    y *= loss_fn.window

    pad = (loss_fn.max_shift, loss_fn.max_shift)
    num_windows = 2 * loss_fn.max_shift + 1
    batch_size = X.size(0)

    predictions = torch.nn.functional.pad(y, pad)
    predictions = predictions.unsqueeze(2)
    predictions = torch.nn.functional.unfold(predictions, (1, num_windows))
    predictions = predictions.reshape(
        batch_size, loss_fn.num_ifos, num_windows, -1
    )
    # num_windows x batch_size, num_ifos x time
    predictions = predictions.transpose(0, 2).transpose(1, 2)

    # num_windows x batch_size, num_ifos
    targets = X
    predictions = predictions - predictions.mean(-1, keepdims=True)
    targets = targets - targets.mean(-1, keepdims=True)
    corr = (predictions * targets).sum(axis=-1)
    norm = (targets**2).sum(-1) * (predictions**2).sum(-1)
    corr /= norm**0.5

    _, indices = corr.max(axis=0)
    mask = torch.arange(num_windows).view(-1, 1, 1).to(indices.device)
    mask = 1 - 2 * torch.exp(-((mask - indices) ** 2) / 10)
    scores = corr * mask

X = X.cpu().numpy()
y = y.cpu().numpy()
corr = corr.cpu().numpy()
scores = scores.cpu().numpy()
det_stats = ((corr**2).sum(-1) ** 0.5).max(0)

rows = []
for i in range(nrows):
    cols = []
    for j in range(ncols):
        k = i * ncols + j
        p = figure(
            x_range=(-0.5, 0.5),
            y_range=(X.min(), X.max()),
        )
        p.extra_y_ranges = {"pred": Range1d(y.min(), y.max())}
        p.add_layout(LinearAxis(y_range_name="pred"), "right")
        for ifo in range(2):
            p.line(
                t,
                X[k, ifo],
                line_color=palette[ifo],
                line_width=0.5,
                line_alpha=0.6,
            )
            p.line(
                t,
                y[k, ifo],
                line_color=palette[ifo + 2],
                line_width=0.5,
                line_alpha=0.6,
                line_dash="dotted",
                y_range_name="pred",
            )
        cols.append(p)
    rows.append(cols)
grid = gridplot(rows, width=225, height=125)
show(grid)

In [None]:
state_dict = torch.load(run.checkpoint_dir / "epoch_010.pt")
model.load_state_dict(state_dict)

model.eval()
with torch.no_grad():
    waveform_idx = 0
    signal_preds, background_preds, snrs = [], [], []
    for X in valid_loader:
        X = X.to("cuda")
        X_aug = preprocessor(X)
        y_hat = nn(X_aug)
        corr = loss_fn.correlate(y_hat, X_aug)
        values, _ = corr.max(axis=0)
        values = (values**2).sum(-1) ** 0.5
        background_preds.append(values)

        # this time sample waveforms explicitly
        # and deterministically
        stop = min(waveform_idx + len(X), num_valid_waveforms)
        idx = torch.arange(waveform_idx, stop)
        waveforms, _ = valid_injector.sample(idx, device="cuda")
        snr = compute_network_snr(
            waveforms, injector.background, SAMPLE_RATE, injector.mask
        )
        snrs.append(snr)

        X = X[: len(waveforms)]

        # do inference with waveforms near the
        # start, middle, and end of kernels
        buffer = int(0.55 * SAMPLE_RATE)
        middle = waveforms.size(-1) // 2 - buffer
        waveform_loss, waveform_pred = 0, 0
        for i in range(5):
            offset = int(0.9 * i * SAMPLE_RATE / 4)
            start = middle - offset
            stop = start + kernel_size
            X_aug = X + waveforms[:, :, start:stop]
            X_aug = preprocessor(X_aug)
            y_hat = nn(X_aug)
            corr = loss_fn.correlate(y_hat, X_aug)
            values, _ = corr.max(axis=0)
            values = (values**2).sum(-1) ** 0.5
            waveform_pred += values

        waveform_pred /= 5
        signal_preds.append(waveform_pred)
        waveform_idx += len(X)

    signal_preds = torch.cat(signal_preds)
    background_preds = torch.cat(background_preds)
    pool_size = int(8 * SAMPLE_RATE / valid_stride)
    #     background_preds = torch.nn.functional.max_pool1d(
    #         background_preds[None, None], pool_size, pool_size
    #     )[0, 0]
    snrs = torch.cat(snrs)
signal_preds = signal_preds.cpu().numpy()
background_preds = background_preds.cpu().numpy()
snrs = snrs.cpu().numpy()

In [None]:
hist, bins = np.histogram(background_preds, 32)
survival = np.cumsum(hist[::-1])[::-1]

p = figure(
    height=300,
    width=700,
    x_axis_label="Detection Statistic",
    y_axis_label="Background survival function",
    y_axis_type="log",
    y_range=(0.1, 2 * survival.max()),
)
p.extra_y_ranges = {"SNR": Range1d(0.5 * snrs.min(), 2 * snrs.max())}
p.add_layout(LogAxis(y_range_name="SNR", axis_label="SNR"), "right")

centers = (bins[:-1] + bins[1:]) / 2
width = 0.95 * (bins[1] - bins[0])
r = p.vbar(
    x=centers,
    top=survival,
    width=width,
    bottom=0.1,
    fill_color=palette[0],
    fill_alpha=0.5,
    line_color="#000000",
    line_width=0.5,
)
p.circle(
    signal_preds,
    snrs,
    size=6,
    fill_color=palette[1],
    line_color=palette[1],
    fill_alpha=0.5,
    line_width=0.5,
    y_range_name="SNR",
)
show(p)

In [None]:
def do_inference_at_idx(
    idx: Optional[int] = None,
    inference_sampling_rate: float = 128,
    inference_duration: float = 8,
):
    if idx is None:
        N = len(background_preds) - int(inference_duration // valid_stride)
        idx = np.random.randint(N)
    inference_stride = int(SAMPLE_RATE / inference_sampling_rate)
    start = idx * valid_stride - int(inference_duration * SAMPLE_RATE // 2)
    stop = start + int(inference_duration * SAMPLE_RATE)
    bckgrd = torch.Tensor(valid_background[:, start:stop])
    num_steps = (
        inference_duration * SAMPLE_RATE - kernel_size
    ) // inference_stride + 1
    num_batches = (num_steps - 1) // (4 * BATCH_SIZE) + 1

    corrs, shifts = [], []
    with torch.no_grad():
        for i in range(num_batches):
            start = i * BATCH_SIZE * 4 * inference_stride
            stop = (i + 1) * BATCH_SIZE * 4 * inference_stride
            stop = min(stop, bckgrd.shape[-1] - kernel_size)
            ix = torch.arange(start, stop, inference_stride)
            x = slice_kernels(bckgrd, ix, kernel_size)
            x = torch.Tensor(x).to("cuda")
            x = preprocessor(x)
            y_hat = nn(x)
            corr = loss_fn.correlate(y_hat, x)
            values, indices = corr.max(axis=0)
            values = (values**2).sum(-1) ** 0.5
            corrs.append(values.cpu().numpy())
            shifts.append(indices.cpu().numpy())
    shifts = np.stack(shifts)[0]
    corrs = np.concatenate(corrs)
    return shifts, corrs, y_hat.cpu().numpy(), inference_stride


def plot_corrs(corrs, shifts, inference_sampling_rate=128):
    t = np.arange(len(corrs)) / inference_sampling_rate
    p = figure(height=300, width=700)
    for i in range(2):
        p.line(t, shifts[:, i], line_color=palette[i])

    p.extra_y_ranges = {"corr": Range1d(0, 0.5)}
    p.add_layout(LinearAxis(y_range_name="corr"), "right")
    p.line(t, corrs, y_range_name="corr")
    show(p)

In [None]:
shifts, corrs, y_hat, inference_stride = do_inference_at_idx()
plot_corrs(corrs, shifts)

In [None]:
start = int(0.6 / 6.2 * 768)
p = figure(height=300, width=700, tools="box_zoom,reset")
ti = t[start] + np.arange(y_hat.shape[-1]) / SAMPLE_RATE
for i in range(20):
    for j in range(2):
        offset = (i * inference_stride) or None
        stop = -offset if offset is not None else None
        p.line(
            ti[offset:],
            y_hat[start + i, j, :stop],
            line_color=palette[j],
            line_alpha=0.4,
        )
show(p)

In [None]:
idx = background_preds.argmax()
shifts, corrs, y_hat, inference_stride = do_inference_at_idx(idx)
plot_corrs(corrs, shifts)

In [None]:
start = int(3.9 / 6.2 * len(y_hat))
p = figure(height=300, width=700, tools="box_zoom,reset")
ti = t[start] + np.arange(y_hat.shape[-1]) / SAMPLE_RATE
for i in range(40):
    for j in range(2):
        offset = (i * inference_stride) or None
        stop = -offset if offset is not None else None
        shift = shifts[start + i, j] // SAMPLE_RATE
        p.line(
            ti[offset:] + shift,
            y_hat[start + i, j, :stop],
            line_color=palette[j],
            line_alpha=0.4,
        )
show(p)

In [None]:
idx = np.argsort(background_preds)[-2]
shifts, corrs, y_hat, inference_stride = do_inference_at_idx(idx)
plot_corrs(corrs, shifts)

In [None]:
start = int(3.2 / 6.2 * 768)
p = figure(height=300, width=700, tools="box_zoom,reset")
ti = t[start] + np.arange(y_hat.shape[-1]) / SAMPLE_RATE
for i in range(20):
    for j in range(2):
        offset = (i * inference_stride) or None
        stop = -offset if offset is not None else None
        shift = shifts[start + i, j] // SAMPLE_RATE
        p.line(
            ti[offset:] + shift,
            y_hat[start + i, j, :stop],
            line_color=palette[j],
            line_alpha=0.4,
        )
show(p)