In [None]:
%load_ext autoreload
%autoreload 2

## Setup

In [None]:
"""Find optimized shaping by grid search

Loads a small amount of data for the day/chunk and then performs a grid search
to find the optimal shaping parameters. This includes:
- The shaping times (rise, flat)
- The fast shaping time (rise)
The grid search produces a dataframe of data which then needs to be substrate corrected.
Finally, a gaussian width of the laser data is reflective of the resolution.
"""

import click
import os
import glob
import json

# from cryoant.daq.xia.listmode import find_optimum_processing
from cryoant.daq.xia.listmode import load_and_process
from beest.laser import correct_substrate_heating
import matplotlib.pyplot as plt
import cryoant as ct
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
import gc


class ScatterHistogramPlot:
    def __init__(self, x, y, name, settings, axis="x"):
        self.x = x
        self.y = y
        self.name = name
        self.settings = settings
        self.axis = axis

    def plot(self):
        if self.axis == "x":
            fig, (ax_scatter, ax_hist) = plt.subplots(
                2,
                1,
                sharex=True,
                gridspec_kw={"height_ratios": [3, 1]},
                figsize=(5, 5),
                dpi=500,
                constrained_layout=True,
            )
        else:
            fig, (ax_hist, ax_scatter) = plt.subplots(
                1,
                2,
                sharey=True,
                gridspec_kw={"width_ratios": [1, 3]},
                figsize=(5, 5),
                dpi=500,
                constrained_layout=True,
            )

        # Scatter plot
        ax_scatter.scatter(self.x, self.y, lw=0, s=0.1, alpha=0.2)
        ax_scatter.set(**self.settings)

        # Histogram
        if self.axis == "x":
            h, e = np.histogram(self.x, bins=1000)
            ax_hist.step(e[:-1], h, color="blue")
            ax_hist.set(ylabel="Count", yscale="log")
        else:
            h, e = np.histogram(self.y, bins=1000)
            ax_hist.step(h, e[:-1], color="blue")
            ax_hist.set(xlabel="Count", xscale="log")

        return fig

In [None]:
def file_processor(file, multithread=True):
    # dfi, header, _ = find_optimum_processing(file, multithread=multithread)
    dfi, header, _ = load_and_process(file, multithread=multithread)
    size = os.path.getsize(file)
    print(f"Size: {size / 1024 / 1024 / 1024} GB")
    if len(dfi) == 0:
        print(f"Skipping {file}")
        return dfi, header, size
    dfi["ig_laser"] = np.abs(
        (dfi.time % 0.1)
        - (dfi.time % 0.1).rolling(min([1000, int(0.01 * len(dfi))])).median()
    ) < 0.01 * np.mean(dfi.time % 0.1)
    dfi["fname"] = os.path.basename(file)
    gc.collect()
    return dfi, header, size


def process_chunk(files_chunk, no_parallel):
    if no_parallel:
        return [file_processor(file, multithread=False) for file in files_chunk]
    else:
        return Parallel(n_jobs=8)(delayed(file_processor)(file) for file in files_chunk)

## File Processor Debug

## Low-Level Debugging

Mean of empty slice?

In [None]:
"""Generate scatter from listmode data.

date is in YYMMDD format. Gets all files in setup with 20{date} in path.
Identify laser by rolling median filter.
Scatter laser data for investigation of substrate heating correction.
"""

plt.style.use(f"{list(ct.__path__)[0]}/plot.mpl")

DPI_VIS, DPI_SV = 200, 50
DATE = "240812"
no_parallel = False

directory = "/beest_data/summer2024/Be7_Ta_PR_Mask_listmode_ben/d/"
files = glob.glob(os.path.join(directory, "*_Al_*/*.bin"))

# Sort files by filesize
files.sort(key=os.path.getsize)
files = [file for file in files if os.path.getsize(file) > 0]

chunk_size = 8  # Define the chunk size for processing
df = pd.DataFrame()
total, skip = 0, 0
files = [file for file in files if f"20{DATE}" in os.path.dirname(file)][:4]
headers = []
for i in range(0, len(files) + chunk_size, chunk_size):
    files_chunk = files[i : i + chunk_size]
    print(f"Processing chunk {1 + i // chunk_size}/{-(-len(files)//chunk_size)}")
    results = process_chunk(files_chunk, no_parallel)
    for result in results:
        if result is None:
            skip += 1
            continue
        dfi, header, size = result
        if len(dfi) == 0:
            skip += 1
            continue
        headers.append(header)
        df = pd.concat([df, dfi])
        del dfi  # Explicitly delete dfi to free memory
        total += size
    del results  # Discard the processed chunk
    gc.collect()  # Trigger garbage collection
    print(f"Chunk {1 + i // chunk_size}/{-(-len(files)//chunk_size)} processed")
print(
    f"Processed {len(files) - skip} files, total size: {total / 1024 / 1024 / 1024} GB"
)

testcols = [col.split("_")[1] for col in df.columns if "h_" == col[:2]]
for col in testcols:
    fig = ScatterHistogramPlot(
        df[f"hmV_{col}"],
        df[f"omV_{col}"],
        f"Shaping: {col}",
        {
            "xlabel": f"hmV_{col}",
            "ylabel": f"omV_{col}",
            "title": f"Shaping: {col}",
        },
    ).plot()
    fig.savefig(f"out/shapetest_precorrected_{DATE}_{col}.png")
    df[f"corrected_hmV_{col}"] = correct_substrate_heating(
        df[f"hmV_{col}"], df.ig_laser, df[f"smV_{col}"]
    )
    fig = ScatterHistogramPlot(
        df[f"corrected_hmV_{col}"],
        df[f"omV_{col}"],
        f"Shaping: {col}",
        {
            "xlabel": f"corrected_hmV_{col}",
            "ylabel": f"omV_{col}",
            "title": f"Shaping: {col}",
        },
    ).plot()
    fig.savefig(f"out/shapetest_corrected_{DATE}_{col}.png")

print("Saving data to HDF5")
df.to_hdf(f"out/processed/shapetest_{DATE}.h5", key="data", mode="a")
[
    pd.Series(header).to_hdf(
        f"out/processed/shapetest_{DATE}.h5", key="metadata", mode="a"
    )
    for header in headers
]
print(
    f"Data Saved. File size: {os.path.getsize(f'out/processed/shapetest_{DATE}.h5') / 1024 / 1024 / 1024} GB"
)

In [None]:
testcols = [col.split("_")[1] for col in df.columns if "h_" == col[:2]]

In [None]:
"""
argmaxFFT, maxFFT, angmaxFFT, gradients = deskewFFT(
    pulseV_laser, pulseV_laser_other
)
"""

pulseHeight = pulseV_laser
otherHeight = pulseV_laser_other
bins = 100000
voltageRange = [0.0, 1.0]
minFFT = 350
gradientRange = (-0.040, 0.040, 0.00001)
# fmt: off
gradients = np.arange(*gradientRange, dtype=np.float32)
maxFFT    = np.empty_like(gradients,  dtype=np.float32)
angmaxFFT = np.empty_like(gradients,  dtype=np.float32)
argmaxFFT = np.empty_like(gradients,  dtype=np.intc)
# fmt: on
for i, grad in enumerate(gradients):
    h, b = np.histogram(pulseHeight - otherHeight * grad, range=voltageRange, bins=bins)
    fft = np.fft.rfft(h - np.mean(h), norm="ortho")
    argmaxFFT[i] = minFFT + np.argmax(np.abs(fft))
    maxFFT[i] = np.abs(fft)[argmaxFFT[i] - minFFT]
    angmaxFFT[i] = np.angle(fft[argmaxFFT[i] - minFFT])

In [None]:
from beest.laser import deskewFFT

"""correct_substrate_heating(
        df[f"hmV_{col}"], df.ig_laser, df[f"smV_{col}"]
    )
    """
height_mV = df[f"hmV_{col}"]
ig_laser = df.ig_laser
sumV = df[f"smV_{col}"]
deskewParams = {}
slice_max = 0.025
numLines = 20
firstLine = 7
pulseV_laser = height_mV[ig_laser].astype(np.float32).to_numpy()
sum_heights = sumV[ig_laser].astype(np.float32).to_numpy()
pulseV_laser_other = sum_heights - pulseV_laser.astype(np.float32)

#: TODO Should deskewFFT take all entries to preserve column shape?
argmaxFFT, maxFFT, angmaxFFT, gradients = deskewFFT(pulseV_laser, pulseV_laser_other)
#: Angmax may vary from -pi to pi, but its basis is still 2pi
offset = angmaxFFT[np.argmax(maxFFT)] / 2 / np.pi

sliceEdges = np.arange(
    #: Offset is located at the peak heights, slice between peaks with 0.5
    (0.5 + offset) / argmaxFFT[np.argmax(maxFFT)],
    slice_max,
    1 / argmaxFFT[np.argmax(maxFFT)],
)

corrections = np.zeros(numLines, dtype=float)

xmin = 0
xmax = 0.05
xs = pulseV_laser_other
ys = pulseV_laser - pulseV_laser_other * gradients[np.argmax(maxFFT)]
for i in range(numLines):
    ymin = sliceEdges[firstLine + i]
    ymax = sliceEdges[firstLine + i + 1]
    cutoutx = (pulseV_laser_other < xmax) & (pulseV_laser_other > xmin)
    cutouty = (ymin < ys) & (ymax > ys)
    cutout = cutoutx & cutouty
    try:
        poly = np.poly1d(np.polyfit(xs[cutout], ys[cutout], 1))
    except TypeError:
        continue
    corrections[i] = poly[1]
avg_correction = np.average(corrections)
true_grad = gradients[np.argmax(maxFFT)] + avg_correction

corrected = np.array(height_mV.copy())
corrected[ig_laser] = height_mV[ig_laser] - pulseV_laser_other * true_grad

OK This sucks. Do it manually.

# Manual Filtering of Data to find Something that looks decent

In [None]:
from cryoant.daq.xia.listmode import (
    read_listmode_file_optimized,
    get_opt_tpz_for_channel,
)

DATE = "240812"

In [None]:
directory = "/beest_data/summer2024/Be7_Ta_PR_Mask_listmode_ben/d/"
files = glob.glob(os.path.join(directory, "*_Al_*/*.bin"))
files.sort(key=os.path.getsize)
files = [file for file in files if os.path.getsize(file) > 0]
files = [file for file in files if f"20{DATE}" in os.path.dirname(file)]
filename = files[0]
display(filename)

In [None]:
kwdict = json.load(open(f"kwfile-20{DATE}.json", "r"))
display(kwdict)

In [None]:
kwdict["debug_plots"] = False
kwdict["outdir"] = "out/dev"

df, header, opt_tpzs = load_and_process(
    filename,
    known_tpz=None,
    multithread=False,
    events_total=-1,
    drop_trace=False,
    kwdict=kwdict,
)

In [None]:
from cryoant.daq.xia.listmode import extract_pulse_heights_threaded

risetest = [100, 700, 100]
flattest = [100, 200, 100]
fastrise = [10, 100, 50]
testcols = []
tests = np.array(
    np.meshgrid(
        np.arange(risetest[0], risetest[1], risetest[2]),
        np.arange(flattest[0], flattest[1], flattest[2]),
        np.arange(fastrise[0], fastrise[1], fastrise[2]),
    )
).T.reshape(-1, 3)


def tester(params):
    rise, flat, fast = params
    testcol = f"r{rise}f{flat}s{fast}"
    heights, fast_amps = extract_pulse_heights_threaded(
        df,
        opt_tpzs,
        rise=rise,
        flat=flat,
        fall=rise,
        fast_rise=fast,
        fast_fall=fast,
    )
    return testcol, heights, fast_amps


tester(tests[0])

In [None]:
# dfs = []
# for file in files:
#     _, df = read_listmode_file_optimized(
#                     file, events_total=1e4
#                 )
#     dfs.append(df)
# df = pd.concat(dfs)
# del dfs