In [None]:
import os
from functools import partial

import numpy
import scipy
from matplotlib import pyplot

import os

import pyhf
import cabinetry

from discohist import (
    fit_cabinetry,
    fit_cabinetry_post,
    fit_linspace,
    fit_normal,
    fit_signal,
    region,
    serial,
    blind,
    fit_mcmc_mix,
    mcmc_core,
)

from discohist.region_properties import region_properties

In [None]:
BASEPATH = "/home/tombs/Cambridge/interval-evidence/searches/atlas_susy_1Ljets_2021/"

# fits

In [None]:
def main():
    region_name_to_scan = {
#         "SR2JBVEM_meffInc30_gluino": (5, 40),
#         "SR2JBVEM_meffInc30_squark": (40, 160),
#         "SR4JhighxBVEM_meffInc30": (0, 30),
#         "SR4JlowxBVEM_meffInc30": (0, 30),
        "SR6JBVEM_meffInc30_gluino": (0, 14),
        "SR6JBVEM_meffInc30_squark": (2, 18),
    }
    
    region_name_to_anchors = {
        "SR2JBVEM_meffInc30_gluino": [38.0],
        "SR4JlowxBVEM_meffInc30": [25.0],
        "SR6JBVEM_meffInc30_gluino": [5.0, 1.0, 13.0],
        "SR6JBVEM_meffInc30_squark": [9.0, 2.0, 16.25],
    }

    for name, (lo, hi) in region_name_to_scan.items():
        print(name)
        test(name, lo, hi, region_name_to_anchors=region_name_to_anchors)


def test(name, lo, hi, *, nbins=200, region_name_to_anchors=None):
    if region_name_to_anchors is None:
        region_name_to_anchors = {}
        
    dir_region = os.path.join(BASEPATH, name)
    region_1 = region.Region.load(dir_region)

    dir_fit = os.path.join(dir_region, "fit")

    # cabinetry
    if 0:
        cab = fit_cabinetry.fit(region_1)
        print(cab)
        cab_post = fit_cabinetry_post.fit(region_1)
        print(cab_post)

        # normal
        norm = fit_normal.fit(region_1)
        print(norm)

    # linspace
    linspace = fit_linspace.fit(
        region_1,
        lo,
        hi,
        nbins + 1,
        anchors=region_name_to_anchors.get(name),
    )
    print(linspace)
    
    # find the global maximum for comparison
    properties = region_properties(region_1)
    optimum = scipy.optimize.minimize(
        properties.objective_value_and_grad,
        properties.init,
        bounds=properties.bounds,
        jac=True,
        method="L-BFGS-B",
    )
    assert optimum.success
    print("fun:", optimum.fun)
    
    levels = numpy.array(linspace.levels) - optimum.fun
    
    x = numpy.linspace(linspace.start, linspace.stop, len(levels))
    pyplot.plot(x, -levels, "k")
    pyplot.ylim(-8, 0.5)
    pyplot.show()

    
main()

# mcmc

In [None]:
import os

import numpy

from pyhf_stuff import fit_mcmc_mix, mcmc_core, region


def main():
    region_name_to_scan = {
        "SR2JBVEM_meffInc30_gluino": (5, 40),
        "SR2JBVEM_meffInc30_squark": (40, 160),
        "SR4JhighxBVEM_meffInc30": (0, 30),
        "SR4JlowxBVEM_meffInc30": (0, 30),
        "SR6JBVEM_meffInc30_gluino": (0, 12),
        "SR6JBVEM_meffInc30_squark": (0, 12),
    }

    for name, (lo, hi) in region_name_to_scan.items():
        print(name)
        dump_region(name, lo, hi)


def dump_region(name, lo, hi, nbins=50):
    dir_region = os.path.join(BASEPATH, name)
    region_1 = region.Region.load(dir_region)

    dir_fit = os.path.join(dir_region, "fit")

    mix = fit_mcmc_mix.fit(
        region_1,
        nbins,
        (lo, hi),
        seed=0,
        nsamples=10_000,
        nrepeats=8,
        nprocesses=8,
    )
    # mix.dump(dir_fit)

    neff = mcmc_core.n_by_fit(mix).sum()
    nrepeats = mix.nrepeats
    nsamples = mix.nsamples
    total = numpy.sum(mix.yields)
    print(
        "acceptance: %.2f (%d / %d)"
        % (total / (nrepeats * nsamples), total, nrepeats * nsamples)
    )
    print(
        "efficiency: %.2f (%.1f / %.1f)"
        % (nrepeats * neff / total, neff, total / nrepeats)
    )
    
    x = numpy.linspace(*mix.range_, len(mix.yields) + 1)
    y = numpy.array(mix.yields)
    pyplot.hist(
        x[:-1], 
        weights=y / y.max(), 
        range=mix.range_, 
        bins=len(x) - 1,
        histtype="step",
        color="r",
        lw=2,
    )
    
    pyplot.show()

main()

# signal scan

In [None]:
def main():
    region_name_to_scan = {
        "SR2JBVEM_meffInc30_gluino": (0, 35),
        "SR2JBVEM_meffInc30_squark": (0, 80),
        "SR4JhighxBVEM_meffInc30": (0, 25),
        "SR4JlowxBVEM_meffInc30": (0, 20),
        "SR6JBVEM_meffInc30_gluino": (0, 12),
        "SR6JBVEM_meffInc30_squark": (0, 15),
    }

    for name, (lo, hi) in region_name_to_scan.items():
        print(name)
        dump_region(name, lo, hi)


def dump_region(name, lo, hi, nbins=10):
    dir_region = os.path.join(BASEPATH, name)
    region_1 = region.Region.load(dir_region)

    dir_fit = os.path.join(dir_region, "fit")

    sig = fit_signal.fit(region_1, lo, hi, nbins + 1)
    print(sig)
    
    levels = numpy.array(sig.levels)
    levels -= levels.min()
    
    x = numpy.linspace(sig.start, sig.stop, len(levels))
    pyplot.plot(x, levels, "k")
    pyplot.ylim(0, 8)
    pyplot.show()
    
main()

# try fitting the bkgonly without bin selection (fails)

In [None]:
def test_native():
    spec = serial.load_json_gz(os.path.join(BASE, "bkg.json.gz"))
    workspace = pyhf.workspace.Workspace(region.clear_poi(spec))
    
    channels = workspace.channel_slices.keys()

    # find control and signal region strings
    control_regions_2j = set()
    control_regions_4j = set()
    control_regions_6j = set()
    signal_regions = set()
    for name in channels:
        if name.startswith("WR") or name.startswith("TR"):
            if name[2:4] == "2J":
                control_regions_2j.add(name)
            elif name[2:4] == "4J":
                control_regions_4j.add(name)
            elif name[2:4] == "6J":
                control_regions_6j.add(name)
            else:
                raise ValueError(name)
            continue
        assert name.startswith("SR"), name
        signal_regions.add(name)

    sr_name = "SR2JBVEM_meffInc30"

    workspace = region.prune(workspace, sr_name, *control_regions_2j)
    
    model = workspace.model()
    data = workspace.data(model)
    model_blind = blind.Model(model, {sr_name})
    
    # manual cabinetry fit
    prediction = cabinetry.model_utils.prediction(
        model_blind, fit_results=cabinetry.fit.fit(model_blind, data)
    )
    
    print(prediction)
    
    index = model_blind.config.channels.index(sr_name)
    print(numpy.sum(prediction.model_yields[index], axis=0))
    print(prediction.total_stdev_model_bins[index])

    
test_native()