# ER / NR discrimination test

In [None]:
!wget -nc https://raw.githubusercontent.com/FlamTeam/flamedisx-notebooks/master/_if_on_colab_setup_flamedisx.ipynb
BRANCH="master"  # git branch to use (only for Colab)
%run _if_on_colab_setup_flamedisx.ipynb

In [None]:
import gzip
import pickle

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

from multihist import Hist1d, Histdd
import numpy as np
from scipy import stats
import tensorflow as tf
from tqdm import tqdm

import flamedisx as fd

have_gpu = len(tf.config.list_physical_devices('GPU')) > 0
print(f"Running tensorflow {tf.__version__}, GPU {'active' if have_gpu else 'NOT ACTIVE!'}")

## Initialization

In [None]:
# Construct toy spatial distribution of neutron background source

tau = 5  # cm  attenuation
nbins = 100

# Full dimensions of TPC (not FV!)
tpc_radius = 47.9  # cm
tpc_length = 97.6  # cm

# XENON1T 1.3t FV
fv_radius = 41.26  # cm
fv_high = -9  # cm
fv_low = -92.9  # cm

r = np.linspace(0, fv_radius, nbins + 1)
z = np.linspace(fv_low, fv_high, nbins + 1)
theta = np.linspace(0, 2 * np.pi, nbins + 1)

# Create test background histogram
spatial_h = Histdd(bins=[r, theta, z], axis_names=['r', 'theta', 'z'])

r_centers, theta_centers, z_centers = spatial_h.bin_centers()

# compute r,z function
rr, zz = np.meshgrid(r_centers, z_centers)

def rate(r, z):
    # Rate at position r,z in wrt full TPC
    r_edge = tpc_radius - r
    z_edge_top = -z
    z_edge_bot =  tpc_length + z

    # Add a little z dependent tau to make dist more round
    tau_edge = tau * (0.9 + ((2*z + tpc_length)/tpc_length)**2)

    # Add constant component
    c = 5e-3

    return c + np.exp(-r_edge/tau_edge) + np.exp(-z_edge_top/tau) + np.exp(-z_edge_bot/tau)

vv = rate(rr, zz)

# Set each theta slice
for idx in range(len(theta_centers)):
    spatial_h.histogram[:,idx,:] = vv.T

# Normalize
spatial_h.histogram /= spatial_h.n
# Compute bin volumes for cylindrical coords (r dr dtheta)
bin_volumes = spatial_h.bin_volumes() * r_centers[:, np.newaxis, np.newaxis]
spatial_h.histogram *= bin_volumes

sample = spatial_h.get_random(1_000_000)

test_h = Histdd(sample.T[0]**2, sample.T[2], bins=100, axis_names=['r2', 'z'])
test_h.plot(log_scale=True, cmap='jet', log_scale_vmin=1e-1)
plt.xlim(0, tpc_radius**2)
plt.ylim(-tpc_length, 0)
plt.plot()

In [None]:
# Load energy spectrum for neutron background, use up to 50 keV
import pandas as pd
radiogenics = pd.read_csv("radiogenic_spectrum.csv", names=['energy', 'rate'])
radiogenics = radiogenics[radiogenics['energy'] < 50]

In [None]:
elife=326e3

class LowMassWIMPSource(fd.x1t_sr0.SR0WIMPSource):
    fv_radius = fv_radius
    fv_high = fv_high
    fv_low = fv_low
    mw = 200  # GeV
    pretend_wimps_dont_modulate = True
    
    @staticmethod
    def electron_detection_eff(drift_time, *, elife=elife, extraction_eff=0.96):
        return extraction_eff * tf.exp(-drift_time / elife)

class LowMassWIMPModulationSource(LowMassWIMPSource):
    pretend_wimps_dont_modulate = False

class LowEnergyERSource(fd.x1t_sr0.SR0ERSource):
    fv_radius = fv_radius
    fv_high = fv_high
    fv_low = fv_low
    def _single_spectrum(self):
        """Return (energies in keV, rate at these energies),
        """
        return (tf.dtypes.cast(
                    tf.linspace(0., 10., 1000),  # 10 keV for 1 TeV WIMP
                    dtype=fd.float_type()),
                tf.ones(1000, dtype=fd.float_type()))
    
    @staticmethod
    def electron_detection_eff(drift_time, *, elife=elife, extraction_eff=0.96):
        return extraction_eff * tf.exp(-drift_time / elife)

#class NRSpectrum(fd.x1t_sr0.SR0WIMPSource):
#    # Make a time-averaged 200 GeV WIMP spectrum
#    mw = 10  # GeV
#    pretend_wimps_dont_modulate = True
#
#nr_spectrum_hist = NRSpectrum().energy_hist

class NRBackgroundSource(fd.x1t_sr0.SR0NRSource):
    # The NR background is not uniform in space,
    # use spatial_rate_hist to model this
    spatial_rate_hist = spatial_h
    spatial_rate_bin_volumes = bin_volumes
    # neutron energy spectrum
    def _single_spectrum(self):
        """Return (energies in keV, rate at these energies),
        """
        #return (tf.dtypes.cast(nr_spectrum_hist.bin_centers()[1],
        #                       dtype=fd.float_type()),
        #        tf.dtypes.cast(nr_spectrum_hist.histogram[0],
        #                       dtype=fd.float_type()))
        return (tf.dtypes.cast(radiogenics['energy'],
                               dtype=fd.float_type()),
                tf.dtypes.cast(radiogenics['rate'] * 1000 * 365.25, # ev/tonne/year
                               dtype=fd.float_type()))
    
    @staticmethod
    def electron_detection_eff(drift_time, *, elife=elife, extraction_eff=0.96):
        return extraction_eff * tf.exp(-drift_time / elife)


def add_corrected_signals(d):
    d['cs1'] = (0.142 / (1 + 0.219)) * d['s1'] / (
        d['photon_detection_eff'] * d['photon_gain_mean'])
    d['cs2'] = (11.4 / (1 - 0.63) / 0.96) * d['s2'] / (
        d['electron_detection_eff'] * d['electron_gain_mean'])

In [None]:
dsets = dict(
    er=dict(source_class=LowEnergyERSource),
    nr=dict(source_class=LowMassWIMPSource),
    nr_bkg=dict(source_class=NRBackgroundSource),
    nr_mod=dict(source_class=LowMassWIMPModulationSource))

for k, v in dsets.items():
    dsets[k]['source'] = v['source_class'](
        batch_size=300 if have_gpu else 10, 
        max_sigma=5)
    
nr_conditions = ['nr', 'nr_mod']

for nrc in nr_conditions:
    dsets[nrc]['source'].energy_hist.sum(axis=1).plot(label=nrc)
plt.legend()

## Compute rate histograms

In [None]:
def std_axes():
    plt.yscale('log')
    plt.xlabel("cS1 [PE]")
    plt.ylabel("cS2 [PE]")


# Restore histograms
#fn = 'hists_1e6_20200213_30GeV_326elife.pkl'
#fn = 'hists_1e6_20200218_15GeV_326elife.pkl'
fn = 'hists_1e6_20200217_200GeV_326elife.pkl'
with open(fn, mode='rb') as f:
    hists = pickle.load(f)
    for dname, q in dsets.items():
        if dname not in hists:
            continue
        q['mh'] = hists[dname]

remake_hists = False

for dname, q in dsets.items():

    if remake_hists or not 'mh' in q:
        print(f"Building histogram for {dname}")
        
        mh = Histdd(bins=(
            np.linspace(0, 80, 81 + 1),
            np.geomspace(10**1.7 / (1 - 0.63),
                         10**3.9 / (1 - 0.63), 
                         70)))
        
        n_batches = 100 if dname == 'er' else 40
        trials_per_batch = int(1e6)

        for _ in tqdm(range(n_batches)):
            d = q['source'].simulate(trials_per_batch)
            add_corrected_signals(d)
            mh.add(d['cs1'], d['cs2'])
        
        # Convert to PDF
        mh /= mh.bin_volumes() * trials_per_batch * n_batches

        # Multiply by total expected event rate
        # (from the source, i.e. before correcting for efficiencies)
        mh *= q['source'].mu_before_efficiencies()

        q['mh'] = mh

    q['events_per_bin'] = q['mh'] * q['mh'].bin_volumes()
    q['mh'].plot(cblabel='rate * PDF')
    plt.title(f"{dname}: {q['events_per_bin'].n:.02f} expected events")
    std_axes()

    plt.show()

In [None]:
# # If you save histograms, remember to download them!!
# with open('hists_1e6_20200205_15GeV.pkl', mode='wb') as f:
#     pickle.dump({
#         k: v['mh'] 
#         for k, v in dsets.items()}, f)

  * Make sure none of the models are 'cut off' in cS1 / cS2, since cS1 and cS2 cut acceptances are not currently accounted for in our likelihood (unlike S1 or S2 cuts). This is not a limitation of flamedisx: the correction value is known for each event since the correction depends only on observables, so ultimately a cS1 cut is just a space-dependent S1 cut (which flamedisx fully supports).
  * The ROC curves will depend on the extent of the ER spectrum. If you include more high-energy ER events that can be discriminated anyway, the ER leakage in any likelihood will go down. The key figure of merit we are trying to derive here, the decrease in ER leakage at ~50 % NR acceptance when switching to the full likelihood, should be unaffected by this.

## Histogram-based discrimination

In [None]:
for dname in nr_conditions:
    q = dsets[dname]
    q['signal_background_ratio'] = ratio = \
        (dsets['er']['mh'] + 1e-20)/(q['mh'] + 1e-20)
    q['histogram_ordering'] = np.argsort(ratio.histogram.ravel())

    ratio.plot(log_scale=True, cmap=plt.cm.seismic, 
           vmin=1e-4, vmax=1e4, cblabel='ER / NR density ratio')
    std_axes()
    plt.title(f"er vs {dname}")
    plt.show()

In [None]:
# For NR bkg
for dname in nr_conditions:
    q = dsets[dname]
    q['signal_background_ratio_nrbkg'] = ratio = \
        (dsets['nr_bkg']['mh'] + 1e-20)/(q['mh'] + 1e-20)
    q['histogram_ordering_nrbkg'] = np.argsort(ratio.histogram.ravel())

    ratio.plot(log_scale=True, cmap=plt.cm.seismic, 
           vmin=1e-4, vmax=1e4, cblabel='NRBKG / NR density ratio')
    std_axes()
    plt.title(f"nrbkg vs {dname}")
    plt.show()

In [None]:
# For NR WIMP vs Modulating WIMP
for dname in nr_conditions:
    q = dsets[dname]
    q['signal_background_ratio_mod'] = ratio = \
        (dsets['nr']['mh'] + 1e-20)/(q['mh'] + 1e-20)
    q['histogram_ordering_mod'] = np.argsort(ratio.histogram.ravel())

    ratio.plot(log_scale=True, cmap=plt.cm.seismic, 
           vmin=1e-4, vmax=1e4, cblabel='non-mod / mod WIMP density ratio')
    std_axes()
    plt.title(f"non-mod vs {dname}")
    plt.show()

In [None]:
plt.figure()
plt.gcf().patch.set_facecolor('white')

def hist_to_cdf(hist, ordering):
    return np.cumsum(hist.histogram.ravel()[ordering])/hist.n

for dname in nr_conditions:
    q = dsets[dname]
    q['roc_from_histogram'] = (
        hist_to_cdf(dsets['er']['events_per_bin'], q['histogram_ordering']),
        hist_to_cdf(q['events_per_bin'], q['histogram_ordering']))

    plt.plot(*q['roc_from_histogram'], label='er vs. ' + dname)

    q['roc_from_histogram_nrbkg'] = (
        hist_to_cdf(dsets['nr_bkg']['events_per_bin'], q['histogram_ordering_nrbkg']),
        hist_to_cdf(q['events_per_bin'], q['histogram_ordering_nrbkg']))

    plt.plot(*q['roc_from_histogram_nrbkg'], label='nr bkg vs. ' + dname)
    
    q['roc_from_histogram_mod'] = (
        hist_to_cdf(dsets['nr']['events_per_bin'], q['histogram_ordering_mod']),
        hist_to_cdf(q['events_per_bin'], q['histogram_ordering_mod']))

    plt.plot(*q['roc_from_histogram_mod'], label='non-mod vs. ' + dname)

plt.xscale('log')
plt.legend(loc='best')
plt.axhline(0.5, alpha=0.5, c='k', linewidth=0.5)

plt.xlim(1e-4, 1)
plt.xlabel("ER rejection")

plt.ylim(0, 1)
plt.ylabel("NR acceptance")
plt.show()

As expected, there is no difference between NR with and without modulation, since the 2D histogram does not see time.

## Flamedisx-based discrimination

In [None]:
# Load earlier results: 
#fn = 'discstudy_20200218_1e6_maxsigma5_15GeV_326elife.pkl.gz'
#fn = 'discstudy_20200213_1e6_maxsigma5_30GeV_326elife.pkl.gz'
fn = 'discstudy_20200217_1e6_maxsigma5_200GeV_326elife.pkl.gz'
with gzip.open(fn, mode='rb') as f:
    q = pickle.load(f)
    for k, v in q.items():
        dsets[k]['data'] = v

In [None]:
remake_data = False
n_trials_events = int(1e6) # if have_gpu else int(1e4)

for dname, q in dsets.items():
    # Get simulated data
    if remake_data or 'data' not in q:
        print(f"Simulating data for {dname}")
        q['data'] = sim_data = q['source'].simulate(n_trials_events)
        sim_data['event_time'] += int(5e9)  # Add 5 seconds
        add_corrected_signals(sim_data)
        
        # Ensure cs1 and cs2 are in range of the histogram
        # to avoid extrapolation in multihist's lookup.
        # NB: we are assuming both histograms have the same binning here!
        bes = q['mh'].bin_edges
        mask = (
            (bes[0][0] <= sim_data['cs1']) & (sim_data['cs1'] < bes[0][-1]) &
            (bes[1][0] <= sim_data['cs2']) & (sim_data['cs2'] < bes[1][-1]))
        print(f"{dname}: Throwing out {100 * (~mask).sum() / len(sim_data):.2f}% of events")
        q['data'] = sim_data[mask].copy()
    sim_data = q['data']
    
    # Compute differential rates ("likelihoods") for all models ("sources")
    for likelihood_dsetname, likelihood_dset in dsets.items():
        sim_data['l_mh_' + likelihood_dsetname] = likelihood_dset['mh'].lookup(
            sim_data['cs1'], sim_data['cs2'])

        key = 'l_full_' + likelihood_dsetname
        if key not in sim_data:
            print(f"Computing likelihood of {dname} data under {likelihood_dsetname} model")
            likelihood_dset['source'].set_data(sim_data.copy())
            sim_data[key] = likelihood_dset['source'].batched_differential_rate()
        
    # Compute ratios of differential rates ("likelihood ratios")
    # for both histogram- and flamedisx based method.
    for nrc in nr_conditions:
        sim_data[f'lr_mh_{nrc}'] = sim_data['l_mh_er'] / sim_data[f'l_mh_{nrc}']
        sim_data[f'lr_full_{nrc}'] = sim_data['l_full_er'] / sim_data[f'l_full_{nrc}']

        sim_data[f'lr_mh_nr_{nrc}'] = sim_data['l_mh_nr_bkg'] / sim_data[f'l_mh_{nrc}']
        sim_data[f'lr_full_nr_{nrc}'] = sim_data['l_full_nr_bkg'] / sim_data[f'l_full_{nrc}']

    sim_data[f'lr_mh_mod'] = sim_data['l_mh_nr'] / sim_data[f'l_mh_nr_mod']
    sim_data[f'lr_full_mod'] = sim_data['l_full_nr'] / sim_data[f'l_full_nr_mod']

In [None]:
# Save results to gzipped pickle. Compression takes a while; download longer.
# Don't forget to download it!!
#with gzip.open('discstudy_20200205_1e6_maxsigma5_15GeV.pkl.gz', mode='wb') as f:
#    pickle.dump({dname: q['data']
#                 for dname, q in dsets.items()},
#                f)

## Compare differential rates

Compare differential rates. There will be an offset because (cS1, cS2) and (S1, S2) have different ranges/means -- so the rates are differential with respect to different coordinates.

In [None]:
f, axes = plt.subplots(2, 2, figsize=(12, 10))

for dn_i, dname in enumerate(['er', 'nr']):
    for lh_i, lh_name in enumerate(['er', 'nr']):
        ax = axes[dn_i, lh_i]
        plt.sca(ax)
        
        q = dsets[dname]['data']
        if dname == 'er' and lh_name == 'er':
            q_s = q.copy()
        y, x = q[f'l_full_{lh_name}'], q[f'l_mh_{lh_name}']

        Histdd(x, y,
               bins=(np.geomspace(1e-7, 1e-1, 100),
                     np.geomspace(1e-7, 1e-1, 100))).plot(
            log_scale=True, cblabel='Events/bin',
            vmin=0.8, vmax=n_trials_events * 0.05)
        
        #xs = np.geomspace(1e-7, 1e-1, 100)
        #plt.plot(xs, xs * 1e2, 'r')
        #plt.plot(xs, xs * 1e1, 'r')
        #plt.plot(xs, xs * 2.8, 'r')

        plt.plot([1e-7, 1e-1], [1e-7, 1e-1], 'k-')
        plt.yscale('log')
        plt.xscale('log')
        plt.xlabel("Histogram")
        plt.ylabel("Flamedisx")
        plt.ylim(1e-7, 1e-1)
        plt.xlim(1e-7, 1e-1)
        plt.title(f"{dname.upper()} data, {lh_name.upper()} model")
        plt.gca().set_aspect(1)
plt.show()

In [None]:
f, axes = plt.subplots(2, 2, figsize=(12, 10))

min_ll = 1e-7
for dn_i, dname in enumerate(['nr_bkg', 'nr']):
    for lh_i, lh_name in enumerate(['nr_bkg', 'nr']):
        ax = axes[dn_i, lh_i]
        plt.sca(ax)
        
        q = dsets[dname]['data']
        y, x = q[f'l_full_{lh_name}'], q[f'l_mh_{lh_name}']

        Histdd(x, y,
               bins=(np.geomspace(min_ll, 1e-1, 100),
                     np.geomspace(min_ll, 1e-1, 100))).plot(
            log_scale=True, cblabel='Events/bin',
            vmin=0.8, vmax=n_trials_events * 0.05)

        plt.plot([min_ll, 1e-1], [min_ll, 1e-1], 'k-')
        plt.yscale('log')
        plt.xscale('log')
        plt.xlabel("Histogram")
        plt.ylabel("Flamedisx")
        plt.ylim(min_ll, 1e-1)
        plt.xlim(min_ll, 1e-1)
        plt.title(f"{dname.upper()} data, {lh_name.upper()} model")
        plt.gca().set_aspect(1)
plt.show()

In [None]:
# # Zoom-in on the low-energy NR data
# d = dsets['nr']['data']
# dsets['er']['mh'].plot(log_scale=True, vmin=1e-6, vmax=1e-1, cmap=plt.cm.Blues, cblabel="Diffrate hist")
# plt.scatter(d['cs1'], d['cs2'], c=d['l_full_er'], s=0.1, 
#             vmax=1e-1, vmin=1e-6, norm=matplotlib.colors.LogNorm(), cmap=plt.cm.Reds)
# plt.colorbar(label='Diffrate Flamedisx')
# plt.xlim(0, 10)
# plt.xlabel("cS1 [PE]")
# plt.ylim(0, 3e3)
# plt.ylabel("cS2 [PE]")

In [None]:
# Check spatial distribution of sources
for dname in ['er', 'nr', 'nr_bkg', 'nr_mod']:
    q = dsets[dname]['data']
    print(len(q))
    Histdd(q['r']**2, q['z'], bins=100, axis_names=['r2', 'z']).plot(log_scale=True)
    plt.title(dname)
    plt.xlim(0, tpc_radius**2)
    plt.ylim(-tpc_length, 0)
    plt.show()

## Compare event-by-event discrimination

In [None]:
# d_er = dsets['er']['data']
# d_nr = dsets['nr']['data']

# for d, alt, cmap in [#(d_er, d_nr, plt.cm.viridis),
#                      (d_nr, d_er, plt.cm.magma)
#                     ]:
#     # For each event in d, find what fraction of the alt data is more NR-like than it
#     # (under both likelihoods)
#     f_above = {
#         lt: np.searchsorted(np.sort(alt[f'lr_{lt}_nr'].values), 
#                             d[f'lr_{lt}_nr'].values).astype(np.float) / len(alt)
#         for lt in ('mh', 'full')}

#     # Get ratio. 
#     #   0 = mh sees the event as more NR-like than any of the alt data
#     #   > 1: mh is worse at discriminating, < 1 mh is better at discriminating
#     ratio = f_above['mh'] / f_above['full']
#     mask = np.isfinite(ratio)
    
#     xkey, ykey = 's1', 'z'
    
#     plt.scatter(d[xkey][mask], d[ykey][mask], c=ratio[mask], #cmap=cmap,
#                 vmin=0, vmax=2, cmap=plt.cm.seismic,
#                 s=0.2,)
#     plt.colorbar()
#     plt.scatter(d[xkey][~mask], d[ykey][~mask], c='g',
#                 s=0.2, alpha=0.2)
    
#     plt.xlabel(xkey)
#     #plt.ylim(0, 700)
#     plt.ylabel(ykey)
#     plt.show()

For these NR events, red events have a more NR-like ER/NR likelihood ratio in flamedisx, and blue ones in the histogram likelihood. For green events the ratio of the two likelihood ratios is not a finite number.

Log likelihood ratio histograms for ER and NR data under both likelihoods below. Note many events are at the edges for the histogram likelihood; for these either the NR or ER histogram was zero, and we clip the likelihood ratio.

In [None]:
plt.figure()
plt.gcf().patch.set_facecolor('white')
for lt, nrc, color in [['mh', 'nr', 'b'], ['full', 'nr', 'g'], ['full', 'nr_mod', 'purple']]:
    hists = dict()
    cis = dict()
    for dname, q in dsets.items():
        if dname not in ('er', nrc):
            continue

        d = q['data'][f'lr_{lt}_{nrc}']
        print(f"Found {np.sum(np.isnan(d))} NaNs for {lt}:{dname}")
        clip_exp = 21
        hists[dname] = Hist1d(
            np.log10(d.clip(10**-clip_exp, 10**clip_exp).values.astype('float')),
            bins=np.linspace(-clip_exp, clip_exp + 0.1, 140))
        hists[dname].plot(label=f"{dname}{lt}: {hists[dname].n}")
plt.yscale('log')
plt.legend()

#plt.axvline(-12, color='k', linestyle='--')
#plt.axvline(-11, color='k', linestyle='--')

plt.show()

## ROC Curves

In [None]:
def binom_interval(success, total, conf_level=0.95):
    """Confidence interval on binomial - using Jeffreys interval
    Code stolen from https://gist.github.com/paulgb/6627336
    Agrees with http://statpages.info/confint.html for binom_interval(1, 10)
    """
    # TODO: special case for success = 0 or = total? see wikipedia
    quantile = (1 - conf_level) / 2.
    lower = stats.beta.ppf(quantile, success, total - success + 1)
    upper = stats.beta.ppf(1 - quantile, success + 1, total - success)
    
    # If something went wrong with a limit calculation, report the trivial limit
    lower[np.isnan(lower)] = 0
    upper[np.isnan(upper)] = 1
    return lower, upper

def make_roc(key, dnames, nbins=10_000):
    hists = dict()
    cdfs = dict()
    cdf_intervals = dict()
    for dname in dnames:
        q = dsets[dname]
        hists[dname] = Hist1d(
            np.log10(q['data'][key].clip(1e-20, 1e20)),
            bins=np.linspace(-21, 21, nbins))

        cdfs[dname] = hists[dname].cumulative_density
        cdf_intervals[dname] = np.stack(binom_interval(
            np.cumsum(hists[dname].histogram),
            hists[dname].n, 
            conf_level=.68))

    return dict(hists=hists, cdfs=cdfs, cdf_intervals=cdf_intervals)

# 'regular (wrt ER)': lr_mh_nr, lr_full_nr, lr_full_nr_mod
# dnames 'er' 'nr'
rocs = dict()
for lt, nrc in [['mh', 'nr'], ['full', 'nr'], ['full', 'nr_mod']]:
    key = f'lr_{lt}_{nrc}'
    rocs[(lt, nrc)] = make_roc(key, dnames=['er', 'nr'])

# 'wrt nr bkg': lr_mh_nr, lr_full_nr, lr_full_nr_mod
# dnames 'nr_bkg' 'nr'
rocs_nrbkg = dict()
for lt, nrc in [['mh', 'nr'], ['full', 'nr'], ['full', 'nr_mod']]:
    key = f'lr_{lt}_nr_{nrc}'
    rocs_nrbkg[(lt, nrc)] = make_roc(key, dnames=['nr_bkg', 'nr'])

# 'wrt fd nr': lr_full_nr, lr_full_nr_mod
# dnames 'nr' 'nr_mod'
rocs_mod = dict()
for lt, nrc in [['mh', 'mod'], ['full', 'mod']]:
    key = f'lr_{lt}_{nrc}'
    rocs_mod[(lt, nrc)] = make_roc(key, dnames=['nr', 'nr_mod'], nbins=100_000)

roc_labels = {('mh', 'nr'): "2d Hist (sampled)",
              ('full', 'nr'): "Flamedisx",
              ('full', 'nr_mod'): "Fd modulation",
              ('mh', 'mod'): "2d Hist (sampled)",
              ('full', 'mod'): "Fd modulation",
             }
roc_colors = {('mh', 'nr'): 'b',
              ('full', 'nr'): 'g',
              ('full', 'nr_mod'): 'darkorange',
              ('mh', 'mod'): 'b',
              ('full', 'mod'): 'darkorange',
             }

In [None]:
# for rn, roc in rocs.items():
#     x, y = roc['cdfs']['er'], roc['cdfs']['nr']
#     plt.errorbar(x, y,
#                  xerr=np.abs(roc['cdf_intervals']['er'] - x),
#                  yerr=np.abs(roc['cdf_intervals']['nr'] - y),
#                  linestyle='',
#                  color=roc_colors[rn],
#                  label=roc_labels[rn])
# plt.plot(*dsets['nr']['roc_from_histogram'], 
#          label='2D Hist ($N = \infty$)',
#          color='k', linestyle='--')
#      
# plt.xlabel("ER (flat 0-10 keV) background")
# plt.ylabel(f"{LowMassWIMPSource.mw} GeV/c^2 WIMP acceptance")
# plt.xscale('log')
# plt.xlim(1e-5, 1e-1)
# plt.ylim(0., 1)
# plt.legend(loc='best')
# 
# plt.tight_layout()
# #plt.savefig('fd_roc_15GeV.png', dpi=200)
# plt.show()

In [None]:
# for rn, roc in rocs_nrbkg.items():
#     x, y = roc['cdfs']['nr_bkg'], roc['cdfs']['nr']
#     plt.errorbar(x, y,
#                  xerr=np.abs(roc['cdf_intervals']['nr_bkg'] - x),
#                  yerr=np.abs(roc['cdf_intervals']['nr'] - y),
#                  linestyle='',
#                  color=roc_colors[rn],
#                  label=roc_labels[rn])
# plt.plot(*dsets['nr']['roc_from_histogram_nrbkg'], 
#          label='2D Hist ($N = \infty$)',
#          color='k', linestyle='--')
#      
# plt.xlabel("NR radiogenics background")
# plt.ylabel(f"{LowMassWIMPSource.mw} GeV/c^2 WIMP acceptance")
# plt.xscale('log')
# plt.xlim(1e-5, 1)
# plt.ylim(0., 1)
# plt.legend(loc='best')
# 
# plt.tight_layout()
# #plt.savefig('fd_roc_nr_bkg_15GeV.png', dpi=200)
# plt.show()

In [None]:
# for rn, roc in rocs_mod.items():
#     x, y = roc['cdfs']['nr'], roc['cdfs']['nr_mod']
#     plt.errorbar(x, y,
#                  xerr=np.abs(roc['cdf_intervals']['nr'] - x),
#                  yerr=np.abs(roc['cdf_intervals']['nr_mod'] - y),
#                  linestyle='',
#                  #marker='.',
#                  color=roc_colors[rn],
#                  label=roc_labels[rn]
#                 )
# plt.plot(*dsets['nr']['roc_from_histogram_mod'], 
#          label='2D Hist ($N = \infty$)',
#          color='k', linestyle='--')
# #plt.plot([0,1],[0,1], c='k', linewidth=1)
# plt.xlabel("non mod wimp")
# plt.ylabel(f"{LowMassWIMPSource.mw} GeV/c^2 WIMP acceptance")
# plt.xscale('log')
# plt.xlim(1e-5, 1)
# plt.ylim(0., 1)
# plt.legend(loc='best')
# 
# plt.tight_layout()
# #plt.savefig('fd_roc_nr_bkg_15GeV.png', dpi=200)
# plt.show()

In [None]:
# Combined ROC curve plot
plt.figure()

diag = np.geomspace(1e-5, 1, 100)
plt.plot(diag, diag, c='k', alpha=0.5)

plt.plot(*dsets['nr']['roc_from_histogram_mod'], 
         #label='2D Hist ($N = \infty$)',
         color='C0', linestyle='--')

roc = rocs_mod[('full', 'mod')]
x, y = roc['cdfs']['nr'], roc['cdfs']['nr_mod']
plt.errorbar(x, y,
             xerr=np.abs(roc['cdf_intervals']['nr'] - x),
             yerr=np.abs(roc['cdf_intervals']['nr_mod'] - y),
             #linestyle='',
             #marker='.',
             color='C0', #roc_colors[rn],
             label='WIMP-like background',
            )

plt.plot(*dsets['nr']['roc_from_histogram_nrbkg'], 
         #label='2D Hist ($N = \infty$)',
         color='C1', linestyle='--')
roc = rocs_nrbkg[('full', 'nr_mod')]
x, y = roc['cdfs']['nr_bkg'], roc['cdfs']['nr']
plt.errorbar(x, y,
             xerr=np.abs(roc['cdf_intervals']['nr_bkg'] - x),
             yerr=np.abs(roc['cdf_intervals']['nr'] - y),
             #linestyle='',
             color='C1',
             label='neutron background',
            )

plt.plot(*dsets['nr']['roc_from_histogram'], 
         #label='2D Hist ($N = \infty$)',
         color='C2', linestyle='--')

roc = rocs[('full', 'nr_mod')]
x, y = roc['cdfs']['er'], roc['cdfs']['nr']
plt.errorbar(x, y,
             xerr=np.abs(roc['cdf_intervals']['er'] - x),
             yerr=np.abs(roc['cdf_intervals']['nr'] - y),
             #linestyle='',
             color='C2',
             label='ER background',
            )

plt.xlabel("Background acceptance")
plt.ylabel(f"{LowMassWIMPSource.mw} GeV/c$^2$ WIMP acceptance")
#plt.xscale('log')
plt.xlim(1e-5, 1)
plt.ylim(0., 1)
plt.legend(loc='lower right', frameon=False)
plt.gca().set_aspect('equal')

#plt.tight_layout()
#plt.savefig('fd_rocs_15GeV_326elife.png', dpi=200, bbox_to_inches='tight')
plt.show()

In [None]:
def bg_reduction_factor(nr_acceptance, er_leakage):
    orig_roc = dsets['nr']['roc_from_histogram']
    leakage_hist = np.interp(nr_acceptance, orig_roc[1], orig_roc[0])
    return leakage_hist / er_leakage

for rocname, roc in rocs.items():
    acc, leak = roc['cdfs']['nr'], roc['cdfs']['er']   
    red = bg_reduction_factor(acc, leak)
    red_bounds = np.stack([
        bg_reduction_factor(acc, q)
        for q in roc['cdf_intervals']['er']])

    plt.plot(acc, 
             red, 
             label=roc_labels[rocname], 
             color=roc_colors[rocname])
    plt.fill_between(acc,
                     red_bounds.min(axis=0), 
                     red_bounds.max(axis=0), 
                     color=roc_colors[rocname],
                     alpha=0.3, linewidth=0, step='mid')  
    
    # If you're really paranoid, you might want to plot the xerror too:
    # plt.errorbar(acc, 
    #              red, 
    #              xerr=np.abs(roc['cdf_intervals']['nr'] - acc),
    #              yerr=[red - red_bounds.min(axis=0),
    #                    red_bounds.max(axis=0) - red],
    #              label=roclabels[rocname])

plt.legend(loc='best')
plt.grid(alpha=0.2, c='k', linestyle='-')

plt.xlabel("NR acceptance")
plt.ylabel("Background reduction factor")

plt.ylim(0, 4)
plt.xlim(0, 1)

#plt.ylim(0.8, 1.2)
plt.tight_layout()
#plt.savefig('fd_bg_reduction_15GeV_326elife.png', dpi=200, bbox_inches='tight')

## Compare sensitivity

Create likelihoods for:
  1. ER vs NR histogram-based
  2. ER vs NR flamedisx
  3. ER vs NR-mod histogram-based -- could omit, should be same as (1)
  4. ER vs NR-mod flamedisx

Sources are represented by `ColumnSource`s, with a simulate method that draws from `dsets[dname]['data']`.


In [None]:
class FastSource(fd.ColumnSource):

    def _differential_rate(self, data_tensor, ptensor):
        return tf.clip_by_value(self._fetch(self.column, data_tensor) * self.scale_by,
                                1e-20, 1e20)

    def random_truth(self, n_events, fix_truth=None, **params):
        if fix_truth is not None or len(params):
            raise NotImplementedError
        return dsets[self.dname]['data'].sample(n_events, replace=True)

mu_of_source = dict(er=1, nr=1, nr_mod=1, nr_bkg=1)
likelihood_types = ('mh', 'full')

fast_sources = {
    ltype: {
        dname: type(
            f'FS_{dname}_{ltype}', 
            (FastSource,), 
            dict(column=f'l_{ltype}_{dname}',
                 dname=dname,
                 # Adjust mu to the desired value, scaling
                 # the differential rate accordingly
                 mu=mu_of_source[dname],
                 scale_by=mu_of_source[dname]/dsets[dname]['events_per_bin'].n,
                 ))
        for dname in dsets.keys()}
    for ltype in likelihood_types}

likelihoods = {
    ltype: {
        nrc: fd.LogLikelihood(
            sources=dict(er=fast_sources[ltype]['er'],
                         nr=fast_sources[ltype][nrc],
                         nr_bkg=fast_sources[ltype]['nr_bkg'],
                         ),
            free_rates=('er', 'nr', 'nr_bkg'),
            # max_sigma and n_trials are irrelevant for ColumnSource
            batch_size=10_000)
        for nrc in nr_conditions}
    for ltype in likelihood_types}


In [None]:
# # Check simulated events are reasonable
# ll = likelihoods['full']['nr']
# for sname, color in [['er', 'b'], ['nr', 'r']]:
#     d = ll.sources[sname].simulate(1000)
#     plt.scatter(d['cs1'], d['cs2'], color=color, s=1)
# likelihoods['full']['nr']
# d = ll.simulate()
# plt.scatter(d['cs1'], d['cs2'], s=1, c='g')
# from collections import Counter
# Counter(d['source'].values)

Time for 400 trials, 1e4 ER events /trials, minuit, default tolerance:
  * master 23 January, colab GPU: 41 minutes
  * trace_likelihood 25 January, Jelle's laptop: 5 minutes

In [None]:
fd.inference.LOWER_RATE_MULTIPLIER_BOUND = 1e-3

In [None]:
from tqdm.notebook import tqdm

In [None]:
#nrs = np.linspace(1e-9, 24, 40)
#ers = np.linspace(6100, 6500, 40)
#xx, yy = np.meshgrid(nrs, ers)

n_trials = 10000
allow_failure = False

rm_scale = 15.  # Exposure scale [tonne year]
er_bkg_reduction = 8.2  # ER background reduction from Xe1T to XenT
nr_bkg_reduction = 40  # NR background reduction from Xe1T to XenT
discovery=False

xtol=1e-4
gtol=1e-3

n_bestfit_failures = 0
n_cond_bestfit_failures = 0
n_guess_limit_failures = 0
n_limit_failures = 0

import pandas as pd

sensi_results = {
    ltype: {
        nrc: pd.DataFrame(
            dict(t=np.zeros(n_trials),
                 bestfit_nr=np.zeros(n_trials),
                 bestfit_er=np.zeros(n_trials),
                 condfit_er=np.zeros(n_trials),
                 limit=np.zeros(n_trials))).to_records()
        for nrc in nr_conditions}
    for ltype in likelihood_types}

#(er=627, nr=1, nr_mod=1, nr_bkg=1.4)
truth = dict(nr_rate_multiplier=0.,
             er_rate_multiplier=627. * rm_scale / er_bkg_reduction,
             nr_bkg_rate_multiplier=1.4 * rm_scale / nr_bkg_reduction)

# Expected total number of wimps at reference xsec of 1e-45
nwimps = 627 * rm_scale * dsets['nr_mod']['events_per_bin'].n / dsets['er']['events_per_bin'].n
if discovery:
    nwimps *= 2e-47 / 1e-45  # desired xsec / wimprates xsec
    truth['nr_rate_multiplier'] = nwimps

import warnings

for trial_i in tqdm(range(n_trials)):
    # The nr model choice changes what events we need to pick from,
    # so this has to be the outer loop.
    for nrc in nr_conditions: 

        # Draw simulated events.
        # The histogram and flamedisx likelihood draw from the same
        # event reservoir (they just look at different columns later)
        d = likelihoods['mh'][nrc].simulate(**truth)

        for ltype in likelihood_types:
            #if discovery:  # Only use mh nr and full nrmod, maybe also for sensi.. yeah
            if f'{ltype}_{nrc}' not in ['mh_nr', 'full_nr_mod']:
                continue
            
            q = sensi_results[ltype][nrc]

            ll = likelihoods[ltype][nrc]
            ll.set_data(d)

            guess = dict(nr_rate_multiplier=3, 
                         er_rate_multiplier=len(d))
            
            try:
                bf = ll.bestfit(
                    guess=guess,
                    fix=dict(nr_bkg_rate_multiplier=truth['nr_bkg_rate_multiplier']),
                    optimizer_kwargs=dict(options=dict(gtol=gtol, xtol=xtol)),
                    allow_failure=allow_failure,
                )
            except:
                n_bestfit_failures += 1
                print("bestfit failure")
                continue
            q['bestfit_nr'][trial_i] = bf['nr_rate_multiplier']
            q['bestfit_er'][trial_i] = bf['er_rate_multiplier']
            
            #continue
            
            # lls = [-2*ll(nr_rate_multiplier=x,
            #              er_rate_multiplier=y,
            #              nr_bkg_rate_multiplier=truth['nr_bkg_rate_multiplier'])
            #        for (x, y) in np.array([xx.ravel(), yy.ravel()]).T]
            # ll_min = min(lls)
            # lls = np.reshape(lls, (40, 40))
            # 
            # grid_min = np.unravel_index(np.argmin(lls), lls.shape)
            # 
            # plt.figure(figsize=(10, 8))
            # plt.pcolormesh(xx, yy, (lls - ll_min), vmin=0, vmax=16)
            # plt.colorbar()
            # plt.plot([3, 0], [truth['er_rate_multiplier'],
            #                   truth['er_rate_multiplier']], 'o', color='orange')
            # plt.plot(nrs[grid_min[1]], ers[grid_min[0]], 'go')
            # plt.plot(bf['nr_rate_multiplier'], bf['er_rate_multiplier'], 'ro')
            # plt.title(f'{ltype} {nrc}')
            # plt.show()
            # 
            # continue

            # Compute best fit nuisance params conditional on truth 
            # and test statistic value (which depends on this)
            try:
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    cf = ll.bestfit(
                        guess=dict(er_rate_multiplier=bf['er_rate_multiplier']),
                        fix=dict(nr_rate_multiplier=0 if discovery else truth['nr_rate_multiplier'],
                                 nr_bkg_rate_multiplier=truth['nr_bkg_rate_multiplier']),
                        optimizer_kwargs=dict(options=dict(gtol=gtol, xtol=xtol)),
                        allow_failure=allow_failure)
            except:
                n_cond_bestfit_failures += 1
                print("conditional bestfit failure")
            q['condfit_er'][trial_i] = cf['er_rate_multiplier']
            q['t'][trial_i] = 2 * (ll(**bf) - ll(**cf))
            
            if discovery:
                continue
            
            try:
                guess_limit = ll.limit(
                    'nr_rate_multiplier', 
                    bestfit=bf,
                    fix={k: v for k, v in bf.items()
                         if k != 'nr_rate_multiplier'},
                    optimizer_kwargs=dict(options=dict(gtol=gtol, xtol=xtol)),
                    allow_failure=allow_failure)
            except:
                n_guess_limit_failures += 1
                print("guess limit failure")
                continue
            
            try:
                q['limit'][trial_i] = ll.limit(
                    'nr_rate_multiplier', 
                    bestfit=bf,
                    guess={**bf, **dict(nr_rate_multiplier=guess_limit)},
                    fix=dict(nr_bkg_rate_multiplier=truth['nr_bkg_rate_multiplier']),
                    optimizer_kwargs=dict(options=dict(gtol=gtol, xtol=xtol)),
                    allow_failure=False)
            except:
                n_limit_failures += 1
                print("limit failure")
                
            continue
            
            lls = [-2*ll(nr_rate_multiplier=x,
                         er_rate_multiplier=y,
                         nr_bkg_rate_multiplier=truth['nr_bkg_rate_multiplier'])
                   for (x, y) in np.array([xx.ravel(), yy.ravel()]).T]
            m2ll_best = -2*ll(**bf)
            wilks_crit = stats.norm.ppf(0.9) ** 2
            
            fun = (np.array(lls) - (m2ll_best + wilks_crit))**2
            fun_min = min(fun)
            fun = np.reshape(fun, (40, 40))
            
            fun -= xx * 0.01
            
            grid_min = np.unravel_index(np.argmin(fun), fun.shape)
            
            plt.figure(figsize=(10, 8))
            plt.pcolormesh(xx, yy, (fun - fun_min), vmin=0, vmax=9)
            plt.colorbar()
            plt.plot([3, 0], [truth['er_rate_multiplier'],
                              truth['er_rate_multiplier']], 'o', color='orange', label='truth and guess')
            plt.plot(nrs[grid_min[1]], ers[grid_min[0]], 'go', label='grid minimum')
            plt.plot(bf['nr_rate_multiplier'], bf['er_rate_multiplier'], 'ro', label='bestfit')
            plt.axvline(q['limit'][trial_i], color='r', linestyle='--', label='limit')
            plt.legend()
            plt.title(f'{ltype} {nrc}')
            plt.show()

print("n bestfit failures", n_bestfit_failures)
print("n conditional bestfit failures", n_cond_bestfit_failures)
print("n guess limit failures: ", n_guess_limit_failures)
print("n limit failures: ", n_limit_failures)

In [None]:
#fn = 'disc_results_20200219_10k_15ty_200GeV_2e-47.pkl'
#fn = 'sensi_results_20200219_10k_15ty_200GeV.pkl'

In [None]:
with open(fn, mode='wb') as f:
    pickle.dump(sensi_results, f)

In [None]:
with open(fn, mode='rb') as f:
    sensi_results = pickle.load(f)

In [None]:
# list of sensi files for different exposures
exposures = [5, 10, 15, 20]  #, 25]  # tonne year
disc_sigmas = {k: dict() for k in exposures}
for exp in exposures:
    #name = f'disc_results_20200218_1k_{exp}ty_200GeV_2e-47.pkl'  # [5, 10, 15, 20]
    name = f'disc_results_20200219_10k_{exp}ty_200GeV_2e-47.pkl'  # [5, 10, 15, 20]
    with open(name, mode='rb') as f:
        disc_res = pickle.load(f)

    for (ltype, nrc) in [('mh', 'nr'), ('full', 'nr_mod')]:            
        q_t = disc_res[ltype][nrc]['t']
        q_t = q_t[np.isfinite(q_t) & (q_t >= 0)]
        sigmas = q_t**0.5
        
        percentiles = np.percentile(sigmas, stats.norm.cdf([-2, -1, 0, 1, 2])*100)
        disc_sigmas[exp][ltype] = percentiles

In [None]:
delta = 0.25
label_set = False
plt.figure()
for i in [3, 4, 5]:
    plt.axhline(i, color='gray', linestyle='--')
for exp in exposures:
    for (ltype, offset) in [('mh', -1.5*delta), ('full', 1.5*delta)]:
        p = disc_sigmas[exp][ltype]
        c, l = ('C0', '2D Hist (sampled)') if ltype == 'mh' else ('orange', 'Flamedisx')
        plt.fill_between([exp-delta+offset, exp+delta+offset], 2*[p[4]], 2*[p[0]],
                         color=c, alpha=0.4, lw=0)
        plt.fill_between([exp-delta+offset, exp+delta+offset], 2*[p[3]], 2*[p[1]],
                         color=c, alpha=1, lw=0, label=l if not label_set else None)
        plt.plot([exp-delta+offset, exp+delta+offset], 2*[p[2]], color='k')
    label_set = True

def s_to_p(s):
    return stats.chi2(1).sf(s**2)/2    
    
def p_to_s(p):
    return -stats.norm.ppf(p)
    
plt.xlim(0, max(exposures)+5)
plt.ylim(0, 9)
plt.xlabel('Exposure [tonne year]')
plt.ylabel('Discovery significance [sigma]')
plt.legend(frameon=False, loc='upper left')

#secax = plt.gca().secondary_yaxis('right', functions=(s_to_p, p_to_s))
#print(secax.get_yticks())
secax = plt.gca().twinx()
secax.set_ylabel('Discovery p-value')
secax.set_ylim(0, 9)

secax.set_yticklabels(['%.2g' % s_to_p(i) for i in range(10)])

#plt.savefig('discovery_exposure_sig_10k.png', dpi=200, bbox_inches='tight')
plt.show()

In [None]:
f, axes = plt.subplots(3, 1, figsize=(7, 10))

def hist_plot(data, bins, color, label):
    data = data[np.isfinite(data)]
    mean = np.mean(data)
    # Error calculation is very dodgy, since dist
    # is far from Gaussian.
    err = np.std(data)/len(data)**0.5

    Hist1d(data, bins=bins).plot(
        color=color, 
        set_xlim=True,
        errors=True,
        error_style='band',
        label=f'{label}: {mean:.2f} $\pm$ {err:.2f}')
    plt.axvline(mean,
                color=color,
                linestyle='--')
    plt.legend(loc='best', frameon=False)

def norm_plot(data, bins, color='k', label=r'$\mathcal{N}(\mu,\sigma)$', alpha=0.6):
    bin_centers = (bins[:-1] + bins[1:])/2
    plt.plot(bin_centers,
             n_trials*np.diff(stats.norm.cdf(bins,
                                             loc=np.nanmean(data),
                                             scale=np.nanstd(data))),
             color=color, alpha=alpha, label=label)

n_bins = 50
pvals = False
x_max = 70 if discovery else 5
x_min = 0 if discovery else -0.2

done=False

for ltype in likelihood_types:
    for nrc in nr_conditions: 
        if nrc == 'nr' and ltype == 'full': continue
        if (ltype, nrc) not in rocs:
            continue

        color = roc_colors[(ltype, nrc)]
        label = roc_labels[(ltype, nrc)]
        if label.startswith('Fd'):
            label = 'Flamedisx'
                
        q = sensi_results[ltype][nrc]

        ####################################################################################
        plt.sca(axes[0])
        bins = np.linspace(0, x_max, n_bins)
        hist_plot(q['bestfit_nr'],
                  bins=bins,
                  label=label, color=color)
        plt.xlabel("Best fit [DM events]")
        if discovery:
            plt.ylim(0, n_trials/11)
            norm_plot(q['bestfit_nr'], bins)
        else:
            plt.yscale('log')
            plt.ylim(0.7, n_trials)

        #####################################################################################
        plt.sca(axes[1])
        bins = np.linspace(x_min, x_max, n_bins)
        hist_plot(q['t'],
                  bins=bins,
                  label=label, color=color)
        plt.xlabel(r'$-2\ln\lambda(0)$') if discovery else plt.xlabel(r'$-2\ln\lambda(\mu)$')

        # plot asymptotic dist
        if discovery:
            plt.ylim(0, n_trials/14)
            bin_centers = (bins[:-1] + bins[1:])/2
     
            p_vals = stats.chi2(1).sf(q['t'])/2
            #sigmas = -stats.norm.ppf(p_vals)
            # Alternatively Z = sqrt(q0)
            sigmas = q['t']**0.5
            
            # plot noncentral chi2
            # BOTH of these should work
            #non_centrality = (0 - np.mean(q['bestfit_nr']))**2 / np.std(q['bestfit_nr'])**2
            non_centrality = np.median(q['t'])  # from Asimov likelihood

            plt.plot(bin_centers,
                     n_trials*np.diff(stats.ncx2.cdf(x=bins, df=1, nc=non_centrality)),
                    'k', alpha=0.6, label=r'$\chi^{2}(k=1, \lambda=$' + 'median' + r'$[q_{0}])$')

            # from Walk approx (checked this is equal to stats.ncx2 with non centr mu and sigma)
            #def f(q0, mu, sigma):
            #    # + (1 - Phi(mu/sigma))delta(q0)
            #    return np.exp(-1*(q0**0.5 - (mu/sigma))**2/2) / (2 * (2*np.pi)**0.5 * q0**0.5)

            #plt.plot(bin_centers,
            #         n_trials * (50.2/n_bins) * f(bin_centers, np.mean(q['bestfit_nr']), np.std(q['bestfit_nr'])),
            #         'r')
        else:
            plt.yscale('log')
            plt.ylim(0.7, n_trials)

            bins = np.linspace(x_min, x_max, n_bins)
            bin_centers = (bins[:-1] + bins[1:])/2
            #plt.plot(bin_centers, n_trials*0.5*np.diff(stats.chi2(1).cdf(bins)), 'k',
            #         label=r'$\frac{1}{2}\chi^{2}(k=1)$')
            # high res chi2
            test_ax = np.linspace(x_min, x_max, 200)
            plt.plot(test_ax, n_trials*0.5*stats.chi2(1).pdf(test_ax)*(bins[1]-bins[0]), 'k',
                     label=r'$\frac{1}{2}\chi^{2}(k=1)$')

        ##########################################################################################
        plt.sca(axes[2])
        if discovery:
            if pvals:
                bins = np.geomspace(1e-7, 1, n_bins)
                hist_plot(p_vals,
                          bins=bins,
                          label=label, color=color)
                plt.xscale('log')
                plt.xlabel('Discovery p-value')
            else:
                bins = np.linspace(0, 8, n_bins)
                hist_plot(sigmas,
                          bins=bins,
                          label=label, color=color)
                plt.xlabel('Discovery sigma')
                
                norm_plot(sigmas, bins)            
            plt.ylim(0, None)
        else:
            bins = np.geomspace(0.5, 100, n_bins)
            hist_plot(q['limit'], 
                      bins=bins,
                      label=label, color=color)
            plt.xlabel("Limit [DM events]")
            plt.xscale('log')
            plt.ylim(0, max(10, n_trials / 11))
            plt.xlim(0.5, 100)
            
            if not done:
                done=True
                
                def n_to_x(n):
                    return n * 1e-45 / nwimps
                def x_to_n(x):
                    return x * nwimps / 1e-45
                # xsec scale? compute expected n wimps, scale to where 1 wimp is 1e-45
                #nwimps
                secax = plt.gca().secondary_xaxis('top', functions=(n_to_x, x_to_n))
                #secax = plt.gca().twiny()
                secax.set_xlabel(r'Limit [cm$^2$]')
                #secax.set_xscale('log')
                #secax.set_xlim(0.5, 100)

                #new_tick_vals = np.geomspace(1e-48, 1e-47, 2) #secax.get_xticks()
                #tick_vals = new_tick_vals * nwimps / 1e-45
                #print(tick_vals)
                
                #secax.set_xticks(tick_vals)
                #secax.set_xticklabels(['%.2g' % (t*1e-45/nwimps) for t in tick_vals])
                #secax.set_xticklabels(['%.2g' % t for t in new_tick_vals])

for i_ax, ax in enumerate(axes):
    ax.set_ylabel("Trials / bin")

    if discovery or i_ax == 1: 
        plt.sca(ax)
        # Sort legend entries
        # https://stackoverflow.com/a/46160465
        handles, labels = ax.get_legend_handles_labels()
        order = [0,2,1]
        plt.legend([handles[idx] for idx in order],
                   [labels[idx] for idx in order],
                   frameon=False)

plt.tight_layout()
plt.subplots_adjust(hspace=0.4)
#plt.savefig('discovery_5ty_200GeV_10k.png', dpi=200, bbox_to_inches='tight')
#plt.savefig('sensitivity_15ty_200GeV_10k.png', dpi=200, bbox_to_inches='tight')


#plt.savefig('sensitivity_results_10ty_30wimps_200GeV_20200217_disc.png', dpi=200, bbox_inches='tight')
# Bottom panel for the paper
#extent = axes[2].get_window_extent().transformed(f.dpi_scale_trans.inverted())
#f.savefig('discovery_power.png', dpi=200, bbox_inches=extent.expanded(1.2, 1.4))

In [None]:
a = np.linspace(0, 1, 10_000)

log=False
mw = 200
exposures = [10, 15, 20]  # tonne year

plt.figure(figsize=(6,6))
for i, exp in enumerate(exposures):
    name = f'sensi_results_20200219_10k_{exp}ty_{mw}GeV.pkl'
    #name = 'sensi_results_20200225_1k_15ty_200GeV_xtol-6_gtol-5.pkl'
    with open(name, mode='rb') as f:
        sensi_res = pickle.load(f)
    r = sensi_res['full']['nr_mod']['t']
    r_mh = sensi_res['mh']['nr']['t']

    plt.plot(1 - (0.5 + 0.5*stats.chi2(1).cdf(np.sort(r).clip(0, None))),
             1-a,
             color=f'C{i}', label=f'{exp}ty')
    plt.plot(1 - (0.5 + 0.5*stats.chi2(1).cdf(np.sort(r_mh).clip(0, None))),
             1-a,
             color=f'C{i}', linestyle='--')  # label=f'{exp}ty')
plt.legend(frameon=False)
min_p = 1e-3
max_p = 1 if log else 0.55
plt.plot([min_p, max_p], [min_p, max_p], 'k--')
plt.axvline(0.5, color='gray', linestyle='--')
plt.fill_between([0.5, 1], 2*[1], 2*[min_p], color='lightgray', lw=0)
plt.gca().set_aspect(1)
if log:
    plt.yscale('log')
    plt.xscale('log')
plt.ylim(min_p, max_p)
plt.xlim(min_p, max_p)
plt.xlabel('Asymptotic p-value')
plt.ylabel('Observed p-value')
#plt.savefig('pvalue_check_200GeV_10k_linlin.png', dpi=200, bbox_to_inches='tight')
plt.show()