# Fast initial tests to check good ranges and behaviours
# with printout and visuals

In [None]:
import os
os.environ["NPROCESSES"] = "8"

In [None]:
import os

import numpy
from matplotlib import pyplot

from discohisto import (
    fit_cabinetry,
    fit_cabinetry_post,
    fit_linspace,
    fit_mcmc_mix,
    fit_normal,
    fit_signal,
    mcmc_core,
    region,
)
from discohisto.region_fit import region_fit

In [None]:
# fill me in
BASEPATH = "/home/tombs/Cambridge/interval-evidence/searches/atlas_susy_2Ljets_2022/"

# fits

In [None]:
def main_fits():
    region_name_to_scan = {
#         "ewk_high": (2, 8),
#         "ewk_int": (20, 70),
#         "ewk_llbb": (0, 3),
#         "ewk_low": (0, 100),
#         "ewk_offshell": (0, 40),
#         "rjr_sr2l_isr": (0, 100),
#         "rjr_sr2l_low": (0, 100),
        "str_src_12_31": (0, 20),
        "str_src_12_61": (0, 40),
#         "str_src_31_81": (0, 40),
#         "str_src_81": (0, 100),
#         "str_srhigh_12_301": (0, 60),
#         "str_srhigh_301": (0, 20),
        "str_srlow_101_201": (0, 80),
#         "str_srlow_101_301": (0, 80),
#         "str_srlow_12_81": (0, 35),
#         "str_srlow_301": (0, 20),
#         "str_srmed_101": (0, 100),
#         "str_srmed_12_101": (20, 90),
#         "str_srzhigh": (0, 20),
#         "str_srzlow": (0, 60),
#         "str_srzmed": (0, 40),
    }

    region_name_to_anchors = {
        "str_src_12_31": [20.0],
        "str_src_12_61": [40.0],
        "str_srlow_101_201": [80.0],
        "str_src_31_81": [36.],
        "str_srhigh_301": [15.],
        "str_src_81": [150.],
        "str_srlow_101_301": [100.],
    }

    for name, (lo, hi) in region_name_to_scan.items():
        print(name)
        dump(name, lo, hi, region_name_to_anchors=region_name_to_anchors)


def dump(name, lo, hi, nbins=200, region_name_to_anchors=None):
    if region_name_to_anchors is None:
        region_name_to_anchors = {}

    dir_region = os.path.join(BASEPATH, name)
    region_1 = region.Region.load(dir_region)

    # cabinetry
    if 0:
        cab = fit_cabinetry.fit(region_1)
        print(cab)
        cab_post = fit_cabinetry_post.fit(region_1)
        print(cab_post)

        # normal
        norm = fit_normal.fit(region_1)
        print(norm)

    # linspace
    linspace = fit_linspace.fit(
        region_1,
        lo,
        hi,
        nbins + 1,
        anchors=region_name_to_anchors.get(name),
    )
    print(linspace)

    levels = numpy.array(linspace.levels) - region_fit(region_1).fun

    x = numpy.linspace(linspace.start, linspace.stop, len(levels))
    pyplot.plot(x, -levels, "k")
    pyplot.ylim(-8, 0.5)
    pyplot.show()


main_fits()

# mcmc

In [None]:
import jax
jax.device_count()

In [None]:
def main_mcmc():
    region_name_to_scan = {
        "ewk_high": (2, 8),
        "ewk_int": (20, 70),
        "ewk_llbb": (0, 3),
        "ewk_low": (0, 100),
        "ewk_offshell": (0, 40),
        "rjr_sr2l_isr": (0, 100),
        "rjr_sr2l_low": (0, 100),
        "str_src_12_31": (0, 20),
        "str_src_12_61": (0, 40),
        "str_src_31_81": (0, 40),
        "str_src_81": (0, 100),
        "str_srhigh_12_301": (0, 60),
        "str_srhigh_301": (0, 20),
        "str_srlow_101_201": (0, 80),
        "str_srlow_101_301": (0, 80),
        "str_srlow_12_81": (0, 35),
        "str_srlow_301": (0, 20),
        "str_srmed_101": (0, 100),
        "str_srmed_12_101": (20, 90),
        "str_srzhigh": (0, 20),
        "str_srzlow": (0, 60),
        "str_srzmed": (0, 40),
    }

    for name, (lo, hi) in region_name_to_scan.items():
        print(name)
        dump_region(name, lo, hi)


def dump_region(name, lo, hi, nbins=50):
    dir_region = os.path.join(BASEPATH, name)
    region_1 = region.Region.load(dir_region)
    
    if name.startswith("rjr"):
        step_size = 0.1
        prob_eye = 0.01
    else:
        step_size = 0.5
        prob_eye = 0.1

    mix = fit_mcmc_mix.fit(
        region_1,
        nbins,
        (lo, hi),
        seed=0,
        nsamples=10_000,
        nrepeats=8,
        step_size=step_size,
        prob_eye=prob_eye,
    )

    neff = mcmc_core.n_by_fit(mix).sum()
    nrepeats = mix.nrepeats
    nsamples = mix.nsamples
    total = numpy.sum(mix.yields)
    print(
        "acceptance: %.2f (%d / %d)"
        % (total / (nrepeats * nsamples), total, nrepeats * nsamples)
    )
    print(
        "efficiency: %.2f (%.1f / %.1f)"
        % (nrepeats * neff / total, neff, total / nrepeats)
    )

    x = numpy.linspace(*mix.range_, len(mix.yields) + 1)    
    y = numpy.array(mix.yields)
    weight = numpy.log(numpy.maximum(y / y.max(), 1e-300))
    pyplot.hist(
        x[:-1],
        weights=weight,
        range=mix.range_,
        bins=len(x) - 1,
        histtype="step",
        color="r",
        lw=2,
    )
    pyplot.ylim(-8, 0.5)
    pyplot.xlim(*mix.range_)
    pyplot.show()
    
%time main_mcmc()

# signal scan

In [None]:
def main_signal():
    region_name_to_scan = {
        "ewk_high": (0, 10),
        "ewk_int": (0, 40),
        "ewk_llbb": (0, 10),
        "ewk_low": (0, 30),
        "ewk_offshell": (0, 30),
        "rjr_sr2l_isr": (0, 40),
        "rjr_sr2l_low": (0, 40),
        "str_src_12_31": (0, 15),
        "str_src_12_61": (0, 20),
        "str_src_31_81": (0, 25),
        "str_src_81": (0, 40),
        "str_srhigh_12_301": (0, 25),
        "str_srhigh_301": (0, 12),
        "str_srlow_101_201": (0, 25),
        "str_srlow_101_301": (0, 35),
        "str_srlow_12_81": (0, 30),
        "str_srlow_301": (0, 20),
        "str_srmed_101": (0, 40),
        "str_srmed_12_101": (0, 40),
        "str_srzhigh": (0, 15),
        "str_srzlow": (0, 40),
        "str_srzmed": (0, 30),
    }

    for name, (lo, hi) in region_name_to_scan.items():
        print(name)
        dump_region(name, lo, hi)


def dump_region(name, lo, hi, nbins=10):
    dir_region = os.path.join(BASEPATH, name)
    region_1 = region.Region.load(dir_region)

    dir_fit = os.path.join(dir_region, "fit")

    sig = fit_signal.fit(region_1, lo, hi, nbins + 1)
    print(sig)

    levels = numpy.array(sig.levels)
    levels -= levels.min()

    x = numpy.linspace(sig.start, sig.stop, len(levels))
    pyplot.plot(x, -levels, "k")
    pyplot.ylim(-8, 0.5)
    pyplot.show()


main_signal()