In [4]:
import hax
from hax import cuts
import pickle
from gain_extrapolator_reduced import get_gain

In [5]:
class XAMSAnalysis():
    '''
    This holds the data and functions to perform basic analysis on the XAMS data.
    Functions here are:
    - ``load`` to load the data and all relevant minitrees
    - ``cut_*`` to apply a cut (and give ``plot=True`` to see what it cuts)
    - ``cuts_*`` for quick application of multiple cuts
    - ``plot_*`` for standard plots.
    - ``corr_*`` for corrections
    '''
    __version__ = '0.0.0'
    # Holds dataframe containing all data
    d = None
    # This will list all basic cuts
    cut_list = []
    
    ################################################# BASICS
    
    def __init__(self, filenames, processed_data_path, minitree_path):
        self.filenames = filenames
        self.processed_data_path = processed_data_path
        self.minitree_path = minitree_path
        self.cut_list = [
            self.cut_interaction_exists,
            self.cut_thresholds,
            self.cut_low_energy,
            self.cut_largest_other_s1,
            self.cut_largest_other_s2,
            self.cut_saturation,
            self.cut_s2_aft,
            self.cut_s1_aft,
            self.cut_drift_time
        ]        
        return
    
    def load(self):
        '''
        Load the data into datafame
        '''
        hax.init(
            # Always use these lines to tell hax that we don't care about Xe1T
            experiment='XAMS', 
            pax_version_policy='loose', use_runs_db = False,
            # Here come the useful settings
            main_data_paths = [self.processed_data_path],
            minitree_paths = [self.minitree_path],       
         )
        # Load data
        self.d = hax.minitrees.load(self.filenames, ['Fundamentals','Basics', ExtraS1S2Properties])
        # Recompute drift time to us
        self.d['drift_time'] = self.d['drift_time'] * 1e-3
        self.d['cs1'] = self.d['s1']
        self.d['cs2'] = self.d['s2']
        self.d['t'] = (self.d['event_time'] - self.d['event_time'].values[0]) * 1e-9


        
        
    ################################################# PLOTTING
    def plot_s1s2(self, **kwargs):
        '''
        Make an S1-S2 2d histogram
        '''
        plt.hist2d(self.d['s1'], self.d['s2'], **kwargs)
        plt.xlabel('S1 (p.e.)')
        plt.ylabel('S2 (p.e.)')
        return

    def plot_cs1cs2(self, **kwargs):
        plt.hist2d(self.d['cs1'], self.d['cs2'], **kwargs)
        plt.xlabel('cS1 (p.e.)')
        plt.ylabel('cS2 (p.e.)')
        return
    
    def plot_cs1bs2(self, **kwargs):
        plt.hist2d(self.d['cs1'], self.d['s2_bot'], **kwargs)
        plt.xlabel('cS1 (p.e.)')
        plt.ylabel('S2 bottom (p.e.)')
        return
     
    ################################################# CUTS
    
    def cuts_apply_all(self):
        for cut in self.cut_list:
            cut()
        return None
        

    def cut_interaction_exists(self):
        self.d = cuts.isfinite(self.d, 's1')
    
    def cut_largest_other_s1(self, largest_other_s1_max=5, plot=False, apply=True, **kwargs):
        if plot:
            plt.hist2d(self.d['s1'], self.d['largest_other_s1'], **kwargs)
            plt.xlabel('S1 (p.e.)')
            plt.ylabel('Largest other S1 (p.e.)')
            plt.axhline(largest_other_s1_max, ls='--', color='red', lw=2)
            plt.show()
            
        if apply: self.d = cuts.below(self.d, 'largest_other_s1', largest_other_s1_max)
        return

    def cut_largest_other_s2(self, largest_other_s2_max = 100., plot=False, apply=True, **kwargs):
        if plot:
            plt.hist2d(self.d['s2'], self.d['largest_other_s2'], **kwargs)
            plt.xlabel('S2 (p.e.)')
            plt.ylabel('Largest other S2 (p.e.)')
            plt.axhline(largest_other_s2_max, ls='--', color='red', lw=2)
            plt.show()
        if apply: self.d = cuts.below(self.d, 'largest_other_s2', largest_other_s2_max)
        return
    
    def cut_thresholds(self, s1_threshold=5., s2_threshold=100., plot=False, apply=True, **kwargs):
        if plot:
            self.plot_s1s2(**kwargs)
            plt.axvline(s1_threshold, color='red', ls='--', lw=2)
            plt.axhline(s2_threshold, color='red', ls='--', lw=2)
            plt.show()
        if apply:
            self.d = cuts.above(self.d, 's1', s1_threshold)
            self.d = cuts.above(self.d, 's2', s2_threshold)
        return
    
    def cut_saturation(self):
        self.d = cuts.below(self.d, 's1_n_saturated_channels', 1)
        self.d = cuts.below(self.d, 's2_n_saturated_channels', 1)
        return
    
    def cut_low_energy(self, s1_max = 1000, s2_max = 60e3, plot=False, apply=True, **kwargs):
        if plot:
            self.plot_s1s2(**kwargs)
            plt.axvline(s1_max, color='red', ls='--', lw=2)
            plt.axhline(s2_max, color='red', ls='--', lw=2)
            plt.show()
        if apply:
            self.d = cuts.below(self.d, 's1', s1_max)
            self.d = cuts.below(self.d, 's2', s2_max)
        return
        
    def cut_s2_aft(self,  s2_aft_range=(0.55, 0.78), plot=False, apply=True):
        if plot:
            plt.hist2d(np.log10(self.d['s2']), self.d['s2_area_fraction_top'], bins=100, norm=LogNorm())
            plt.xlabel('log10 of S2/p.e.')
            plt.ylabel('S2 aft')
            for _l in s2_aft_range:
                plt.axhline(_l, color='red', ls='--', lw=2)
            plt.show()
        if apply: self.d = cuts.range_selection(self.d, 's2_area_fraction_top', s2_aft_range)
        return
    
    def cut_s1_aft(self, plot=False, apply=True, s1_bins=20, s1_range=(0, 2000), dt_range=(0,60), dt_bins=60):
        # Interpolate S1 versus s1 aft
        x1, y1 = get_trend(self.d['drift_time'], self.d['s1_area_fraction_top'], dt_range, dt_bins)
        f_s1_aft = my_interp(x1, y1, kind='linear')
        # Compute the difference from the trend
        self.d['s1_aft_difference'] = self.d['s1_area_fraction_top'] - f_s1_aft(self.d['drift_time'])

        # Get the upper and lower percentiles...
        x, y_upper = get_trend(self.d['s1'], self.d['s1_aft_difference'], bins=s1_bins, x_range=s1_range, 
                               mode='percentile', pct=95)
        x, y_lower = get_trend(self.d['s1'], self.d['s1_aft_difference'], bins=s1_bins, x_range=s1_range, 
                               mode='percentile',pct=5)

        # ... And their interpolation...
        f_lower = my_interp(x, y_lower, kind='cubic')
        f_upper = my_interp(x, y_upper, kind='cubic')
        
        
        if plot:
            plt.hist2d(self.d['drift_time'], self.d['s1_area_fraction_top'], bins=100, norm=LogNorm(), range=((0, 60), (0,1)))
            x_plot = np.linspace(0, 60, 250)
            plt.plot(x_plot, f_s1_aft(x_plot), color='red', label='Interpolation')
            plt.scatter(x1, y1, s=5, label='Binned mean trend')
            plt.xlabel('Drift time ($\mu$s)')
            plt.ylabel('S1 AFT')
            plt.legend()
            plt.show()
            
            plt.hist2d(self.d['s1'], self.d['s1_aft_difference'], bins=100, 
                       range=(s1_range, (-0.5, 0.5)), norm=LogNorm())
            plt.axhline(0, color='black')
            x_plot = np.linspace(s1_range[0], s1_range[1], 400)

            plt.plot(x_plot, f_upper(x_plot), color='red', label='Interpolation')
            plt.plot(x_plot, f_lower(x_plot), color='red')
            plt.scatter(x, y_upper, s=10, color='red', label='Binned percentile')
            plt.scatter(x, y_lower, s=10, color='red')
            plt.xlabel('S1 (a.u.)')
            plt.ylabel('Difference from mean AFT')
            plt.show()
    
        self.d['AFT_Upper'] = (self.d['s1_aft_difference'] < f_upper(self.d['s1'])) # Add boolean variable
        self.d['AFT_Lower'] = (self.d['s1_aft_difference'] > f_lower(self.d['s1']))
        if apply:
            self.d = cuts.selection(self.d, self.d['AFT_Upper'], 'AFT_Upper')
            self.d = cuts.selection(self.d, self.d['AFT_Lower'], 'AFT_Lower')
        return
        
    def cut_drift_time(self, drift_time_bounds=(0, 60), plot=False, apply=True, **kwargs):
        if plot:
            plt.hist(self.d['drift_time'], **kwargs)
            plt.show()
        if apply:
            self.d = cuts.range_selection(self.d, 'drift_time', drift_time_bounds)
        return

        
        
    ########################################### CORRECTIONS
    
    def corr_s1_ly(self, ly_filename='/home/erik/win/xams/analysis/light_yield/na22_ly.pickle', kind='quadratic'):
        '''
        Correct S1 light yield based on interpolation of points in pickle file.
        '''
        # Import the correction from calibration
        # Interpolate light curve
        x, y = pickle.load(open(ly_filename, 'rb'))
        f_s1_corr = my_interp(x, y, kind=kind)
        def get_cs1(s1, z, f_s1_corr):
            average_s1 = np.average([f_s1_corr(_z) for _z in np.linspace(-10, 0, 500)])
            return s1/f_s1_corr(z) * average_s1
        self.d['cs1']= get_cs1(self.d['cs1'], self.d['z'], f_s1_corr)

    def corr_s1_ly_poly(self, ly_filename='/home/erik/win/xams/analysis/light_yield/na22_ly_poly.pickle'):
        f, popt, pcov = pickle.load(open(ly_filename, 'rb'))
        def get_cs1(s1, z, f):
            average_s1 = np.average([f(_z, *popt) for _z in np.linspace(-10, 0, 500)])
            return s1/f(z, *popt) * average_s1
        self.d['cs1']= get_cs1(self.d['cs1'], self.d['z'], f)

    def corr_s2_sag(self, cs1_range = (150, 250), cs2_cutoff = 7e3, bins=10, mode='median', plot=False, apply=True):
        '''
        Ugh such an ungly word but at least clear for our own jargon.
        '''
        _d = self.d[(self.d['cs1'] > cs1_range[0]) & (self.d['cs1'] < cs1_range[1])]
        _d2 = _d[_d['cs2'] > cs2_cutoff]
        x, y = get_trend(_d2['t'], _d2['cs2'],x_range=(min(_d2['t']), max(_d2['t'])), bins=bins, mode='median')
        f_s2 = my_interp(x, y, kind='linear')
        def f_s2_corr(t, f_s2):
            return f_s2(0) / f_s2(t)

        if plot:
            plt.hist(_d['cs2'], bins=100, histtype='step')
            plt.hist2d(_d['t'], _d['s2'], bins=10)
            plt.axhline(cs2_cutoff, color='red', lw=2, ls='--')
            plt.xlabel('Time (s)')
            plt.ylabel('cS2 (p.e.)')

            x_plot = np.linspace(0, max(_d['t']), 250)
            plt.scatter(x,y, color='red')
            plt.plot(x_plot, f_s2(x_plot), color='red')

            plt.show()

            s2_cutoff = 7e3
            plt.hist(_d['cs2'], bins=100, histtype='step')
            plt.axvline(cs2_cutoff, color='red', lw=2, ls='--')
            plt.xlabel('cS2 (p.e.)')
            plt.show()
        if apply:
            self.d['cs2'] = self.d['cs2'] * f_s2_corr(self.d['t'], f_s2)
        
    def corr_pmtgains(self, processing_gains, voltages, verbose=False):
        '''
        Usage: give processing gain list such as in config, [PMT1voltage, PMT2voltage]
        Enable verbose to get a bunch of prints
        '''
        # Get the multiplicative factors
        fbot = processing_gains[3] / get_gain(1, voltages[0]) # Bottom PMT is pmt 1 is channel 3
        ftop = processing_gains[0] / get_gain(2, voltages[1])
        if verbose: print('Using gains %f and %f, factors %f and %f, PMT1 and 2 respectively.'
                         % (get_gain(1, voltages[0]), get_gain(2, voltages[1]), fbot, ftop))

        # Correction for all peaks
        for peak_name in ['s1', 's2', 'largest_other_s1', 'largest_other_s2']:
            self._correct_this_peak(peak_name, fbot, ftop)
        
        # Set the cs1 to s1
        # I am assuming that you perform a gain correction BEFORE any other correction (fair assumption?)
        self.d['cs1'] = self.d['s1']
        self.d['cs2'] = self.d['s2']

        return
    
    ########################################### AUXILARY
    
    def _correct_this_peak(self, peak_name, fbot, ftop):
        self.d[peak_name + '_top'] = self.d[peak_name] * self.d[peak_name + '_area_fraction_top'] * ftop
        self.d[peak_name + '_bot'] = self.d[peak_name] * (1-self.d[peak_name + '_area_fraction_top']) * fbot
        self.d[peak_name] = self.d[peak_name + '_top'] + self.d[peak_name + '_bot']
        self.d[peak_name + 'area_fraction_top'] = self.d[peak_name + '_top'] / self.d[peak_name]
        return
            