In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import re
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
import time

from trigger_study.wk9.preliminary.OptimumFilter import *
from trace_IO import *
# --- Configuration ---
channel_std = {
    45: 0.8440, 46: 0.7196, 47: 0.7514, 48: 0.7595, 49: 0.7367,
    50: 0.8430, 51: 0.6612, 52: 0.7112, 53: 0.6817,
}
thresholds = {
    n: {ch: sigma * n for ch, sigma in channel_std.items()}
    for n in [3, 4, 5]
}
channels = list(range(45, 54))  # Channels to process

# --- Locate and sort files by energy ---
base_path = Path("/ceph/dwong/trigger_samples/qp_sample")
file_paths = sorted(base_path.glob("QP_sample_quantized_*eV.zst"))

def extract_energy(fp):
    match = re.search(r'_(\d+)eV', fp.stem)
    return int(match.group(1)) if match else -1

file_paths = sorted(file_paths, key=extract_energy)
energies = [extract_energy(fp) for fp in file_paths]

# --- Efficiency results container ---
efficiency_results = {thr: [0.0] * len(file_paths) for thr in thresholds}

# --- Worker function ---
def process_file_optimized(index, energy, path_str):
    traces = load_traces_from_zstd(path_str, n_traces=100)  # shape: (100, 54, 32768)
    n_traces = traces.shape[0]

    # Compute amplitudes for all 9 channels
    amps = np.zeros((n_traces, len(channels)))
    for i, ch in enumerate(channels):
        amps[:, i] = [channel_optimum_filters[ch].fit(tr[ch])[0] for tr in traces]

    # Evaluate threshold efficiency
    result = {}
    for thr_val in thresholds:
        threshold_vec = np.array([thresholds[thr_val][ch] for ch in channels])
        triggered = (amps > threshold_vec).any(axis=1)
        result[thr_val] = np.sum(triggered) / n_traces

    return index, result

# --- Parallel execution ---
start_time = time.time()

with ProcessPoolExecutor(max_workers=10) as executor:
    futures = [
        executor.submit(process_file_optimized, i, e, str(p))
        for i, (e, p) in enumerate(zip(energies, file_paths))
    ]
    for future in tqdm(as_completed(futures), total=len(futures), desc="Processing files"):
        idx, result = future.result()
        for thr_val in result:
            efficiency_results[thr_val][idx] = result[thr_val]

print(f"Completed in {time.time() - start_time:.2f} seconds")

# --- Plotting ---
plt.figure(figsize=(8, 5))
for thr_val in sorted(thresholds):
    plt.plot(energies, efficiency_results[thr_val], label=f"{thr_val}σ")
plt.xlabel("Energy [eV]")
plt.ylabel("Trigger Efficiency")
plt.title("Trigger Efficiency vs Energy (Any of 9 Channels)")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()


Processing files:   0%|          | 0/32 [05:26<?, ?it/s]


KeyboardInterrupt: 