In [1]:
import pyhf
import numpy
import jax

pyhf.set_backend("jax")
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)

In [2]:
def test_jax(xtest):
    
    def foo(x):
        return pyhf.probability.Poisson(x).log_prob(3)
    
    foo_jax = jax.jit(foo)
    
    print(foo(xtest))
    print(foo_jax(xtest))
    
    eps = 1e-6
    foo_grad = (foo(xtest + eps) - foo(xtest)) / eps
    print(jax.grad(foo)(xtest), foo_grad)
    
    jaxpr = jax.make_jaxpr(foo)(xtest)
    
    print(jaxpr)
    
test_jax(2.1)

-1.665947435039924
-1.665947435039924
0.4285714285714286 0.42857108883964656
{ lambda ; a:f64[]. let
    b:f64[] = log a
    c:f64[] = mul 3.0 b
    d:f64[] = sub c a
    e:f64[] = add 3.0 1.0
    f:f64[] = lgamma e
    g:f64[] = sub d f
  in (g,) }


In [3]:
def test_model(xtest):
    model = pyhf.simplemodels.uncorrelated_background(
        signal=[10.0], bkg=[50.0], bkg_uncertainty=[7.0]
    )

    actualdata = numpy.array([55])
    data = numpy.concatenate([actualdata, model.config.auxdata])
    
    def foo(x):
        return model.logpdf([0.0, x], data).sum()
    
    foo_jax = jax.jit(foo)
    
    print(foo(xtest))
    print(foo_jax(xtest))
    
    eps = 1e-6
    foo_grad = (foo(xtest + eps) - foo(xtest)) / eps
    print(jax.grad(foo)(xtest), foo_grad)
    
    # %timeit numpy.array(foo(xtest))
    # %timeit numpy.array(foo_jax(xtest))
    # 6.57 ms ± 79.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    # 4.8 µs ± 21.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

test_model(1.1)

-6.050081503235106
-6.050081503235134
-4.638218923933216 -4.638262822709294


In [4]:
def _model_logpdf_masked(model, pars, data, mask):
    tensorlib, _ = pyhf.get_backend()
    pars, data = tensorlib.astensor(pars), tensorlib.astensor(data)
    
    # Verify parameter and data shapes
    if pars.shape[-1] != model.config.npars:
        raise ValueError(
            f"pars has len {pars.shape[-1]} but "
            f"{model.config.npars} was expected"
        )

    len_actualdata = model.nominal_rates.shape[-1]
    len_auxdata = len(model.config.auxdata)
    if data.shape[-1] != len_actualdata + len_auxdata:
        raise ValueError(
            f"data has len {data.shape[-1]} but "
            f"{model.config.nmaindata + model.config.nauxdata} was expected"
        )
        
    # nan != nan
    pdf = _model_make_pdf_masked(model, pars, mask)
    
    actualdata = data[:len_actualdata]
    auxdata = data[len_actualdata:]
    
    actualdata_masked = tensorlib.where(mask, actualdata, 0)
    data_masked = tensorlib.concatenate([actualdata_masked, auxdata])
    
    result = pdf.log_prob(data_masked)

    if model.batch_size:
        return result
    
    return tensorlib.reshape(result, (1,))


def _model_make_pdf_masked(model, pars, mask):
    tensorlib, _ = pyhf.get_backend()

    pdfobjs = []
    
    mainpdf = _main_model_make_pdf_masked(model.main_model, pars, mask)
    if mainpdf:
        pdfobjs.append(mainpdf)
        
    constraintpdf = model.constraint_model.make_pdf(pars)
    if constraintpdf:
        pdfobjs.append(constraintpdf)

    return pyhf.probability.Simultaneous(pdfobjs, model.fullpdf_tv, model.batch_size)


def _main_model_make_pdf_masked(main_model, pars, mask):
    tensorlib, _ = pyhf.get_backend()
    
    lambdas_data = main_model.expected_data(pars)
    
    # pyhf gets poisson(0 | 0.0) wrong, so settle for a small mean
    tiny = numpy.finfo(numpy.float64).tiny
    lambdas_blinded = tensorlib.where(mask, lambdas_data, tiny)
    
    return pyhf.probability.Independent(pyhf.probability.Poisson(lambdas_blinded))

In [5]:
def test_model_masked(xtest):
    model = pyhf.simplemodels.uncorrelated_background(
        signal=[0.0], bkg=[55.0], bkg_uncertainty=[7.0]
    )
    
    nobs = 55

    actualdata = numpy.array([nobs])
    data = numpy.concatenate([actualdata, model.config.auxdata])
    
    def foo(x):
        return model.logpdf([0.0, x], data).sum()
    
    foo_jax = jax.jit(foo)
    
    print(foo(xtest))
    print(foo_jax(xtest))
    

    mask = numpy.array([False])
    
    def bar(x):
        return _model_logpdf_masked(model, [0.0, x], data, mask).sum()
    
    bar_jax = jax.jit(bar)
    
    print(bar(xtest))
    print(bar_jax(xtest))
    
    tensorlib, _ = pyhf.get_backend()
    pars = tensorlib.astensor([0.0, xtest])
    expected_data = model.make_pdf(pars).expected_data()
    mu = expected_data[0]
    
    like = pyhf.probability.Poisson(mu).log_prob(nobs)
    
    print(foo(xtest), bar(xtest) + like)
    

test_model_masked(1.2)

-7.969519336366574
-7.969519336366631
-4.073084699392325
-4.073084699392382
-7.969519336366574 -7.969519336366574


In [6]:
tensorlib, _ = pyhf.get_backend()
tensorlib.poisson_logpdf(0, 1e-307)

DeviceArray(-1.e-307, dtype=float64, weak_type=True)

In [7]:
numpy.finfo(numpy.float64).tiny * 2 ** -52

5e-324

In [8]:
def test_model_named():
    model = pyhf.simplemodels.uncorrelated_background(
        signal=[0.0], bkg=[55.0], bkg_uncertainty=[7.0]
    )
    
    print(model)
    print(model.config)
    # https://github.com/scikit-hep/pyhf/blob/9135b41605296727ce949d329886ad5b44345f44/src/pyhf/mixins.py#L6
    print(model.config.channels)
    print(model.config.channels)
    print(model.config.channel_nbins)
    print(model.config.channel_slices)
    
test_model_named()

<pyhf.pdf.Model object at 0x7f712fd39960>
<pyhf.pdf._ModelConfig object at 0x7f70cc433460>
['singlechannel']
['singlechannel']
{'singlechannel': 1}
{'singlechannel': slice(0, 1, None)}


In [9]:
def _make_mask(model, blind_bins):
    """ Return a mask to blind data in specified blind_bins. """
    channel_to_slice = model.config.channel_slices
    
    # the last slice is the number of channelbins
    ntot = next(reversed(channel_to_slice.values())).stop
    mask = numpy.ones(ntot, dtype=bool)
    
    for channelbin in blind_bins:
        str_form = isinstance(channelbin, str)
        if str_form:
            channelbin = (channelbin, 0)

        channel, bin_ = channelbin
        
        slice_ = channel_to_slice[channel]
        assert slice_.step is None
        slice_range = range(slice_.start, slice_.stop)
        
        nbins = len(slice_range)
        if str_form and nbins != 1:
            raise ValueError(f"bin index is needed for channel {channel} with {nbins=}")
            
        i = slice_range[bin_]
        mask[i] = False

    return mask

In [10]:
def test_model_named_mask():
    model = pyhf.simplemodels.uncorrelated_background(
        signal=[0.0], bkg=[55.0], bkg_uncertainty=[7.0]
    )
    
    print(_make_mask(model, {("singlechannel", 0)}))
    
test_model_named_mask()

[False]


In [11]:
def model_logpdf_blind(model, pars, data, blind_bins):
    """
    Return a "logpdf" value with blinded channel-bins.
    
    Modified from pyhf.pdf.Model.logpdf

    Args:
        model: pyhf.pdf.Model-like
        pars (:obj:`tensor`): The parameter values
        data (:obj:`tensor`): The measurement data
        blinded_channelbins: Sequence of either
            pair (channel_name, bin_index)
            or str channel_name.
            str channel_name requires that the channel has one bin only.

    Returns:
        Tensor: The "log density" value

    """
    mask = _make_mask(model, blind_bins)
    return _model_logpdf_masked(model, pars, data, mask)

In [12]:
def test_model_blinded(xtest):
    model = pyhf.simplemodels.uncorrelated_background(
        signal=[0.0], bkg=[55.0], bkg_uncertainty=[7.0]
    )
    
    nobs = 55

    actualdata = numpy.array([nobs])
    data = numpy.concatenate([actualdata, model.config.auxdata])
    
    def foo(x):
        return model.logpdf([0.0, x], data).sum()
    
    foo_jax = jax.jit(foo)
    
    print(foo(xtest))
    print(foo_jax(xtest))
    
    blinded = {"singlechannel"}
    
    def bar(x):
        return model_logpdf_blind(model, [0.0, x], data, blinded).sum()
    
    bar_jax = jax.jit(bar)
    
    print(bar(xtest))
    print(bar_jax(xtest))
    
    tensorlib, _ = pyhf.get_backend()
    pars = tensorlib.astensor([0.0, xtest])
    expected_data = model.make_pdf(pars).expected_data()
    mu = expected_data[0]
    
    like = pyhf.probability.Poisson(mu).log_prob(nobs)
    
    print(foo(xtest), bar(xtest) + like)
    

test_model_blinded(1.2)

-7.969519336366574
-7.969519336366631
-4.073084699392325
-4.073084699392382
-7.969519336366574 -7.969519336366574
