In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.distributions as D
from matplotlib.backends.backend_pdf import PdfPages
import os, sys
os.chdir("..")
from Source.Models.autoregGMM import AutoRegGMM
from Source.Models.autoregBinned import AutoRegBinned
from Source.Util.simulateToyData import ToySimulator
from Source.Util.util import load_params, get, get_device

  from .autonotebook import tqdm as notebook_tqdm


# Ramp

In [2]:
runpath = "runs/paper_ramp2/"
params = load_params(runpath + "paramfile.yaml")
params["device"] = get_device()

In [3]:
model = AutoRegGMM(params)
state_dict = torch.load(runpath+"models/model_run0.pt", map_location=params["device"])
model.load_state_dict(state_dict)

Model AutoRegGMM hyperparameters: n_head=4, n_per_head=15, n_blocks=4, intermediate_fac=4, n_gauss=20
Bayesianization hyperparameters: bayesian=3, prior_prec=1.0, iterations=50


<All keys matched successfully>

## Data

In [4]:
data = ToySimulator(params).data
data_split = params["data_split"]
n_data = len(data)
cut1 = int(n_data - data_split[0])
cut2 = int(n_data * (data_split[0] + data_split[1]))
data_train = data[:cut1]
data_test = data[cut2:]

In [5]:
n_samples = 100000
nBNN = 50
data_predict = np.zeros((0, data_train.shape[1]))
for i in range(nBNN):
    data_predict = np.append(data_predict, model.sample_n(n_samples), axis=0)



Sampling time estimate: 4.28 s = 0.07 min
Sampling time estimate: 2.41 s = 0.04 min
Sampling time estimate: 2.46 s = 0.04 min
Sampling time estimate: 2.88 s = 0.05 min
Sampling time estimate: 2.53 s = 0.04 min
Sampling time estimate: 2.74 s = 0.05 min
Sampling time estimate: 2.98 s = 0.05 min
Sampling time estimate: 3.42 s = 0.06 min
Sampling time estimate: 3.06 s = 0.05 min
Sampling time estimate: 3.55 s = 0.06 min
Sampling time estimate: 3.07 s = 0.05 min
Sampling time estimate: 3.56 s = 0.06 min
Sampling time estimate: 3.13 s = 0.05 min
Sampling time estimate: 4.49 s = 0.07 min
Sampling time estimate: 3.25 s = 0.05 min
Sampling time estimate: 2.99 s = 0.05 min
Sampling time estimate: 3.64 s = 0.06 min
Sampling time estimate: 3.15 s = 0.05 min
Sampling time estimate: 3.56 s = 0.06 min
Sampling time estimate: 3.03 s = 0.05 min
Sampling time estimate: 3.06 s = 0.05 min
Sampling time estimate: 3.11 s = 0.05 min
Sampling time estimate: 3.08 s = 0.05 min
Sampling time estimate: 3.48 s = 0

## Plot

In [6]:
def plot_paper(out, obs_train, obs_test, obs_predict, name, bins=60, weight_samples=1, 
               predict_weights=None, unit=None, range=None, ymaxAbs=1., ymaxRel=1.):
    with PdfPages(out) as pp:
        y_t,  bins = np.histogram(obs_test, bins=bins, range=range)
        y_tr, _ = np.histogram(obs_train, bins=bins)

        if weight_samples == 1:
            y_g,  _ = np.histogram(obs_predict, bins=bins, weights=predict_weights)
            hists = [y_t, y_g, y_tr]
            hist_errors = [np.sqrt(y_t), np.sqrt(y_g), np.sqrt(y_tr)]
        else:
            obs_predict = obs_predict.reshape(weight_samples,
                    len(obs_predict)//weight_samples)
            hist_weights = (weight_samples*[None] if predict_weights is None
                            else predict_weights.reshape(obs_predict.shape))
            hists_g = np.array([np.histogram(obs_predict[i,:], bins=bins,
                                             weights=hist_weights[i])[0]
                                for i in np.arange(weight_samples)])
            hists = [y_t, np.mean(hists_g, axis=0), y_tr]
            hist_errors = [np.sqrt(y_t), np.std(hists_g, axis=0), np.sqrt(y_tr)]
        integrals = [np.sum((bins[1:] - bins[:-1]) * y) for y in hists]
        scales = [1 / integral if integral != 0. else 1. for integral in integrals]
            
        FONTSIZE = 14
        labels = ["True", "Model", "Train"]
        colors = ["#e41a1c", "#3b528b", "#1a8507"]
        dup_last = lambda a: np.append(a, a[-1])

        fig1, axs = plt.subplots(3, 1, sharex=True,
                gridspec_kw={"height_ratios" : [4, 1, 1], "hspace" : 0.00})
        fig1.tight_layout(pad=0.0, w_pad=0.0, h_pad=0.0, rect=(0.07,0.06,0.99,0.95))
        
        for y, y_err, scale, label, color in zip(hists, hist_errors, scales,
                                            labels, colors):

            axs[0].step(bins, dup_last(y) * scale, label=label, color=color,
                    linewidth=1.0, where="post")
            axs[0].step(bins, dup_last(y + y_err) * scale, color=color,
                    alpha=0.5, linewidth=0.5, where="post")
            axs[0].step(bins, dup_last(y - y_err) * scale, color=color,
                    alpha=0.5, linewidth=0.5, where="post")
            axs[0].fill_between(bins, dup_last(y - y_err) * scale,
                    dup_last(y + y_err) * scale, facecolor=color,
                    alpha=0.3, step="post")

            if label == "True": continue

            ratio = (y * scale)/ (hists[0] * scales[0])
            ratio_err = np.sqrt((y_err / y)**2 + (hist_errors[0] / hists[0])**2)
            ratio_isnan = np.isnan(ratio)
            ratio[ratio_isnan] = 1.
            ratio_err[ratio_isnan] = 0.

            axs[1].step(bins, dup_last(ratio), linewidth=1.0, where="post", color=color)
            axs[1].step(bins, dup_last(ratio + ratio_err), color=color, alpha=0.5,
                    linewidth=0.5, where="post")
            axs[1].step(bins, dup_last(ratio - ratio_err), color=color, alpha=0.5,
                    linewidth=0.5, where="post")
            axs[1].fill_between(bins, dup_last(ratio - ratio_err),
                    dup_last(ratio + ratio_err), facecolor=color, alpha=0.3, step="post")

            delta = np.fabs(ratio - 1) * 100
            delta_err = ratio_err * 100

            markers, caps, bars = axs[2].errorbar((bins[:-1] + bins[1:])/2, delta,
                    yerr=delta_err, ecolor=color, color=color, elinewidth=0.5,
                    linewidth=0, fmt=".", capsize=2)
            [cap.set_alpha(0.5) for cap in caps]
            [bar.set_alpha(0.5) for bar in bars]


        axs[0].legend(loc="upper left", frameon=False, fontsize=FONTSIZE)
        axs[0].set_ylabel("Normalized", fontsize = FONTSIZE)

        axs[1].set_ylabel(r"$\frac{\mathrm{Model}}{\mathrm{True}}$",
                fontsize = FONTSIZE)
        axs[1].set_yticks([0.95,1,1.05])
        axs[1].set_ylim([0.9,1.1])
        axs[1].axhline(y=1, c="black", ls="--", lw=0.7)
        axs[1].axhline(y=1.2, c="black", ls="dotted", lw=0.5)
        axs[1].axhline(y=0.8, c="black", ls="dotted", lw=0.5)
        plt.xlabel(r"${%s}$ %s" % (name, ("" if unit is None else f"[{unit}]")),
                fontsize = FONTSIZE)

        axs[2].set_ylim((0.05,20))
        axs[2].set_yscale("log")
        axs[2].set_yticks([0.1, 1.0, 10.0])
        axs[2].set_yticklabels([r"$0.1$", r"$1.0$", "$10.0$"])
        axs[2].set_yticks([0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                            2., 3., 4., 5., 6., 7., 8., 9.], minor=True)

        axs[2].axhline(y=1.0,linewidth=0.5, linestyle="--", color="grey")
        axs[2].axhspan(0, 1.0, facecolor="#cccccc", alpha=0.3)
        axs[2].set_ylabel(r"$\delta [\%]$", fontsize = FONTSIZE)

        plt.savefig(pp, format="pdf")
        plt.close()
        
        fig2, axs = plt.subplots(1, 1)
        fig2.tight_layout(pad=0.0, w_pad=0.0, h_pad=0.0, rect=(0.07,0.06,0.99,0.95))
        
        axs.set_ylabel("Absolute uncertainty", fontsize = FONTSIZE)
        axs.set_xlabel(r"${%s}$ %s" % (name, ("" if unit is None else f"[{unit}]")),
                fontsize = FONTSIZE)
        
        axs.step(bins, dup_last(hist_errors[1] * scales[1]), color=colors[1])
        axs.set_ylim(0., ymaxAbs)
            
        plt.savefig(pp, format="pdf")
        plt.close()
        
        fig3, axs = plt.subplots(1, 1)
        fig3.tight_layout(pad=0.0, w_pad=0.0, h_pad=0.0, rect=(0.07,0.06,0.99,0.95))
        
        axs.set_ylabel("Relative uncertainty", fontsize = FONTSIZE)
        axs.set_xlabel(r"${%s}$ %s" % (name, ("" if unit is None else f"[{unit}]")),
                fontsize = FONTSIZE)
        
        axs.step(bins, dup_last(hist_errors[1] / hists[1]), color=colors[1])
        axs.set_ylim(0., ymaxRel)
            
        plt.savefig(pp, format="pdf")
        plt.close()

In [7]:
plot_paper("Scripts/paper/GMMramp.pdf", data_train[:,1], data_test[:,1], data_predict[:,1], 
           "x_1", weight_samples=nBNN, ymaxAbs=.1, ymaxRel=.1, range=[.1, .9])

## Plot with likelihoods

In [8]:
xmin, xmax = .1, .9
prec = 500

In [9]:
def getMargLikelihoodnDim(model, xmin, xmax, dim, prec=1000):
    xs = torch.linspace(xmin, xmax, prec)
    idx = model.n_jets * torch.ones(prec, 2, dtype=torch.int).float()
    idx[:,1] = xs
    
    mu, sigma, weights = model.net(idx) #indices: batch, component, gauss
    mix = D.Categorical(weights)
    comp = D.Normal(mu, sigma)
    gmm = D.MixtureSameFamily(mix, comp)
    
    probs = torch.zeros(prec, prec, dtype=torch.float) #likelihood array
    for ix in range(prec):
        base = torch.zeros(prec, 2)
        base[:,0] = xs[ix]
        base[:,1] = xs
        probs[ix,:] = torch.exp(torch.sum(gmm.log_prob(base), axis=-1))
    probs = probs.detach().cpu().numpy()
    xs = xs.detach().cpu().numpy()
    
    marg = np.trapz(probs, x=xs, axis=dim)
    norm = np.trapz(marg, x=xs)
    #print(f"Normalization: {norm:.5f}")
    marg /= norm
    
    return xs, marg

In [10]:
xs = np.zeros(prec)
margL = np.zeros((prec, nBNN))
#model.train_loader = DataLoader(dataset=data[:cut1], batch_size=params["batch_size"], shuffle=True)
for j in range(nBNN):
    if model.net.bayesian >= 1:
        model.net.map = get(model.params, "fix_mu", False)
        for i in range(model.net.n_blocks):
            model.net.transformer.h[i].mlp.c_fc.random = None
            model.net.transformer.h[i].mlp.c_proj.random = None
    if model.net.bayesian >= 2:
        for i in range(model.net.n_blocks):
            model.net.transformer.h[i].attn.c_attn.random = None
            model.net.transformer.h[i].attn.c_proj.random = None
    if model.net.bayesian >= 3:
        model.net.transformer.wte.random = None
        model.net.lm_head.random = None
    xs, margL[:,j] = getMargLikelihoodnDim(model, xmin, xmax, 0, prec=prec)

In [11]:
meanMargL = np.mean(margL, axis=1)
stdMargL = np.std(margL, axis=1)

In [12]:
def plot_paper_likelihoods(out, obs_train, obs_test, obs_predict, name, bins=60, weight_samples=1, 
               predict_weights=None, unit=None, range=None, ymaxAbs=1., ymaxRel=1.):
    with PdfPages(out) as pp:
        y_t,  bins = np.histogram(obs_test, bins=bins, range=range)
        y_tr, _ = np.histogram(obs_train, bins=bins)

        if weight_samples == 1:
            y_g,  _ = np.histogram(obs_predict, bins=bins, weights=predict_weights)
            hists = [y_t, y_g, y_tr]
            hist_errors = [np.sqrt(y_t), np.sqrt(y_g), np.sqrt(y_tr)]
        else:
            obs_predict = obs_predict.reshape(weight_samples,
                    len(obs_predict)//weight_samples)
            hist_weights = (weight_samples*[None] if predict_weights is None
                            else predict_weights.reshape(obs_predict.shape))
            hists_g = np.array([np.histogram(obs_predict[i,:], bins=bins,
                                             weights=hist_weights[i])[0]
                                for i in np.arange(weight_samples)])
            hists = [y_t, np.mean(hists_g, axis=0), y_tr]
            hist_errors = [np.sqrt(y_t), np.std(hists_g, axis=0), np.sqrt(y_tr)]
        integrals = [np.sum((bins[1:] - bins[:-1]) * y) for y in hists]
        scales = [1 / integral if integral != 0. else 1. for integral in integrals]
            
        FONTSIZE = 14
        labels = ["True", "Model", "Train"]
        colors = ["#e41a1c", "#3b528b", "#1a8507"]
        dup_last = lambda a: np.append(a, a[-1])

        fig1, axs = plt.subplots(3, 1, sharex=True,
                gridspec_kw={"height_ratios" : [4, 1, 1], "hspace" : 0.00})
        fig1.tight_layout(pad=0.0, w_pad=0.0, h_pad=0.0, rect=(0.07,0.06,0.99,0.95))
        
        for y, y_err, scale, label, color in zip(hists, hist_errors, scales,
                                            labels, colors):

            axs[0].step(bins, dup_last(y) * scale, label=label, color=color,
                    linewidth=1.0, where="post")
            axs[0].step(bins, dup_last(y + y_err) * scale, color=color,
                    alpha=0.5, linewidth=0.5, where="post")
            axs[0].step(bins, dup_last(y - y_err) * scale, color=color,
                    alpha=0.5, linewidth=0.5, where="post")
            axs[0].fill_between(bins, dup_last(y - y_err) * scale,
                    dup_last(y + y_err) * scale, facecolor=color,
                    alpha=0.3, step="post")

            if label == "True": continue

            ratio = (y * scale)/ (hists[0] * scales[0])
            ratio_err = np.sqrt((y_err / y)**2 + (hist_errors[0] / hists[0])**2)
            ratio_isnan = np.isnan(ratio)
            ratio[ratio_isnan] = 1.
            ratio_err[ratio_isnan] = 0.

            axs[1].step(bins, dup_last(ratio), linewidth=1.0, where="post", color=color)
            axs[1].step(bins, dup_last(ratio + ratio_err), color=color, alpha=0.5,
                    linewidth=0.5, where="post")
            axs[1].step(bins, dup_last(ratio - ratio_err), color=color, alpha=0.5,
                    linewidth=0.5, where="post")
            axs[1].fill_between(bins, dup_last(ratio - ratio_err),
                    dup_last(ratio + ratio_err), facecolor=color, alpha=0.3, step="post")

            delta = np.fabs(ratio - 1) * 100
            delta_err = ratio_err * 100

            markers, caps, bars = axs[2].errorbar((bins[:-1] + bins[1:])/2, delta,
                    yerr=delta_err, ecolor=color, color=color, elinewidth=0.5,
                    linewidth=0, fmt=".", capsize=2)
            [cap.set_alpha(0.5) for cap in caps]
            [bar.set_alpha(0.5) for bar in bars]


        axs[0].fill_between(xs, meanMargL + stdMargL, meanMargL - stdMargL,
                           facecolor=colors[1], alpha=.3)
        axs[0].legend(loc="upper left", frameon=False, fontsize=FONTSIZE)
        axs[0].set_ylabel("Normalized", fontsize = FONTSIZE)

        axs[1].set_ylabel(r"$\frac{\mathrm{Model}}{\mathrm{True}}$",
                fontsize = FONTSIZE)
        axs[1].set_yticks([0.95,1,1.05])
        axs[1].set_ylim([0.9,1.1])
        axs[1].axhline(y=1, c="black", ls="--", lw=0.7)
        axs[1].axhline(y=1.2, c="black", ls="dotted", lw=0.5)
        axs[1].axhline(y=0.8, c="black", ls="dotted", lw=0.5)
        plt.xlabel(r"${%s}$ %s" % (name, ("" if unit is None else f"[{unit}]")),
                fontsize = FONTSIZE)

        axs[2].set_ylim((0.05,20))
        axs[2].set_yscale("log")
        axs[2].set_yticks([0.1, 1.0, 10.0])
        axs[2].set_yticklabels([r"$0.1$", r"$1.0$", "$10.0$"])
        axs[2].set_yticks([0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                            2., 3., 4., 5., 6., 7., 8., 9.], minor=True)

        axs[2].axhline(y=1.0,linewidth=0.5, linestyle="--", color="grey")
        axs[2].axhspan(0, 1.0, facecolor="#cccccc", alpha=0.3)
        axs[2].set_ylabel(r"$\delta [\%]$", fontsize = FONTSIZE)

        plt.savefig(pp, format="pdf")
        plt.close()
        
        fig2, axs = plt.subplots(1, 1)
        fig2.tight_layout(pad=0.0, w_pad=0.0, h_pad=0.0, rect=(0.07,0.06,0.99,0.95))
        
        axs.set_ylabel("Absolute uncertainty", fontsize = FONTSIZE)
        axs.set_xlabel(r"${%s}$ %s" % (name, ("" if unit is None else f"[{unit}]")),
                fontsize = FONTSIZE)
        
        axs.step(bins, dup_last(hist_errors[1] * scales[1]), color=colors[1])
        axs.plot(xs, stdMargL, color=colors[1])
        axs.set_ylim(0., ymaxAbs)
            
        plt.savefig(pp, format="pdf")
        plt.close()
        
        fig3, axs = plt.subplots(1, 1)
        fig3.tight_layout(pad=0.0, w_pad=0.0, h_pad=0.0, rect=(0.07,0.06,0.99,0.95))
        
        axs.set_ylabel("Relative uncertainty", fontsize = FONTSIZE)
        axs.set_xlabel(r"${%s}$ %s" % (name, ("" if unit is None else f"[{unit}]")),
                fontsize = FONTSIZE)
        
        axs.step(bins, dup_last(hist_errors[1] / hists[1]), color=colors[1])
        axs.plot(xs, stdMargL / meanMargL, color=colors[1])
        axs.set_ylim(0., ymaxRel)
            
        plt.savefig(pp, format="pdf")
        plt.close()

In [13]:
plot_paper_likelihoods("Scripts/paper/GMMramp_likelihoods.pdf", data_train[:,1], 
            data_test[:,1], data_predict[:,1], "x_1", weight_samples=nBNN, 
            ymaxAbs=.1, ymaxRel=.1, range=[.1, .9])

# Sphere

In [51]:
runpath = "runs/paper_sphere2/"
params = load_params(runpath + "paramfile.yaml")
params["device"] = get_device()

In [52]:
model = AutoRegGMM(params)
state_dict = torch.load(runpath+"models/model_run0.pt", map_location=params["device"])
model.load_state_dict(state_dict)

Model AutoRegGMM hyperparameters: n_head=4, n_per_head=15, n_blocks=4, intermediate_fac=4, n_gauss=20
Bayesianization hyperparameters: bayesian=3, prior_prec=1.0, iterations=50


<All keys matched successfully>

## Data

In [53]:
data = ToySimulator(params).data
data_split = params["data_split"]
n_data = len(data)
cut1 = int(n_data - data_split[0])
cut2 = int(n_data * (data_split[0] + data_split[1]))
data_train = data[:cut1]
data_test = data[cut2:]

In [54]:
n_samples = 100000
nBNN = 50
data_predict = np.zeros((0, data_train.shape[1]))
for i in range(nBNN):
    data_predict = np.append(data_predict, model.sample_n(n_samples), axis=0)



Sampling time estimate: 3.22 s = 0.05 min
Sampling time estimate: 3.11 s = 0.05 min
Sampling time estimate: 2.97 s = 0.05 min
Sampling time estimate: 3.40 s = 0.06 min
Sampling time estimate: 3.03 s = 0.05 min
Sampling time estimate: 2.46 s = 0.04 min
Sampling time estimate: 2.64 s = 0.04 min
Sampling time estimate: 3.39 s = 0.06 min
Sampling time estimate: 3.06 s = 0.05 min
Sampling time estimate: 3.28 s = 0.05 min
Sampling time estimate: 3.64 s = 0.06 min
Sampling time estimate: 2.99 s = 0.05 min
Sampling time estimate: 3.10 s = 0.05 min
Sampling time estimate: 3.14 s = 0.05 min
Sampling time estimate: 3.36 s = 0.06 min
Sampling time estimate: 3.22 s = 0.05 min
Sampling time estimate: 3.56 s = 0.06 min
Sampling time estimate: 2.87 s = 0.05 min
Sampling time estimate: 4.02 s = 0.07 min
Sampling time estimate: 4.02 s = 0.07 min
Sampling time estimate: 3.10 s = 0.05 min
Sampling time estimate: 3.99 s = 0.07 min
Sampling time estimate: 2.95 s = 0.05 min
Sampling time estimate: 3.40 s = 0

## Plot

In [66]:
def plot_paper(out, obs_train, obs_test, obs_predict, name, bins=60, weight_samples=1, 
               predict_weights=None, unit=None, range=None, ymaxAbs=1., ymaxRel=1.):
    with PdfPages(out) as pp:
        y_t,  bins = np.histogram(obs_test, bins=bins, range=range)
        y_tr, _ = np.histogram(obs_train, bins=bins)

        if weight_samples == 1:
            y_g,  _ = np.histogram(obs_predict, bins=bins, weights=predict_weights)
            hists = [y_t, y_g, y_tr]
            hist_errors = [np.sqrt(y_t), np.sqrt(y_g), np.sqrt(y_tr)]
        else:
            obs_predict = obs_predict.reshape(weight_samples,
                    len(obs_predict)//weight_samples)
            hist_weights = (weight_samples*[None] if predict_weights is None
                            else predict_weights.reshape(obs_predict.shape))
            hists_g = np.array([np.histogram(obs_predict[i,:], bins=bins,
                                             weights=hist_weights[i])[0]
                                for i in np.arange(weight_samples)])
            hists = [y_t, np.mean(hists_g, axis=0), y_tr]
            hist_errors = [np.sqrt(y_t), np.std(hists_g, axis=0), np.sqrt(y_tr)]
        integrals = [np.sum((bins[1:] - bins[:-1]) * y) for y in hists]
        scales = [1 / integral if integral != 0. else 1. for integral in integrals]
            
        FONTSIZE = 14
        labels = ["True", "Model", "Train"]
        colors = ["#e41a1c", "#3b528b", "#1a8507"]
        dup_last = lambda a: np.append(a, a[-1])

        fig1, axs = plt.subplots(3, 1, sharex=True,
                gridspec_kw={"height_ratios" : [4, 1, 1], "hspace" : 0.00})
        fig1.tight_layout(pad=0.0, w_pad=0.0, h_pad=0.0, rect=(0.07,0.06,0.99,0.95))
        
        for y, y_err, scale, label, color in zip(hists, hist_errors, scales,
                                            labels, colors):

            axs[0].step(bins, dup_last(y) * scale, label=label, color=color,
                    linewidth=1.0, where="post")
            axs[0].step(bins, dup_last(y + y_err) * scale, color=color,
                    alpha=0.5, linewidth=0.5, where="post")
            axs[0].step(bins, dup_last(y - y_err) * scale, color=color,
                    alpha=0.5, linewidth=0.5, where="post")
            axs[0].fill_between(bins, dup_last(y - y_err) * scale,
                    dup_last(y + y_err) * scale, facecolor=color,
                    alpha=0.3, step="post")

            if label == "True": continue

            ratio = (y * scale)/ (hists[0] * scales[0])
            ratio_err = np.sqrt((y_err / y)**2 + (hist_errors[0] / hists[0])**2)
            ratio_isnan = np.isnan(ratio)
            ratio[ratio_isnan] = 1.
            ratio_err[ratio_isnan] = 0.

            axs[1].step(bins, dup_last(ratio), linewidth=1.0, where="post", color=color)
            axs[1].step(bins, dup_last(ratio + ratio_err), color=color, alpha=0.5,
                    linewidth=0.5, where="post")
            axs[1].step(bins, dup_last(ratio - ratio_err), color=color, alpha=0.5,
                    linewidth=0.5, where="post")
            axs[1].fill_between(bins, dup_last(ratio - ratio_err),
                    dup_last(ratio + ratio_err), facecolor=color, alpha=0.3, step="post")

            delta = np.fabs(ratio - 1) * 100
            delta_err = ratio_err * 100

            markers, caps, bars = axs[2].errorbar((bins[:-1] + bins[1:])/2, delta,
                    yerr=delta_err, ecolor=color, color=color, elinewidth=0.5,
                    linewidth=0, fmt=".", capsize=2)
            [cap.set_alpha(0.5) for cap in caps]
            [bar.set_alpha(0.5) for bar in bars]


        axs[0].legend(loc="upper left", frameon=False, fontsize=FONTSIZE)
        axs[0].set_ylabel("Normalized", fontsize = FONTSIZE)

        axs[1].set_ylabel(r"$\frac{\mathrm{Model}}{\mathrm{True}}$",
                fontsize = FONTSIZE)
        axs[1].set_yticks([0.95,1,1.05])
        axs[1].set_ylim([0.9,1.1])
        axs[1].axhline(y=1, c="black", ls="--", lw=0.7)
        axs[1].axhline(y=1.2, c="black", ls="dotted", lw=0.5)
        axs[1].axhline(y=0.8, c="black", ls="dotted", lw=0.5)
        plt.xlabel(r"${%s}$ %s" % (name, ("" if unit is None else f"[{unit}]")),
                fontsize = FONTSIZE)

        axs[2].set_ylim((0.05,20))
        axs[2].set_yscale("log")
        axs[2].set_yticks([0.1, 1.0, 10.0])
        axs[2].set_yticklabels([r"$0.1$", r"$1.0$", "$10.0$"])
        axs[2].set_yticks([0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                            2., 3., 4., 5., 6., 7., 8., 9.], minor=True)

        axs[2].axhline(y=1.0,linewidth=0.5, linestyle="--", color="grey")
        axs[2].axhspan(0, 1.0, facecolor="#cccccc", alpha=0.3)
        axs[2].set_ylabel(r"$\delta [\%]$", fontsize = FONTSIZE)

        plt.savefig(pp, format="pdf")
        plt.close()
        
        fig2, axs = plt.subplots(1, 1)
        fig2.tight_layout(pad=0.0, w_pad=0.0, h_pad=0.0, rect=(0.07,0.06,0.99,0.95))
        
        axs.set_ylabel("Absolute uncertainty", fontsize = FONTSIZE)
        axs.set_xlabel(r"${%s}$ %s" % (name, ("" if unit is None else f"[{unit}]")),
                fontsize = FONTSIZE)
        
        axs.step(bins, dup_last(hist_errors[1] * scales[1]), color=colors[1])
        axs.set_ylim(0., ymaxAbs)
            
        plt.savefig(pp, format="pdf")
        plt.close()
        
        fig3, axs = plt.subplots(1, 1)
        fig3.tight_layout(pad=0.0, w_pad=0.0, h_pad=0.0, rect=(0.07,0.06,0.99,0.95))
        
        axs.set_ylabel("Relative uncertainty", fontsize = FONTSIZE)
        axs.set_xlabel(r"${%s}$ %s" % (name, ("" if unit is None else f"[{unit}]")),
                fontsize = FONTSIZE)
        
        axs.step(bins, dup_last(hist_errors[1] / hists[1]), color=colors[1])
        axs.set_ylim(0., ymaxRel)
            
        plt.savefig(pp, format="pdf")
        plt.close()

In [None]:
R_train, _ = ToySimulator.getSpherical(data_train)
R_test, _ = ToySimulator.getSpherical(data_test)
R_predict, _= ToySimulator.getSpherical(data_predict)
plot_paper("Scripts/paper/GMMsphere.pdf", R_train, R_test, R_predict, 
           "x_1", weight_samples=nBNN, ymaxAbs=.03, ymaxRel=.1, range=[-1.5, 1.5])

## Plot with likelihoods

In [58]:
xmin, xmax = -1.5, 1.5
prec = 500

In [59]:
def getMargLikelihoodnDim(model, xmin, xmax, dim, prec=1000):
    xs = torch.linspace(xmin, xmax, prec)
    idx = model.n_jets * torch.ones(prec, 2, dtype=torch.int).float()
    idx[:,1] = xs
    
    mu, sigma, weights = model.net(idx) #indices: batch, component, gauss
    mix = D.Categorical(weights)
    comp = D.Normal(mu, sigma)
    gmm = D.MixtureSameFamily(mix, comp)
    
    probs = torch.zeros(prec, prec, dtype=torch.float) #likelihood array
    for ix in range(prec):
        base = torch.zeros(prec, 2)
        base[:,0] = xs[ix]
        base[:,1] = xs
        probs[ix,:] = torch.exp(torch.sum(gmm.log_prob(base), axis=-1))
    probs = probs.detach().cpu().numpy()
    xs = xs.detach().cpu().numpy()
    
    marg = np.trapz(probs, x=xs, axis=dim)
    norm = np.trapz(marg, x=xs)
    #print(f"Normalization: {norm:.5f}")
    marg /= norm
    
    return xs, marg

In [60]:
xs = np.zeros(prec)
margL = np.zeros((prec, nBNN))
#model.train_loader = DataLoader(dataset=data[:cut1], batch_size=params["batch_size"], shuffle=True)
for j in range(nBNN):
    if model.net.bayesian >= 1:
        model.net.map = get(model.params, "fix_mu", False)
        for i in range(model.net.n_blocks):
            model.net.transformer.h[i].mlp.c_fc.random = None
            model.net.transformer.h[i].mlp.c_proj.random = None
    if model.net.bayesian >= 2:
        for i in range(model.net.n_blocks):
            model.net.transformer.h[i].attn.c_attn.random = None
            model.net.transformer.h[i].attn.c_proj.random = None
    if model.net.bayesian >= 3:
        model.net.transformer.wte.random = None
        model.net.lm_head.random = None
    xs, margL[:,j] = getMargLikelihoodnDim(model, xmin, xmax, 0, prec=prec)

In [61]:
meanMargL = np.mean(margL, axis=1)
stdMargL = np.std(margL, axis=1)

In [64]:
def plot_paper_likelihoods(out, obs_train, obs_test, obs_predict, name, bins=60, weight_samples=1, 
               predict_weights=None, unit=None, range=None, ymaxAbs=1., ymaxRel=1.):
    with PdfPages(out) as pp:
        y_t,  bins = np.histogram(obs_test, bins=bins, range=range)
        y_tr, _ = np.histogram(obs_train, bins=bins)

        if weight_samples == 1:
            y_g,  _ = np.histogram(obs_predict, bins=bins, weights=predict_weights)
            hists = [y_t, y_g, y_tr]
            hist_errors = [np.sqrt(y_t), np.sqrt(y_g), np.sqrt(y_tr)]
        else:
            obs_predict = obs_predict.reshape(weight_samples,
                    len(obs_predict)//weight_samples)
            hist_weights = (weight_samples*[None] if predict_weights is None
                            else predict_weights.reshape(obs_predict.shape))
            hists_g = np.array([np.histogram(obs_predict[i,:], bins=bins,
                                             weights=hist_weights[i])[0]
                                for i in np.arange(weight_samples)])
            hists = [y_t, np.mean(hists_g, axis=0), y_tr]
            hist_errors = [np.sqrt(y_t), np.std(hists_g, axis=0), np.sqrt(y_tr)]
        integrals = [np.sum((bins[1:] - bins[:-1]) * y) for y in hists]
        scales = [1 / integral if integral != 0. else 1. for integral in integrals]
            
        FONTSIZE = 14
        labels = ["True", "Model", "Train"]
        colors = ["#e41a1c", "#3b528b", "#1a8507"]
        dup_last = lambda a: np.append(a, a[-1])

        fig1, axs = plt.subplots(3, 1, sharex=True,
                gridspec_kw={"height_ratios" : [4, 1, 1], "hspace" : 0.00})
        fig1.tight_layout(pad=0.0, w_pad=0.0, h_pad=0.0, rect=(0.07,0.06,0.99,0.95))
        
        for y, y_err, scale, label, color in zip(hists, hist_errors, scales,
                                            labels, colors):

            axs[0].step(bins, dup_last(y) * scale, label=label, color=color,
                    linewidth=1.0, where="post")
            axs[0].step(bins, dup_last(y + y_err) * scale, color=color,
                    alpha=0.5, linewidth=0.5, where="post")
            axs[0].step(bins, dup_last(y - y_err) * scale, color=color,
                    alpha=0.5, linewidth=0.5, where="post")
            axs[0].fill_between(bins, dup_last(y - y_err) * scale,
                    dup_last(y + y_err) * scale, facecolor=color,
                    alpha=0.3, step="post")

            if label == "True": continue

            ratio = (y * scale)/ (hists[0] * scales[0])
            ratio_err = np.sqrt((y_err / y)**2 + (hist_errors[0] / hists[0])**2)
            ratio_isnan = np.isnan(ratio)
            ratio[ratio_isnan] = 1.
            ratio_err[ratio_isnan] = 0.

            axs[1].step(bins, dup_last(ratio), linewidth=1.0, where="post", color=color)
            axs[1].step(bins, dup_last(ratio + ratio_err), color=color, alpha=0.5,
                    linewidth=0.5, where="post")
            axs[1].step(bins, dup_last(ratio - ratio_err), color=color, alpha=0.5,
                    linewidth=0.5, where="post")
            axs[1].fill_between(bins, dup_last(ratio - ratio_err),
                    dup_last(ratio + ratio_err), facecolor=color, alpha=0.3, step="post")

            delta = np.fabs(ratio - 1) * 100
            delta_err = ratio_err * 100

            markers, caps, bars = axs[2].errorbar((bins[:-1] + bins[1:])/2, delta,
                    yerr=delta_err, ecolor=color, color=color, elinewidth=0.5,
                    linewidth=0, fmt=".", capsize=2)
            [cap.set_alpha(0.5) for cap in caps]
            [bar.set_alpha(0.5) for bar in bars]


        axs[0].fill_between(xs, meanMargL + stdMargL, meanMargL - stdMargL,
                           facecolor=colors[1], alpha=.3)
        axs[0].legend(loc="upper left", frameon=False, fontsize=FONTSIZE)
        axs[0].set_ylabel("Normalized", fontsize = FONTSIZE)

        axs[1].set_ylabel(r"$\frac{\mathrm{Model}}{\mathrm{True}}$",
                fontsize = FONTSIZE)
        axs[1].set_yticks([0.95,1,1.05])
        axs[1].set_ylim([0.9,1.1])
        axs[1].axhline(y=1, c="black", ls="--", lw=0.7)
        axs[1].axhline(y=1.2, c="black", ls="dotted", lw=0.5)
        axs[1].axhline(y=0.8, c="black", ls="dotted", lw=0.5)
        plt.xlabel(r"${%s}$ %s" % (name, ("" if unit is None else f"[{unit}]")),
                fontsize = FONTSIZE)

        axs[2].set_ylim((0.05,20))
        axs[2].set_yscale("log")
        axs[2].set_yticks([0.1, 1.0, 10.0])
        axs[2].set_yticklabels([r"$0.1$", r"$1.0$", "$10.0$"])
        axs[2].set_yticks([0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                            2., 3., 4., 5., 6., 7., 8., 9.], minor=True)

        axs[2].axhline(y=1.0,linewidth=0.5, linestyle="--", color="grey")
        axs[2].axhspan(0, 1.0, facecolor="#cccccc", alpha=0.3)
        axs[2].set_ylabel(r"$\delta [\%]$", fontsize = FONTSIZE)

        plt.savefig(pp, format="pdf")
        plt.close()
        
        fig2, axs = plt.subplots(1, 1)
        fig2.tight_layout(pad=0.0, w_pad=0.0, h_pad=0.0, rect=(0.07,0.06,0.99,0.95))
        
        axs.set_ylabel("Absolute uncertainty", fontsize = FONTSIZE)
        axs.set_xlabel(r"${%s}$ %s" % (name, ("" if unit is None else f"[{unit}]")),
                fontsize = FONTSIZE)
        
        axs.step(bins, dup_last(hist_errors[1] * scales[1]), color=colors[1])
        axs.plot(xs, stdMargL, color=colors[1])
        axs.set_ylim(0., ymaxAbs)
            
        plt.savefig(pp, format="pdf")
        plt.close()
        
        fig3, axs = plt.subplots(1, 1)
        fig3.tight_layout(pad=0.0, w_pad=0.0, h_pad=0.0, rect=(0.07,0.06,0.99,0.95))
        
        axs.set_ylabel("Relative uncertainty", fontsize = FONTSIZE)
        axs.set_xlabel(r"${%s}$ %s" % (name, ("" if unit is None else f"[{unit}]")),
                fontsize = FONTSIZE)
        
        axs.step(bins, dup_last(hist_errors[1] / hists[1]), color=colors[1])
        axs.plot(xs, stdMargL / meanMargL, color=colors[1])
        axs.set_ylim(0., ymaxRel)
            
        plt.savefig(pp, format="pdf")
        plt.close()

In [65]:
plot_paper_likelihoods("Scripts/paper/GMMsphere_likelihoods.pdf", data_train[:,1], 
            data_test[:,1], data_predict[:,1], "x_1", weight_samples=nBNN, 
            ymaxAbs=.1, ymaxRel=.1, range=[-1.5, 1.5])

  ratio = (y * scale)/ (hists[0] * scales[0])
  ratio_err = np.sqrt((y_err / y)**2 + (hist_errors[0] / hists[0])**2)
  ratio = (y * scale)/ (hists[0] * scales[0])
