In [1]:
import hax
from hax import cuts
import pickle
import os
from gain_extrapolator_reduced import get_gain



In [1]:
class XAMSAnalysis():
    '''
    This holds the data and functions to perform basic analysis on the XAMS data.
    Functions here are:
    - ``load`` to load the data and all relevant minitrees
    - ``cut_*`` to apply a cut (and give ``plot=True`` to see what it cuts)
    - ``cuts_*`` for quick application of multiple cuts
    - ``plot_*`` for standard plots.
    - ``corr_*`` for corrections
    '''
    __version__ = '0.0.2'
    # Holds dataframe containing all data
    d = None
    # This will list all basic cuts
    cut_list = []
    corr_s2_decrease_isapplied = False
    ################################################# BASICS
    
    def __init__(self, filenames, processed_data_path, minitree_path, include_NaI = False):
        self.filenames = filenames
        self.processed_data_path = processed_data_path
        self.minitree_path = minitree_path
        self.include_NaI = include_NaI
        # Cut list is not used for but may be used in future as a set of 'standard cuts'
        self.cut_list = [
            self.cut_interaction_exists,
            self.cut_thresholds,
            self.cut_low_energy,
            self.cut_largest_other_s1,
            self.cut_largest_other_s2,
            self.cut_saturation,
            self.cut_s2_aft,
            self.cut_s1_aft,
            self.cut_drift_time
        ]
        self.start_time = 0.
        self.end_time = 0.
        return
    
    def load(self, verbose=True, treemakers=None, **kwargs):
        '''
        Load the data into datafame using hax.minitrees.load.
        Any args will be passed to hax.minitrees.load.
        Extra properties computes: 
        - NaI energy
        - Dataset number
        - t : time in s since start of run
        - Drift time converted to us
        - cs1 and cs2 set identical to s1 and s2
        '''
        hax.init(
            # Always use these lines to tell hax that we don't care about Xe1T
            experiment='XAMS', 
            pax_version_policy='loose', use_runs_db = False,
            # Here come the useful settings
            main_data_paths = [self.processed_data_path],
            minitree_paths = [self.minitree_path],       
         )
        
        if not treemakers:
            # If no argument for treemakers, get default
            if self.include_NaI:
                treemakers = ['Fundamentals','Basics', ExtraS1S2Properties, NaIProperties]
            else:
                treemakers = ['Fundamentals','Basics', ExtraS1S2Properties]
            
        # Load data
        if self.include_NaI:
            self.d = hax.minitrees.load(self.filenames, treemakers,
                                       **kwargs)
        else:
            self.d = hax.minitrees.load(self.filenames, treemakers, **kwargs)
            
        #################### EXTRA PROPERTIES COMPUTED UPON LOADING DATA    
        
        # Recompute drift time to us
        self.d['drift_time'] = self.d['drift_time'] * 1e-3
        # Set cs1 and cs2 identical to s1
        self.d['cs1'] = self.d['s1']
        self.d['cs2'] = self.d['s2']
        # Compute bottom properties
        self.d['s2_bot'] = self.d['s2'] * (1 - self.d['s2_area_fraction_top'])
        self.d['cs2b'] = self.d['s2'] * (1 - self.d['s2_area_fraction_top'])
        # Compute time since start of run in seconds
        self.d['t'] = (self.d['event_time'] - self.d['event_time'].values[0]) * 1e-9      
        # Time since previous event. Set to zero (worst case) for event zero.
        self.d['ms_since_previous_event'] = np.concatenate( ([0],np.diff(self.d.event_time))) * 1e-6
        # Calibration of NaI
        if self.include_NaI:
            self.d['NaI_energy'] = self.d['NaI_area'] * 511 / 1287.8107708648045
        
        # Re-set the run number by checking for change in event number
        dset_change_indices = np.where(np.diff(self.d['event_number']) != 1)[0]
        # Number of events in each dataset
        dset_lengths = np.diff(np.concatenate([np.array([0]), dset_change_indices + 1, np.array([len(self.d)])]))
        if len(dset_lengths) != len(self.filenames):
            print('Warning: auto-computing of dataset index failed.')
        dset_index_array = np.concatenate([np.ones(dset_length, dtype=int) * i 
                                           for i, dset_length in enumerate(dset_lengths)])
        self.d['run_number'] = dset_index_array
        self.add_dataset_props()
               
        
        if verbose: print('Loaded %d (%.1f k) events.' % (self.n_events, self.n_events * 0.001))
        if verbose: print('Total live time: %.1f seconds (%.1f hours)' % (self.livetime, self.livetime / 3600.))
        return
    
    ################################################# ADDING PROPERTIES         
    def add_g1g2_props(self, pickle_file = '/home/erik/win/xams/analysis/light_yield/data/doke_sel2.pickle'):
        '''
        Add properties based on g1 and g2.
        '''
        popt_doke = pickle.load(open(pickle_file, 'rb'))
        g1, g2 = popt_doke
        # Number of gammas
        self.d['n_g'] = self.d['cs1'] / g1
        self.d['n_e'] = self.d['cs2b'] / g2
        self.d['n_quanta'] = self.d['n_g'] + self.d['n_e']
        self.d['f_g'] = self.d['n_g'] / (self.d['n_quanta']) # gamma fraction
        self.d['e_ces'] = 13.7e-3 * (self.d['cs1'] / g1 + self.d['cs2b'] / g2)
        return
        
    def add_dataset_props(self):
        '''
        Add the total run livetime. Warning: run only on data with high enough event rate and uncut!
        '''
        time = 0.
        for rn in np.unique(self.d['run_number']):
            time += (max(self.d[self.d['run_number'] == rn]['t']) - min(self.d[self.d['run_number'] == rn]['t']))
        self.livetime = time
        self.n_events = len(self.d)
        self.rate = self.n_events / self.livetime
        # Unix timestamps of first and last event
        self.start_time = self.d['event_time'].values[0] * 1e-9
        self.end_time = self.d['event_time'].values[-1] * 1e-9
        return
        
    def add_s1_waveforms(self, verbose = False, cache_path = '/home/erik/win/data/xams_run8/cache/'):
        '''
        Get the S1 pulse shape only for the events in the dataframe.
        '''
        d_s1pulse_list = []
        for rn_i, rn in enumerate(self.filenames):
            # We only load the waveforms for the events still in the dataset after cuts
            event_numbers_this_dataset = self.d[self.d['run_number'] == rn_i]['event_number'].values
            if verbose: print(rn)
            # Load the data and dump into cache file (after building cache files, there is a massive speedup)
            d_s1pulse = hax.minitrees.load(datasets=rn, treemakers=[S1Pulse], 
                                    cache_file = os.path.join(cache_path, '%s_S1Pulse.cache' % rn) )
            # Throw out events not in dataframe
            d_s1pulse = d_s1pulse[d_s1pulse['event_number'].isin(event_numbers_this_dataset)]
            d_s1pulse_list.append(d_s1pulse)
        s1_pulses_df = pd.concat(d_s1pulse_list)
        s1_pulses = s1_pulses_df.s1_pulse.values
        self.d['s1_pulse'] = s1_pulses 
        return 
    
    def add_percentiles(self, fractions_desired = [0.1, 0.2, 0.3, 0.4]):
        '''
        Compute the point in the S1 waveform where 10, 20, 30 and 40 percent of the area is reached.
        Adds properties to the dataframe: `s1_X0p_point` and `s1_fraction_outside_pulse`.
        Time in ns since start of pulse (pulse[0] = 0 ns)
        '''
        s1_percentile_points = []
        s1_area_fraction_outside_pulse = []


        for i, ev in self.d.iterrows():
            pulse = np.array(ev['s1_pulse'])
            area_times = np.ones(4) * float('nan')
            area_tot = ev['s1']
            self._integrate_until_fraction(pulse, area_tot, fractions_desired=fractions_desired, results=area_times)
            s1_percentile_points.append(2. * area_times) # Correct for 2 ns here
            s1_area_fraction_outside_pulse.append(1 - np.sum(pulse) / ev['s1'])

        s1_percentile_points = np.array(s1_percentile_points)
        for i, fraction in enumerate(fractions_desired):
            self.d['s1_%d_percentile_point' % (fraction * 100)] = s1_percentile_points[:, i]        
        self.d['s1_fraction_outside_pulse'] = s1_area_fraction_outside_pulse
        return 
        
    ################################################# PLOTTING
    def plot_s1s2(self, **kwargs):
        '''
        Make an S1-S2 2d histogram
        '''
        plt.hist2d(self.d['s1'], self.d['s2'], **kwargs)
        plt.xlabel('S1 (p.e.)')
        plt.ylabel('S2 (p.e.)')
        return

    def plot_s1bs2(self, **kwargs):
        '''
        Make an S1-S2 2d histogram
        '''
        plt.hist2d(self.d['s1'], self.d['s2_bot'], **kwargs)
        plt.xlabel('S1 (p.e.)')
        plt.ylabel('Bottom S2 (p.e.)')
        return
    
    def plot_cs1cs2(self, **kwargs):
        plt.hist2d(self.d['cs1'], self.d['cs2'], **kwargs)
        plt.xlabel('cS1 (p.e.)')
        plt.ylabel('cS2 (p.e.)')
        return
    
    def plot_cs1cs2b(self, **kwargs):
        plt.hist2d(self.d['cs1'], self.d['cs2b'], **kwargs)
        plt.xlabel('cS1 (p.e.)')
        plt.ylabel('cS2b (p.e.)')
        return
    
    def plot_cs1bs2(self, **kwargs):
        plt.hist2d(self.d['cs1'], self.d['s2_bot'], **kwargs)
        plt.xlabel('cS1 (p.e.)')
        plt.ylabel('S2 bottom (p.e.)')
        return
    
    def plot_e_line(self, e, pickle_file = '/home/erik/win/xams/analysis/light_yield/data/doke_sel2.pickle',
                    s1_range=(0,2e3), y_axis = 'cs2b', s2_aft = None, **kwargs):
        popt_doke = pickle.load(open(pickle_file, 'rb'))
        g1, g2 = popt_doke
        x_plot = np.linspace(*s1_range, num = 100)
        if y_axis == 'cs2b':
            y_plot = g2 * (e / 13.7e-3 - x_plot / g1)
        elif y_axis == 'cs2':
            y_plot = g2 * (e / 13.7e-3 - x_plot / g1) * 1 / (1 - s2_aft)
        plt.plot(x_plot, y_plot, **kwargs)
        return
        
    def plot_cs1_rate(self, bin_width=1, cs1_max = 400, **kwargs):
        counts, bin_edges = np.histogram(self.d['cs1'], bins=round(cs1_max / bin_width), range=(0, cs1_max))
        bins_cs1 = 0.5 *(bin_edges[1:] + bin_edges[:-1])
        
        plt.plot(bins_cs1, 1/self.livetime * counts, ls='steps', **kwargs)
        plt.ylim(0,)
        plt.xlabel('cS1 (p.e.)')
        if bin_width == 1:
            plt.ylabel('Differential rate (s$^{-1}$p.e.$^{-1}$)')
        else:
            plt.ylabel('Differential rate (s$^{-1}$(%.1f p.e.)$^{-1}$)' % bin_width)
        return bins_cs1, counts
    
    def plot_ces_rate(self, bin_width=1, ces_max = 400, **kwargs):
        bins, normed_counts = self._plot_rate_hist(key='e_ces', key_max = ces_max, bin_width = bin_width)
        plt.xlabel('CES (keV)')
        if bin_width == 1:
            plt.ylabel('Differential rate (s$^{-1}$keV$^{-1}$)')
        else:
            plt.ylabel('Differential rate (s$^{-1}$(%.1f keV)$^{-1}$)' % bin_width)
        return bins, normed_counts
    
    def _plot_rate_hist(self, key, key_max, bin_width, **kwargs):
        counts, bin_edges = np.histogram(self.d[key], bins=round(key_max / bin_width), range=(0, key_max))
        bins = 0.5 *(bin_edges[1:] + bin_edges[:-1])
        
        plt.plot(bins, 1/self.livetime * counts, ls='steps', **kwargs)
        plt.ylim(0,)
        plt.xlim(0,)
        return bins, 1/self.livetime * counts
        
    ################################################# CUTS
    
    def cuts_apply_all(self):
        for cut in self.cut_list:
            cut()
        return None
    
    def cuts_history(self):
        return cuts.history(self.d)
    
    
    ################################################# INDIVIDUAL CUTS
        
    def cut_interaction_exists(self, plot=False, apply=True):
        if apply:
            self.d = cuts.isfinite(self.d, 's1')
        return
    
    def cut_time_since_previous(self, max_time_ms=1., plot=False, apply=True, **kwargs):
        if plot:
            plt.hist(self.d['ms_since_previous_event'], **kwargs)
            plt.xlabel('Time since previous event (ms)')
            plt.ylabel('Counts')
            plt.axvline(max_time_ms)
        if apply:
            self.d = cuts.above(self.d, 'ms_since_previous_event', max_time_ms)
        return
    
    def cut_largest_other_s1(self, largest_other_s1_max=5, plot=False, apply=True, **kwargs):
        if plot:
            plt.hist2d(self.d['s1'], self.d['largest_other_s1'], **kwargs)
            plt.xlabel('S1 (p.e.)')
            plt.ylabel('Largest other S1 (p.e.)')
            plt.axhline(largest_other_s1_max, ls='--', color='red', lw=2)
        if apply: self.d = cuts.below(self.d, 'largest_other_s1', largest_other_s1_max)
        return

    def cut_largest_other_s2(self, largest_other_s2_max = 100., plot=False, apply=True, **kwargs):
        if plot:
            plt.hist2d(self.d['s2'], self.d['largest_other_s2'], **kwargs)
            plt.xlabel('S2 (p.e.)')
            plt.ylabel('Largest other S2 (p.e.)')
            plt.axhline(largest_other_s2_max, ls='--', color='red', lw=2)
        if apply: self.d = cuts.below(self.d, 'largest_other_s2', largest_other_s2_max)
        return
    
    def cut_thresholds(self, s1_threshold=5., s2_threshold=100., plot=False, apply=True, **kwargs):
        if plot:
            self.plot_s1s2(**kwargs)
            plt.axvline(s1_threshold, color='red', ls='--', lw=2)
            plt.axhline(s2_threshold, color='red', ls='--', lw=2)
            plt.show()
        if apply:
            self.d = cuts.above(self.d, 's1', s1_threshold)
            self.d = cuts.above(self.d, 's2', s2_threshold)
        return
    
    def cut_saturation(self, plot=False, apply=True):
        if plot:
            sat = (self.d['s1_n_saturated_channels'] > 0) | (self.d['s2_n_saturated_channels'] > 0)
            nonsat = np.invert(sat)
            plt.scatter(self.d['s1'][nonsat], self.d['s2'][nonsat], color='blue',
                        edgecolor='None', s=2, label='Not saturated')
            plt.scatter(self.d['s1'][sat], self.d['s2'][sat], color='red',
                        edgecolor='None', s=10, label='ADC saturated')
            plt.xlabel('S1 (p.e.)')
            plt.ylabel('S2 (p.e.)')
            plt.legend(loc='best')
            
        if apply:
            self.d = cuts.below(self.d, 's1_n_saturated_channels', 1)
            self.d = cuts.below(self.d, 's2_n_saturated_channels', 1)
        return

    
    def cut_low_energy(self, cs1_max = 200, cs2_max = 30e3, plot=False, apply=True, **kwargs):
        if plot:
            self.plot_cs1cs2(**kwargs)
            plt.axvline(cs1_max, color='red', ls='--', lw=2)
            plt.axhline(cs2_max, color='red', ls='--', lw=2)
            plt.show()
        if apply:
            self.d = cuts.below(self.d, 'cs1', cs1_max)
            self.d = cuts.below(self.d, 'cs2', cs2_max)
        return
        
    def cut_s2_aft(self,  s2_aft_range=(0.55, 0.78), plot=False, apply=True):
        if plot:
            plt.hist2d(np.log10(self.d['s2']), self.d['s2_area_fraction_top'], bins=100, norm=LogNorm())
            plt.xlabel('log10 of S2/p.e.')
            plt.ylabel('S2 aft')
            for _l in s2_aft_range:
                plt.axhline(_l, color='red', ls='--', lw=2)
            plt.show()
        if apply: self.d = cuts.range_selection(self.d, 's2_area_fraction_top', s2_aft_range)
        return
    
    def cut_s1_aft(self, plot=False, apply=True, s1_bins=20, s1_range=(0, 2000), dt_range=(0,60), dt_bins=60):
        # Interpolate S1 versus s1 aft
        x1, y1 = get_trend(self.d['drift_time'], self.d['s1_area_fraction_top'], dt_range, dt_bins)
        f_s1_aft = my_interp(x1, y1, kind='linear')
        # Compute the difference from the trend
        self.d['s1_aft_difference'] = self.d['s1_area_fraction_top'] - f_s1_aft(self.d['drift_time'])

        # Get the upper and lower percentiles...
        x, y_upper = get_trend(self.d['s1'], self.d['s1_aft_difference'], bins=s1_bins, x_range=s1_range, 
                               mode='percentile', pct=95)
        x, y_lower = get_trend(self.d['s1'], self.d['s1_aft_difference'], bins=s1_bins, x_range=s1_range, 
                               mode='percentile',pct=5)

        # ... And their interpolation...
        f_lower = my_interp(x, y_lower, kind='cubic')
        f_upper = my_interp(x, y_upper, kind='cubic')
        
        
        if plot:
            plt.hist2d(self.d['drift_time'], self.d['s1_area_fraction_top'], bins=100, norm=LogNorm(), range=((0, 60), (0,1)))
            x_plot = np.linspace(0, 60, 250)
            plt.plot(x_plot, f_s1_aft(x_plot), color='red', label='Interpolation')
            plt.scatter(x1, y1, s=5, label='Binned mean trend')
            plt.xlabel('Drift time ($\mu$s)')
            plt.ylabel('S1 AFT')
            plt.legend()
            plt.show()
            
            plt.hist2d(self.d['s1'], self.d['s1_aft_difference'], bins=100, 
                       range=(s1_range, (-0.5, 0.5)), norm=LogNorm())
            plt.axhline(0, color='black')
            x_plot = np.linspace(s1_range[0], s1_range[1], 400)

            plt.plot(x_plot, f_upper(x_plot), color='red', label='Interpolation')
            plt.plot(x_plot, f_lower(x_plot), color='red')
            plt.scatter(x, y_upper, s=10, color='red', label='Binned percentile')
            plt.scatter(x, y_lower, s=10, color='red')
            plt.xlabel('S1 (a.u.)')
            plt.ylabel('Difference from mean AFT')
            plt.show()
    
        self.d['AFT_Upper'] = (self.d['s1_aft_difference'] < f_upper(self.d['s1'])) # Add boolean variable
        self.d['AFT_Lower'] = (self.d['s1_aft_difference'] > f_lower(self.d['s1']))
        if apply:
            self.d = cuts.selection(self.d, self.d['AFT_Upper'], 'AFT_Upper')
            self.d = cuts.selection(self.d, self.d['AFT_Lower'], 'AFT_Lower')
        return
        
    def cut_drift_time(self, drift_time_bounds=(0, 60), plot=False, apply=True, **kwargs):
        print('Warning: this cut is depricated, please use the fiducial volume cut!')
        if plot:
            plt.hist(self.d['drift_time'], **kwargs)
            for _l in drift_time_bounds:
                plt.axvline(_l, color='red', ls='--', lw=2)
        if apply:
            self.d = cuts.range_selection(self.d, 'drift_time', drift_time_bounds)
        return
    
    def cut_fiducial_volume(self, z_bounds=(-9.5, -0.5), plot=False, apply=True, **kwargs):
        if plot:
            plt.hist(self.d['z'], **kwargs)
            for _l in z_bounds:
                plt.axvline(_l, color='red', ls='--', lw=2)
        if apply:
            self.d = cuts.range_selection(self.d, 'z', z_bounds)
        return
        
    def cut_s2_range_50p_area(self, 
                              mode = 'auto',
                              dt_range = (0, 60),
                              nbins = 30,
                              dt_cutoff = 30,
                              pickle_file = '../light_yield/data/cs137_s2_width.pickle', cutoff=251.5,
                              plot=False, apply=True, verbose=True, **kwargs):
        '''
        Cut all events width a width inconsistent with drift time. S2 width model from Jelle, fit parameters from 
        pickle file in argument (mode == 'pickle') or fit medians (mode=='auto').
        Semi-automatic mode: fit the offset, but not the diffusion.
        
        Fitting options:
          `dt_range` sets the range where to bin
          `nbins` sets the number of bins
          `dt_cutoff` is the lower values for the cutoff.
        June 2017 / August 2017 / October 2017
        '''
        
        if (mode == 'auto' or mode == 'semi') | (plot):
            # We need the medians either for plotting or for computing diffusion...
            dt, s2_50 = get_trend(self.d['drift_time'], self.d['s2_range_50p_area'], dt_range, bins=nbins, mode='median')
      
        if mode == 'auto':
            # Full-auto mode: fix neither initial width nor diffusion
            syst_err = np.sqrt(0.1**2 + 0.4**2)
            popt_s2_w, pcov_s2_w = scipy.optimize.curve_fit(self.s2_width_model_t, dt[dt >= dt_cutoff], 
                                                            s2_50[dt >= dt_cutoff],
                                                            p0=[10, 250e-9])
            # Errors...
            perr_s2_w = np.sqrt(np.diag(pcov_s2_w))
            total_err = np.sqrt(perr_s2_w[0]**2 + syst_err**2)
            if verbose:
                print('Parameters found: ', popt_s2_w)
                print('Diffusion is %.2f +- %.2f +- %.2f cm**2 / s' % (popt_s2_w[0], perr_s2_w[0], syst_err))
                print('Diffusion is %.2f +- %.2f cm**2 / s' % (popt_s2_w[0], total_err))
        elif mode == 'pickle':
            # Full-pickle mode: read the parameters from file.
            popt_s2_w, pcov_s2_w = pickle.load(open(pickle_file, 'rb'))
        elif mode == 'semi':
            # Semi-auto mode: read the diffusion constant from file, but not the initial width.
            popt_s2_w, pcov_s2_w = pickle.load(open(pickle_file, 'rb'))
            diff_const = popt_s2_w[0]
            if verbose: 
                print('Using diffusion constant %.2f cm**2 / s' % (diff_const))
            # Define the function to optimize: the diffusion model with diffusion constant fixed.
            def width_model_diffusion_fixed(t, w0):
                return self.s2_width_model_t(t, diff_const, w0)
            
            popt_s2_w0, pcov_s2_w0 = scipy.optimize.curve_fit(width_model_diffusion_fixed, dt[dt >= dt_cutoff], 
                                                            s2_50[dt >= dt_cutoff],
                                                            p0=[250e-9])
            popt_s2_w = np.concatenate([[diff_const], popt_s2_w0])
            if verbose: print('w0 found: %.2f ns' % (popt_s2_w[1] * 1e9))
            
        # Now that we have the width model, compute the difference to the model.
        self.d['s2_width_difference'] = self.d['s2_range_50p_area'] - self.s2_width_model_t(self.d['drift_time'],
                                                                                          *popt_s2_w)

        if plot:
            plt.hist(self.d['s2_width_difference'], bins=230, histtype='step', range=(-300, 2000))
            plt.axvline(cutoff, ls='--', color='red')
            plt.yscale('log')
            plt.xlabel(r'Difference from S2 width model (\si{\nano s})')
            plt.ylabel(r'Counts / (\SI{10}{ns})')
            plt.xlim(-300, 2000)
            plt.show()
            
            plt.hist2d(self.d['drift_time'], self.d['s2_range_50p_area'], bins=100, range=((0, 62), (0, 1e3)),
                       norm=LogNorm(), **kwargs)
            plt.colorbar(label='Counts')
            plt.xlabel(r'Drift time (\si{\micro s})')
            plt.ylabel(r'S2 width (\si{\nano s})')
            x_plot = np.linspace(0, 62, 100)
            plt.plot(x_plot, self.s2_width_model_t(x_plot, *popt_s2_w), label='Width model')
            plt.plot(x_plot, self.s2_width_model_t(x_plot, *popt_s2_w) + cutoff, label='Cutoff')
            for _l in (dt_cutoff, dt_range[1]):
                plt.axvline(_l, ls='--', color='gray')
            plt.scatter(dt, s2_50)
            plt.legend(framealpha=0.9, loc = 'lower right')
            plt.show()
        if apply:
            self.d = cuts.below(self.d, 's2_width_difference', cutoff)
        return popt_s2_w
    
    def cut_s2_range_50p_area_low(self, cutoff = -50, apply=True, plot=False):
        '''
        Cut 
        '''
        if 's2_width_difference' not in self.d.keys():
            raise RuntimeWarning('Cut not applied, run cut_s2_range_50p_area first...')
            return
        if apply:
            self.d = cuts.above(self.d, 's2_width_difference', cutoff)
        return
    
    def cut_NaI_interaction_exists(self):
        if not self.include_NaI:
            print('Warning: not applying NaI cut since you disabled it!')
            return
        self.d = cuts.isfinite(self.d, 'NaI_area')
        
        
    ########################################### CORRECTIONS
    
    def corr_s1_ly(self, ly_filename='/home/erik/win/xams/analysis/light_yield/na22_ly.pickle', kind='quadratic'):
        '''
        Correct S1 light yield based on interpolation of points in pickle file.
        '''
        # Import the correction from calibration
        # Interpolate light curve
        x, y = pickle.load(open(ly_filename, 'rb'))
        f_s1_corr = my_interp(x, y, kind=kind)
        def get_cs1(s1, z, f_s1_corr):
            average_s1 = np.average([f_s1_corr(_z) for _z in np.linspace(-10, 0, 500)])
            return s1/f_s1_corr(z) * average_s1
        self.d['cs1']= get_cs1(self.d['cs1'], self.d['z'], f_s1_corr)

    def corr_s1_ly_poly(self, ly_filename='data/cs137_ly_p2_rough.pickle', order=2):
        if order != 2:
            raise NotImplementedError('Only order-2 fits allowed for now...')
        
        # Assumes a second-degree polynomial fit.
        def p2(x, a0, a1, a2):
            return a0 + a1 * x + a2 * x**2
        
        popt, pcov = pickle.load(open(ly_filename, 'rb'))
        def get_cs1(s1, z, f):
            # Perform a volume average
            average_s1 = np.average([f(_z, *popt) for _z in np.linspace(-10, 0, 500)])
            return s1/f(z, *popt) * average_s1
        self.d['cs1']= get_cs1(self.d['s1'], self.d['z'], p2)
        return

    def corr_s2_electron_lifetime(self, lifetime=None, pickle_file=None, verbose=False):
        '''
        Enter electron lifetime in microseconds please
        '''
        if pickle_file:
            popt_life, _ = pickle.load(open(pickle_file, 'rb'))
            if lifetime:
                print('Warning: you gave both electron lifetime AND a pickle file, I will use the pickle ONLY.')
            lifetime = popt_life[1]
        if verbose: print('Using lifetime %f' % lifetime)
        if lifetime <= 0: raise ValueError('What? Negative lifetime? No, screw you! This is what I got: %f' % lifetime)
        self.d['cs2'] = self.d['s2'] * np.exp(self.d['drift_time'] / lifetime)
        self.d['cs2b'] = self.d['s2_bot'] * np.exp(self.d['drift_time'] / lifetime)
        
        
    def corr_s2_sag(self, cs1_range = (150, 250), cs2_cutoff = 7e3, time_bins=10, mode='median', plot=False, apply=True,
                   **kwargs):
        '''
        Ugh such an ungly word but at least clear for our own jargon.
        '''
        print('Warning: function is depricated. Use the `corr_s2_decrease` function please.')
        _d = self.d[(self.d['cs1'] > cs1_range[0]) & (self.d['cs1'] < cs1_range[1])]
        _d2 = _d[_d['cs2'] > cs2_cutoff]
        x, y = get_trend(_d2['t'], _d2['cs2'],x_range=(min(_d2['t']), max(_d2['t'])), bins=time_bins, mode='median')
        f_s2 = my_interp(x, y, kind='linear')
        def f_s2_corr(t, f_s2):
            return f_s2(0) / f_s2(t)

        if plot:
            plt.hist2d(_d['t'], _d['cs2'], **kwargs)
            plt.axhline(cs2_cutoff, color='red', lw=2, ls='--')
            plt.xlabel('Time (s)')
            plt.ylabel('cS2 (p.e.)')

            x_plot = np.linspace(0, max(_d['t']), 250)
            plt.scatter(x,y, color='red')
            plt.plot(x_plot, f_s2(x_plot), color='red')

            plt.show()

            s2_cutoff = 7e3
            plt.hist(_d['cs2'], bins=100, histtype='step', label='Before corr')
            plt.axvline(cs2_cutoff, color='red', lw=2, ls='--')
            plt.xlabel('cS2 (p.e.)')
        if apply:
            self.d['cs2'] = self.d['cs2'] * f_s2_corr(self.d['t'], f_s2)
            if plot:
                _d = self.d[(self.d['cs1'] > cs1_range[0]) & (self.d['cs1'] < cs1_range[1])]
                plt.hist(_d['cs2'], bins=100, histtype='step', label='After corr')
        return (x, y)

#     def corr_s2_decrease(self, parameters, cs1_range = (175, 225), plot=False, apply=True, **kwargs):
#         '''
#         Use a linear correction function to correct the S2 decrease. Note that this supersedes the quick-and-dirty 
#         `corr_s2_sag` function.
#         '''
#         def lin(x, a0, a1):
#             return a0 + a1 * x
        
#         if plot:
#             # Select slice
#             _d = self.d[(self.d['cs1'] > cs1_range[0]) & (self.d['cs1'] < cs1_range[1])]
#             plt.hist2d(_d['t'], _d['cs2'], **kwargs)
#             plt.xlabel('Time (s)')
#             plt.ylabel('cS2 (p.e.)')
#             # Plot approximation
#             x_plot = np.linspace(0, max(_d['t']), 250)
#             plt.plot(x_plot, lin(x_plot, *parameters))
        
#         if apply:
#             if self.corr_s2_decrease_isapplied:
#                 print('Already applied correction, will not do it again!')
#                 return
#             self.corr_s2_decrease_isapplied = True
#             self.d['cs2'] = self.d['cs2'] * lin(0, *parameters) / (lin(self.d['t'], *parameters))
#         return
    
    def corr_s2_decrease(self, f, apply=True):
        if apply:
            if self.corr_s2_decrease_isapplied:
                print('Already applied correction, will not do it again!')
                return
            self.corr_s2_decrease_isapplied = True
            self.d['cs2'] = self.d['cs2'] * f(self.d['t'])
        return
        
    def corr_pmtgains(self, processing_gains, voltages, verbose=False):
        '''
        Usage: give processing gain list such as in config, [PMT1voltage, PMT2voltage]
        Enable verbose to get a bunch of prints
        Should NOT be used if correct gains set...
        '''
        print('Hey, why not use the correct gain in processing?')
        # Get the multiplicative factors
        fbot = processing_gains[3] / get_gain(1, voltages[0]) # Bottom PMT is pmt 1 is channel 3
        ftop = processing_gains[0] / get_gain(2, voltages[1])
        if verbose: print('Using gains %f and %f, factors %f and %f, PMT1 and 2 respectively.'
                         % (get_gain(1, voltages[0]), get_gain(2, voltages[1]), fbot, ftop))

        # Correction for all peaks
        for peak_name in ['s1', 's2', 'largest_other_s1', 'largest_other_s2']:
            self._correct_this_peak(peak_name, fbot, ftop)
        
        # Set the cs1 to s1
        # I am assuming that you perform a gain correction BEFORE any other correction (fair assumption?)
        self.d['cs1'] = self.d['s1']
        self.d['cs2'] = self.d['s2']

        return
    
    def corr_z(self, offset, drift_velocity, verbose=False):
        '''
        TODO place this in the XAMS ini!
        Enter offset in microseconds, velocity in km/s
        '''
        self.d['z'] = (self.d['drift_time'] - offset) * 1e-6 * drift_velocity * 1e5 * (- 1)
        return
        
    ########################################### OTHER
    
    def len(self):
        return len(self.d)
    
    ########################################### AUXILARY
    
    def _correct_this_peak(self, peak_name, fbot, ftop):
        self.d[peak_name + '_top'] = self.d[peak_name] * self.d[peak_name + '_area_fraction_top'] * ftop
        self.d[peak_name + '_bot'] = self.d[peak_name] * (1-self.d[peak_name + '_area_fraction_top']) * fbot
        self.d[peak_name] = self.d[peak_name + '_top'] + self.d[peak_name + '_bot']
        self.d[peak_name + 'area_fraction_top'] = self.d[peak_name + '_top'] / self.d[peak_name]
        return
   
    def _integrate_until_fraction(self, w, area_tot, fractions_desired, results):
        """For array of fractions_desired, integrate w until fraction of area is reached, place sample index in results
        Will add last sample needed fractionally.
        eg. if you want 25% and a sample takes you from 20% to 30%, 0.5 will be added.
        Assumes fractions_desired is sorted and all in [0, 1]!
        This function is stolen and modified from Pax - Erik Hogenbirk September 2017
        """
        fraction_seen = 0
        current_fraction_index = 0
        needed_fraction = fractions_desired[current_fraction_index]
        for i, x in enumerate(w):
            # How much of the area is in this sample?
            fraction_this_sample = x/area_tot
            # Will this take us over the fraction we seek?
            # Must be while, not if, since we can pass several fractions_desired in one sample
            while fraction_seen + fraction_this_sample >= needed_fraction:
                # Yes, so we need to add the next sample fractionally
                area_needed = area_tot * (needed_fraction - fraction_seen)
                if x != 0:
                    results[current_fraction_index] = i + area_needed/x
                else:
                    results[current_fraction_index] = i
                # Advance to the next fraction
                current_fraction_index += 1
                if current_fraction_index > len(fractions_desired) - 1:
                    return
                needed_fraction = fractions_desired[current_fraction_index]
            # Add this sample's area to the area seen, advance to the next sample
            fraction_seen += fraction_this_sample
        if needed_fraction == 1:
            results[current_fraction_index] = len(w)
        else:
            return -1


    def s2_width_model_t(self, t, diffusion_constant, w0):
        '''
        S2 width model stolen from Jelly Monster. Great info on wiki.
        Input drift time in us
        '''
        # diffusion_constant = PAX_CONFIG['WaveformSimulator']['diffusion_constant_liquid']
        v_drift = drift_velocity_liquid = 1.68 * 10**5 # cm/s
        t = t * 1e-6 # Convert to seconds
        # w0 = 348.6 * units.ns
        # WATCH the constant: it is NOT 4.something
        return 1e9 * np.sqrt(w0 ** 2 + 3.6395 * diffusion_constant * t / v_drift ** 2)

In [3]:
def combine_frameworks(x0, x1=None):
    '''
    Combine the data from multiple XAMSAnalysis objects into one.
    Either give two objects as argument, or a list of objects.
    '''
    if x1 is None:
        assert type(x0) == list
        
        x_combined = x0[0]
        for x_to_add in x0[1:]:
            x_combined = combine_frameworks(x_combined, x_to_add)
        return x_combined
    else:
        run_names = np.concatenate([x0.filenames, x1.filenames])
        processed_data_path = x0.processed_data_path
        minitree_path = x0.minitree_path

        x = XAMSAnalysis(run_names, processed_data_path, minitree_path)
        # Extract datasets
        d0 = x0.d
        d1 = x1.d
        # Increment run numbers
        d0_run_numbers = d0['run_number']
        d1_run_numbers = d1['run_number'] + np.max(d0['run_number']) + 1
        x.d = pd.concat([d0, d1])
        x.d['run_number'] = pd.concat([d0_run_numbers, d1_run_numbers])
        # Compute meta-data
        x.livetime = x0.livetime + x1.livetime
        x.n_events = x0.n_events + x1.n_events
        x.rate = x.n_events / x.livetime

        return x