# Comparing the Different Components of Our True Metrics to Measurable Metrics

In [32]:
import yaml
import torch
import torch.nn as nn
from pathlib import Path
from joint_tof_opt import *
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output


with open("../data/parameter_mapping.json", "r") as f:
    parameter_mapping = yaml.safe_load(f)

with open("../experiments/tof_config.yaml", "r") as f:
    tof_config = yaml.safe_load(f)

experiments = parameter_mapping["experiments"]
fetal_f = tof_config["fetal_f"]
maternal_f = tof_config["maternal_f"]
sampling_rate = tof_config["sampling_rate"]
num_timepoints = tof_config["datapoint_count"]
bin_count = tof_config["bin_count"]
filter_hw = 0.001

In [33]:
experiment_labels = [exp.get("name", exp["filename"]) for exp in experiments]
experiment_dropdown = widgets.Dropdown(
    options=[(label, idx) for idx, label in enumerate(experiment_labels)],
    value=0,
    description="Experiment:",
    layout=widgets.Layout(width="60%"),
)
view_rfft_checkbox = widgets.Checkbox(
    value=False,
    description="View RFFT",
    indent=False,
)
output = widgets.Output()


def _to_plot_data(series: torch.Tensor, use_rfft: bool):
    if not use_rfft:
        return series.detach().cpu().numpy()
    spectrum = torch.fft.rfft(series - series.mean())
    return spectrum.abs().detach().cpu().numpy()


def _plot_line(ax, series: torch.Tensor, label: str, use_rfft: bool, color=None, linestyle="-", logscale: bool = False):
    data = _to_plot_data(series, use_rfft)
    ax.plot(data, label=label, color=color, linestyle=linestyle)
    if logscale:
        ax.set_yscale("log")
    ax.legend()


def run_experiment(idx: int, use_rfft: bool = False):
    experiment = experiments[idx]
    ppath_filename = Path("../data") / experiment["filename"]
    measurand = "abs"
    window = torch.ones(bin_count)
    window /= window.norm(p=1)

    filter_fetal_f = fetal_f
    fetal_filter = CombSeparator(
        sampling_rate,
        filter_fetal_f,
        2 * filter_fetal_f,
        filter_hw,
        num_timepoints // 2 + 1,
        True,
    )
    maternal_filter = CombSeparator(
        sampling_rate,
        maternal_f,
        2 * maternal_f,
        filter_hw,
        num_timepoints // 2 + 1,
        True,
    )

    baseline_tof = compute_tof_data_series(ppath_filename, tof_config, True, True)
    baseline_measurand = get_named_moment_module(measurand, baseline_tof).forward(window)
    # fetal_component = fetal_filter.forward(baseline_measurand.unsqueeze(0).unsqueeze(0)).squeeze(0)
    maternal_component = maternal_filter.forward(baseline_measurand.unsqueeze(0).unsqueeze(0)).squeeze(0) / 4.
    residual = baseline_measurand - maternal_component - baseline_measurand.mean()
    fetal_component = fetal_filter.forward(residual.unsqueeze(0).unsqueeze(0)).squeeze(0) / 4.
    measured_maternal_energy = ((maternal_component - maternal_component.mean()) ** 2).sum().item()
    measured_fetal_energy = ((fetal_component - fetal_component.mean()) ** 2).sum().item()

    pure_maternal_tof = compute_tof_data_series(ppath_filename, tof_config, True, False)
    maternal_measurand = get_named_moment_module(measurand, pure_maternal_tof).forward(window)
    pure_maternal_energy = ((maternal_measurand - maternal_measurand.mean()) ** 2).sum().item()

    pure_fetal_tof = compute_tof_data_series(ppath_filename, tof_config, False, True)
    fetal_measurand = get_named_moment_module(measurand, pure_fetal_tof).forward(window)
    pure_fetal_energy = ((fetal_measurand - fetal_measurand.mean()) ** 2).sum().item()

    measured_metric = measured_fetal_energy / (measured_maternal_energy) ** (1 / 2)
    pure_metric = pure_fetal_energy / (pure_maternal_energy) ** (1 / 2)

    print("Measured Maternal Energy: ", measured_maternal_energy)
    print("Measured Fetal Energy: ", measured_fetal_energy)
    print("Pure Maternal Energy: ", pure_maternal_energy)
    print("Pure Fetal Energy: ", pure_fetal_energy)
    print("Measured Metric: ", measured_metric)
    print("Pure Metric: ", pure_metric)

    fig, axes = plt.subplots(3, 2, figsize=(10, 8))
    _plot_line(axes[0, 0], maternal_component, "Measured Maternal Component", use_rfft)
    _plot_line(axes[0, 1], fetal_component, "Measured Fetal Component", use_rfft, color="orange")
    _plot_line(axes[1, 0], maternal_measurand - maternal_measurand.mean(), "Pure Maternal Component", use_rfft, color="green")
    _plot_line(axes[1, 1], fetal_measurand - fetal_measurand.mean(), "Pure Fetal Component", use_rfft, color="red")
    _plot_line(axes[2, 0], baseline_measurand - baseline_measurand.mean(), "Detrended Baseline Measurand", use_rfft, color="purple", logscale=use_rfft)
    _plot_line(
        axes[2, 0],
        (fetal_component + maternal_component),
        "Sum of Measured Components",
        use_rfft,
        color="brown",
        linestyle="--",
    )
    _plot_line(axes[2, 1], baseline_measurand - baseline_measurand.mean(), "Detrended Baseline Measurand", use_rfft, color="purple", logscale=use_rfft)
    _plot_line(
        axes[2, 1],
        maternal_measurand + fetal_measurand - maternal_measurand.mean() - fetal_measurand.mean(),
        "Sum of Pure Components",
        use_rfft,
        color="cyan",
        linestyle="--",
    )
    for ax in axes.flat:
        ax.set_xlabel("Frequency Bin" if use_rfft else "Time Index")
    plt.tight_layout()
    plt.show()


def on_experiment_change(change):
    if change["name"] != "value":
        return
    with output:
        clear_output(wait=True)
        run_experiment(change["new"], view_rfft_checkbox.value)


def on_view_change(change):
    if change["name"] != "value":
        return
    with output:
        clear_output(wait=True)
        run_experiment(experiment_dropdown.value, change["new"])


experiment_dropdown.observe(on_experiment_change)
view_rfft_checkbox.observe(on_view_change)
display(widgets.HBox([experiment_dropdown, view_rfft_checkbox]), output)

# Run once for initial selection
with output:
    clear_output(wait=True)
    run_experiment(experiment_dropdown.value, view_rfft_checkbox.value)

HBox(children=(Dropdown(description='Experiment:', layout=Layout(width='60%'), options=(('experiment_0000.npz'â€¦

Output()