In [1]:
# default_exp inferno

In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

# INFERNO loss

> API details.

In [3]:
#hide
from nbdev.showdoc import *

In [4]:
from pytorch_inferno.model_wrapper import ModelWrapper
from pytorch_inferno.callback import *
from pytorch_inferno.data import get_paper_data
from pytorch_inferno.plotting import *
from pytorch_inferno.inference import *
from pytorch_inferno.utils import *
from pytorch_inferno.inferno import *

from fastcore.all import partialler
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import itertools
from typing import *
from collections import OrderedDict
from fastcore.all import store_attr
from abc import abstractmethod

import torch.nn.functional as F
from torch import optim, autograd, nn, Tensor
import torch
from torch.distributions import Normal

In [5]:
bs = 2000
data, test = get_paper_data(200000, bs=bs, n_test=1000000)

In [6]:
net = nn.Sequential(nn.Linear(3,100),  nn.ReLU(),
                    nn.Linear(100,100),nn.ReLU(),
                    nn.Linear(100,10), VariableSoftmax(0.1))
init_net(net)
model = ModelWrapper(net)

In [None]:
%%time
model.fit(200, data=data, opt=partialler(optim.SGD,lr=1e-6), loss=None,
          cbs=[PaperInferno(float_r=True, float_l=True, alpha_aux=[Normal(0,2), Normal(0,2)], float_b=True, b_aux=Normal(1000,100)),
               LossTracker(),SaveBest('weights/best_ie4.h5'),EarlyStopping(10)])

1: Train=1655.936137084961 Valid=971.9936413574219
2: Train=759.7977899169922 Valid=685.628374633789
3: Train=676.5568557739258 Valid=770.0616033935547
4: Train=616.5499044799805 Valid=612.3249523925781
5: Train=577.2624838256836 Valid=596.5664288330078
6: Train=560.8420623779297 Valid=565.7110998535156
7: Train=548.3731503295899 Valid=559.9101647949219
8: Train=514.2079681396484 Valid=558.5456134033203
9: Train=518.2867294311524 Valid=529.9420764160157
10: Train=511.2441650390625 Valid=524.2600451660156
11: Train=502.26902099609373 Valid=506.7028521728516
12: Train=483.35285614013674 Valid=489.73298706054686
13: Train=496.1157257080078 Valid=486.9799993896484
14: Train=464.79114227294923 Valid=471.14662841796877
15: Train=455.9293603515625 Valid=480.18162536621094
16: Train=447.44385986328126 Valid=457.1601885986328
17: Train=439.6057748413086 Valid=459.266572265625
18: Train=439.7900717163086 Valid=439.43915283203125
19: Train=434.8647869873047 Valid=435.0459729003906
20: Train=423.9

In [8]:
model.save('weights/Inferno_Test_exact_bm4.h5')

In [9]:
model.load('weights/Inferno_Test_exact_bm4.h5')

# Results

## BM 0

In [None]:
preds = model._predict_dl(test, pred_cb=InfernoPred())

In [None]:
df = pd.DataFrame({'pred':preds})
df['gen_target'] = test.dataset.y
df.head()

In [None]:
plot_preds(df, bin_edges=np.linspace(0,10,11))

In [None]:
bin_preds(df)

In [None]:
df.head()

In [None]:
f_s,f_b = get_shape(df,1),get_shape(df,0)

In [None]:
f_s.sum(), f_b.sum()

In [None]:
f_s, f_b

In [None]:
asimov = (50*f_s)+(1000*f_b)

In [None]:
asimov, asimov.sum()

In [None]:
n = 1050
x = np.linspace(20,80,61)
y = np.zeros_like(x)
for i,m in enumerate(x):
    pois = torch.distributions.Poisson((m*f_s)+(1000*f_b))
    y[i] = -pois.log_prob(asimov).sum()
y

In [None]:
y_tf2 = np.array([31.626238,31.466385,31.313095,31.166267,31.025808,30.891619,30.76361
,30.641693,30.525778,30.415783,30.31162,30.213215,30.120483,30.033348
,29.951736,29.875574,29.804789,29.739307,29.679066,29.623993,29.574026
,29.5291,29.489151,29.454117,29.423939,29.398558,29.377914,29.361954
,29.35062,29.343859,29.341618,29.343842,29.350483,29.36149,29.376812
,29.396404,29.420216,29.448202,29.480318,29.516518,29.556757,29.600994
,29.649185,29.70129,29.757267,29.817076,29.88068,29.948036,30.019108
,30.093859,30.17225,30.25425,30.339819,30.42892,30.521524,30.617598
,30.7171,30.820007,30.926281,31.035892,31.148808], dtype='float32')

In [None]:
y_tf2-y_tf2.min()

In [None]:
plot_likelihood(y-y.min())

In [None]:
plot_likelihood(y_tf2-y_tf2.min())

# Nuisances - via interpolation

In [None]:
bkg = test.dataset.x[test.dataset.y.squeeze() == 0]
assert len(bkg) == 500000

In [None]:
b_shapes = get_paper_syst_shapes(bkg, df, model=model, pred_cb=InfernoPred())

In [None]:
df

In [None]:
plot_preds(df, pred_names=['pred', 'pred_-0.2_3', 'pred_0.2_3', 'pred_0_2.5', 'pred_0_3.5'], bin_edges=np.linspace(0,10,11))

In [None]:
fig = plt.figure(figsize=(12,8))
for r in [-1,0,1]:
    for l in [-1,0,1]:
        alpha = Tensor((r,l))[None,:]
        s = interp_shape(alpha, **b_shapes).squeeze()
        print(s)
        plt.plot(s, label=f'{r} {l}')
plt.legend()

# Newton

In [None]:
profiler = partialler(calc_profile, n=1050, mu_scan=torch.linspace(20,80,61), true_mu=50)

## BM 1
r free, l fixed

In [None]:
bm1_b_shapes = OrderedDict([('f_b_nom', b_shapes['f_b_nom']),
                            ('f_b_up',  b_shapes['f_b_up'][0][None,:]),
                            ('f_b_dw',  b_shapes['f_b_dw'][0][None,:])])

In [None]:
bm1_b_shapes['f_b_up'].shape

In [None]:
nll = profiler(f_s=f_s, n_steps=100, **bm1_b_shapes)

In [None]:
nll = to_np(nll)

In [None]:
plot_likelihood(nll-nll.min())

## BM 1l
r fixed, l free

In [None]:
bm1l_b_shapes = OrderedDict([('f_b_nom', b_shapes['f_b_nom']),
                             ('f_b_up',  b_shapes['f_b_up'][1][None,:]),
                             ('f_b_dw',  b_shapes['f_b_dw'][1][None,:])])

In [None]:
nll = profiler(f_s=f_s, n_steps=100, **bm1l_b_shapes)

In [None]:
nll = to_np(nll)

In [None]:
plot_likelihood(nll-nll.min())

## BM 2

In [None]:
nll = profiler(f_s=f_s, n_steps=100, **b_shapes)

In [None]:
nll = to_np(nll)

In [None]:
plot_likelihood(nll-nll.min())

## BM 3

In [None]:
alpha_aux = [Normal(0,2), Normal(0,2)]

In [None]:
nll = profiler(f_s=f_s, n_steps=100, alpha_aux=alpha_aux, **b_shapes)

In [None]:
nll = to_np(nll)

In [None]:
plot_likelihood(nll-nll.min())

## BM 4

In [None]:
alpha_aux = [Normal(0,2), Normal(0,2)]

In [None]:
nll = profiler(f_s=f_s, n_steps=100, alpha_aux=alpha_aux, float_b=True, b_aux=Normal(1000,100), **b_shapes)

In [None]:
nll = to_np(nll)

In [None]:
plot_likelihood(nll-nll.min())