In [None]:
%load_ext autoreload
%autoreload 2
import waffles
import numpy as np
import json
import shutil 
from tqdm import tqdm
import matplotlib.pyplot as plt
from plotly import graph_objects as go
import plotly.subplots as psu

from waffles.input_output.hdf5_structured import load_structured_waveformset
from waffles.data_classes.Waveform import Waveform
from waffles.data_classes.WaveformSet import WaveformSet
from waffles.data_classes.BasicWfAna import BasicWfAna
from waffles.data_classes.IPDict import IPDict
from waffles.data_classes.UniqueChannel import UniqueChannel
from waffles.data_classes.ChannelWsGrid import ChannelWsGrid
from waffles.utils.baseline.baseline import SBaseline
from waffles.utils.numerical_utils import average_wf_ch
from waffles.utils.selector_waveforms import WaveformSelector
from waffles.np02_utils.AutoMap import generate_ChannelMap, dict_uniqch_to_module, dict_module_to_uniqch, strUch
from waffles.np02_utils.PlotUtils import np02_gen_grids, plot_grid, plot_detectors, genhist, fithist, runBasicWfAnaNP02, plot_averages

In [None]:
dettype = "membrane"
#dettype = "cathode"

datadir = f"/eos/experiment/neutplatform/protodune/experiments/ProtoDUNE-VD/commissioning/"
det = "VD_Cathode_PDS" if dettype == "cathode" else "VD_Membrane_PDS"
endpoint = 106 if dettype == "cathode" else 107

dletter = dettype.upper()[0] # C or M...
group1 = [ f"{dletter}{detnum}({chnum})" for detnum in range(1, 3) for chnum in range(1,3) ]
group2 = [ f"{dletter}{detnum}({chnum})" for detnum in range(3, 5) for chnum in range(1,3) ]
group3 = [ f"{dletter}{detnum}({chnum})" for detnum in range(5, 7) for chnum in range(1,3) ]
group4 = [ f"{dletter}{detnum}({chnum})" for detnum in range(7, 9) for chnum in range(1,3) ]
groupLow = group1+group2
groupHig = group3+group4
groupall = group1+group2+group3+group4

In [None]:
run_to_module = {

    "membrane" : {
        #40808 : ["M3(1)", "M3(2)", "M4(1)", "M4(2)"],  
        #40807 : ["M5(1)"], # Mask 8
        #41522 : ["M5(2)"], # Mask 16   
        40808 : ["M6(1)", "M6(2)"], # Mask 16      
        #41236 : ["M7(1)", "M7(2)", "M8(1)", "M8(2)"],
    },
    "cathode" : { 
       40808 : ["C1(1)", "C1(2)"],
       40807 : ["C2(1)", "C2(2)","C3(1)", "C3(2)", "C5(1)", "C5(2)", "C6(1)", "C6(2)", "C6(2)", "C7(1)", "C7(2)", "C8(1)", "C8(2)"],
       40808 : ["C4(1)", "C4(2)"],
    }

}
run_to_module = run_to_module[dettype]

run_to_unich = { r: [ dict_module_to_uniqch[m].channel for m in modules ] for r, modules in run_to_module.items() }
channels = [ x for v in run_to_unich.values() for x in v]

In [None]:
from glob import glob
import copy
def open_processed(run, dettype, output_dir, channels = None, endpoints=None, nwaveforms=None, mergefiles = False, verbose=True):
    """
    Open the processed waveform set for a given run and detector type.
    """
    try: 
        wfset = load_structured_waveformset(
            f"{output_dir}/processed/run{run:0d}_{dettype}/processed_merged_run{run:06d}_structured_{dettype}.hdf5",
            max_to_load=nwaveforms,
            channels_filter=channels,
            endpoint_filter=endpoints
        )
    except:
        files = glob(f"{output_dir}/processed/run{run:06d}_{dettype}/processed_*_run{run:06d}_*_{dettype}.hdf5")
        if verbose:
            print("List of files found:")
            print(files)
        if not mergefiles or len(files)==1:
            files = files[0]
            wfset = load_structured_waveformset(files, max_to_load=nwaveforms, channels_filter=channels, endpoint_filter=endpoints, verbose=verbose)
        else: 
            wfset = load_structured_waveformset(files[0], max_to_load=nwaveforms, channels_filter= channels, endpoint_filter=endpoints, verbose=verbose)
            for f in files[1:]:
                tmpwf = load_structured_waveformset(f, max_to_load=nwaveforms, verbose=False)
                wfset.merge(copy.deepcopy(tmpwf))
    return wfset

In [None]:
import copy
import time
def select_channels(waveform: Waveform, channels: list) -> bool:
    if waveform.channel not in channels:
        return False
    return True
def create_wfset(run_to_unich, endpoint):

    nwaveforms = 80000
    wfset_full = None
    for run, channels in run_to_unich.items():
        wfset = open_processed(run, dettype, datadir, channels, [endpoint], nwaveforms=nwaveforms, verbose=True)
        # wfset = WaveformSet.from_filtered_WaveformSet(wfset, select_channels, channels)
        if wfset_full is None:
            wfset_full = copy.deepcopy(wfset)
        else:
            wfset_full.merge(copy.deepcopy(wfset))
        print(f"Loaded run {run}")
    return wfset_full
    
start = time.time()
wfset_full = create_wfset(run_to_unich, endpoint)
end = time.time()
print(end - start)
wfset_full

In [None]:
runBasicWfAnaNP02(wfset_full, int_ll=250, int_ul=280, amp_ll=100, amp_ul=260, threshold=50, configyaml="")

In [None]:
argsheat = dict(
    mode="heatmap",
    analysis_label="std",
    adc_range_above_baseline=1500,
    adc_range_below_baseline=-300,
    adc_bins=125,
    time_bins=wfset_full.points_per_wf//2,
    filtering=36,
    share_y_scale=False,
    share_x_scale=True,
    wfs_per_axes=5000,
    zlog=True,
    width=1300,
    higth=650,
    showplots=True
)

detector=["M6(1)"]
#detector=["C6(1)"]

plot_detectors(wfset_full, detector, **argsheat)


In [None]:
extractor = WaveformSelector('cuts.yaml')
wfset_clean = WaveformSet.from_filtered_WaveformSet(wfset_full, extractor.applycuts, show_progress=True)
print(f"Original waveforms: {len(wfset_full.waveforms)}, after cut: {len(wfset_clean.waveforms)}")

In [None]:
argsheat = dict(
    mode="heatmap",
    analysis_label="std",
    adc_range_above_baseline=1500,
    adc_range_below_baseline=-300,
    adc_bins=125,
    time_bins=wfset_full.points_per_wf//2,
    filtering=36,
    share_y_scale=False,
    share_x_scale=True,
    wfs_per_axes=5000,
    zlog=True,
    width=1300,
    higth=650,
    showplots=True
)
detector=["M6(1)"]
#detector=["C6(1)"]
plot_detectors(wfset_clean, detector, **argsheat)

In [None]:
fig = psu.make_subplots(rows=1, cols=2)
gt = np02_gen_grids(wfset_clean, detector=["M6(1)", "M6(2)"])
plot_averages(fig, gt["Custom"])
fig.show()

In [None]:
argsheat = dict(
    mode="heatmap",
    analysis_label="std",
    adc_range_above_baseline=1000,
    adc_range_below_baseline=-100,
    adc_bins=125,
    time_bins=wfset_full.points_per_wf//2,
    filtering=36,
    share_y_scale=False,
    share_x_scale=True,
    wfs_per_axes=5000,
    zlog=True,
    width=1300,
    heigth=650,
    return_fig=True
)

#detector=["M3(1)","M3(2)","M4(1)","M4(2)"]
#detector=["M5(1)","M5(2)","M6(1)","M6(2)","M7(1)","M7(2)","M8(1)","M8(2)"]

#detector=["C1(1)","C1(2)","C3(1)","C3(2)","C4(1)","C4(2)","C5(1)","C5(2)"]
#detector=["C2(1)","C2(2)","C6(1)","C6(2)","C7(1)","C7(2)","C8(1)","C8(2)"]

#detector = groupLow
#detector = groupHig

detector = ["M6(1)"]


figs = plot_detectors(wfset_clean, detector, **argsheat)

fig, rows, cols, title, g = figs[0]

# wfset_ch = ChannelWsGrid.clusterize_waveform_set(wfset_clean)
# for ch, wfch in wfset_ch[endpoint].items():
#     print(dict_uniqch_to_module[strUch(endpoint,ch)])
     
        
plot_averages(fig, g)
fig.show()