# Comparing the Different Components of Our True Metrics to Measurable Metrics

In [5]:
import yaml
import torch
import torch.nn as nn
from pathlib import Path
from joint_tof_opt import *
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
import numpy as np

with open("../data/parameter_mapping.json", "r") as f:
    parameter_mapping = yaml.safe_load(f)

with open("../experiments/tof_config.yaml", "r") as f:
    tof_config = yaml.safe_load(f)

tof_config["datapoint_count"] = 301
tof_config["end_sec"] = 10.0
tof_config["sampling_rate"] = 30
tof_config["selected_sdd_index"] = 2
tof_config["fetal_f"] = 2.8
tof_config["maternal_f"] = 1.0

experiments = parameter_mapping["experiments"]
fetal_f = tof_config["fetal_f"]
maternal_f = tof_config["maternal_f"]
sampling_rate = tof_config["sampling_rate"]
num_timepoints = tof_config["datapoint_count"]
bin_count = tof_config["bin_count"]
filter_hw = 0.3

In [6]:
window_len = 20

def _make_window_sliders():
    return [
        widgets.FloatSlider(
            value=1.0,
            min=0.0,
            max=1.0,
            step=0.01,
            description=f"w{i:02d}",
            continuous_update=False,
            readout_format=".2f",
        )
        for i in range(window_len)
    ]


def _build_window(sliders):
    values = torch.tensor([slider.value for slider in sliders], dtype=torch.float32)
    if values.numel() != bin_count:
        values = F.interpolate(values.view(1, 1, -1), size=bin_count, mode="linear", align_corners=False).view(-1)
    return values


def _to_plot_data(series: torch.Tensor, use_rfft: bool):
    if not use_rfft:
        return series.detach().cpu().numpy()
    temp_series = series - series.mean()
    temp_series = temp_series * torch.hann_window(len(temp_series), device=temp_series.device)
    spectrum = torch.fft.rfft(temp_series)
    return spectrum.abs().detach().cpu().numpy()

def _get_x_axis(length: int, use_rfft: bool):
    if not use_rfft:
        return np.arange(length) / sampling_rate
    freqs = torch.fft.rfftfreq(length, d=1/sampling_rate)
    return freqs.detach().cpu().numpy()
    

def _plot_line(ax, series: torch.Tensor, label: str, use_rfft: bool, color=None, linestyle="-", logscale: bool = False):
    data = _to_plot_data(series, use_rfft)
    x_axis = _get_x_axis(len(series), use_rfft)
    ax.plot(x_axis, data, label=label, color=color, linestyle=linestyle)
    if logscale:
        ax.set_yscale("log")
    ax.legend()

def _build_filters():
    fetal_filter = PSAFESeparator(
        sampling_rate,
        fetal_f,
        True
    )
    maternal_filter = PSAFESeparator(
        sampling_rate,
        maternal_f,
        True
    )
    return fetal_filter, maternal_filter

# def _build_filters():
#     fetal_filter = CombSeparator(
#         sampling_rate,
#         fetal_f,
#         2 * fetal_f,
#         filter_hw,
#         num_timepoints // 2 + 1,
#         True,
#     )
#     maternal_filter = CombSeparator(
#         sampling_rate,
#         maternal_f,
#         2 * maternal_f,
#         filter_hw,
#         num_timepoints // 2 + 1,
#         True,
#     )
#     return fetal_filter, maternal_filter


# def _build_filters():
#     fetal_filter = FourierSeparator(
#         sampling_rate,
#         fetal_f,
#         2 * fetal_f,
#         filter_hw,
#     )
#     maternal_filter = FourierSeparator(
#         sampling_rate,
#         maternal_f,
#         2 * maternal_f,
#         filter_hw,
#     )
#     return fetal_filter, maternal_filter



def _compute_components(ppath_filename: Path, window: torch.Tensor):
    measurand = "abs"
    fetal_filter, maternal_filter = _build_filters()

    baseline_tof = compute_tof_data_series(ppath_filename, tof_config, True, True)
    baseline_measurand = get_named_moment_module(measurand, baseline_tof).forward(window)
    detrended_baseline = baseline_measurand - baseline_measurand.mean()

    maternal_component = maternal_filter.forward(detrended_baseline.unsqueeze(0).unsqueeze(0))
    residual = detrended_baseline - maternal_component
    fetal_component = fetal_filter.forward(detrended_baseline.unsqueeze(0).unsqueeze(0))
    # fetal_component = fetal_filter.forward(residual.unsqueeze(0).unsqueeze(0))

    measured_maternal_energy = ((maternal_component - maternal_component.mean()) ** 2).sum().item()
    measured_fetal_energy = ((fetal_component - fetal_component.mean()) ** 2).sum().item()

    pure_maternal_tof = compute_tof_data_series(ppath_filename, tof_config, True, False)
    maternal_measurand = get_named_moment_module(measurand, pure_maternal_tof).forward(window)
    maternal_measurand = maternal_filter.forward(maternal_measurand.unsqueeze(0).unsqueeze(0))
    pure_maternal_energy = ((maternal_measurand - maternal_measurand.mean()) ** 2).sum().item()

    pure_fetal_tof = compute_tof_data_series(ppath_filename, tof_config, False, True)
    fetal_measurand = get_named_moment_module(measurand, pure_fetal_tof).forward(window)
    fetal_measurand = fetal_filter.forward(fetal_measurand.unsqueeze(0).unsqueeze(0))
    pure_fetal_energy = ((fetal_measurand - fetal_measurand.mean()) ** 2).sum().item()

    measured_metric = measured_fetal_energy / (measured_maternal_energy) ** (1 / 2)
    pure_metric = pure_fetal_energy / (pure_maternal_energy) ** (1 / 2)

    return {
        "baseline_measurand": baseline_measurand,
        "detrended_baseline": detrended_baseline,
        "maternal_component": maternal_component,
        "fetal_component": fetal_component,
        "maternal_measurand": maternal_measurand,
        "fetal_measurand": fetal_measurand,
        "measured_maternal_energy": measured_maternal_energy,
        "measured_fetal_energy": measured_fetal_energy,
        "pure_maternal_energy": pure_maternal_energy,
        "pure_fetal_energy": pure_fetal_energy,
        "measured_metric": measured_metric,
        "pure_metric": pure_metric,
    }


def _plot_results(data, use_rfft: bool):
    fig, axes = plt.subplots(3, 2, figsize=(10, 8))
    _plot_line(axes[0, 0], data["maternal_component"], "Measured Maternal Component", use_rfft)
    _plot_line(axes[1, 0], data["fetal_component"], "Measured Fetal Component", use_rfft, color="orange")
    _plot_line(
        axes[0, 1],
        data["maternal_measurand"] - data["maternal_measurand"].mean(),
        "Pure Maternal Component",
        use_rfft,
        color="green",
    )
    _plot_line(
        axes[1, 1],
        data["fetal_measurand"] - data["fetal_measurand"].mean(),
        "Pure Fetal Component",
        use_rfft,
        color="red",
    )
    _plot_line(
        axes[2, 0],
        data["baseline_measurand"] - data["baseline_measurand"].mean(),
        "Detrended Baseline Measurand",
        use_rfft,
        color="purple",
        logscale=use_rfft,
    )
    _plot_line(
        axes[2, 0],
        data["fetal_component"] + data["maternal_component"],
        "Sum of Measured Components",
        use_rfft,
        color="brown",
        linestyle="--",
    )
    _plot_line(
        axes[2, 1],
        data["detrended_baseline"],
        "Detrended Baseline Measurand",
        use_rfft,
        color="purple",
        logscale=use_rfft,
    )
    _plot_line(
        axes[2, 1],
        data["maternal_measurand"] + data["fetal_measurand"] - data["maternal_measurand"].mean() - data["fetal_measurand"].mean(),
        "Sum of Pure Components",
        use_rfft,
        color="cyan",
        linestyle="--",
    )
    for ax in axes.flat:
        ax.set_xlabel("Frequency Bin" if use_rfft else "Time Index")
    plt.tight_layout()
    plt.show()


def _render(idx: int, use_rfft: bool, sliders):
    experiment = experiments[idx]
    ppath_filename = Path("../data") / experiment["filename"]
    window = _build_window(sliders)

    data = _compute_components(ppath_filename, window)

    print("Measured Maternal Energy: ", data["measured_maternal_energy"])
    print("Measured Fetal Energy: ", data["measured_fetal_energy"])
    print("Pure Maternal Energy: ", data["pure_maternal_energy"])
    print("Pure Fetal Energy: ", data["pure_fetal_energy"])
    print("Measured Metric: ", data["measured_metric"])
    print("Pure Metric: ", data["pure_metric"])

    _plot_results(data, use_rfft)


experiment_labels = [exp.get("name", exp["filename"]) for exp in experiments]
experiment_dropdown = widgets.Dropdown(
    options=[(label, idx) for idx, label in enumerate(experiment_labels)],
    value=0,
    description="Experiment:",
    layout=widgets.Layout(width="60%"),
)
view_rfft_checkbox = widgets.Checkbox(
    value=False,
    description="View RFFT",
    indent=False,
)
output = widgets.Output()

window_sliders = _make_window_sliders()
window_grid = widgets.GridBox(
    window_sliders,
    layout=widgets.Layout(grid_template_columns="repeat(4, 1fr)", grid_gap="4px 8px"),
)


def _rerun():
    with output:
        clear_output(wait=True)
        _render(experiment_dropdown.value, view_rfft_checkbox.value, window_sliders)


def on_experiment_change(change):
    if change["name"] != "value":
        return
    _rerun()


def on_view_change(change):
    if change["name"] != "value":
        return
    _rerun()


def on_window_change(change):
    if change["name"] != "value":
        return
    _rerun()


experiment_dropdown.observe(on_experiment_change)
view_rfft_checkbox.observe(on_view_change)
for slider in window_sliders:
    slider.observe(on_window_change)

display(widgets.VBox([widgets.HBox([experiment_dropdown, view_rfft_checkbox]), window_grid]), output)

# Run once for initial selection
_rerun()

VBox(children=(HBox(children=(Dropdown(description='Experiment:', layout=Layout(width='60%'), options=(('experâ€¦

Output()