In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

pd.options.display.max_colwidth = 100   # Hm, not ideal. Shorten comments?

# import logging
# logging.basicConfig(level=logging.DEBUG,
#                     format='{name} in {threadName} at {asctime}: {message}', style='{')
# logging.getLogger().setLevel(logging.DEBUG)

In [None]:
import os
import glob
import shutil

run_id = '180219_2005'

def clean_dir():
    for d in glob.glob(f'./strax_data/{run_id}_*'):
        filename = os.path.split(d)[-1]
        if filename.endswith("reduced_records") or not filename.endswith('records'):
            shutil.rmtree(d)

clean_dir()

In [None]:
import strax
strax.register_all(strax.xenon.plugins)
mystrax = strax.Strax(max_workers=2)

In [None]:
%%time
mystrax.make(run_id, 'peaks',
             #profile_to='test_peaks_par.prof'
             save=['peaks', 'reduced_records'])

# Reduce peak info

In [None]:
mystrax.data_info('peak_basics')

In [None]:
df = mystrax.get_df(run_id, 'peak_basics')
df.head()

This takes about 30x less memory than the raw peaks (with waveforms, area_per_channel, etc). A substantial reduction, but not enough to forego chunking.

In [None]:
from multihist import Histdd
d = df
mh = Histdd(d['area'], d['range_50p_area'],
            bins=(np.logspace(0, 7, 100),
                  np.logspace(1, 4, 100)))
mh.plot(log_scale=True)
plt.xscale('log')
plt.yscale('log')

In [None]:
d = df[df['n_channels'] > 3]
plt.scatter(d['area'], 
            d['range_50p_area'],
            c=d['area_fraction_top'], 
            s=0.1,
            cmap=plt.cm.rainbow, vmin=0, vmax=1)
plt.xscale('log')

plt.colorbar(label='Area fraction top')
plt.xlabel("Area (pe)")
plt.ylabel("Width (50% area, ns)")
plt.gca().patch.set_facecolor('black')
plt.xscale('log')
plt.yscale('log')
plt.xlim(1, 5e6)
plt.ylim(10, 2e4)

# Merging

In [None]:
@mystrax.register
class PeakInfo(strax.MergePlugin):
    depends_on = ('peak_basics', 'peak_classification')

mystrax.data_info('peak_info')

In [None]:
df = mystrax.get_df(run_id, 'peak_info')
df.head()

# Event building

In [None]:
events = mystrax.get_array(data_dir, 'events')

# Events do not overlap
assert np.min(events['time'][1:] - events['endtime'][:-1]) > 0

In [None]:
df = mystrax.get_df(run_id, 'event_basics')

In [None]:
# ev_props = EventBasics().process_and_slurp(data_dir, n_per_iter=10)
# df = pd.DataFrame.from_records(ev_props)

In [None]:
plt.scatter(df['drift_time'] / int(1e3),
            df['s2_range_50p_area'] / int(1e3),
            c=df['s1_area_fraction_top'],
            vmin=0, vmax=0.25, cmap=plt.cm.jet,
            marker='.', edgecolors='none')
plt.colorbar(label="S1 area fraction top", extend='max')
plt.xlabel('Drift time (us)')
plt.ylabel('S2 width (us)')
plt.ylim(0, 4)
plt.tight_layout()

In [None]:
clean_dir()

In [None]:
mystrax.get_df(run_id, 'largest_peak_area').head()

In [None]:
# # Show we've been shown all the correct peaks
# ps = chio.slurp(data_dir + '/peak_basics')
# n_contained_in = np.bincount(fully_contained_in(ps, events) + 1)[1:]
# assert np.all(ev_props['n_peaks'] == n_contained_in)

# Find stuff to investigate (old, but useful functions also below)

In [None]:
df = strax.get_df(run_id, 'peak_basics')

In [None]:
mask = df['n_channels'] >= 5
#mask &= ~np.in1d(max_pmt, [31, 87])
d = df[mask]

plt.scatter(d['area'], 
            d['range_50p_area'],
            c=d['area_fraction_top'], 
            s=0.1,
            cmap=plt.cm.rainbow, vmin=0, vmax=1)

plt.colorbar(label='Area fraction top')
plt.xlabel("Area (pe)")
plt.ylabel("Width (50% area, ns)")
plt.gca().patch.set_facecolor('black')
plt.xscale('log')
plt.yscale('log')
plt.xlim(1, 5e6)
plt.ylim(10, 2e4)

In [None]:
raise ZeroDivisionError

# Waveform inspection tools

In [None]:
def chunk_i(t, subdir='records'):
    chunk_starts = get_chunk_starts(subdir)
    i = np.searchsorted(chunk_starts, t) - 1
    if i < 0:
        # TODO: handle starting exactly at the last chunk
        raise ValueError("time before last chunk starts")
    # TODO: Assumes last chunk is infinitely long...
    return i
    
def get_data(t_start, t_end, channels=None, subdir='records'):
    """Return all things from subdir that overlap with [t_start, t_end]
    in channels.
    
    This is quite slow if you have big chunks.
    """
    chunk_start = chunk_i(t_start, subdir)
    chunk_end = chunk_i(t_end, subdir)
    in_files = chunk_files(subdir)
    result = []
    for i in range(chunk_start, chunk_end + 1):
        d = strax.load(in_files[i])
        d = d[(t_start < d['time'] + d['length'] * d['dt']) 
              & (d['time'] < t_end)]
        if channels is not None:
            d = d[np.in1d(d['channel'], channels)]
        result.append(d)
    return np.concatenate(result)
    
def plot_wvs(r, t0=None, time_unit='ns', alternate_colors=False, **kwargs):
    time_unit_str = time_unit
    time_unit_num = int(dict(ns=1, us=1e3, ms=1e6, s=1e9)[time_unit])

    t0 = r['time'][0]
    for i, d in enumerate(r):
        length = d['length']
        w = d['data'][:length]
        t = (np.arange(length, dtype=np.int64) * d['dt'] + (d['time'] - t0)) 
        if alternate_colors:
            color = 'k' if i % 2 == 0 else 'darkslategrey'
        else:
            color = 'k'
        plt.plot(t/time_unit_num, w/d['dt'], color=color, **kwargs)
        
    plt.xlabel("Time (%s)" % time_unit_str)
    plt.ylabel("Amplitude (pe/ns)")

# Try to view PMT waveforms

In [None]:
df = strax.io_chunked.slurp_df(data_dir + '/peak_basics')

In [None]:
sd = df[
    (df['area'] > 1e4)
    & (df['area_fraction_top'] > 0.9)
    & (df['max_pmt'] == 87)
]

In [None]:
#get_data(d.time - before, d.endtime + after, subdir='peaks')

In [None]:
def get_wv(t_start, t_end, subdir='peaks', channels=None, **kwargs):
    r = get_data(t_start, t_end, subdir=subdir, channels=channels)
    if len(r):
        plot_wvs(r, **kwargs)
    else:
        print("Nothing found")
    
def get_wv_of(x, extend=0, **kwargs):
    try:
        t_end = x['endtime']
    except KeyError:
        t_end = x['time'] + x['dt'] * x['length']
    get_wv(x['time'] - extend, t_end + extend,
            **kwargs)

In [None]:
get_wv_of(sd.iloc[1], extend=int(1e5), 
          channels=[87], subdir='records',
          time_unit='us', alternate_colors=True)

In [None]:
ts = get_chunk_starts('records')
detector_time = (ts[-1] - ts[0] + np.diff(ts).mean()) / int(1e9)

In [None]:
!du -h {input_dir}/records

In [None]:
# weirdo_is = np.where((peaks['area'] > 1e5) & (aft > 0.9))[0]

In [None]:
def plot_peak(p, t0=None, **kwargs):
    n = p['length']
    if t0 is None:
        t0 = p['time']
    plt.plot((p['time'] - t0) + np.arange(n) * p['dt'], 
             p['data'][:n] / p['dt'], 
             linestyle='steps-mid',
             **kwargs)
    plt.xlabel("Time (ns)")
    plt.ylabel("Sum waveform (PE / ns)")
    
def plot_peaks(peaks):
    t0 = peaks[0]['time']
    for p in peaks:
        plot_peak(p, t0=t0,
                  label='%.1e PE, %d ns dt' % (p['area'], p['dt'], ))
    plt.ylim(0, None)

i = weirdo_is[0]
plot_peaks(peaks[i-1:i+5])
plt.legend(loc='best')
#plt.yscale('symlog')
plt.show()
aft[i-1:i+3]

In [None]:
#peaks[max_pmt[]]