In [None]:
%load_ext autoreload
%autoreload 2
import waffles
import numpy as np
import json
import shutil 
from tqdm import tqdm
import matplotlib.pyplot as plt
from plotly import graph_objects as go
import plotly.subplots as psu
import pandas as pd
from typing import Optional
import os

import waffles
from waffles.input_output.hdf5_structured import load_structured_waveformset
from waffles.data_classes.Waveform import Waveform
from waffles.data_classes.WaveformSet import WaveformSet
from waffles.data_classes.BasicWfAna import BasicWfAna
from waffles.data_classes.IPDict import IPDict
from waffles.data_classes.UniqueChannel import UniqueChannel
from waffles.data_classes.ChannelWsGrid import ChannelWsGrid
from waffles.utils.baseline.baseline import SBaseline
from waffles.utils.numerical_utils import average_wf_ch
from waffles.utils.utils import compute_peaks_rise_fall_ch
from waffles.utils.selector_waveforms import WaveformSelector
from waffles.np02_utils.AutoMap import generate_ChannelMap, dict_uniqch_to_module, dict_module_to_uniqch, strUch, ordered_channels_membrane
from waffles.np02_utils.PlotUtils import np02_gen_grids, plot_grid, plot_detectors, genhist, fithist, runBasicWfAnaNP02, plot_averages, plot_averages_w_peaks_rise_fall, plot_averages_normalized
from waffles.np02_utils.load_utils import open_processed, ch_read_calib

In [None]:
#dettype = "membrane"
dettype = "cathode"

datadir = f"/eos/experiment/neutplatform/protodune/experiments/ProtoDUNE-VD/commissioning/"
det = "VD_Cathode_PDS" if dettype == "cathode" else "VD_Membrane_PDS"
endpoint = 106 if dettype == "cathode" else 107

dletter = dettype.upper()[0] # C or M...
group1 = [ f"{dletter}{detnum}({chnum})" for detnum in range(1, 3) for chnum in range(1,3) ]
group2 = [ f"{dletter}{detnum}({chnum})" for detnum in range(3, 5) for chnum in range(1,3) ]
group3 = [ f"{dletter}{detnum}({chnum})" for detnum in range(5, 7) for chnum in range(1,3) ]
group4 = [ f"{dletter}{detnum}({chnum})" for detnum in range(7, 9) for chnum in range(1,3) ]
groupLow = group1+group2
groupHig = group3+group4
groupall = group1+group2+group3+group4

In [None]:
run_to_module = {

    "membrane" : {
        #40800 : ["M3(1)", "M3(2)"], # Mask 8, int 1600
        40801 : ["M3(1)", "M3(2)"], # Mask 8, int 2250
        40808 : ["M4(1)", "M4(2)", "M6(1)", "M6(2)"], # Mask 16, int 4000
        41522 : ["M5(1)", "M5(2)"], # Mask 16, int 3500   
        41236 : ["M7(1)", "M7(2)"], # Mask 8, int 3250 
        #40807 : ["M8(1)", "M8(2)"], # Mask 8, int 4000 
        #42062 : ["M8(1)", "M8(2)"], # !! Mask 8, int 1500 -> After swapping fibers: Mask 8 was top but here is bottom !!
        42156 : ["M8(1)", "M8(2)"], # !! Mask 8, int 2100 -> After swapping fibers !!
    },
    "cathode" : { 
       40808 : ["C1(1)", "C1(2)", "C7(1)", "C7(2)", "C8(1)", "C8(2)"], # Mask 16, int 4000
       41519 : ["C2(1)", "C2(2)"], # Mask 8, int 3000 
       41536 : ["C3(1)", "C3(2)"], # Mask 1, 2800
       #41539 : ["C4(1)", "C4(2)", "C5(1)", "C5(2)"], # Mask 16, int 1600
       42002 : ["C4(1)", "C4(2)", "C5(1)", "C5(2)"], # Mask 16, int 2400
       40807 : ["C6(1)", "C6(2)"], # Mask 8, int 4000  
    }

}
run_to_module = run_to_module[dettype]

run_to_unich = { r: [ dict_module_to_uniqch[m].channel for m in modules ] for r, modules in run_to_module.items() }
channels = [ x for v in run_to_unich.values() for x in v]

In [None]:
from collections import Counter
def get_good_timestamps(wfset, run, endpoint):
    timestamps = sorted([ wf.timestamp for wf in wfset.waveforms ])
    c = Counter(timestamps)
    print(f"Total number of timestamps: {len(c)}")
    matchtimestamps = [ k for k in c if c[k] >= len(list(wfset.available_channels[run][endpoint]))-2]
    print(f"Remaining timestamps: {len(matchtimestamps)}")
    return matchtimestamps

In [None]:
def get_external(waveform: Waveform, validtimes = []) -> bool:
    if waveform.timestamp  not in validtimes:
        return False    
    return True
matchingtimes = []

In [None]:
import copy
import time
def select_channels(waveform: Waveform, channels: list) -> bool:
    if waveform.channel not in channels:
        return False
    return True

def create_wfset(run_to_unich, endpoint):
    nwaveforms = 10000
    wfset_full = None
    for run, channels in run_to_unich.items():
        wfset = open_processed(run, dettype, datadir, None, [endpoint], nwaveforms=nwaveforms, verbose=True)
        #wfset = open_processed(run, dettype, datadir, channels, [endpoint], nwaveforms=nwaveforms, verbose=True)
        
        matchtimestamps = get_good_timestamps(wfset, run, endpoint)
        wfset = WaveformSet.from_filtered_WaveformSet(wfset, get_external, matchtimestamps, show_progress=True)
        wfset = WaveformSet.from_filtered_WaveformSet(wfset, select_channels, channels)
        if wfset_full is None:
            wfset_full = copy.deepcopy(wfset)
        else:
            wfset_full.merge(copy.deepcopy(wfset))
        print(f"Loaded run {run}")
    return wfset_full
    
start = time.time()
wfset_full = create_wfset(run_to_unich, endpoint)
end = time.time()
print(end - start)
wfset_full

In [None]:
runBasicWfAnaNP02(wfset_full, int_ll=250, int_ul=280, amp_ll=100, amp_ul=260, threshold=30, configyaml="")

In [None]:
argsheat = dict(
    mode="heatmap",
    analysis_label="std",
    adc_range_above_baseline=4000,
    adc_range_below_baseline=-250,
    adc_bins=125,
    time_bins=wfset_full.points_per_wf//2,
    filtering=36,
    share_y_scale=False,
    share_x_scale=True,
    wfs_per_axes=5000,
    zlog=True,
    width=1300,
    higth=650,
    showplots=True
)

#detector = groupLow
detector = groupHig
#detector = groupall

#detector=["M7(1)","M7(2)"]
#detector=["C1(1)"]

plot_detectors(wfset_full, detector, **argsheat)


In [None]:
extractor = WaveformSelector('cuts.yaml')
wfset_clean = WaveformSet.from_filtered_WaveformSet(wfset_full, extractor.applycuts, show_progress=True)
print(f"Original waveforms: {len(wfset_full.waveforms)}, after cut: {len(wfset_clean.waveforms)}")

In [None]:
argsheat = dict(
    mode="heatmap",
    analysis_label="std",
    adc_range_above_baseline=4100,
    adc_range_below_baseline=-150,
    adc_bins=125,
    time_bins=wfset_full.points_per_wf//2,
    filtering=36,
    share_y_scale=False,
    share_x_scale=True,
    wfs_per_axes=5000,
    zlog=True,
    width=1300,
    higth=650,
    return_fig=True,
    showplots=True
)

figs = plot_detectors(wfset_clean, detector, **argsheat)
fig, rows, cols, title, g = figs[0]


In [None]:
fig_simple = psu.make_subplots(rows=1, cols=1)
#fig_simple = psu.make_subplots(rows=2, cols=4)
fig_simple = psu.make_subplots(rows=4, cols=4)
plot_averages(fig_simple, g)
fig_simple.show()

In [None]:

#detector=["M8(1)","M8(2)"]

figs = plot_detectors(wfset_clean, detector, **argsheat)
fig, rows, cols, title, g = figs[0]

# wfset_ch = ChannelWsGrid.clusterize_waveform_set(wfset_clean)
# for ch, wfch in wfset_ch[endpoint].items():
#     print(dict_uniqch_to_module[strUch(endpoint,ch)])
     
        
plot_averages(fig, g)
fig.show()

In [None]:
#fig_simple = psu.make_subplots(rows=2, cols=4)
fig_simple = psu.make_subplots(rows=4, cols=4)
#gt = np02_gen_grids(wfset_clean, detector)
#plot_averages_w_peaks(fig_simple, gt["Custom"])
peaks_all = compute_peaks_rise_fall_ch(wfset_clean)
plot_averages_w_peaks_rise_fall(peaks_all, fig_simple, g, x_range=(200,550), rise_fall=False)
fig_simple.show()

In [None]:
calibration_file = 'np02-config-v4.0.0.csv'
calibration_data = ch_read_calib(calibration_file)


In [None]:
for (ep, ch), vals in peaks_all.items():
    peak_value = vals["peak_value"]

    if ch not in calibration_data[endpoint]:
        print(f"Channel {ch} not found in calibration file")
        continue

    spe_amp = calibration_data[endpoint][ch]['SpeAmpl']
    normalized_peak = peak_value / spe_amp

    print(
        f"Endpoint {ep}, Channel {ch}: "
        f"Peak = {peak_value}, "
        f"SPE = {spe_amp}, "
        f"Normalized = {normalized_peak}"
    )

In [None]:
import waffles

In [None]:
template_outputdir = waffles.__path__[0] + "/np02_data/templates/templates_large_pulses/"
template_outputdir

In [None]:
detector = groupall
fig_simple = psu.make_subplots(rows=4, cols=4)
gt = np02_gen_grids(wfset_clean, detector)
#plot_averages_normalized(fig_simple, gt["Custom"], spe_by_channel)
plot_averages_normalized(fig_simple, g, spe_by_channel, save=True, save_dir=template_outputdir)
fig_simple.show()