In [None]:
import os
import json
import functools
from types import SimpleNamespace

import numpy
import matplotlib
from matplotlib import pyplot
import scipy.special
import scipy.optimize

from discohisto import (
    fit_normal,
    fit_cabinetry,
    fit_cabinetry_post,
    fit_linspace,
    fit_mcmc_mix,
    fit_mcmc_tfp_ham,
    region,
    limit,
    stats,
)

import report.frame

In [None]:
pyplot.rcParams.update(
    {
        "text.usetex": True,
        "font.family": "serif",
        "font.size": 10,
        "figure.facecolor": "w",
    }
)

In [None]:
FRAME = report.frame.load("report/results.csv")

In [None]:
print(numpy.array_equal(FRAME.reported_n, FRAME.region_n))

# Compare fitted backgrounds

In [None]:
len(set(FRAME.search_))

In [None]:
len(FRAME.fit_cabinetry_bkg)

In [None]:
# TODO error bars

In [None]:
def plot_bkgs():
    repored_bkg = FRAME.reported_bkg
    fit_cabinetry_bkg = FRAME.fit_cabinetry_bkg

    y = fit_cabinetry_bkg / repored_bkg
    x = numpy.arange(len(y)) + 0.5
    
    figure, axis = pyplot.subplots(
        dpi=400,
        figsize=numpy.array([10, 3]) * 0.7,
        gridspec_kw={
            "top": 0.97,
            "right": 0.995,
            "bottom": 0.05,
            "left": 0.07,
        },
    )
    
    print(len(y))
    
    axis.scatter(x, y, lw=0, s=4 ** 2, marker=".", color="xkcd:blue")
    
    axis.axhline(1.0, color="k", alpha=0.1, zorder=0.5)
    
    axis.set_xlim(-1, len(y) + 1)
    axis.set_ylim(0, 2)
    axis.set_ylabel("fit / reported")
    
    axis.set_xticks([])
    
    figure.savefig("validation_frame_plot_bkgs.png")
    pyplot.show()
    pyplot.close(figure)
    
plot_bkgs()

In [None]:
def plot_bkgs_post():
    repored_bkg = FRAME.reported_bkg
    fit_cabinetry_bkg = FRAME.fit_cabinetry_post_bkg

    y = fit_cabinetry_bkg / repored_bkg
    x = numpy.arange(len(y)) + 0.5
    
    pyplot.scatter(x, y, lw=0, s=2, marker=",")
    
    pyplot.ylim(0, 2)
    pyplot.show()
    
plot_bkgs_post()

# Inspect mean log likelihoods

In [None]:
def print_mean_logls():
    name_to_mean_logl = {
        "cabinetry": FRAME.limit_cabinetry_logl.mean(),
        "normal": FRAME.limit_normal_logl.mean(),
        "normal_log": FRAME.limit_normal_log_logl.mean(),
        "delta": FRAME.limit_delta_logl.mean(),
        "linspace": FRAME.limit_linspace_logl.mean(),
        "mcmc": FRAME.limit_mcmc_logl.mean(),
    }
    
    ref = max(name_to_mean_logl.values())
    
    for name, q in name_to_mean_logl.items():
        print("%15s %7.4f %7.4f" % (name, q, q - ref))

print_mean_logls()    

In [None]:
def print_optimized_mixture():
    name_to_mixture_part = {
        "cabinetry": FRAME.limit_cabinetry_logl,
        "normal_log": FRAME.limit_normal_log_logl,
        "linspace": FRAME.limit_linspace_logl,
        "mcmc": FRAME.limit_mcmc_logl,
    }
    
    parts = numpy.stack(list(name_to_mixture_part.values())).T
    
    def mixture_mean_logl(x):
        log_weights = log_softmax(x)
        return scipy.special.logsumexp(parts + log_weights, axis=1).mean()
    
    # logit coordinates have a shift freedom. Constrain it by setting x[-1]=0
    def loss(x_start):
        x = numpy.append(x_start, 0.0)
        return -mixture_mean_logl(x)
    
    result = scipy.optimize.minimize(
        loss,
        [0.0] * (len(name_to_mixture_part) - 1)
    )
    print(result)
    
    result_weights = numpy.exp(log_softmax(numpy.append(result.x, 0.0)))
    print("weights", result_weights)
    
    print("%15s %7.4f _______" % ("mixture", -loss(result.x)))
    x_p6_p4 = _safe_log([0.6, 0.4, 0])
    # offset to wash out the appended zero
    print("%15s %7.4f _______" % (".6, .4", -loss(x_p6_p4 + 300)))
    
    # plot a scan
    x = numpy.linspace(0, 1, 100)
    y = []
    for xi in x:
        log_weights = _safe_log([xi, 1 - xi, 0])
        # offset to wash out the appended zero
        y.append(-loss(log_weights + 700))
    pyplot.plot(x, y)
    pyplot.show()
    
    
def log_softmax(x):
    # log(e^xi / sum e^xi)
    s = x - x.max()
    return s - numpy.log(numpy.exp(s).sum())


def _safe_log(x):
    x = numpy.asarray(x)
    iszero = x == 0
    return numpy.where(iszero, -numpy.inf, numpy.log(x + iszero))


print_optimized_mixture()

# Compare observed limits

In [None]:
def plot_limits(label):
    reported_obs = FRAME.reported_s95obs
    label_2obs = getattr(FRAME, "limit_%s_2obs" % label)
    label_3obs = getattr(FRAME, "limit_%s_3obs" % label)
    
    figure, axis = pyplot.subplots(
        dpi=400,
        figsize=numpy.array([4, 4]) * 0.7,
        gridspec_kw={
            "top": 0.97,
            "right": 0.995,
            "bottom": 0.05,
            "left": 0.07,
        },
    )
    
    axis.scatter(reported_obs, label_2obs, color="xkcd:red", lw=0, s=3 ** 2, marker=".")
    axis.scatter(reported_obs, label_3obs, color="xkcd:blue", lw=0, s=3 ** 2, marker=".")
    axis.plot([0, 400], [0, 400], "k", alpha=0.1, zorder=0.5)
    
    axis.set_yscale("log")
    axis.set_xscale("log")
    axis.set_xlim(1.5, 400)
    axis.set_ylim(1.5, 400)
    
    axis.set_xlabel("reported limit")
    axis.set_ylabel("our limit")
    
    pyplot.show()


In [None]:
plot_limits("cabinetry")

In [None]:
plot_limits("cabinetry_post")

In [None]:
plot_limits("normal")

In [None]:
plot_limits("normal_log")

In [None]:
plot_limits("delta")

In [None]:
plot_limits("linspace")

In [None]:
plot_limits("mcmc")

In [None]:
2 / numpy.log(2), numpy.exp(2)

In [None]:
3 / numpy.log(2), numpy.exp(3)

# Inspect anomalous differences

In [None]:
def print_anomalies(label):
    search_ = FRAME.search_
    region_ = FRAME.region_
    reported_obs = FRAME.reported_s95obs
    label_2obs = getattr(FRAME, "limit_%s_2obs" % label)
    
    parts = zip(search_, region_, reported_obs, label_2obs)
    
    for search_i, region_i, reported_i, label_i in parts:
        error = numpy.log(label_i / reported_i)
        if not abs(error) > 0.3:
            continue
            
        print(
            "%28s %28s %6.1f %6.1f %6.1f" % 
            (search_i, region_i, reported_i, label_i, error)
        )

In [None]:
print_anomalies("cabinetry")

In [None]:
print_anomalies("linspace")

# Check orderings re expected

In [None]:
def print_orderings(label):
    search_ = FRAME.search_
    region_ = FRAME.region_
    nobs = FRAME.reported_n
    nexp = getattr(FRAME, f"limit_{label}_nexp")
    nexp_hi = getattr(FRAME, f"limit_{label}_nexp_hi")
    nexp_lo = getattr(FRAME, f"limit_{label}_nexp_lo")
    obs = getattr(FRAME, f"limit_{label}_3obs")
    exp = getattr(FRAME, f"limit_{label}_3exp")
    exp_hi = getattr(FRAME, f"limit_{label}_3exp_hi")
    exp_lo = getattr(FRAME, f"limit_{label}_3exp_lo")
    
    parts = zip(
        search_, 
        region_,
        nobs,
        nexp,
        nexp_hi,
        nexp_lo,
        obs,
        exp,
        exp_hi,
        exp_lo,
    )
    
    any_ = False
    
    for items in parts:
        (
            search_i, 
            region_i,
            nobs_i,
            nexp_i,
            nexp_hi_i,
            nexp_lo_i,
            obs_i,
            exp_i,
            exp_hi_i,
            exp_lo_i,
        ) = items
        
        # central
        excess_data = nobs_i > nexp_i
        excess_limit = obs_i > exp_i
        
        if excess_data != excess_limit:
            any_ = True
            print(
                "%28s %28s %6d %6.1f %6.1f %6.1f" % 
                (search_i, region_i, nobs_i, nexp_i, obs_i, exp_i)
            )
        
        # hi
        excess_data = nobs_i > nexp_hi_i
        excess_limit = obs_i > exp_hi_i
        
        if excess_data != excess_limit:
            any_ = True
            print(
                "%28s %28s %6d %6.1f %6.1f %6.1f" % 
                (search_i, region_i, nobs_i, nexp_hi_i, obs_i, exp_hi_i)
            )
            
        # lo
        excess_data = nobs_i > nexp_lo_i
        excess_limit = obs_i > exp_lo_i
        
        if excess_data != excess_limit:
            any_ = True
            print(
                "%28s %28s %6d %6.1f %6.1f %6.1f" % 
                (search_i, region_i, nobs_i, nexp_lo_i, obs_i, exp_lo_i)
            )
        
    if not any_:
        print("ALL OK %r" % label)

In [None]:
print_orderings("cabinetry")

In [None]:
print_orderings("normal")

In [None]:
print_orderings("normal_log")

In [None]:
print_orderings("linspace")

In [None]:
print_orderings("delta")

In [None]:
print_orderings("mcmc")