In [None]:
import os
import json
import functools

import numpy
import matplotlib
from matplotlib import pyplot
import scipy.special
import scipy.optimize

from discohisto import (
    fit_normal,
    fit_cabinetry,
    fit_cabinetry_post,
    fit_linspace,
    fit_mcmc_mix,
    fit_mcmc_tfp_ham,
    region,
    limit,
    stats,
)

In [None]:
SEARCHES_PATH = "searches/"

In [None]:
@functools.cache
def load_searches():
    searches = []
    for item in os.scandir(SEARCHES_PATH):
        if not item.is_dir():
            continue
        searches.append(item.name)
        
    return sorted(searches)


def load_reported(search):
    path = os.path.join(SEARCHES_PATH, search, "reported.json")
    with open(path) as file_:
        reported = json.load(file_)
    return reported


def _get_n_region(reg):
    sr_name = reg.signal_region_name
    observations = reg.workspace

    for obs in reg.workspace["observations"]:
        if obs["name"] == sr_name:
            return obs["data"][0]

    raise ValueError(sr_name)


def _load_mcmc_limits(path, *, suffix):
    mcmc_types = ["mix", "tfp_ham"]
    lim = None
    for mcmc_type in mcmc_types:
        suffix_i = "_mcmc_%s_%s" % (mcmc_type, suffix)
        try:
            lim = limit.LimitScan.load(path, suffix=suffix_i)
        except FileNotFoundError:
            ...
    assert lim is not None
    return lim


def _load_limit(limit_dir, suffix):
    lim = limit.LimitScan.load(limit_dir, suffix=suffix)
    assert lim.levels[6:8] == [-2, -3], lim.levels[6:8]
    return lim


def _limit_logl(lim):
    return numpy.log(numpy.mean(lim.integral_zero))

In [None]:
def load_frame():
    searches = load_searches()
    
    # frame entries
    search_ = []
    region_ = []
    reported_n = []
    reported_bkg = []
    reported_bkg_hi = []
    reported_bkg_lo = []
    reported_s95obs = []
    reported_s95exp = []
    reported_s95exp_hi = []
    reported_s95exp_lo = []
    
    region_n = []
    
    fit_cabinetry_bkg = []
    fit_cabinetry_err = []
    fit_cabinetry_post_bkg = []
    fit_cabinetry_post_err = []
    
    limit_cabinetry_obs_2 = []
    limit_cabinetry_obs_3 = []
    limit_cabinetry_log_like = []
    
    limit_cabinetry_post_obs_2 = []
    limit_cabinetry_post_obs_3 = []
    
    limit_normal_obs_2 = []
    limit_normal_obs_3 = []
    limit_normal_log_like = []
    
    limit_normal_log_obs_2 = []
    limit_normal_log_obs_3 = []
    limit_normal_log_log_like = []
    
    limit_delta_obs_2 = []
    limit_delta_obs_3 = []
    limit_delta_log_like = []
    
    limit_linspace_obs_2 = []
    limit_linspace_obs_3 = []
    limit_linspace_log_like = []
    
    limit_mcmc_obs_2 = []
    limit_mcmc_obs_3 = []
    limit_mcmc_log_like = []
    
    for search in searches:
        reported = load_reported(search)
        for region_name in reported:
            search_.append(search)
            region_.append(region_name)
            
            # reported
            reported_reg = reported[region_name]
            
            n_observed = reported_reg["n"]
            reported_n.append(n_observed)
            reported_bkg.append(reported_reg["bkg"])
            reported_bkg_hi.append(reported_reg["bkg_hi"])
            reported_bkg_lo.append(reported_reg["bkg_lo"])
            reported_s95obs.append(reported_reg["s95obs"])
            reported_s95exp.append(reported_reg["s95exp"])
            reported_s95exp_hi.append(reported_reg["s95exp_hi"])
            reported_s95exp_lo.append(reported_reg["s95exp_lo"])
            
            # region
            region_dir = os.path.join(SEARCHES_PATH, search, region_name)
            region_i = region.Region.load(region_dir)
            region_n.append(_get_n_region(region_i))
            
            # standard fits
            fit_dir = os.path.join(region_dir, "fit")
            
            fit = fit_cabinetry.FitCabinetry.load(fit_dir)
            fit_cabinetry_bkg.append(fit.yield_pre)
            fit_cabinetry_err.append(fit.error_pre)
            
            fit = fit_cabinetry_post.FitCabinetryPost.load(fit_dir)
            fit_cabinetry_post_bkg.append(fit.yield_post)
            fit_cabinetry_post_err.append(fit.error_post)
            
            fit = fit_normal.FitNormal.load(fit_dir)
            mu_delta = fit.yield_linear
            
            # limits
            limit_dir = os.path.join(fit_dir, "limit")
            load_limit = functools.partial(_load_limit, limit_dir)
            
            lim = load_limit("_cabinetry_observed")
            limit_cabinetry_obs_2.append(lim.points[6][-1])
            limit_cabinetry_obs_3.append(lim.points[7][-1])
            limit_cabinetry_log_like.append(_limit_logl(lim))
            
            lim = load_limit("_cabinetry_post_observed")
            limit_cabinetry_post_obs_2.append(lim.points[6][-1])
            limit_cabinetry_post_obs_3.append(lim.points[7][-1])
            
            lim = load_limit("_linspace_observed")
            limit_linspace_obs_2.append(lim.points[6][-1])
            limit_linspace_obs_3.append(lim.points[7][-1])
            limit_linspace_log_like.append(_limit_logl(lim))
            
            lim = load_limit("_normal_observed")
            limit_normal_obs_2.append(lim.points[6][-1])
            limit_normal_obs_3.append(lim.points[7][-1])
            limit_normal_log_like.append(_limit_logl(lim))
            
            lim = load_limit("_normal_log_observed")
            limit_normal_log_obs_2.append(lim.points[6][-1])
            limit_normal_log_obs_3.append(lim.points[7][-1])
            limit_normal_log_log_like.append(_limit_logl(lim))
                        
            lim = limit.LimitScanDelta.load(limit_dir, suffix="_observed")
            assert lim.levels[6:8] == [-2, -3], lim.levels[6:8]
            limit_delta_obs_2.append(lim.points[6][-1])
            limit_delta_obs_3.append(lim.points[7][-1])
            limit_delta_log_like.append(stats.poisson_log_minus_max(n_observed, mu_delta))
            
            lim = _load_mcmc_limits(limit_dir, suffix="observed")
            assert lim.levels[6:8] == [-2, -3], lim.levels[6:8]
            limit_mcmc_obs_2.append(lim.points[6][-1])
            limit_mcmc_obs_3.append(lim.points[7][-1])
            limit_mcmc_log_like.append(_limit_logl(lim))

    out = dict(
        # labels
        search_=search_,
        region_=region_,
        # reported
        reported_n=reported_n,
        reported_bkg=reported_bkg,
        reported_bkg_hi=reported_bkg_hi,
        reported_bkg_lo=reported_bkg_lo,
        reported_s95obs=reported_s95obs,
        reported_s95exp=reported_s95exp,
        reported_s95exp_hi=reported_s95exp_hi,
        reported_s95exp_lo=reported_s95exp_lo,
        region_n=region_n,
        # fits
        fit_cabinetry_bkg=fit_cabinetry_bkg,
        fit_cabinetry_err=fit_cabinetry_err,
        fit_cabinetry_post_bkg=fit_cabinetry_post_bkg,
        fit_cabinetry_post_err=fit_cabinetry_post_err,
        # limits
        limit_cabinetry_obs_2=limit_cabinetry_obs_2,
        limit_cabinetry_obs_3=limit_cabinetry_obs_3,
        limit_cabinetry_log_like=limit_cabinetry_log_like,
        limit_cabinetry_post_obs_2=limit_cabinetry_post_obs_2,
        limit_cabinetry_post_obs_3=limit_cabinetry_post_obs_3,
        limit_normal_obs_2=limit_normal_obs_2,
        limit_normal_obs_3=limit_normal_obs_3,
        limit_normal_log_like=limit_normal_log_like,
        limit_normal_log_obs_2=limit_normal_log_obs_2,
        limit_normal_log_obs_3=limit_normal_log_obs_3,
        limit_normal_log_log_like=limit_normal_log_log_like,
        limit_delta_obs_2=limit_delta_obs_2,
        limit_delta_obs_3=limit_delta_obs_3,
        limit_delta_log_like=limit_delta_log_like,
        limit_linspace_obs_2=limit_linspace_obs_2,
        limit_linspace_obs_3=limit_linspace_obs_3,
        limit_linspace_log_like=limit_linspace_log_like,
        limit_mcmc_obs_2=limit_mcmc_obs_2,
        limit_mcmc_obs_3=limit_mcmc_obs_3,
        limit_mcmc_log_like=limit_mcmc_log_like,
    )
    
    return {key: numpy.array(value) for key, value in out.items()}
        

FRAME = load_frame()

In [None]:
print(numpy.array_equal(FRAME["reported_n"], FRAME["region_n"]))

In [None]:
print(load_searches())

# Compare fitted backgrounds

In [None]:
# TODO error bars

In [None]:
def plot_bkgs():
    repored_bkg = FRAME["reported_bkg"]
    fit_cabinetry_bkg = FRAME["fit_cabinetry_bkg"]

    y = fit_cabinetry_bkg / repored_bkg
    x = numpy.arange(len(y)) + 0.5
    
    pyplot.scatter(x, y, lw=0, s=2, marker=",")
    
    pyplot.ylim(0, 2)
    pyplot.show()
    
plot_bkgs()

In [None]:
def plot_bkgs_post():
    repored_bkg = FRAME["reported_bkg"]
    fit_cabinetry_bkg = FRAME["fit_cabinetry_post_bkg"]

    y = fit_cabinetry_bkg / repored_bkg
    x = numpy.arange(len(y)) + 0.5
    
    pyplot.scatter(x, y, lw=0, s=2, marker=",")
    
    pyplot.ylim(0, 2)
    pyplot.show()
    
plot_bkgs_post()

# Inspect mean log likelihoods

In [None]:
def print_mean_log_likes():
    name_to_mean_log_like = {
        "cabinetry": FRAME["limit_cabinetry_log_like"].mean(),
        "normal": FRAME["limit_normal_log_like"].mean(),
        "normal_log": FRAME["limit_normal_log_log_like"].mean(),
        "delta": FRAME["limit_delta_log_like"].mean(),
        "linspace": FRAME["limit_linspace_log_like"].mean(),
        "mcmc": FRAME["limit_mcmc_log_like"].mean(),
    }
    
    ref = max(name_to_mean_log_like.values())
    
    for name, q in name_to_mean_log_like.items():
        print("%15s %7.4f %7.4f" % (name, q, q - ref))

print_mean_log_likes()    

In [None]:
def print_optimized_mixture():
    name_to_mixture_part = {
        "cabinetry": FRAME["limit_cabinetry_log_like"],
        "normal_log": FRAME["limit_normal_log_log_like"],
        "linspace": FRAME["limit_linspace_log_like"],
        "mcmc": FRAME["limit_mcmc_log_like"],
    }
    
    parts = numpy.stack(list(name_to_mixture_part.values())).T
    
    def mixture_mean_log_like(x):
        log_weights = log_softmax(x)
        return scipy.special.logsumexp(parts + log_weights, axis=1).mean()
    
    # logit coordinates have a shift freedom. Constrain it by setting x[-1]=0
    def loss(x_start):
        x = numpy.append(x_start, 0.0)
        return -mixture_mean_log_like(x)
    
    result = scipy.optimize.minimize(
        loss,
        [0.0] * (len(name_to_mixture_part) - 1)
    )
    print(result)
    
    result_weights = numpy.exp(log_softmax(numpy.append(result.x, 0.0)))
    print("weights", result_weights)
    
    print("%15s %7.4f _______" % ("mixture", -loss(result.x)))
    x_p6_p4 = _safe_log([0.6, 0.4, 0])
    # offset to wash out the appended zero
    print("%15s %7.4f _______" % (".6, .4", -loss(x_p6_p4 + 300)))
    
    # plot a scan
    x = numpy.linspace(0, 1, 100)
    y = []
    for xi in x:
        log_weights = _safe_log([xi, 1 - xi, 0])
        # offset to wash out the appended zero
        y.append(-loss(log_weights + 700))
    pyplot.plot(x, y)
    pyplot.show()
    
    
def log_softmax(x):
    # log(e^xi / sum e^xi)
    s = x - x.max()
    return s - numpy.log(numpy.exp(s).sum())


def _safe_log(x):
    x = numpy.asarray(x)
    iszero = x == 0
    return numpy.where(iszero, -numpy.inf, numpy.log(x + iszero))


print_optimized_mixture()

# Compare observed limits

In [None]:
def plot_limits(label):
    reported_obs = FRAME["reported_s95obs"]
    label_obs_2 = FRAME["limit_%s_obs_2" % label]
    label_obs_3 = FRAME["limit_%s_obs_3" % label]
    
    pyplot.scatter(reported_obs, label_obs_2, color="r", lw=0, s=2, marker=",")
    pyplot.scatter(reported_obs, label_obs_3, color="b", lw=0, s=2, marker=",")
    pyplot.plot([0, 400], [0, 400], "k", alpha=0.2)
    
    pyplot.yscale("log")
    pyplot.xscale("log")
    pyplot.xlim(1.5, 400)
    pyplot.ylim(1.5, 400)
    
    pyplot.show()


In [None]:
plot_limits("cabinetry")

In [None]:
plot_limits("cabinetry_post")

In [None]:
plot_limits("normal")

In [None]:
plot_limits("normal_log")

In [None]:
plot_limits("delta")

In [None]:
plot_limits("linspace")

In [None]:
plot_limits("mcmc")

In [None]:
2 / numpy.log(2), numpy.exp(2)

In [None]:
3 / numpy.log(2), numpy.exp(3)

# Inspect anomalous differences

In [None]:
def print_anomalies(label):
    search_ = FRAME["search_"]
    region_ = FRAME["region_"]
    reported_obs = FRAME["reported_s95obs"]
    label_obs_2 = FRAME["limit_%s_obs_2" % label]
    
    parts = zip(search_, region_, reported_obs, label_obs_2)
    
    for search_i, region_i, reported_i, label_i in parts:
        error = numpy.log(label_i / reported_i)
        if not abs(error) > 0.3:
            continue
            
        print(
            "%28s %28s %6.1f %6.1f %6.1f" % 
            (search_i, region_i, reported_i, label_i, error)
        )

In [None]:
print_anomalies("cabinetry")

In [None]:
print_anomalies("linspace")