In [8]:
import pickle
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Callable, List, Literal, Optional

import h5py
import numpy as np
import torch
from bilby.core.prior import Constraint, Cosine, PriorDict, Uniform
from bilby.gw.conversion import convert_to_lal_binary_black_hole_parameters
from bilby.gw.prior import UniformSourceFrame
from bilby.gw.source import lal_binary_black_hole
from bilby.gw.waveform_generator import WaveformGenerator
from bokeh.io import output_notebook, show
from bokeh.palettes import Dark2_8 as palette
from bokeh.plotting import figure
from gwpy.signal.filter_design import fir_from_transfer
from gwpy.timeseries import TimeSeries
from rich.progress import track

from ml4gw.dataloading import InMemoryDataset
from ml4gw.distributions import Cosine as CosineSampler
from ml4gw.distributions import LogNormal as LogNormalSampler
from ml4gw.distributions import Uniform as UniformSampler
from ml4gw.transforms import RandomWaveformInjection

output_notebook()

In [9]:
# Paths and what not
BASE_DIR = Path.home() / "bbhnet" / "results"
DATA_DIR = Path.home() / "bbhnet" / "data"
RUN_NAME = "notebook-run"

# Data parameters
START = 1262607622
DURATION = 12288
SAMPLE_RATE = 2048
KERNEL_LENGTH = 2
HIGHPASS = 20

# Injection parameters
WAVEFORM_DURATION = 8
NUM_WAVEFORMS = 20000
REFERENCE_FREQUENCY = 50
MINIMUM_FREQUENCY = 20
INJECTION_FRACTION = 0.5
MEAN_SNR = 15
STD_SNR = 15
MIN_SNR = 1

# Optimization parameters
VALID_FRAC = 0.25
LEARNING_RATE = 4e-3
BATCH_SIZE = 256
DECAY = 80000
PATIENCE = 50
MAX_EPOCHS = 100

In [10]:
background = []
for ifo in "HL":
    ts = TimeSeries.fetch_open_data(
        f"{ifo}1", start=START, end=START + DURATION
    )
    ts = ts.resample(SAMPLE_RATE)
    background.append(ts.value)
background = np.stack(background)

In [11]:
train_length = int((1 - VALID_FRAC) * SAMPLE_RATE * DURATION)
train_background, valid_background = np.split(
    background, [train_length], axis=-1
)

In [12]:
priors = dict(
    mass_1=Uniform(name="mass_1", minimum=5, maximum=100, unit=r"$M_{\odot}$"),
    mass_2=Uniform(name="mass_2", minimum=5, maximum=100, unit=r"$M_{\odot}$"),
    mass_ratio=Constraint(name="mass_ratio", minimum=0.2, maximum=5.0),
    luminosity_distance=UniformSourceFrame(
        name="luminosity_distance", minimum=100, maximum=3000, unit="Mpc"
    ),
    dec=Cosine(name="dec"),
    ra=Uniform(name="ra", minimum=0, maximum=2 * np.pi, boundary="periodic"),
    theta_jn=0,
    psi=0,
    phase=0,
    a_1=0,
    a_2=0,
    tilt_1=0,
    tilt_2=0,
    phi_12=0,
    phi_jl=0,
)
prior_dict = PriorDict(priors)

In [13]:
waveform_generator = WaveformGenerator(
    duration=WAVEFORM_DURATION,
    sampling_frequency=SAMPLE_RATE,
    frequency_domain_source_model=lal_binary_black_hole,
    parameter_conversion=convert_to_lal_binary_black_hole_parameters,
    waveform_arguments={
        "waveform_approximant": "IMRPhenomPv2",
        "reference_frequency": REFERENCE_FREQUENCY,
        "minimum_frequency": MINIMUM_FREQUENCY,
    },
)


def generate_waveform(i):
    row = {k: v[i] for k, v in params.items()}
    polarizations = waveform_generator.time_domain_strain(row)
    polarization_names = sorted(polarizations.keys())
    polarizations = np.stack([polarizations[p] for p in polarization_names])

    # center so that coalescence time is middle sample
    dt = WAVEFORM_DURATION / 2
    polarizations = np.roll(polarizations, int(dt * SAMPLE_RATE), axis=-1)
    return polarizations

08:29 bilby INFO    : Waveform generator initiated with
  frequency_domain_source_model: bilby.gw.source.lal_binary_black_hole
  time_domain_source_model: None
  parameter_conversion: bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters


The process of actually generating these waveforms can be pretty time consuming, so we'll create a cache file for skipping it on repeated runs. We'll also use multiple threads in case the cache file doesn't exist so that we're not waiting around too long.

In [14]:
WAVEFORMS_FILE = DATA_DIR / "waveforms.h5"
if WAVEFORMS_FILE.exists():
    print("Using local cache file")
    with h5py.File(WAVEFORMS_FILE, "r") as f:
        polarizations = [f[p][:][:, None] for p in ["cross", "plus"]]
        waveforms = np.concatenate(polarizations, axis=1)
        params = {k: v[:] for k, v in f["params"].items()}
else:
    waveforms = np.zeros(
        (NUM_WAVEFORMS, 2, int(SAMPLE_RATE * WAVEFORM_DURATION))
    )
    params = prior_dict.sample(NUM_WAVEFORMS)
    with ThreadPoolExecutor(4) as pool:
        it = pool.map(generate_waveform, range(NUM_WAVEFORMS))
        it = track(it, "Generating waveforms", total=NUM_WAVEFORMS)
        for i, polarizations in enumerate(it):
            waveforms[i] = polarizations

    with h5py.File(WAVEFORMS_FILE, "w") as f:
        f["cross"] = waveforms[:, 0]
        f["plus"] = waveforms[:, 1]
        params_group = f.create_group("params")
        for p, values in params.items():
            params_group[p] = values

Using local cache file


In [15]:
t = np.arange(0, WAVEFORM_DURATION, 1 / SAMPLE_RATE) - WAVEFORM_DURATION / 2
p = figure(
    width=750,
    height=300,
    x_axis_label="Time from coalescence [s]",
    y_axis_label="Gravitational wave strain [unitless]",
    tools="",
)
for i in range(2):
    p.line(
        t,
        waveforms[1, i],
        line_color=palette[i],
        line_alpha=0.8,
        line_width=1.5,
        legend_label=["cross", "plus"][i],
    )
p.legend.click_policy = "hide"
show(p)

In [16]:
num_train = int((1 - VALID_FRAC) * NUM_WAVEFORMS)
train_waveforms, valid_waveforms = np.split(waveforms, [num_train], axis=0)
train_params = {k: v[:num_train] for k, v in params.items()}
valid_params = {k: v[num_train:] for k, v in params.items()}

In [17]:
waveforms_per_batch = int(INJECTION_FRACTION * BATCH_SIZE)
batches_per_epoch = num_train // waveforms_per_batch
train_loader = InMemoryDataset(
    train_background,
    kernel_size=int(KERNEL_LENGTH * SAMPLE_RATE),
    batch_size=BATCH_SIZE,
    coincident=False,
    shuffle=True,
    batches_per_epoch=batches_per_epoch,
)

num_valid_waveforms = int(VALID_FRAC * NUM_WAVEFORMS)
num_windows = num_valid_waveforms * 2
kernel_size = int(KERNEL_LENGTH * SAMPLE_RATE)
stride = (valid_background.shape[-1] - kernel_size) // num_windows + 1
valid_loader = InMemoryDataset(
    valid_background,
    kernel_size=kernel_size,
    stride=stride,
    batch_size=4 * BATCH_SIZE,
    coincident=True,
    shuffle=False,
)

In [18]:
injector = RandomWaveformInjection(
    sample_rate=SAMPLE_RATE,
    ifos=["H1", "L1"],
    dec=CosineSampler(),
    psi=UniformSampler(0, np.pi),
    phi=UniformSampler(-np.pi, np.pi),
    snr=LogNormalSampler(MEAN_SNR, STD_SNR, MIN_SNR),
    highpass=HIGHPASS,
    prob=INJECTION_FRACTION,
    trigger_offset=-0.6,
    cross=train_waveforms[:, 0],
    plus=train_waveforms[:, 1],
)
injector.fit(H1=train_background[0], L1=train_background[1])
injector = injector.to("cuda")

valid_injector = RandomWaveformInjection(
    sample_rate=SAMPLE_RATE,
    ifos=["H1", "L1"],
    dec=valid_params["dec"],
    psi=valid_params["psi"],
    phi=valid_params["ra"],
    snr=None,
    highpass=HIGHPASS,
    trigger_offset=0.5,
    cross=valid_waveforms[:, 0],
    plus=valid_waveforms[:, 1],
)
valid_injector = valid_injector.to("cuda")



In [19]:
waveform, sampled_params = injector.sample(1, "cuda")
dec, psi, phi, snr = sampled_params[0]

In [20]:
class WhiteningTransform(torch.nn.Module):
    def __init__(
        self,
        num_ifos: int,
        sample_rate: float,
        kernel_length: float,
        fftlength: float = 2,
        highpass: Optional[float] = None,
        fduration: Optional[float] = None,
    ) -> None:
        """Torch module for performing whitening. The first and last
        (fduration / 2) seconds of data are corrupted by the whitening
        and will be cropped. Thus, the output length
        that is ultimately passed to the network will be
        (kernel_length - fduration)
        """

        super().__init__()
        self.num_ifos = num_ifos
        self.sample_rate = sample_rate
        self.kernel_length = kernel_length
        self.fftlength = fftlength

        self.df = 1 / kernel_length
        self.ncorner = int(highpass / self.df) if highpass else 0
        self.fduration = fduration or kernel_length / 2

        # number of samples of corrupted data
        # due to settling in of whitening filter
        self.crop_samples = int((self.fduration / 2) * self.sample_rate)
        self.ntaps = int(self.fduration * self.sample_rate)
        self.pad = (self.ntaps - 1) // 2
        self.kernel_size = int(kernel_length * sample_rate)

        # initialize the parameter with 0s, then fill it out later
        tdf = torch.zeros((num_ifos, 1, self.ntaps - 1))
        self.register_buffer("time_domain_filter", tdf)

        window = torch.hann_window(self.ntaps)
        self.register_buffer("window", window, persistent=False)

    def fit(self, **ifos: np.ndarray) -> None:
        """
        Build a whitening time domain filter
        """
        if len(ifos) != self.num_ifos:
            raise ValueError(
                "Expected to fit whitening transform on {} backgrounds, "
                "but was passed {}".format(self.num_ifos, len(ifos))
            )

        tdfs = []
        for x in ifos.values():
            ts = TimeSeries(x, dt=1 / self.sample_rate)
            asd = ts.asd(
                fftlength=self.fftlength, window="hann", method="median"
            )
            asd = asd.interpolate(self.df).value
            if (asd == 0).any():
                raise ValueError("Found 0 values in background asd")

            tdf = fir_from_transfer(
                1 / asd,
                ntaps=self.ntaps,
                window="hann",
                ncorner=self.ncorner,
            )
            tdfs.append(tdf)

        tdf = torch.tensor(np.stack(tdfs)[:, None, :-1], dtype=torch.float64)
        self.time_domain_filter.copy_(tdf)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        # do a constant detrend along the time axis,
        X = X - X.mean(axis=-1, keepdims=True)
        X[:, :, : self.pad] *= self.window[: self.pad]
        X[:, :, -self.pad :] *= self.window[-self.pad :]

        nfft = min(8 * self.time_domain_filter.size(-1), self.kernel_size)
        if nfft >= self.kernel_size / 2:
            conv = torch.nn.functional.conv1d(
                X,
                self.time_domain_filter,
                groups=self.num_ifos,
                padding=int(self.pad),
            )

            # crop the beginning and ending fduration / 2
            conv = conv[:, :, self.crop_samples : -self.crop_samples]
        else:
            raise NotImplementedError(
                "An optimal torch implementation of whitening for short "
                "fdurations is not complete. Use a larger fduration "
            )
        # scale by sqrt(2 / sample_rate) for some inscrutable
        # signal processing reason beyond my understanding
        return conv * (2 / self.sample_rate) ** 0.5

In [21]:
preprocessor = WhiteningTransform(
    num_ifos=2,
    sample_rate=SAMPLE_RATE,
    kernel_length=KERNEL_LENGTH,
    highpass=HIGHPASS,
)
preprocessor.fit(H1=train_background[0], L1=train_background[1])
preprocessor = preprocessor.to("cuda")

In [22]:
def convN(
    in_planes: int,
    out_planes: int,
    kernel_size: int = 3,
    stride: int = 1,
    groups: int = 1,
    dilation: int = 1,
) -> torch.nn.Conv1d:
    """1d convolution with padding"""
    if not kernel_size % 2:
        raise ValueError("Can't use even sized kernels")

    return torch.nn.Conv1d(
        in_planes,
        out_planes,
        kernel_size=kernel_size,
        stride=stride,
        padding=dilation * int(kernel_size // 2),
        groups=groups,
        bias=False,
        dilation=dilation,
    )


def conv1(in_planes: int, out_planes: int, stride: int = 1) -> torch.nn.Conv1d:
    """kernel-size 1 convolution"""
    return torch.nn.Conv1d(
        in_planes, out_planes, kernel_size=1, stride=stride, bias=False
    )


class BasicBlock(torch.nn.Module):
    expansion: int = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        kernel_size: int = 3,
        stride: int = 1,
        downsample: Optional[torch.nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
    ) -> None:

        super().__init__()
        if norm_layer is None:
            norm_layer = torch.nn.BatchNorm1d
        if groups != 1 or base_width != 64:
            raise ValueError(
                "BasicBlock only supports groups=1 and base_width=64"
            )
        if dilation > 1:
            raise NotImplementedError(
                "Dilation > 1 not supported in BasicBlock"
            )

        # Both self.conv1 and self.downsample layers
        # downsample the input when stride != 1
        self.conv1 = convN(inplanes, planes, kernel_size, stride)
        self.bn1 = norm_layer(planes)
        self.relu = torch.nn.ReLU(inplace=True)
        self.conv2 = convN(planes, planes, kernel_size)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(torch.nn.Module):
    """
    Bottleneck blocks implement one extra convolution
    compared to basic blocks. In this layers, the `planes`
    parameter is generally meant to _downsize_ the number
    of feature maps first, which then get expanded out to
    `planes * Bottleneck.expansion` feature maps at the
    output of the layer.
    """

    expansion: int = 4

    def __init__(
        self,
        inplanes: int,
        planes: int,
        kernel_size: int = 3,
        stride: int = 1,
        downsample: Optional[torch.nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = torch.nn.BatchNorm1d

        width = int(planes * (base_width / 64.0)) * groups

        # conv1 does no downsampling, just reduces the number of
        # feature maps from inplanes to width (where width == planes)
        # if groups == 1 and base_width == 64
        self.conv1 = convN(inplanes, width, kernel_size)
        self.bn1 = norm_layer(width)

        # conv2 keeps the same number of feature maps,
        # but downsamples along the time axis if stride
        # or dilation > 1
        self.conv2 = convN(width, width, kernel_size, stride, groups, dilation)
        self.bn2 = norm_layer(width)

        # conv3 expands the feature maps back out to planes * expansion
        self.conv3 = conv1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)

        self.relu = torch.nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(torch.nn.Module):
    """1D ResNet architecture

    Simple extension of ResNet to 1D convolutions with
    arbitrary kernel sizes to support the longer timeseries
    used in BBH detection.

    Args:
        num_ifos:
            The number of interferometers used for BBH
            detection. Sets the channel dimension of the
            input tensor
        layers:
            A list representing the number of residual
            blocks to include in each "layer" of the
            network. Total layers (e.g. 50 in ResNet50)
            is `2 + sum(layers) * factor`, where factor
            is `2` for vanilla `ResNet` and `3` for
            `BottleneckResNet`.
        kernel_size:
            The size of the convolutional kernel to
            use in all residual layers. _NOT_ the size
            of the input kernel to the network, which
            is determined at run-time.
        zero_init_residual:
            Flag indicating whether to initialize the
            weights of the batch-norm layer in each block
            to 0 so that residuals are initialized as
            identities. Can improve training results.
        groups:
            Number of convolutional groups to use in all
            layers. Grouped convolutions induce local
            connections between feature maps at subsequent
            layers rather than global. Generally won't
            need this to be >1, and wil raise an error if
            >1 when using vanilla `ResNet`.
        width_per_group:
            Base width of each of the feature map groups,
            which is scaled up by the typical expansion
            factor at each layer of the network. Meaningless
            for vanilla `ResNet`.
        stride_type:
            Whether to achieve downsampling on the time axis
            by strided or dilated convolutions for each layer.
            If left as `None`, strided convolutions will be
            used at each layer. Otherwise, `stride_type` should
            be one element shorter than `layers` and indicate either
            `stride` or `dilation` for each layer after the first.
        norm_layer:
            The layer type to use for normalization after each
            convolution. If left as `None`, defaults to 1D batch norm
            (`torch.nn.BatchNorm1d`).
    """

    block = BasicBlock

    def __init__(
        self,
        num_ifos: int,
        layers: List[int],
        kernel_size: int = 3,
        zero_init_residual: bool = False,
        groups: int = 1,
        width_per_group: int = 64,
        stride_type: Optional[List[str]] = None,
        norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = torch.nn.BatchNorm1d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1

        # TODO: should we support passing a single string
        # for simplicity here?
        if stride_type is None:
            # each element in the tuple indicates if we should replace
            # the stride with a dilated convolution instead
            stride_type = ["stride"] * (len(layers) - 1)
        if len(stride_type) != (len(layers) - 1):
            raise ValueError(
                "'stride_type' should be None or a "
                "{}-element tuple, got {}".format(len(layers) - 1, stride_type)
            )

        self.groups = groups
        self.base_width = width_per_group

        # start with a basic conv-bn-relu-maxpool block
        # to reduce the dimensionality before the heavy
        # lifting starts
        self.conv1 = torch.nn.Conv1d(
            num_ifos,
            self.inplanes,
            kernel_size=7,
            stride=2,
            padding=3,
            bias=False,
        )
        self.bn1 = norm_layer(self.inplanes)
        self.relu = torch.nn.ReLU(inplace=True)
        self.maxpool = torch.nn.MaxPool1d(kernel_size=3, stride=2, padding=1)

        # now create layers of residual blocks where each
        # layer uses the same number of feature maps for
        # all its blocks (some power of 2 times 64).
        # Don't downsample along the time axis in the first
        # layer, but downsample in all the rest (either by
        # striding or dilating depending on the stride_type
        # argument)
        residual_layers = [self._make_layer(64, layers[0], kernel_size)]
        it = zip(layers[1:], stride_type)
        for i, (num_blocks, stride) in enumerate(it):
            block_size = 64 * 2 ** (i + 1)
            layer = self._make_layer(
                block_size,
                num_blocks,
                kernel_size,
                stride=2,
                stride_type=stride,
            )
            residual_layers.append(layer)
        self.residual_layers = torch.nn.ModuleList(residual_layers)

        # Average pool over each feature map to create a
        # single value for each feature map that we'll use
        # in the fully connected head
        self.avgpool = torch.nn.AdaptiveAvgPool1d(1)

        # use a fully connected layer to map from the
        # feature maps to the binary output that we need
        self.fc = torch.nn.Linear(block_size * self.block.expansion, 1)

        for m in self.modules():
            if isinstance(m, torch.nn.Conv1d):
                torch.nn.init.kaiming_normal_(
                    m.weight, mode="fan_out", nonlinearity="relu"
                )
            elif isinstance(m, (torch.nn.BatchNorm1d, torch.nn.GroupNorm)):
                torch.nn.init.constant_(m.weight, 1)
                torch.nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros,
        # and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to
        # https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    torch.nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    torch.nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(
        self,
        planes: int,
        blocks: int,
        kernel_size: int = 3,
        stride: int = 1,
        stride_type: Literal["stride", "dilation"] = "stride",
    ) -> torch.nn.Sequential:
        block = self.block
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation

        if stride_type == "dilation":
            self.dilation *= stride
            stride = 1
        elif stride_type != "stride":
            raise ValueError("Unknown stride type {stride}")

        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = torch.nn.Sequential(
                conv1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(
            block(
                self.inplanes,
                planes,
                kernel_size,
                stride,
                downsample,
                self.groups,
                self.base_width,
                previous_dilation,
                norm_layer,
            )
        )
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    kernel_size,
                    groups=self.groups,
                    base_width=self.base_width,
                    dilation=self.dilation,
                    norm_layer=norm_layer,
                )
            )

        return torch.nn.Sequential(*layers)

    def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        for layer in self.residual_layers:
            x = layer(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self._forward_impl(x)


# TODO: implement as arg of ResNet instead?
class BottleneckResNet(ResNet):
    """A version of ResNet that uses bottleneck blocks"""

    block = Bottleneck

In [23]:
nn = BottleneckResNet(2, layers=[3, 4, 4, 3]).to("cuda")
optimizer = torch.optim.Adam(nn.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=1e-2,
    epochs=MAX_EPOCHS,
    steps_per_epoch=batches_per_epoch,
    anneal_strategy="cos",
)
loss_fn = torch.nn.functional.binary_cross_entropy_with_logits
model = torch.nn.Sequential(preprocessor, nn)

In [24]:
from contextlib import contextmanager

from rich import box
from rich.align import Align
from rich.console import Console, Group
from rich.layout import Layout
from rich.panel import Panel
from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn
from rich.syntax import Syntax
from rich.table import Table
from rich.live import Live

console = Console(force_jupyter=True)


def make_layout():
    layout = Layout(name="root")
    layout.split(
        Layout(name="progress", size=3),
        Layout(name="table", ratio=1),
    )
    return layout


class Tracker:
    def __init__(self, output_directory: Path) -> None:
        self.output_directory = output_directory
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)

        self.table = Table()
        self.table.add_column("Epoch")
        self.table.add_column("Train Loss")
        self.table.add_column("Valid Loss")
        self.best_train_loss = self.best_valid_loss = float("inf")
        self.history = {"train_loss": [], "valid_loss": []}

    @property
    def checkpoint_dir(self) -> Path:
        return self.output_directory / "training" / "checkpoints"

    @property
    def best_weights_path(self) -> Path:
        return self.output_directory / "training" / "weights.pt"

    def update(self, train_loss: float, valid_loss: float) -> None:
        self.history["train_loss"].append(train_loss)
        self.history["valid_loss"].append(valid_loss)

        if train_loss <= self.best_train_loss:
            train_color = "[cyan]"
            self.best_train_loss = train_loss
        else:
            train_color = ""

        if valid_loss <= self.best_valid_loss:
            torch.save(model.state_dict(), self.best_weights_path)
            valid_color = "[green]"
            self.best_valid_loss = valid_loss
        else:
            valid_color = ""

        epoch = len(self.history["train_loss"])
        str_epoch = str(epoch).zfill(3)
        if not epoch % 5:
            fname = self.checkpoint_dir / f"epoch_{str_epoch}.pt"
            torch.save(model.state_dict(), fname)

        self.table.add_row(
            f"Epoch {epoch}/{MAX_EPOCHS}",
            f"{train_color}{train_loss:0.3e}",
            f"{valid_color}{valid_loss:0.3e}"
        )

        layout["table"].update(self.table)

    def save(self) -> None:
        fname = self.output_directory / "training" / "history.pkl"
        with open(fname, "wb") as f:
            pickle.dump(self.history, f)


tracker = Tracker(BASE_DIR / RUN_NAME)
layout = make_layout()
prog = Progress()
task_id = prog.add_task("Working on it", total=100)
layout["progress"].update(prog)
layout["table"].update(tracker.table)


def train_loop(epochs):
    with Live(layout, refresh_per_second=1):
        for i in range(epochs):
            yield i


def epoch_loop(loader):
    prog.tasks[task_id].total = len(loader)
    prog.tasks[task_id].completed = 0
    for X in loader:
        yield X
        prog.advance(task_id)

In [25]:
for i in train_loop(MAX_EPOCHS):
    epoch_loss = 0
    for X in epoch_loop(train_loader):
        optimizer.zero_grad(set_to_none=True)
        X = X.to("cuda")

        X, idx, _ = injector(X)
        y_hat = model(X)

        y = torch.zeros((len(X), 1), device="cuda")
        y[idx] = 1

        loss = loss_fn(y_hat, y)
        loss.backward()
        optimizer.step()
        scheduler.step()

        epoch_loss += loss.item()
    epoch_loss /= batches_per_epoch

    valid_loss = 0
    with torch.no_grad():
        steps, waveform_idx = 0, 0
        for X in epoch_loop(valid_loader):
            X = X.to("cuda")
            y = torch.zeros((len(X), 1), device="cuda")

            if waveform_idx < num_valid_waveforms:
                stop = min(waveform_idx + len(X) // 2, num_valid_waveforms)
                idx = torch.arange(waveform_idx, stop)
                waveforms, _ = valid_injector.sample(idx, device="cuda")
                start = waveforms.size(-1) // 2 - kernel_size // 2
                stop = start + kernel_size
                waveforms = waveforms[:, :, start: stop]

                X[::2] += waveforms
                y[::2] = 1

            y_hat = model(X)
            loss = loss_fn(y_hat, y)
            valid_loss += loss.item()

            waveform_idx += len(X) // 2
            steps += 1
        valid_loss /= steps
    tracker.update(epoch_loss, valid_loss)
tracker.save()

Output()

KeyboardInterrupt: 

In [19]:
from bokeh.layouts import gridplot

with torch.no_grad():
    X = preprocessor(X)

t = np.arange(0, 1, 1 / SAMPLE_RATE)
nrows = 4
ncols = 4
mask = y[:, 0] == 1
X1 = X[mask].cpu().numpy()[-nrows * ncols:]
y1 = y_hat[mask].cpu().numpy()[-nrows * ncols:, 0]
mass1s = valid_params["mass_1"][-nrows * ncols:]
mass2s = valid_params["mass_2"][-nrows * ncols:]
dists = valid_params["luminosity_distance"][-nrows * ncols:]

rows = []
for i in range(nrows):
    cols = []
    for j in range(ncols):
        k = i * ncols + j
        p = figure(
            title=f"{y1[k]:0.3f}/{mass1s[k]:0.1f}/{mass2s[k]:0.1f}/{dists[k]:0.1f}",
            x_range=(0, 1),
            y_range=(X1.min(), X1.max())
        )
        for ifo in range(2):
            p.line(
                t,
                X1[k, ifo],
                line_color=palette[ifo],
                line_width=0.5,
                line_alpha=0.6
            )
        cols.append(p)
    rows.append(cols)
grid = gridplot(rows, width=225, height=125)
show(grid)