# Setup

In [None]:
import numpy as np
import pandas as pd

import holoviews as hv
import datashader
from holoviews.operation.datashader import aggregate, shade, datashade, dynspread
from holoviews.operation import decimate
hv.extension('bokeh')

import strax

import gc
real_gc_collect = gc.collect

# Somebody thought it was a good idea to call gc.collect explicitly somewhere in holoviews
# This makes dynamic PMT maps super slow
# Until I trace the offender:
gc.collect = lambda : None

# Custom wheel zoom tool that only zooms in time
from bokeh.models import WheelZoomTool
time_zoom = WheelZoomTool(dimensions='width')

In [None]:
# Get ADC->pe multiplicative conversion factor
from pax.configuration import load_configuration
from pax.dsputils import adc_to_pe
pax_config = load_configuration('XENON1T')["DEFAULT"]
to_pe = np.array([adc_to_pe(pax_config, ch) 
                  for ch in range(pax_config['n_channels'])])

tpc_r = pax_config['tpc_radius']

# Get locations of PMTs
r = []
for q in pax_config['pmts']:
    r.append(dict(x=q['position']['x'],
                  y=q['position']['y'],
                  i=q['pmt_position'],
                  array=q.get('array', 'other')))
f = 1.08
pmt_locs = pd.DataFrame(r)
n_top = len(pax_config['channels_top'])

# Data loading

In [None]:
records = strax.load('test_records_reduced')
hits = strax.find_hits(records)
peaks = strax.find_peaks(hits, to_pe, gap_threshold=300, min_hits=3)
strax.sum_waveform(peaks, records, to_pe)
print(len(records), len(hits), len(peaks))

## Basic stats

In [None]:
# Fraction of records representing an entire pulse
(records['pulse_length'] < 110).sum() / len(records)

In [None]:
# Fraction of records in the most frequent channel
np.histogram(records['channel'], 
             bins=np.arange(0, 260) - 0.5)[0].max() / len(records)

In [None]:
# Integral in pe
areas = records['data'].sum(axis=1) * to_pe[records['channel']]

In [None]:
areas.max()

## Records to holoviews

In [None]:
def normalize_time(t):
    return (t - records[0]['time']) / 1e9

# Create dataframe with record metadata
df = pd.DataFrame(dict(area=areas,
                       time=normalize_time(records['time']), 
                       channel=records['channel']))

# Convert to holoviews Points
points = hv.Points(df, 
                   kdims=[hv.Dimension('time', label='Time', unit='sec'),
                          hv.Dimension('channel', label='PMT number', range=(0, 260))], 
                   vdims=[hv.Dimension('area', label='Area', unit='pe', 
                                       #range=(0, 1000)
                                      )])

# Plots

## PMT pattern

In [None]:
%opts Points.PMTPattern [color_index=2 tools=['hover'] show_grid=False] (size=17, cmap='magma')

def pattern_plot(array, areas):
    mask = pmt_locs['array'] == array
    d = pmt_locs[mask].copy()
    d['area'] = areas[mask]

def pattern_between(t_0, t_1):
    """Return PMT pattern between time t_0 and t_1"""

    return areas

def pmt_map(t_0, t_1, array='top', **kwargs):
    # Compute the PMT pattern (fast)
    ps = points[(t_0 <= points['time']) 
                & (points['time'] < t_1)]
    areas = np.bincount(ps['channel'],
                        weights=ps['area'],
                        minlength=len(pmt_locs))
    
    # Which PMTs should we include?
    pmt_mask = pmt_locs['array'] == array
    d = pmt_locs[pmt_mask].copy()
    d['area'] = areas[pmt_mask]
    
    # Convert to holoviews points
    d = hv.Dataset(d,  
                   kdims=[hv.Dimension('x', unit='cm', range=(-tpc_r * f, tpc_r * f)),
                          hv.Dimension('y', unit='cm', range=(-tpc_r * f, tpc_r * f)),
                          hv.Dimension('i', label='PMT number'),
                          hv.Dimension('area', 
                                       label='Area', 
                                       unit='PE')])

    return d.to(hv.Points, 
                vdims=['area', 'i'],
                group='PMTPattern', 
                label=array.capitalize(),
                **kwargs)

def pmt_map_range(x_range, array='top', **kwargs):
    # For use in dynamicmap with streams
    if x_range is None:
        x_range = (0, 0)
    return pmt_map(x_range[0], x_range[1], array=array, **kwargs)

In [None]:
hv.DynamicMap(pmt_map, kdims=['t_0', 't_1']).redim.range(t_0=(0., 1.), t_1=(0.1, 10.))

### Old selection stuff?

In [None]:
"""
%%opts QuadMesh [width=n_bins_t height=400 tools=['xbox_select']] (alpha=0 hover_line_alpha=1 hover_fill_alpha=0)

def selected_pmt_area(index):
    # NB: If you make exceptions in these callbacks, you get nothing!    
    # You also can't print anything. What kind of ghostly process is this running in??
    
    selected_bins = index
    # Get displayed time range
    if xyrange.x_range is None:
        t_0 = times[0]
        t_1 = times[-1]
    else:
        t_0 = xyrange.x_range[0]
        t_1 = xyrange.x_range[1]
    t_range = t_1 - t_0

    if len(selected_bins):
        tsel_0 = t_0 + selected_bins[0] * t_range/n_bins_t
        tsel_1 = t_0 + selected_bins[-1] * t_range/n_bins_t

        #return pmt_maps(tsel_0, tsel_1)
    
    return pmt_maps(t_0, t_1)

selection = hv.streams.Selection1D(source=points)  

quadmesh_helper = aggregate(points, width=40, height=20, 
                            streams=[xyrange, selection]).map(hv.QuadMesh, hv.Image)
                            
                            
tools=[wzt, 'xbox_select']
""";

## PMT vs time pulse map

In [None]:

xrange_stream = hv.streams.RangeX(source=points)
# TODO: weigh by area

def channel_map():
    return dynspread(datashade(points, 
                               y_range=(0, 260), 
                               streams=[xrange_stream])).opts(
        plot=dict(width=600, 
                  tools=[time_zoom, 'xpan'], 
                  default_tools=['save', 'pan', 'box_zoom', 'save', 'reset'],
                  show_grid=False))

In [None]:
# channel_map()

## Pulse-level sum waveform

In [None]:
def pulse_level_waveform():
    # Datashader doesn't do 1d histograms. 
    # It can compute a 2d histogram and then sum it though...
    # See https://github.com/bokeh/datashader/issues/225
    agg = aggregate(points, 
                    aggregator=datashader.sum('area'), 
                    streams=[xrange_stream], 
                    x_sampling=1e-8, 
                    height=2)
    waveform = agg.map(lambda x: x.reduce(channel=np.sum), hv.Image)
    return waveform.opts(
        plot=dict(width=600, 
                  tools=[time_zoom, 'xpan'], 
                  show_grid=True, 
                  default_tools=['save', 'pan', 'box_zoom', 'save', 'reset']),
        norm=dict(framewise=True))

# pulse_level_waveform()

## Peak sum waveform

In [None]:
np.argsort(np.arange(30))[-10]

In [None]:
def plot_peak(p):
    y = p['data'][:p['length']]
    t = normalize_time(np.arange(p['length'], dtype=np.int64) * p['dt'] + p['time'])
    c = hv.Curve(dict(time=t, y=y), kdims=points.kdims[0], group='PeakSumWaveform')
    return c.opts(plot=dict(interpolation='steps-mid',
                            #default_tools=['save', 'pan', 'box_zoom', 'save', 'reset'],
                            #tools=[time_zoom, 'xpan'],
                            width=600,
                            shared_axes=False,
                            show_grid=True),
                  style=dict(color='b')
                  #norm=dict(framewise=True)
                 )

def plot_peaks(t_0, t_1, n_max=10):
    # Find peaks in this range
    ps = peaks[(normalize_time(peaks['time'] + peaks['length'] * peaks['dt']) > t_0)
               & (normalize_time(peaks['time']) < t_1)]
    print(len(ps))
    # Show only the largest n_max peaks
    if len(ps) > n_max:
        areas = ps['area']
        max_area = np.sort(areas)[-n_max]
        ps = ps[areas >= max_area]
        
    print(len(ps))
    
    return hv.Overlay(items=[plot_peak(p) for p in ps])


def plot_peak_range(x_range, **kwargs):
    # For use in dynamicmap with streams
    if x_range is None:
        x_range = (0, 10)
    return plot_peaks(x_range[0], x_range[1], **kwargs)

In [None]:
#hv.DynamicMap(plot_peaks, kdims=['t_0', 't_1']).redim.range(t_0=(0., 10.), t_1=(1, 10.))

# Combine

In [None]:
from functools import partial

In [None]:
# %%opts Curve.PeakSumWaveform (color='b') {+framewise}
top_map = hv.DynamicMap(partial(pmt_map_range, array='top'), streams=[xrange_stream])
bot_map = hv.DynamicMap(partial(pmt_map_range, array='bottom'), streams=[xrange_stream])
#waveform = 
waveform = hv.DynamicMap(plot_peak_range, streams=[xrange_stream]) #* pulse_level_waveform()
layout = waveform + top_map + channel_map() + bot_map
layout.cols(2)

# OLd stuff

In [None]:
# Time visible in a single window
t_window = 0.1

# Speed of the visualization
fps = 20
speed = 0.1

t_max = (records['time'][-1] - records['time'][0])/1e9
dt = speed/fps    # Time shift per frame
n_frames = int((t_max - t_window) / dt)   # Number of frames needed


waveform.periodic(1/fps, count=n_frames, 
                  param_fn=lambda i: dict(x_range=(i * dt, t_window + i * dt)),
                  timeout=60);

TODO: Link color scales!