# Dataset Validation and Spectral Extraction

- Verify data processing
- Apply timestamp and laser correction algorithms
- Extract nuclear and laser spectra
- Verify spectra (per channel, per dataset)

In [None]:
%load_ext autoreload
%autoreload 2

import os, glob
from cryoant.daq.xia.listmode import load_and_process
from beest.laser import correct_substrate_heating
import matplotlib.pyplot as plt
from matplotlib import colors
import cryoant as ct
import numpy as np
from joblib import Parallel, delayed
# import pandas as pd
import modin.pandas as pd
import modin.config as mcfg

plt.style.use(f"{list(ct.__path__)[0]}/plot.mpl")
pd.set_option("plotting.backend", "matplotlib")

DPI_VIS, DPI_SV = 200, 50
DATE = 20240725

figdir = f"out/spectra_calibration/{DATE}-setup/"
os.makedirs(figdir, exist_ok=True)

df = pd.read_feather(f"out/trace-chewing/processed/{DATE}.feather")
display(df.shape)

### Wall-Clock Time

Each file reuses the same eventID and times, therefore it's important to use the file time to get a real timestamp for every event before continuing

In [None]:
timestamps = {
    k: pd.to_datetime("_".join(k.split("_")[-3:-1]), format="%Y-%m-%d_%H.%M.%S")
    for k in df.fname.unique()
}
df["timestamp"] = (
    df.groupby("fname", observed=True)
    .fname.apply(lambda x: x.map(lambda y: timestamps.copy()[y]))
    .reset_index(level=0, drop=True)
    .astype("category")
)
df["chunk"] = df.fname.apply(
    lambda x: int(x.split("_")[-1].split(".")[0].removeprefix("chunk"))
).astype("category")
zero_timestamps = {
    k: v for k, v in timestamps.items() if k.split("_")[-1].split(".")[0] == "chunk0"
}
zero_timestamps = dict(sorted(zero_timestamps.items(), key=lambda x: x[1]))
#: Create a mapping of fname to run_id. Timestamps between two zero_timestamps belong to the former run_id
#: zero_timestamps is sorted so first, the zero files can be assigned to their run ids
run_ids = {k: v for v, k in enumerate(zero_timestamps.keys())}
#: Now the remainder of files need to be assinged a run id based on the latest zero file they are greater than
for fname, timestamp in timestamps.items():
    if fname in zero_timestamps:
        continue
    run_ids[fname] = [
        run_ids[k] for k in zero_timestamps.keys() if timestamp > zero_timestamps[k]
    ][-1]

df["run_id"] = df.fname.map(run_ids).astype("category")

#: APPLY does not work with modin (produces DataFrame). Luckily, the minimum timestamp for each run is the zero chunk timestamp.
for _, g in df.groupby("run_id"):
    df.loc[g.index, "run_start"] = g["timestamp"].astype("datetime64[ns]").min()
df["realtime"] = pd.Series(pd.to_timedelta(df.time, unit="s")).add(
    pd.to_datetime(df.run_start.astype(str))
)

In [None]:
%%capture
%%
"""Validated wall-clock time.

Realtime axis is now non-repeating across all files, regardless of if multiple runs exist
(and event ID resets)
"""
fig, ax = plt.subplots(figsize=(10,2), dpi=250, constrained_layout=True)
for file in df.fname.unique():
    data = df[df.fname == file]
    ax.plot(
        (
            data["realtime"].astype("datetime64[s]")
            - df["realtime"].astype("datetime64[s]").min()
        )/60/60,
        data["eventID"],
    )
ax.set(
    xlabel="Time (h)",
    ylabel="Event ID",
    title=f"Event ID vs Wall-Clock Time for {DATE}",
)

### Laser Identification

Identification must come before substrate correction so as to not substrate correct the nuclear data.

In [None]:
"""Setup: Good Event Selection"""

#: Generally good data
df["ig_data"] = ~(df.ib_head | df.ib_clipped | df.ib_flat | df.ib_error)

#: Initial Setting of laser/nuclear data
df["ig_laser"] = False
df["ig_event"] = False

In [None]:
"""Setup: Laser Frequencies per Run"""

laser_hzs = {k: 100 for k in df.run_id.unique()}

In [None]:
"""Identify laser via frequency.

Once identified, anticoincident is nuclear data.
"""

groups = []
for (run, ch), chgroup in df[df.ig_data].groupby(["run_id", "channel"]):
    groups.append(
        (
            (
                (chgroup.realtime - chgroup.realtime.min()).dt.total_seconds()
                % (1 / laser_hzs[run])
            )
            - (
                (chgroup.realtime - chgroup.realtime.min()).dt.total_seconds()
                % (1 / laser_hzs[run])
            )
            .rolling(50)
            .median()
        )
        .abs()
        .lt(
            0.05
            * (
                (chgroup.realtime - chgroup.realtime.min()).dt.total_seconds()
                % (1 / laser_hzs[run])
            ).mean()
        )
    )

df["ig_laser"] = pd.concat(groups).reindex(df.index).fillna(df.ig_laser)

del groups

#: Same flags for nuclear events plus NOT laser
df["ig_event"] = ~(df.ib_head | df.ib_clipped | df.ib_flat | df.ib_error | df.ig_laser)

In [None]:
"""Validate Laser Identification

Modulo
"""

runs = df.run_id.unique()
channels = df.channel.unique()
fig, axes = plt.subplots(
    len(runs),
    len(channels),
    figsize=(10 * len(channels), 4 * len(runs)),
    dpi=200,
    constrained_layout=True,
)
if len(runs) == 1:
    axes = np.expand_dims(axes, axis=0)
if len(channels) == 1:
    axes = np.expand_dims(axes, axis=1)
fig.suptitle(f"Laser Identification for {DATE}")

for (run, channel), chgroup in df[df.ig_data].groupby(["run_id", "channel"]):
    i = np.where(runs == run)[0][0]
    j = np.where(channels == channel)[0][0]
    ax = axes[i, j]
    ax.scatter(
        (chgroup.realtime - df.realtime.min()).dt.total_seconds(),
        (chgroup.realtime - df.realtime.min()).dt.total_seconds()
        % (1 / laser_hzs[run]),
        s=0.1,
        lw=0,
        alpha=0.5,
        label="All",
    )
    data = chgroup[chgroup.ig_laser]
    ax.scatter(
        (data.realtime - df.realtime.min()).dt.total_seconds(),
        (data.realtime - df.realtime.min()).dt.total_seconds() % (1 / laser_hzs[run]),
        s=0.1,
        lw=0,
        alpha=0.5,
        label="Laser ID (rolling 50x median)",
    )
    ax.set(
        xlabel="Time since start (s)",
        ylabel=f"Time modulo 1/{laser_hzs[run]} s",
        title=f"Run {run}, Channel {channel}",
    )
    ax.legend(loc="lower left", bbox_to_anchor=(1, 1))

del ax, data, run, channel, chgroup, i, j

display(fig)
fig.savefig(os.path.join(figdir, f"laserIdentification.png"))
print(f"Saved {os.path.join(figdir, f'laserIdentification.png')}")
plt.close(fig)

In [None]:
"""Validate Laser Identification

Spectra
"""

plt.close("all")
runs = df.run_id.unique()
channels = df.channel.unique()
fig, axes = plt.subplots(
    len(runs),
    len(channels),
    figsize=(5 * len(channels), 3 * len(runs)),
    dpi=200,
    constrained_layout=True,
)
axes = np.array(fig.axes).reshape(len(runs), len(channels))
fig.suptitle(f"Corrected Laser for {DATE}")

for group in df[df.ig_data].groupby(["run_id", "channel"]):
    assert isinstance(group[0], tuple)
    run, channel = tuple(group[0])
    chgroup = group[1]
    i = np.where(runs == run)[0][0]
    j = np.where(channels == channel)[0][0]
    ax = axes[i, j]
    assert isinstance(ax, type(fig.axes[0])), f"Expected AxesSubplot, got {type(ax)}"
    data = chgroup[
        (chgroup.height_mV.between(*chgroup.height_mV.quantile([0, 0.99])))
        & (chgroup.otherV.between(*chgroup.otherV.quantile([0, 0.99])))
    ]
    hb = ax.hexbin(
        data.height_mV,
        data.otherV,
        gridsize=500,
        lw=0,
        cmap="inferno",
        norm=colors.LogNorm(),
    )
    cb = fig.colorbar(hb, ax=ax)
    cb.set_label("log10(N)")
    ax.set(
        xlabel="Height (mV)",
        ylabel="Other (V)",
        title=f"Channel {channel}",
    )

display(fig)
fig.savefig(os.path.join(figdir, f"laser-precorrected.png"))
print(f"Saved {os.path.join(figdir, f'laser-precorrected.png')}")
plt.close(fig)

### Laser Substrate Correction

In [None]:
"""Setup: Specific min frequency selections

Varies per channel and dataset based on STJ gain and noise
"""

mVmins = {rid: {ch: 2.6 for ch in df.channel.unique()} for rid in df.run_id.unique()}
#: Per run, channel mVmin assignment
# mVmins[0][18] = 3

In [None]:
"""Do substrate correction.

Takes a while. May need parallelization.
Should I use groupby or simply filter by channel/run?
"""


#: Parallelize AFTER groupby
#: If pandas operations in parallel makes modin confused, then
#: do parallel numpy operations. Makes mapping back to DataFrame more difficult
def process_group(run_id, channel, *args, **kwargs):
    """Pass grouby indicators to results for concatenation."""
    correct, gradient = correct_substrate_heating(*args, **kwargs)
    return (run_id, channel, correct, gradient)


grouped = df[df.ig_laser].groupby(["run_id", "channel"], observed=True)

results = Parallel(n_jobs=-1)(
    #: This for loop is sequential, thus what gets parallelized is numpy only
    # construct_delayed(
    delayed(process_group)(
        run,
        ch,
        group.height_mV._to_pandas(),
        group.sumV._to_pandas(),
        dev_plot=True,
        dev_dir=figdir,
        dev_name=f"substrateCorrection-run{run}-ch{ch}",
        kwdict={
            "mVmin": mVmins[run][ch],
            "dev_plot": True,
            "dev_dir": figdir,
            "dev_name": f"substrateCorrection-run{run}-ch{ch}",
        },
    )
    for (run, ch), group in grouped
)

runs, channels, corrects, gradients = zip(*results)

df.loc["correct"] = pd.concat(corrects).reindex(df.index)
df.loc["gradient"] = pd.concat(gradients).reindex(df.index).astype("category")

del grouped, results, runs, channels, corrects, gradients

In [None]:
"""Validate the correction.

Takes ~3m
"""

plt.close("all")
runs = df.run_id.unique()
channels = df.channel.unique()
fig, axes = plt.subplots(
    len(runs),
    len(channels),
    figsize=(5 * len(channels), 3 * len(runs)),
    dpi=200,
    constrained_layout=True,
)
axes = np.array(fig.axes).reshape(len(runs), len(channels))
fig.suptitle(f"Corrected Laser for {DATE}")

for group in df[df.ig_data].groupby(["run_id", "channel"]):
    assert isinstance(group[0], tuple)
    run, channel = tuple(group[0])
    chgroup = group[1]
    i = np.where(runs == run)[0][0]
    j = np.where(channels == channel)[0][0]
    ax = axes[i, j]
    assert isinstance(ax, type(fig.axes[0])), f"Expected AxesSubplot, got {type(ax)}"
    data = chgroup[
        (chgroup.correct.between(*chgroup.correct.quantile([0, 0.99])))
        & (chgroup.otherV.between(*chgroup.otherV.quantile([0, 0.99])))
    ]
    hb = ax.hexbin(
        data.correct,
        data.otherV,
        gridsize=500,
        lw=0,
        cmap="inferno",
        norm=colors.LogNorm(),
    )
    cb = fig.colorbar(hb, ax=ax)
    cb.set_label("log10(N)")
    ax.set(
        xlabel="Height (mV)",
        ylabel="Other (V)",
        title=f"Channel {channel}",
    )

display(fig)
fig.savefig(os.path.join(figdir, f"run{run}-laser-corrected.png"))
print(f"Saved {os.path.join(figdir, f'run{run}-laser-corrected.png')}")
plt.close(fig)

In [None]:
%%capture
%%
"""Validate laser correction.

Unlikely useful here as laser wasn't very well resolved across all runs.
"""
fig, ax = plt.subplots(figsize=(10, 6), dpi=200, constrained_layout=True)
for file in df.fname.unique():
    d = df[(df.fname == file) & (df.ig_laser)]
    ax.scatter(d.height_mV, d.otherV, s=0.1, lw=0, alpha=0.1, label=file)
ax.legend(
    loc="upper right",
    fontsize=8,
    # title="File",
    # title_fontsize=8,
    # shadow=True,
    # fancybox=True,
    # frameon=True,
    # framealpha=0.5,
    # edgecolor="black",
    # facecolor="white",
    ncol=2,
    markerscale=10,
    scatterpoints=10,
    # handletextpad=0.1,
    # handlelength=0.5,
    # handleheight=0.5,
    # borderpad=0.1,
    # labelspacing=0.1,
    # columnspacing=0.1,
    # numpoints=1,
    # mode="expand",
    bbox_to_anchor=(1.05, 1),
    # bbox_transform=None,
    # handler_map=None,
)
fig.savefig(os.path.join(figdir,"laserCorrected-hexbin.png"))

In [None]:
df.to_feather(f"out/spectra_calibration/{DATE}-setup.feather")

# Further Studies

## Nuclear Data

## Using Laser Data!

In [None]:
"""Per Channel Nuclear Data

~ig_laser

"""

channels = df.channel.unique()
fig, axes = plt.subplots(
    len(channels),
    1,
    figsize=(8, 3.2 * len(channels)),
    dpi=DPI_VIS,
    constrained_layout=True,
)
for channel, ax in list(zip(channels, np.ravel(axes))):
    range = (0, 50)
    bins = 1000
    if channel == 27:
        bins = 350
        range = (0, 25)
    data = df[(df.channel == channel) & (~df.ig_laser)]
    h, b = np.histogram(data.height_mV, bins=bins, range=range)
    ax.step(b[:-1], h, where="post", label="All Laser-Anticoincident Data", zorder=-21)
    for mxmult in [2, 4, 6, 8, 10, 15, 20]:
        data = df[(df.channel == channel) & (~df.ig_laser) & (df.multiplicity < mxmult)]
        h, b = np.histogram(data.height_mV, bins=bins, range=range)
        ax.step(
            b[:-1], h, where="post", label=f"Multiplicity < {mxmult}", zorder=-mxmult
        )
    ax.set(
        title=f"Channel {channel}",
        xlabel="Height [mV]",
        ylabel="Counts",
        yscale="log",
    )
    ax.legend()
display(fig)
fig.savefig(os.path.join(figdir, "per_channel-nuclear_data-laser_anticoincidence.png"))
print(
    f"Saved {os.path.join(figdir, 'per_channel-nuclear_data-laser_anticoincidence.png')}"
)
plt.close(fig)

# Debug