# Definitions for single-electron S2 gain analysis

### Imports

In [None]:
import numpy as np
import hax

from hax.minitrees import MultipleRowExtractor
from hax.treemakers.peak_treemakers import PeakExtractor

### Settings

In [3]:
stop_after = 10e3
print('Warning: set stop_after to %d (increase treemaker version to actually implement.)' % stop_after)



## Treemakers

### Hits

In [2]:
class HitExtractor(MultipleRowExtractor):
    """
    Extract hit properties.
    """
    
    # Default branch selection is EVERYTHING in peaks, overwrite for speed increase
    # Don't forget to include branches used in cuts
    extra_branches = ['peaks.*', 'hits.*']
    hit_fields = ['area']
    event_cut_list = []
    peak_cut_list = []
    event_cut_string = 'True'
    peak_cut_string = 'True'
    stop_after = np.inf

    # Hacks for want of string support :'(
    peaktypes = dict(lone_hit=0, s1=1, s2=2, unknown=3)
    detectors = dict(tpc=0, veto=1, sum_wv=2, busy_on=3, busy_off=4)
    
    def __init__(self, *args, **kwargs):
        MultipleRowExtractor.__init__(self, *args, **kwargs)
        self.event_cut_string = self.build_cut_string(self.event_cut_list, 'event')
        self.peak_cut_string = self.build_cut_string(self.peak_cut_list, 'peak')          
    
    def build_cut_string(self, cut_list, obj):
        '''
        Build a string of cuts that can be applied using eval() function.
        '''
        # If no cut is specified, always pass cut
        if len(cut_list) == 0:
            return 'True'
        # Check if user entered range_50p_area, since this won't work
        cut_list = [cut.replace('range_50p_area','range_area_decile[5]') for cut in cut_list]

        cut_string = '('
        for cut in cut_list[:-1]:
            cut_string += obj + '.' + cut + ') & ('
        cut_string += obj + '.' + cut_list[-1] + ')'
        return cut_string

    def extract_data(self, event):
        if event.event_number == self.stop_after:
            raise hax.paxroot.StopEventLoop()
            
        # Holds data for all hits in event
        hit_data = []
        # Check if event passes cut
        if eval(self.build_cut_string(self.event_cut_list, 'event')):
            # Loop over peaks and check if peak passes cut
            for peak in event.peaks:
                if eval(self.peak_cut_string):
                    for hit in peak.hits:
                        _current_hit = {}
                        for field in self.hit_fields:
                            _x = getattr(hit, field)
                            _current_hit[field] = _x
                            # The event number is necessary to join to event properties
                            _current_hit['event_number'] = event.event_number                  
                        hit_data.append(_current_hit)

            return hit_data
        else:
            # If event does not pass cut return empty list
            return []


### Peaks

In [None]:
class XAMSPeaks(PeakExtractor):
    __version__ = '0.0.7'
    stop_after = stop_after
    peak_cut_list = ['detector == "tpc"', 'type !="lone_hit"']
    peak_fields = ['area', 'range_50p_area', 'center_time', 'n_hits']

### Pulses

In [None]:
class PulseExtractor(MultipleRowExtractor):
    '''
    Extract pulse properties of all pulses that are in dataframe.
    Requires 'found in pulse' property.
    
    '''
    extra_branches = ['pulses.*']
    df = None
    stop_after = np.inf
    pulse_properties = []
    
    def __init__(self, *args, **kwargs):
        MultipleRowExtractor.__init__(self, *args, **kwargs)
        return None
    
    
    def extract_data(self, event):
        if event.event_number == self.stop_after:
            raise hax.paxroot.StopEventLoop()
        pulses = event.pulses
        to_select_this_event = self.df[self.df['event_number'] == event.event_number]['found_in_pulse']
        if len(to_select_this_event) == 0:
            return []
        
        pulse_data = []
       
        for i, pulse in enumerate(pulses):
            if i in to_select_this_event.values:            
                _current_pulse = {}
                _current_pulse['found_in_pulse'] = i
                for prop in self.pulse_properties:
                    _current_pulse[prop] = getattr(pulse, prop)
                    _current_pulse['event_number'] = event.event_number                  
                pulse_data.append(_current_pulse)
                
        return pulse_data

## Functions

### Derived properties

In [1]:
def add_s1s2_properties(d_p):
    '''
    Adds properties to peaks dataframe, in particular:
    s1 time
    s2 time
    time since s1
    time since s2
    
    Requires area of peaks and center_time to be defined.
    '''
    s1_times = []
    s2_times = []
    for event in np.unique(d_p['event_number']):
        peaks_this_event = d_p[d_p['event_number'] == event]
        s1 = peaks_this_event[peaks_this_event['area'] == peaks_this_event['s1']]
        s2 = peaks_this_event[peaks_this_event['area'] == peaks_this_event['s2']]
        assert len(s1) == 1
        assert len(s2) == 1
       
        s1_times.append(np.ones(len(peaks_this_event), float) * s1['center_time'].values[0])
        s2_times.append(np.ones(len(peaks_this_event), float) * s2['center_time'].values[0])

    d_p['s1_center_time'] = np.concatenate(s1_times)
    d_p['s2_center_time'] = np.concatenate(s2_times)
    
    d_p['time_since_s1'] = d_p['center_time'] - d_p['s1_center_time']
    d_p['time_since_s2'] = d_p['center_time'] - d_p['s2_center_time']

    return d_p