In [None]:
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import plotly.io as pio
pio.templates.default = 'plotly_white'

import uproot
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import plotly.express as px

In [None]:
xaxis = dict(
    showline=True,
    ticks='outside',
    mirror=True,
    linecolor='black',
    showgrid=True,
    gridcolor='grey',
    gridwidth=0.25,
)

yaxis = dict(
    showline=True,
    ticks='outside',
    mirror=True,
    linecolor='black',
    showgrid=True,
    gridcolor='grey',
    gridwidth=0.25,
    zeroline=True,
    zerolinecolor='black',
    zerolinewidth=0.25
)


def charge_comp(wfs_array, baselines, channel_number, left_b, right_b, evt_range, coef=1.5e-9):
    baselines_wf = []
    charges = []

    for evt in range(evt_range[0], evt_range[1]):
        baselines_wf.append(np.mean(wfs_array[evt, channel_number, :15]))
        signal = baselines[evt][channel_number] - wfs_array[evt, channel_number, left_b:right_b]
        charges.append(coef*np.sum(signal))
        
    baselines_gcu = baselines[evt_range[0]:evt_range[1], channel_number]
    return np.array(charges), np.array(baselines_wf), baselines_gcu
 

def wfs_2d_plot_np(data, xlabel, ylabel,
                plot_width=200, plot_height=200,
                col_wrap=4, height=500, width=950):
    
    aggs = []
    for i in range(nrows*ncols):
        agg, xedges, yedges = np.histogram2d(x=data[i][xlabel], y=data[i][ylabel],
                                   bins=(plot_width, plot_height))
        
        x_centers = (xedges[1:] + xedges[:-1]) / 2
        y_centers = (yedges[1:] + yedges[:-1]) / 2
        agg = agg.T
        zero_mask = agg == 0
        agg = np.log10(agg, where=np.logical_not(zero_mask))
        agg[zero_mask] = np.nan
        aggs.append(agg)
    aggs = np.array(aggs)
        
    fig = px.imshow(
            aggs,
            x=x_centers,
            y=y_centers,
            origin='lower',
            labels={
                'color':'Log10(count)'
            },
        
            color_continuous_scale='inferno',
            height=height,
            width=width,
            facet_col=0,
            facet_col_wrap=col_wrap,
            aspect='auto'
        )

    axis_params = {}
    for i in range(1, 25):
        axis_params['xaxis{}'.format(i)] = xaxis
        axis_params['yaxis{}'.format(i)] = yaxis

    fig.update_layout(
        # xaxis_title="t, ns",
        # yaxis_title="V, ADC counts",
        coloraxis_colorbar=dict(
            title='Log10',
            tickprefix='10^'
        ),
        showlegend=True,
        **axis_params,
        font=dict(
            family="Times New Roman",
            size=20,
            color='black'
        ),
    )

    fig.show()
    
def plot_baselines_diffs(wfs_array, baseline_array,
                         evt_range, nrows=4, ncols=6):

    fig = make_subplots(rows=nrows, cols=ncols,
                        vertical_spacing=0.1,
                        horizontal_spacing=0.1)

    for i in range(nrows*ncols):
        charges, baselines_wf, baselines_gcu = charge_comp(
            wfs_array, baseline_array, i, 150, 250, evt_range)
        
        fig.add_trace(
            go.Histogram(
                x = baselines_wf - baselines_gcu,
                name=f"Channel: {i}"
            ),
            row=int(i/ncols)+1, col=i%ncols+1
        )

    axis_params = {}
    for i in range(1, 25):
        axis_params['xaxis{}'.format(i)] = xaxis
        axis_params['yaxis{}'.format(i)] = yaxis

    fig.update_layout(
        # xaxis_title='baselines_wf - baselines_gcu',
        height=1000,
        width=1200,
        **axis_params,
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.05,
            xanchor="right",
            x=1
        )
    )

    fig.show()
    
def plot_charges_hist(wfs_array, baseline_array, evt_range, nrows=4, ncols=6):

    fig = make_subplots(rows=nrows, cols=ncols,
                        vertical_spacing=0.1,
                        horizontal_spacing=0.1)
    
    for i in range(nrows*ncols):
        charges, baselines_wf, baselines_gcu = charge_comp(
            wfs_array, baseline_array, i, 150, 250, evt_range)

        mini = charges.min()
        maxi = charges.max()

        fig.add_trace(
            go.Histogram(
                x = charges,
                name=f"Channel: {i}",
                xbins=dict(
                    start=mini,
                    end=maxi,
                    size=(maxi - mini) / 100
                ),
            ), row=int(i/ncols)+1, col=i%ncols+1
        )

        fig.update_yaxes(type="log", row=int(i/ncols)+1, col=i%ncols+1)

    axis_params = {}
    for i in range(1, nrows*ncols+1):
        axis_params['xaxis{}'.format(i)] = xaxis
        axis_params['yaxis{}'.format(i)] = yaxis
        
    fig.update_layout(
        # xaxis_title='charges',
        height=1000,
        width=1200,
        **axis_params,
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.05,
            xanchor="right",
            x=1
        )
    )

    fig.show()
    
def plot_charges_scatter(wfs_array, charge_array,
                         baseline_array, evt_range,
                         nrows=4, ncols=6):
    
    fig = make_subplots(rows=nrows, cols=ncols,
                        vertical_spacing=0.1,
                        horizontal_spacing=0.1)
    
    for i in range(nrows*ncols):
        charges, baselines_wf, baselines_gcu = charge_comp(
            wfs_array, baseline_array, i, 150, 250, evt_range)

        fig.add_trace(
            go.Scattergl(
                x = charges,
                y = charge_array[evt_range[0]:evt_range[1], i],
                name=f"Channel: {i}",
                mode='markers',
            ), row=int(i/ncols)+1, col=i%ncols+1
        )

    axis_params = {}
    for i in range(1, nrows*ncols+1):
        axis_params['xaxis{}'.format(i)] = xaxis
        axis_params['yaxis{}'.format(i)] = yaxis

    fig.update_layout(
        # xaxis_title='Manually calculated charge',
        # yaxis_title='GCU calculated charge',
        height=1000,
        width=1200,
        **axis_params,
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.05,
            xanchor="right",
            x=1
        )
    )

    fig.show()
    
def plot_wf_diff_channels_same_evt(
    wfs_array, EvtNumber, range_y=[11200, 11800], nrows=4, ncols=6):
    
    fig = make_subplots(rows=nrows, cols=ncols)

    for i in range(nrows*ncols):
        fig.add_trace(
            go.Scattergl(
                x = np.arange(shp[2]),
                y = wfs_array[EvtNumber, i, :].flatten(),
                name=f"Channel {i}"
        ), row=int(i/ncols)+1, col=i%ncols+1
    )

    axis_params = {}
    for i in range(1, nrows*ncols+1):
        axis_params['xaxis{}'.format(i)] = xaxis
        axis_params['yaxis{}'.format(i)] = yaxis

    for i in range(nrows*ncols):
        fig.update_yaxes(range=range_y, row=int(i/nrows)+1, col=i%ncols+1)

    fig.update_layout(
        **axis_params,
        height=800,
        width=1400,
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.05,
            xanchor="right",
            x=1
        )
    )

    fig.show()

def plot_wf_same_channel_diff_evts(
    wfs_array, ChannelNumber, range_y=[11200, 11800], nrows=10, ncols=6):

    evtIds = np.random.randint(2000, wfs_array.shape[0], size=nrows*ncols)
    fig = make_subplots(rows=nrows, cols=ncols)

    for i in range(nrows*ncols):
        fig.add_trace(
            go.Scattergl(
                x = np.arange(shp[2]),
                y = wfs_array[evtIds[i], 16, :].flatten(),
                name=f"EvtN: {i}",
        ), row=int(i/ncols)+1, col=i%ncols+1
    )

    axis_params = {}
    for i in range(1, nrows*ncols+1):
        axis_params['xaxis{}'.format(i)] = xaxis
        axis_params['yaxis{}'.format(i)] = yaxis

    for i in range(nrows*ncols):
        fig.update_yaxes(range=[11200, 11800], row=int(i/ncols)+1, col=i%ncols+1)

    fig.update_layout(
        **axis_params,
        height=1600,
        width=1400,
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.05,
            xanchor="right",
            x=1
        )
    )

    fig.show()

In [None]:
path = "/data/48PMT/calib/20221207/"

## Run1: Cs137 collimated, 10', n_th 5. IPbus rate ~500 Hz

In [None]:
path_wfs = f'{path}run1/tier1_wf_integration/output_dt2.root'

In [None]:
baseline_array = np.array(uproot.open(path_wfs)['eventTree']['baseline'].array())

In [None]:
charge_array = np.array(uproot.open(path_wfs)['eventTree']['charge'].array())

In [None]:
total_charge_array = np.array(uproot.open(path_wfs)['eventTree']['total_charge'].array())

In [None]:
wfs_array = np.array(uproot.open(path_wfs)['eventTree']['waveforms'].array())

In [None]:
shp = wfs_array.shape
print(f"Number of events: {shp[0]}, number of channels: {shp[1]}, N. of time points: {shp[2]}")

## Merged by events

In [None]:
dfs = []

for i in range(shp[1]):
    df = pd.DataFrame(
        np.array([np.tile(np.arange(shp[2]), shp[0]), wfs_array[:, i, :].flatten()]).T,
        columns=['t', 'wfs']
    )
    
    dfs.append(df)

In [None]:
wfs_2d_plot_np(dfs, 't', 'wfs', plot_width=100, plot_height=100, height=1000, width=1000)

In [None]:
evt_range = [1000, 31000]

## Baselines' differences distributions:

In [None]:
plot_baselines_diffs(wfs_array, baseline_array, evt_range)

## Charge distributions (where the charge is calculated manually)

In [None]:
plot_charges_hist(wfs_array, baseline_array, evt_range)

## Manually calculated charge (x axis) vs. GCU calculated charge (y axis)

In [None]:
plot_charges_scatter(wfs_array, charge_array, baseline_array, evt_range)

## Exapmles of waveforms for different channels but the same event: 

In [None]:
plot_wf_diff_channels_same_evt(wfs_array, EvtNumber=2000)

## Examples of waveforms for the same channel but different events: 60 events selected randomly (with id more than 2000)

In [None]:
plot_wf_same_channel_diff_evts(wfs_array, ChannelNumber=22)