In [1]:
import json
import copy

import numpy
import pyhf

import blind
import jsongz

In [2]:
import jax
import scipy.optimize

jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
pyhf.set_backend("jax")

In [3]:
SPEC = jsongz.load("ins1852821_bkg.json.gz")
# SPEC = jsongz.load("ins1852821_signal.json.gz")

In [4]:
# SR_NAME = "SR0bvetotight_cuts"
SR_NAME = "SR0ZZloose_cuts"

In [5]:
def get_named(seq, name):
    for item in seq:
        if item["name"] == name:
            return item
    raise KeyError(name)

In [6]:
def add_signal(
    spec,
    channel_name,
    measurement_name="NormalMeasurement",
    poi_name="mu_SIG",
):
    # add a dummy signal to its channel
    channel = get_named(spec["channels"], channel_name)

    channel["samples"].append({
        "data": [0.0],
        "modifiers": [{
            "data": None,
            "name": poi_name,
            "type": "normfactor", 
        }],
        "name": "signal",
    })
    
    # add a its modifier name to the measurement
    measurement = get_named(spec["measurements"], measurement_name)
    measurement["config"]["parameters"].append({
        "bounds": [[0.0, 2.0]],
        "fixed": False,
        "inits": [1.0],
        "name": poi_name,
    })
    measurement["config"]["poi"] = poi_name
    
    return spec

In [7]:
def make_model_and_data(spec, signal_channel_name):
    spec = copy.deepcopy(spec)
    spec = add_signal(spec, signal_channel_name)
    workspace = pyhf.workspace.Workspace(spec)
    
    channels_keep = {
        signal_channel_name,
        # TODO arguments?
        "CRZZ_cuts",
        "CRttZ_cuts",
    }
        
    channels_prune = workspace.channel_slices.keys() - channels_keep
    
    workspace_pruned = workspace.prune(channels=channels_prune)
    
    model = workspace_pruned.model()
    data = numpy.array(workspace_pruned.data(model))
    
    return model, data


MODEL, DATA = make_model_and_data(SPEC, SR_NAME)

In [8]:
def test_logpdf(model, data, signal_channel_name):
    parameters = numpy.array(model.config.suggested_init())
    
    print(model.logpdf(parameters, data))
    
    print(blind.model_logpdf_blind(model, parameters, data, []))
    
    logf = blind.model_logpdf_blind(model, parameters, data, {signal_channel_name})
    print(logf)

    slice_ = model.config.channel_slices[signal_channel_name]
    expected_data = model.expected_actualdata(parameters)
    mu = expected_data[slice_]
    nobs = data[slice_]
    print(mu, nobs)

    loglikelihood = pyhf.probability.Poisson(mu).log_prob(numpy.array(nobs))
    print(loglikelihood)
    print(logf + loglikelihood)


test_logpdf(MODEL, DATA, SR_NAME)

[-105.97103393]
[-105.97103393]
[-101.97906232]
[144.28615681] [157.]
[-3.99197161]
[-105.97103393]


In [9]:
def test_hess():
    def f(x):
        return 0.5 * x ** 2
    
    print(jax.hessian(f)(0.1))
    
test_hess()

1.0


In [10]:
def inner_product(x, c):
    return x.dot(c.dot(x))

In [11]:
def d2fdx2(fminus, f, fplus, eps):
    return ((fplus - f) - (f - fminus)) / eps ** 2

In [12]:
def test_opt(model, data, signal_channel_name):
    blind_bins = {signal_channel_name}
#     blind_bins = []
    
    def f(x):
        logy, = blind.model_logpdf_blind(model, x, data, blind_bins)
        return -logy
    
    fjit = jax.jit(jax.value_and_grad(f))
          
    parameters = numpy.array(model.config.suggested_init())
    bounds = numpy.array(model.config.suggested_bounds())
    
    print(fjit(parameters)[0])

    result = scipy.optimize.minimize(
        fjit,
        parameters,
        bounds=bounds,
        jac=True,
    )
    
    print(result)
    
    # SR yields
    slice_ = model.config.channel_slices[signal_channel_name]
    
    def y(x):
        expected_data = model.expected_actualdata(x)
        yres, = expected_data[slice_]
        return yres
    
    print(y(parameters))
    print(y(result.x))
    
    ygrad = jax.grad(y)(result.x)
    
    def cov_jax(x):
        hess = jax.hessian(f)(x)
        return jax.numpy.linalg.inv(hess)
    
    cov_lbfgs = result.hess_inv.todense()
    cov = numpy.array(cov_jax(result.x))
    
    print("cov_det")
    print(numpy.linalg.det(cov_lbfgs))
    print(numpy.linalg.det(cov))
    
    var_lbfgs = inner_product(ygrad, result.hess_inv)
    var = inner_product(ygrad, cov)
    
    print("std")
    print(var_lbfgs ** 0.5)
    # jax hessian seems to give broken answers - numerical failure?
    print(var ** 0.5)
    
    # try numerical derivatives
    eps = 1e-6
    fplus = f(result.x + eps * ygrad)
    fminus = f(result.x - eps * ygrad)
    
    print("d2fdx2")
    hess = d2fdx2(fminus, result.fun, fplus, eps)
    print(hess)
    print(1 / hess ** 0.5)
    
    print((numpy.diff([fminus, result.fun, fplus], n=2) / eps ** 2) ** -0.5)
    
    delta = float(fplus - result.fun)
    print(f"{delta=}")
    var = 0.5 * eps ** 2 / delta
    print(var ** 0.5, 1 / var ** 0.5)


test_opt(MODEL, DATA, SR_NAME)

101.97906231564971
      fun: 80.16637890947715
 hess_inv: <99x99 LbfgsInvHessProduct with dtype=float64>
      jac: array([-2.63176379e-02,  0.00000000e+00,  6.21899187e-05,  2.86279983e-06,
        3.08237655e-07,  1.76529066e-08,  1.26341255e-06, -1.77251172e-07,
        2.68654848e-06, -5.99351291e-07, -1.51296082e-06, -1.08804916e-05,
       -1.20997963e-05, -2.93091429e-06,  3.17972131e-06, -2.54529769e-06,
       -5.70551087e-05,  4.35344098e-05,  2.09393092e-05,  1.35850853e-05,
       -7.39881927e-06,  5.69339832e-06,  0.00000000e+00,  1.43000027e-05,
        2.90790711e-05,  2.66584414e-05,  5.54546785e-06, -1.29090211e-06,
        2.43114820e-04,  0.00000000e+00,  2.65374718e-05, -4.29519718e-06,
       -2.46554387e-06, -1.40414983e-05, -1.55818629e-05, -3.13001696e-05,
        8.09581830e-05,  1.79290707e-05, -6.29505916e-06, -9.44049656e-06,
       -9.99091942e-06,  2.74049950e-05,  2.48982782e-10,  2.76722690e-08,
       -1.23242421e-06, -2.92237853e-05,  9.05852423e-06, 

  r = _umath_linalg.det(a, signature=signature)


522011714.56545126
4.376831983832624e-05
[4.37683198e-05]
delta=0.00026479148496605376
4.345432364018805e-05 23012.669769761775


# TODO optimize with constraints

# TODO try cabinetry