In [None]:
import os
import json

import numpy
import matplotlib
from matplotlib import pyplot

from discohisto import (
    fit_normal,
    fit_cabinetry,
    fit_cabinetry_post,
    fit_linspace,
    fit_mcmc_mix,
    fit_mcmc_tfp_ham,
    region,
)

In [None]:
def plot_fits(
    region_name,
    cabinetry_class=fit_cabinetry.FitCabinetry,
    cabinetry_post_class=fit_cabinetry_post.FitCabinetryPost,
    normal_class=fit_normal.FitNormal,
    linspace_class=fit_linspace.FitLinspace,
):
    print(region_name)
    path = os.path.join(BASENAME, region_name, "fit")
    
    region_i = region.Region.load(os.path.join(BASENAME, region_name))
    n_region = _get_n_region(region_i)
    
    n_reported = REPORTED[region_name]["n"]
    mismatch = ", !!!!!!!!!!!!!!" * (n_region != n_reported)
    print("n = %d (%d%s)" % (n_reported, n_region, mismatch))
    
    bkg = REPORTED[region_name]["bkg"]
    bkg_hi = REPORTED[region_name]["bkg_hi"]
    bkg_lo = REPORTED[region_name]["bkg_lo"]
    
    cabinetry = cabinetry_class.load(path)
    cabinetry_post = cabinetry_post_class.load(path)
    normal = normal_class.load(path)
    linspace = linspace_class.load(path)
    # disgusting hack TODO
    mcmc = None
    for mcmc_class in [fit_mcmc_mix.FitMcmcMix, fit_mcmc_tfp_ham.FitMcmcTfpHam]:
        try:
            mcmc = mcmc_class.load(path)
        except FileNotFoundError:
            ...
    assert mcmc is not None
    
    figure, axis = pyplot.subplots(tight_layout=(0, 0, 0))
    
    # TODO split to functions
    # TODO normalize to area, not maximum
    
    # plot reported
    axis.plot(
        [bkg + bkg_hi, bkg + bkg_lo],
        [-5] * 2,
        color="xkcd:barney purple",
        linewidth=2,
    )
    axis.scatter(
        [bkg], 
        [-5],
        s=10 ** 2,
        color="xkcd:barney purple",
        marker="*",
        linewidth=2,
        facecolor="w",
        zorder=2.2
    )
    
    # plot cabinetry bar
    axis.plot(
        [
            cabinetry.yield_pre - cabinetry.error_pre, 
            cabinetry.yield_pre + cabinetry.error_pre,
        ],
        [-4] * 2,
        color="xkcd:mahogany",
        linewidth=2,
    )
    axis.scatter(
        [cabinetry.yield_pre], 
        [-4],
        s=10 ** 2,
        color="xkcd:mahogany",
        marker="o",
        linewidth=2,
        facecolor="w",
        zorder=2.2
    )
    
    axis.plot(
        [
            cabinetry_post.yield_post - cabinetry_post.error_post, 
            cabinetry_post.yield_post + cabinetry_post.error_post,
        ],
        [-4.5] * 2,
        color="xkcd:green",
        linewidth=2,
    )
    axis.scatter(
        [cabinetry_post.yield_post], 
        [-4.5],
        s=10 ** 2,
        color="xkcd:green",
        marker="o",
        linewidth=2,
        facecolor="w",
        zorder=2.2
    )
    
    # plot normal bar
    axis.plot(
        [
            normal.yield_linear - normal.error_linear, 
            normal.yield_linear + normal.error_linear,
        ],
        [-3] * 2,
        "k",
        linewidth=2,
    )
    axis.scatter(
        [normal.yield_linear], 
        [-3],
        s=10 ** 2,
        color="k",
        marker="D",
        linewidth=2,
        facecolor="w",
        zorder=2.2
    )
    
    # plot linspace
    linspace_y = _linspace_density(linspace)
    linspace_x = numpy.linspace(linspace.start, linspace.stop, len(linspace.levels))
    pyplot.plot(
        linspace_x, 
        _safe_log(linspace_y),
        "b",
        linewidth=2,
    )
    
    # plot mcmc
    mcmc_y = _mcmc_density(mcmc, 50)
    mcmc_x = numpy.linspace(*mcmc.range_, len(mcmc_y))
    pyplot.plot(
        mcmc_x,
        _safe_log(mcmc_y),
        "r",
        linewidth=2,
        drawstyle="steps-post",
    )

    xmin = min(linspace.start, mcmc.range_[0])
    xmax = max(linspace.stop, mcmc.range_[1])
    pyplot.xlim(xmin, xmax)
    pyplot.ylim(-8, 0.5)
    
    pyplot.show()
    
    
def _rebin(array, len_new):
    return numpy.reshape(array, (len_new, -1)).sum(axis=-1)

    
def _safe_log(x):
    is_zero = x == 0
    return numpy.where(
        is_zero,
        -numpy.inf,
        numpy.log(x + is_zero),
    )

def _get_n_region(reg):
    sr_name = reg.signal_region_name
    observations = reg.workspace

    for obs in reg.workspace["observations"]:
        if obs["name"] == sr_name:
            return obs["data"][0]

    raise ValueError(sr_name)


def _linspace_density(linspace):
    levels = numpy.array(linspace.levels)
    y = numpy.exp(levels.min() - levels)
    norm = numpy.trapz(y, dx=(linspace.stop - linspace.start) / (len(levels) - 1))
    return y / norm


def _mcmc_density(mcmc, nbins):
    y = _rebin(mcmc.yields, nbins)
    norm = y.sum() * ((mcmc.range_[1] - mcmc.range_[0]) / nbins)
    y = y / norm
    return numpy.append(y, y[-1])

In [None]:
def load_reported():
    path = os.path.join(BASENAME, "reported.json")
    with open(path) as file_:
        reported = json.load(file_)
    return reported

In [None]:
def main(search):
    # disgusting hack TODO
    global BASENAME, REPORTED
    
    BASENAME = f"/home/tombs/Cambridge/interval-evidence/searches/{search}/"
    REPORTED = load_reported()
    
    print("#", search)
    for sr_name in REPORTED:
        try:
            plot_fits(sr_name)
        except FileNotFoundError as e:
            print("!!! missing", sr_name, e)

In [None]:
main("atlas_susy_1Lbb_2020")

In [None]:
main("atlas_susy_1Ljets_2021")

In [None]:
main("atlas_susy_2hadtau_2020")

In [None]:
main("atlas_susy_2L0J_2019")

In [None]:
main("atlas_susy_2Ljets_2022")

In [None]:
main("atlas_susy_3L_2021")

In [None]:
main("atlas_susy_3Lresonance_2020")

In [None]:
main("atlas_susy_3LRJmimic_2020")

In [None]:
main("atlas_susy_3Lss_2019")

In [None]:
main("atlas_susy_4L_2021")

In [None]:
main("atlas_susy_compressed_2020")

In [None]:
main("atlas_susy_DVmuon_2020")

In [None]:
main("atlas_susy_hb_2019")

In [None]:
main("atlas_susy_jets_2021")