# Import packages

In [None]:
from array import array
import itertools
import time

import matplotlib.pyplot as plt
import numpy as np

import astroML.plotting as aml
import iminuit
import pandas as pd
import root_pandas as rpd
import ROOT

# Define template fit and purity

In [None]:
def haveSameLength(*args):
    n = len(args[0])
    return all(len(l) == n for l in args)

def normalize(x):
    return np.array(x, dtype='f')/np.sum(x)

class TemplateFit:
    def __init__(self, data, dataerr, signal, signalerr, bkg, bkgerr, binEdges, verbosity=1):
        if not haveSameLength(data, dataerr, signal, signalerr, bkg, bkgerr, binEdges[1:]):
            raise ValueError('Inputs do not have the same length (binEdges should have 1 more than the rest)')
        
        self.data = np.array(data, dtype='f')
        self.dataerr = np.array(dataerr, dtype='f')
        self.inputSignal = np.array(signal, dtype='f')
        self.inputSignalerr = np.array(signalerr, dtype='f')
        self.inputBkg = np.array(bkg, dtype='f')
        self.inputBkgerr = np.array(bkgerr, dtype='f')
        self.binEdges = binEdges
        
        self.signal = self.inputSignal/np.sum(self.inputSignal)
        self.signalerr = self.inputSignalerr/np.sum(self.inputSignal)
        self.bkg = self.inputBkg/np.sum(self.inputBkg)
        self.bkgerr = self.inputBkgerr/np.sum(self.inputBkg)
        self.binCenters = np.array([(hedge+ledge)/2.0 for ledge, hedge in zip(binEdges[:-1], binEdges[1:])])
        self.binWidths = np.array([hedge-ledge for ledge, hedge in zip(binEdges[:-1], binEdges[1:])])
               
        self.signalColor = '#3B7EA1'
        self.bkgColor = '#FDB515'
        self.figureSize = (10, 10)
        
        self.verbosity = verbosity
               
        self.doFit()
        
    def doFit(self):
        def Chi2(N, f):
            model = N*(f*self.signal + (1-f)*self.bkg)
            return np.sum(np.power(self.data-model, 2.0))

        mt = iminuit.Minuit(Chi2, N=np.sum(self.data), f=0.5, error_N=1, error_f=1,
                            errordef=1, print_level=self.verbosity)
        mt.migrad()
        
        self.fitN = mt.values['N']
        self.fitNerr = mt.errors['N']
        self.fitf = mt.values['f']
        self.fitferr = mt.errors['f']
        
        self.fitSignal = self.fitN*self.fitf*self.signal
        self.fitSignalerr = self.fitN*self.fitf*self.signalerr
        self.fitBkg = self.fitN*(1-self.fitf)*self.bkg
        self.fitBkgerr = self.fitN*(1-self.fitf)*self.bkgerr

    def getPurity(self, purityMin, purityMax):
        purity, pmin, pmax = getPurity(self.signal, self.bkg, self.binEdges, self.fitf, purityMin, purityMax, True)
        puritylow = getPurity(self.signal, self.bkg, self.binEdges, self.fitf-self.fitferr, purityMin, purityMax)
        purityhigh = getPurity(self.signal, self.bkg, self.binEdges, self.fitf+self.fitferr, purityMin, purityMax)
        if self.verbosity == 1:
            print 'Purity = %2.5f, +%2.5f, -%2.5f'%(purity, purityhigh-purity, purity-puritylow)
        return purity, pmin, pmax
        
    def plotFit(self, xlabel, title, showPurity=False, purityMin=0.0, purityMax=1.0, newFigure=True):
        if newFigure:
            plt.figure(figsize=self.figureSize)
        plt.errorbar(self.binCenters, self.data, yerr=self.dataerr, label='Data', fmt='ko')
        plt.bar(self.binCenters, self.fitSignal, yerr=self.fitSignalerr, width=self.binWidths,
                align='center', label='Signal (MC)', capsize=0,
                color=self.signalColor, ec=self.signalColor, ecolor=self.signalColor)
        plt.bar(self.binCenters, self.fitBkg, yerr=self.fitBkgerr, bottom=self.fitSignal, width=self.binWidths,
                align='center', label='Bkg (data) + Signal (MC)', capsize=0,
                color=self.bkgColor, ec=self.bkgColor, ecolor=self.bkgColor)
        
        if showPurity:
            purity, pmin, pmax = self.getPurity(purityMin, purityMax)
            
            ax = plt.gca()
            ax.axvspan(self.binCenters[0]-self.binWidths[0]/2.0, self.binCenters[pmin]-self.binWidths[pmin]/2.0,
                       color='black', alpha=0.4)
            ax.axvspan(self.binCenters[pmax]+self.binWidths[pmax]/2.0, self.binCenters[-1]+self.binWidths[-1]/2.0,
                       color='black', alpha=0.4)
            plt.suptitle('%s; Purity=%2.2f'%(title, purity))
        else:
            plt.suptitle(title)

        plt.legend(numpoints=1, fontsize=12, loc='best')
        plt.xlabel(xlabel)
        plt.ylabel('Entries')
                
        if newFigure:
            plt.show()
        
    def plotResiduals(self, xlabel):
        plt.figure(figsize=self.figureSize)
        fitTotal = self.fitSignal + self.fitBkg
        residuals = np.divide(self.data-fitTotal, self.dataerr)
        plt.plot(self.binCenters, residuals, 'o')
        plt.xlabel(xlabel)
        plt.ylabel('Fit - Data/Dataerr')
        plt.show()
        
    def plotTemplates(self, xlabel):
        plt.figure(figsize=(self.figureSize[0], self.figureSize[1]/2.0))
        plt.bar(self.binCenters, self.fitSignal, width=self.binWidths, align='center', label='Signal (MC)',
                color=self.signalColor, ec=self.signalColor, alpha=0.4)
        plt.bar(self.binCenters, self.fitBkg, width=self.binWidths, align='center', label='Background (data)',
                color=self.bkgColor, ec=self.bkgColor, alpha=0.4)
        plt.xlabel(xlabel)
        plt.ylabel('Entries')
        plt.legend(loc='best')
        plt.show()
        
    def plotNormalizedTemplates(self, xlabel):
        plt.figure(figsize=(self.figureSize[0], self.figureSize[1]/2.0))
        plt.bar(self.binCenters, self.signal, width=self.binWidths, align='center', label='Signal (MC)',
                color=self.signalColor, ec=self.signalColor, alpha=0.4)
        plt.bar(self.binCenters, self.bkg, width=self.binWidths, align='center', label='Background (data)',
                color=self.bkgColor, ec=self.bkgColor, alpha=0.4)
        plt.xlabel(xlabel)
        plt.ylabel('Entries')
        plt.legend(loc='best')
        plt.show()
        
def getPurity(signal, bkg, binEdges, frac, purityMin, purityMax, returnRange=False):
    # signal and bkg should be normalized to 1
    pmin = min([i for i, edge in enumerate(binEdges) if edge >= purityMin])
    pmax = max([i for i, edge in enumerate(binEdges) if edge <= purityMax])
    purity = np.sum(frac*signal[pmin:pmax])/(np.sum([(1-frac)*bkg[pmin:pmax], frac*signal[pmin:pmax]]))
    
    if returnRange:
        return purity, pmin, pmax
    else:
        return purity

# Define function to perform fits over many datasets and/or templates

In [None]:
def getFitResults(datasets, signals, bkgs, binEdges, pmin=None, pmax=None, verbosity=0, showDistributions=False):
    fitfvals, fitNvals, purvals = array('f'), array('f'), array('f')
    if pmin == None:
        pmin = binEdges[0]
    if pmax == None:
        pmax = binEdges[-1]
        
    for (dataset, signal, bkg) in itertools.product(datasets, signals, bkgs):
        tf = TemplateFit(dataset, np.sqrt(dataset), signal, np.sqrt(signal),
                         bkg, np.sqrt(bkg), binEdges, verbosity=verbosity)
        fitfvals.append(tf.fitf)
        fitNvals.append(tf.fitN)
        purvals.append(tf.getPurity(pmin, pmax)[0])
        
    if showDistributions:
        plt.figure(figsize=(15,6))
        plt.subplot(131)
        aml.hist(fitfvals, 'knuth')
        plt.xlabel('Signal fraction')
        plt.subplot(132)
        aml.hist(fitNvals, 'knuth')
        plt.xlabel('Normalization')
        plt.subplot(133)
        aml.hist(purvals, 'knuth')
        plt.xlabel('Purity')
        plt.show()
        print 'Number of results: %i'%len(fitfvals)
    
    return fitfvals, fitNvals, purvals

def varyWithinBins(realShape, nVariations):
    realShape = np.array(realShape)
    return np.random.poisson(lam=realShape, size=(nVariations, realShape.size))

## Calculate fit uncertainty due to statistical uncertainty on template

In [None]:
def calculateFitUncertainty(data, signal, bkg, binEdges, pmin=None, pmax=None):
    signals = varyWithinBins(signal, 100)
    bkgs = varyWithinBins(bkg, 100)
    fitfvals, fitNvals, purvals = getFitResults([data], signals, bkgs, binEdges, pmin=pmin, pmax=pmax, verbosity=0, showDistributions=True)
    print 'Signal fraction: %2.3f, sigma: %2.3f'%(np.mean(fitfvals), np.std(fitfvals))
    print 'Normalization: %2.2f, sigma: %2.2f'%(np.mean(fitNvals), np.std(fitNvals))
    print 'Purity: %2.5f, sigma: %2.5f'%(np.mean(purvals), np.std(purvals))

In [None]:
def plotFitUncertaintyExamples(data, dataerr, inputSignal, inputBkg, binEdges):
    signals = varyWithinBins(inputSignal, 2)
    bkgs = varyWithinBins(inputBkg, 2)
    plt.figure(figsize=(15,12))
    for i, (signal, bkg) in enumerate(itertools.product(signals, bkgs)):
        plt.subplot(2,2,i+1)
        tf = TemplateFit(data, dataerr, signal, np.sqrt(signal), bkg, np.sqrt(bkg), binEdges, verbosity=0)
        tf.plotFit('', '', newFigure=False)
        ax = plt.gca()
        ax.legend_.remove()
        ax.set_ylabel('')
    plt.show()

## Check closure of template fit

In [None]:
def checkClosure(signal, bkg, binEdges, norm, f, nDatasets, verbosity=0, showDistributions=False):
    normSignal = signal/np.sum(signal)
    normBkg = bkg/np.sum(bkg)
    realShape = norm*(f*normSignal + (1-f)*normBkg)
    datasets = varyWithinBins(realShape, nDatasets)
    
    fitfvals, fitNvals, purvals = getFitResults(datasets, [signal], [bkg], binEdges, verbosity=verbosity, showDistributions=showDistributions)
    return {'fmean': np.mean(fitfvals), 'fsigma': np.std(fitfvals), 'ftrue': f,
            'Nmean': np.mean(fitNvals), 'Nsigma': np.std(fitNvals)}

def checkClosureOverParameters(data, signal, bkg, binEdges, nDatasets):
    datanorm = int(np.sum(data))
    normvals = [datanorm, datanorm/2, datanorm/4, datanorm*2]
    fvals = np.linspace(0.0, 1.0, num=11)

    results = {}
    for (norm, f) in itertools.product(normvals, fvals):
        results[(norm, f)] = checkClosure(signal, bkg, binEdges, norm, f, nDatasets)
    
    return results, fvals, normvals

def plotCheckClosureResults(results, fvals, normvals):
    plots = {norm: [] for norm in normvals}
    for (norm, f), result in results.iteritems():
        plots[norm].append((result['fmean'], result['fsigma'], result['ftrue'], result['Nmean']/norm, result['Nsigma']/norm))

    plt.figure(1, figsize=(12,12))
    plt.figure(2, figsize=(12,12))
    
    sortedNorms = sorted(plots.keys())
    for i, norm in enumerate(sortedNorms, 1):
        fmean, fsigma, ftrue, ratiomean, ratiosigma = zip(*plots[norm])
        plt.figure(1)
        plt.subplot(2,2,i)
        plt.errorbar(ftrue, fmean, fsigma, fmt='ko')
        plt.plot([0,1], [0,1], 'y:')
        ax = plt.gca()
        ax.set_xlim([-0.1, 1.1])
        ax.set_ylim([-0.1, 1.1])
        ax.text(0.0, 1.0, '%i events'%norm, fontsize=20)
        
        plt.figure(2)
        plt.subplot(2,2,i)
        plt.errorbar(ftrue, ratiomean, ratiosigma, fmt='ko')
        plt.plot([0,1], [1,1], 'y:')
        ax = plt.gca()
        ax.set_xlim([-0.1, 1.1])
        ax.set_title('%i events'%norm, fontsize=20)
        
    plt.suptitle('Fit norm/actual norm')
    plt.show()

# Directly define test histograms
Run one or the other of these to set global variables that can be used below

## NN data

In [None]:
dataiso = [float(x) for x in '104.   73.  113.   81.  119.  147.  139.  120.  163.  220.  200.  200.\
  250.  246.  283.  296.  365.  368.  353.  333.  154.   85.   77.   68.\
   63.   71.   50.   66.   62.   51.   57.   36.   41.   38.   54.   54.\
   61.   44.   67.   62.   80.   65.   58.   48.   55.   41.   63.   62.\
   66.   68.   76.   54.   73.   80.   86.  137.  136.  186.  171.  182.\
  187.  152.  149.  169.  142.  138.  149.  147.  128.  126.  154.  146.\
  150.  161.  173.  171.  226.  328.  426.  328.  271.  170.   70.   41.\
   36.   20.    7.    5.    2.    1.    0.    2.    2.    0.    0.    3.\
    1.    1.    2.   25.'.split()]
datanoniso = [float(x) for x in '296.  155.  184.  242.  258.  262.  266.  297.  313.  380.  391.  423.\
  498.  579.  617.  708.  664.  719.  798.  681.  270.  179.  184.  177.\
  147.  133.  115.  116.  111.   93.   97.   79.   99.   85.   78.  102.\
   88.  118.  119.  132.   98.  102.   88.   97.   90.   91.   99.   87.\
   97.   99.  105.   88.  113.  109.  105.  166.  184.  221.  234.  184.\
  220.  165.  177.  166.  155.  145.  157.  139.  147.  150.  151.  158.\
  147.  157.  159.  155.  191.  303.  417.  306.  195.  148.   78.   42.\
   27.   20.    6.    9.    7.    1.    5.    6.    6.    4.    4.    4.\
    3.    7.    4.  129.'.split()]
signalmc = [float(x) for x in '176.   57.   51.   64.   80.   85.  103.   88.  119.  125.  164.  178.\
  186.  245.  275.  270.  292.  296.  378.  334.  133.  119.   91.  102.\
  103.   75.   67.   56.   77.   75.   49.   61.   72.   43.   74.   78.\
   70.   79.  111.   92.  124.  114.   98.  109.   85.   82.  105.   92.\
   92.   96.   95.  107.   92.  115.  148.  178.  205.  293.  297.  283.\
  211.  253.  245.  255.  239.  226.  233.  237.  226.  189.  246.  227.\
  227.  279.  278.  305.  335.  598.  846.  820.  737.  546.  317.  164.\
   76.   38.   32.    9.   16.   17.   16.   15.   21.   13.   16.   13.\
   16.   15.   21.  171.'.split()]
binEdges = [(x/100.0) for x in range(0,101)]

dataisoerr = np.sqrt(dataiso)
datanonisoerr = np.sqrt(datanoniso)
signalmcerr = np.sqrt(signalmc)

title = '12 < pT < 14 GeV'
xlabel = 'NN1'
puritymin = 0.75
puritymax = 0.85

## $\lambda_0$ data

In [None]:
dataiso = [float(x) for x in ' 2.00000000e+00   0.00000000e+00   0.00000000e+00   1.00000000e+00\
   0.00000000e+00   0.00000000e+00   1.00000000e+00   6.00000000e+00\
   5.00000000e+00   9.00000000e+00   1.30000000e+01   2.00000000e+01\
   2.50000000e+01   4.40000000e+01   6.30000000e+01   8.60000000e+01\
   1.29000000e+02   2.16000000e+02   2.23000000e+02   3.05000000e+02\
   3.39000000e+02   4.26000000e+02   5.62000000e+02   7.60000000e+02\
   1.06900000e+03   9.23000000e+02   3.66000000e+02   2.94000000e+02\
   2.50000000e+02   2.10000000e+02   1.89000000e+02   1.77000000e+02\
   1.76000000e+02   1.59000000e+02   1.57000000e+02   1.34000000e+02\
   1.25000000e+02   1.51000000e+02   1.23000000e+02   1.18000000e+02\
   1.27000000e+02   1.19000000e+02   1.08000000e+02   1.06000000e+02\
   9.70000000e+01   8.30000000e+01   9.10000000e+01   9.60000000e+01\
   7.40000000e+01   7.90000000e+01   9.60000000e+01   9.00000000e+01\
   7.50000000e+01   7.60000000e+01   7.20000000e+01   1.03000000e+02\
   9.60000000e+01   8.60000000e+01   4.80000000e+01   6.50000000e+01\
   7.80000000e+01   8.60000000e+01   6.10000000e+01   7.60000000e+01\
   7.20000000e+01   6.20000000e+01   8.10000000e+01   6.30000000e+01\
   6.20000000e+01   6.00000000e+01   5.40000000e+01   5.70000000e+01\
   4.50000000e+01   4.40000000e+01   3.90000000e+01   4.00000000e+01\
   5.40000000e+01   4.20000000e+01   3.00000000e+01   4.90000000e+01\
   4.40000000e+01   3.10000000e+01   3.20000000e+01   2.50000000e+01\
   3.80000000e+01   3.20000000e+01   2.80000000e+01   2.90000000e+01\
   2.80000000e+01   2.70000000e+01   1.60000000e+01   2.20000000e+01\
   2.20000000e+01   2.20000000e+01   1.60000000e+01   2.50000000e+01\
   2.00000000e+01   2.00000000e+01   1.20000000e+01   2.30000000e+01'.split()]
datanoniso = [float(x) for x in '23.     0.     0.     0.     2.     0.     3.     4.     4.     7.\
    16.    14.    25.    40.    59.    87.   124.   151.   191.   288.\
   366.   442.   568.   694.  1095.  1019.   494.   397.   356.   383.\
   336.   298.   273.   286.   304.   266.   260.   265.   280.   215.\
   214.   255.   235.   219.   205.   201.   198.   187.   196.   205.\
   241.   179.   208.   213.   191.   196.   184.   170.   188.   185.\
   180.   151.   157.   166.   131.   149.   129.   118.   132.   118.\
   115.   112.   105.   106.    96.   108.   102.    81.    85.    89.\
    74.    71.    61.    64.    58.    70.    52.    77.    62.    36.\
    61.    56.    37.    40.    63.    49.    38.    49.    49.    39.'.split()]
signalmc = [float(x) for x in '34.     2.     3.     3.     3.     5.     3.    12.    27.    27.\
    49.    58.   105.   157.   185.   238.   319.   390.   412.   545.\
   606.   780.   838.  1175.  1781.  1293.   444.   281.   268.   252.\
   218.   244.   252.   275.   187.   172.   120.   165.   145.   138.\
   130.   145.   118.   148.   134.   114.   122.   113.   108.   105.\
   129.    99.    94.    96.   111.    98.    83.    85.    74.    80.\
    80.    68.    66.    52.    59.    49.    58.    52.    66.    37.\
    45.    45.    30.    46.    34.    38.    29.    25.    40.    40.\
    24.    54.    44.    42.    35.    33.    34.    34.    26.    42.\
    40.    24.    31.    26.    27.    22.    27.    26.    13.    22.'.split()]
binEdges = [(x/100.0) for x in range(0,101)]

dataisoerr = np.sqrt(dataiso)
datanonisoerr = np.sqrt(datanoniso)
signalmcerr = np.sqrt(signalmc)

title = '12 < pT < 14 GeV'
xlabel = '$\lambda_0$'
puritymin = 0.0
puritymax = 0.25

# Run on test data

In [None]:
tf = TemplateFit(dataiso, dataisoerr, signalmc, signalmcerr, datanoniso, datanonisoerr, binEdges)
tf.plotFit(xlabel, title, True, puritymin, puritymax)
tf.plotTemplates(xlabel)
tf.plotResiduals(xlabel)

In [None]:
calculateFitUncertainty(dataiso, signalmc, datanoniso, binEdges, puritymin, puritymax)

In [None]:
plotFitUncertaintyExamples(dataiso, dataisoerr, signalmc, datanoniso, binEdges)

The two cells below should be run sequentially. They are split only because the first takes some time to run.

In [None]:
closureResults = checkClosureOverParameters(dataiso, signalmc, datanoniso, binEdges, 1000)

In [None]:
plotCheckClosureResults(*closureResults)