In [None]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from ipywidgets import interact
import seaborn as sns

# Plot between -10 and 10 with .001 steps.
x_axis = np.arange(-10, 15, 0.001)

# define the bernoulli distribution that models the prior hidden glucose state
# this bernoulli distribution will be considered as the generative process for the hidden states
p_1_process = 0.7
p_1_model = 0.7
# it is also the "good" prior for the generative model, since it coincides perfectly with the generative process
# p_1 is the prior probability of being in the hyperglycemic state
#p_1 = p_1_process
#p_1_model = p_1

# sample size for simulation part
N = 100
#new_sample_1 = []
#samples_1 = []
samples_2 = []
new_sample_2 = []

####
samples_likelihood = []
samples_free_energy = []

samples = []
#samples_free_energy = []

# looping over N to generate the samples and calculate the surprises and free energies for the observations
for i in range(N):
    new_sample = np.random.choice(x_axis, size=1, p=(0.7*norm.pdf(x_axis, 1, 2)+(1-0.7)*norm.pdf(x_axis, 2, 2)) / sum((0.7*norm.pdf(x_axis, 1, 2)+(1-0.7)*norm.pdf(x_axis, 2, 2))))
    samples = np.append(samples, new_sample)
    samples_likelihood = np.append(samples_likelihood, -np.log(0.7*norm.pdf(new_sample, 1, 2)+(1-0.7)*norm.pdf(new_sample, 2, 2)))
    samples_free_energy = np.append(samples_free_energy, np.log((1-(p_1_model*norm.pdf(new_sample, 1, 2))/(p_1_model*norm.pdf(new_sample, 1, 2)+(1-p_1_model)*norm.pdf(new_sample, 1, 2)))/(1-(p_1_process*norm.pdf(new_sample, 1, 2))/(p_1_process*norm.pdf(new_sample, 1, 2)+(1-p_1_process)*norm.pdf(new_sample, 2, 2))))+(p_1_model*norm.pdf(new_sample, 1, 2))/(p_1_model*norm.pdf(new_sample, 1, 2)+(1-p_1_model)*norm.pdf(new_sample, 1, 2))*np.log(((p_1_model*norm.pdf(new_sample, 1, 2))/(p_1_model*norm.pdf(new_sample, 1, 2)+(1-p_1_model)*norm.pdf(new_sample, 1, 2))*(1-(p_1_model*norm.pdf(new_sample, 1, 2))/(p_1_model*norm.pdf(new_sample, 1, 2)+(1-p_1_model)*norm.pdf(new_sample, 2, 2))))/((p_1_process*norm.pdf(new_sample, 1, 2))/(p_1_process*norm.pdf(new_sample, 1, 2)+(1-p_1_process)*norm.pdf(new_sample, 2, 2))*(1-(p_1_model*norm.pdf(new_sample, 1, 2))/(p_1_model*norm.pdf(new_sample, 1, 2)+(1-p_1_model)*norm.pdf(new_sample, 1, 2)))))+(p_1_model*norm.pdf(new_sample, 1, 2))/(p_1_model*norm.pdf(new_sample, 1, 2)+(1-p_1_model)*norm.pdf(new_sample, 1, 2)))
####

fig, ((ax1, ax2, ax3, ax4, ax5, ax6, ax13, ax14), (ax7, ax8, ax9, ax10, ax11, ax12, ax15, ax16)) = plt.subplots(2, 8, figsize=(32, 8))
#fig, ((ax1, ax2, ax3, ax4, ax5, ax6), (ax7, ax8, ax9, ax10, ax11, ax12)) = plt.subplots(2, 6, figsize=(32, 8))
fig.tight_layout()
fig.suptitle('Evaluating Generative Model depending on the Generative Process')

# first subplot
ax1.set_xlabel("Firing Rate [Hz]")
ax1.set_ylabel("Process Likelihood f(o|r)")
line_1_1, = ax1.plot(x_axis, norm.pdf(x_axis, 0, 2), color = "b")
line_1_2, = ax1.plot(x_axis, norm.pdf(x_axis, 2, 2), color = "r")

# second subplot
ax2.set_xlabel("0: hypoglycemic, 1: hyperglycemic")
ax2.set_ylabel("Process Prior of Glycemic State p(s|r)")
ax2.set_ylim((0, 1))
ax2.set_xticks((0.0,1.0))
line_2_1, = ax2.plot(0, (1-p_1_process), "ro", ms=10, mec = "r")
line_2_2, = ax2.plot(1, p_1_process, "bo", ms=10, mec="blue")
line_2_3 = ax2.axvline(0, ymin = 0, ymax = (1-p_1_process), color = "r", lw=5)
line_2_4 = ax2.axvline(1, ymin = 0, ymax = p_1_process, color = "blue", lw=5)

#third subplot
ax3.set_xlabel("Firing Rate [Hz]")
ax3.set_ylabel("Conditional density f(o|s,r)")
ax3.set_ylim((0, 0.2))
line_3_1, = ax3.plot(x_axis, p_1_process*norm.pdf(x_axis, 0, 2), color = "b")
line_3_2, = ax3.plot(x_axis, (1-p_1_process)*norm.pdf(x_axis, 2, 2), color = "r")

# fourth subplot
ax4.set_ylim((0, 0.4))
ax4.set_xlabel("Firing Rate [Hz]")
ax4.set_ylabel("Marginal density f(o|r)")
line_4_1, = ax4.plot(x_axis,(p_1_process*norm.pdf(x_axis, 1, 2)+(1-p_1_process)*norm.pdf(x_axis, 2, 2)), color = "black")

# fifth subplot
# now we define the analytical posterior distributions obtained by using Bayes's Theorem
# note that this is not the way how the approximate posteriors are usually obtained in Active Inference
# here, we have the special case that we are using the "correct" model structure, only comparing different parameter values
ax5.set_xlabel("Firing Rate [Hz]")
ax5.set_ylabel("Process Posterior of Glycemic State p(s|o,r)")
line_5_1, = ax5.plot(x_axis,(p_1_process*norm.pdf(x_axis, 1, 2))/(p_1_process*norm.pdf(x_axis, 1, 2)+(1-p_1_process)*norm.pdf(x_axis, 2, 2)), color = "b")
line_5_2, = ax5.plot(x_axis,(1-(p_1_process*norm.pdf(x_axis, 1, 2))/(p_1_process*norm.pdf(x_axis, 1, 2)+(1-p_1_process)*norm.pdf(x_axis, 2, 2))), color = "r")

# sixth subplot
ax6.set_xlabel("Firing Rate [Hz]")
ax6.set_ylabel("Surprise of Process [nats]")
line_6_1, = ax6.plot(x_axis,(-np.log(p_1_process*norm.pdf(x_axis, 1, 2)+(1-p_1_process)*norm.pdf(x_axis, 2, 2))), color = "black")

# and now the generative model
# first subplot
ax7.set_xlabel("Firing Rate [Hz]")
ax7.set_ylabel("Model Likelihood f(o|m)")
line_7_1, = ax7.plot(x_axis, norm.pdf(x_axis, 0, 2), color = "b")
line_7_2, = ax7.plot(x_axis, norm.pdf(x_axis, 2, 2), color = "r")

# second subplot
ax8.set_xlabel("0: hypoglycemic, 1: hyperglycemic")
ax8.set_ylabel("Model Prior of Glycemic State p(s|m)")
ax8.set_ylim((0, 1))
ax8.set_xticks((0.0,1.0))
line_8_1, = ax8.plot(0, (1-p_1_model), "ro", ms=10, mec = "red")
line_8_2, = ax8.plot(1, p_1_model, "bo", ms=10, mec="blue")
line_8_3 = ax8.axvline(0, ymin = 0, ymax = (1-p_1_model), color = "r", lw=5)
line_8_4 = ax8.axvline(1, ymin = 0, ymax = p_1_model, color = "blue", lw=5)

#third subplot
ax9.set_xlabel("Firing Rate [Hz]")
ax9.set_ylabel("Model Conditional density f(o|s,m)")
ax9.set_ylim((0, 0.2))
line_9_1, = ax9.plot(x_axis, p_1_model*norm.pdf(x_axis, 0, 2), color = "b")
line_9_2, = ax9.plot(x_axis, (1-p_1_model)*norm.pdf(x_axis, 2, 2), color = "r")

# fourth subplot
ax10.set_ylim((0, 0.4))
ax10.set_xlabel("Firing Rate [Hz]")
ax10.set_ylabel("Marginal density f(o|m)")
line_10_1, = ax10.plot(x_axis,(p_1_model*norm.pdf(x_axis, 1, 2)+(1-p_1_model)*norm.pdf(x_axis, 2, 2)), color = "black")

# fifth subplot
# now we define the analytical posterior distributions obtained by using Bayes's Theorem
# note that this is not the way how the approximate posteriors are usually obtained in Active Inference
# here, we have the special case that we are using the "correct" model structure, only comparing different parameter values
ax11.set_xlabel("Firing Rate [Hz]")
ax11.set_ylabel("Model Posterior of Glycemic State p(s|o,m)")
line_11_1, = ax11.plot(x_axis,(p_1_model*norm.pdf(x_axis, 1, 2))/(p_1_model*norm.pdf(x_axis, 1, 2)+(1-p_1_model)*norm.pdf(x_axis, 2, 2)), color = "b")
line_11_2, = ax11.plot(x_axis,(1-(p_1_model*norm.pdf(x_axis, 1, 2))/(p_1_model*norm.pdf(x_axis, 1, 2)+(1-p_1_model)*norm.pdf(x_axis, 2, 2))), color = "r")

# sixth subplot
ax12.set_xlabel("Firing Rate [Hz]")
ax12.set_ylabel("Surprise of Model [nats]")
line_12_1, = ax12.plot(x_axis,(-np.log(p_1_model*norm.pdf(x_axis, 1, 2)+(1-p_1_model)*norm.pdf(x_axis, 2, 2))), color = "black")

ax13.set_xlabel("Surprise of Process [nats]")
sns.kdeplot(samples_likelihood, ax = ax13, fill = True)
ax13.set_xlim((0, 10))
ax13.set_ylim((0,2))

ax14.set_xlabel("Free Energy of Process [nats]")
sns.kdeplot(samples_likelihood, ax = ax14, fill = True)
ax14.set_xlim((0, 10))
ax14.set_ylim((0,2))

ax15.set_xlabel("Surprise of Model [nats]")
sns.kdeplot(samples_likelihood, ax = ax15, fill = True)
ax15.set_xlim((0, 10))
ax15.set_ylim((0,2))

ax16.set_xlabel("Free Energy of Model [nats]")
sns.kdeplot(samples_likelihood, ax = ax16, fill = True)
ax16.set_xlim((0, 10))
ax16.set_ylim((0,2))

def update_parameters(p_1_process = 0.7, mu_1_process = 0, mu_2_process = 2, sigma_1_process = 2, sigma_2_process = 2, p_1_model = 0.7, mu_1_model = 0, mu_2_model = 2, sigma_1_model = 2, sigma_2_model = 2):
    # updates for the generative process
    line_1_1.set_ydata(norm.pdf(x_axis, mu_1_process, sigma_1_process))
    line_1_2.set_ydata(norm.pdf(x_axis, mu_2_process, sigma_2_process))

    line_2_1.set_ydata(1-p_1_process)
    line_2_2.set_ydata(p_1_process)
    line_2_3.set_ydata([0,(1-p_1_process)])
    line_2_4.set_ydata([0,p_1_process])

    line_3_1.set_ydata(p_1_process*norm.pdf(x_axis, mu_1_process, sigma_1_process))
    line_3_2.set_ydata((1-p_1_process)*norm.pdf(x_axis, mu_2_process, sigma_2_process))

    line_4_1.set_ydata((p_1_process*norm.pdf(x_axis, mu_1_process, sigma_1_process)+(1-p_1_process)*norm.pdf(x_axis, mu_2_process, sigma_2_process)))

    line_5_1.set_ydata((p_1_process*norm.pdf(x_axis, mu_1_process, sigma_1_process))/(p_1_process*norm.pdf(x_axis, mu_1_process, sigma_1_process)+(1-p_1_process)*norm.pdf(x_axis, mu_2_process, sigma_2_process)))
    line_5_2.set_ydata(1-((p_1_process*norm.pdf(x_axis, mu_1_process, sigma_1_process))/(p_1_process*norm.pdf(x_axis, mu_1_process, sigma_1_process)+(1-p_1_process)*norm.pdf(x_axis, mu_2_process, sigma_2_process))))

    line_6_1.set_ydata(-np.log((p_1_process*norm.pdf(x_axis, mu_1_process, sigma_1_process)+(1-p_1_process)*norm.pdf(x_axis, mu_2_process, sigma_2_process))))

    # updates for the generative model
    line_7_1.set_ydata(norm.pdf(x_axis, mu_1_model, sigma_1_model))
    line_7_2.set_ydata(norm.pdf(x_axis, mu_2_model, sigma_2_model))

    line_8_1.set_ydata(1-p_1_model)
    line_8_2.set_ydata(p_1_model)
    line_8_3.set_ydata([0,(1-p_1_model)])
    line_8_4.set_ydata([0,p_1_model])

    line_9_1.set_ydata(p_1_model*norm.pdf(x_axis, mu_1_model, sigma_1_model))
    line_9_2.set_ydata((1-p_1_model)*norm.pdf(x_axis, mu_2_model, sigma_2_model))

    line_10_1.set_ydata((p_1_model*norm.pdf(x_axis, mu_1_model, sigma_1_model)+(1-p_1_model)*norm.pdf(x_axis, mu_2_model, sigma_2_model)))

    line_11_1.set_ydata((p_1_model*norm.pdf(x_axis, mu_1_model, sigma_1_model))/(p_1_model*norm.pdf(x_axis, mu_1_model, sigma_1_model)+(1-p_1_model)*norm.pdf(x_axis, mu_2_model, sigma_2_model)))
    line_11_2.set_ydata(1-((p_1_model*norm.pdf(x_axis, mu_1_model, sigma_1_model))/(p_1_model*norm.pdf(x_axis, mu_1_model, sigma_1_model)+(1-p_1_model)*norm.pdf(x_axis, mu_2_model, sigma_2_model))))

    line_12_1.set_ydata(-np.log((p_1_model*norm.pdf(x_axis, mu_1_model, sigma_1_model)+(1-p_1_model)*norm.pdf(x_axis, mu_2_model, sigma_2_model))))

    # looping over N to generate the samples and calculate the surprises and free energies for the observations
    samples_1 = []
    new_sample_1 = []
    for i in range(N):
        new_sample_1 = np.random.choice(x_axis, size=1, p=(p_1_process*norm.pdf(x_axis, mu_1_process, sigma_1_process)+(1-p_1_process)*norm.pdf(x_axis, mu_2_process, sigma_2_process)) / sum((p_1_process*norm.pdf(x_axis, mu_1_process, sigma_1_process)+(1-p_1_process)*norm.pdf(x_axis, mu_2_process, sigma_2_process))))
        samples_1 = np.append(samples_1, new_sample_1)

    sns.kdeplot(-np.log(p_1_process*norm.pdf(samples_1, mu_1_process, sigma_1_process)+(1-p_1_process)*norm.pdf(samples_1, mu_2_process, sigma_2_process)), ax = ax13, fill = True)

    sns.kdeplot(-np.log(p_1_process*norm.pdf(samples_1, mu_1_process, sigma_1_process)+(1-p_1_process)*norm.pdf(samples_1, mu_2_process, sigma_2_process)), ax = ax14,  fill = True)

    sns.kdeplot(-np.log(p_1_model*norm.pdf(samples_1, mu_1_model, sigma_1_model)+(1-p_1_model)*norm.pdf(samples_1, mu_2_model, sigma_2_model)), ax = ax15, fill = True)

    sns.kdeplot(-np.log(p_1_model*norm.pdf(samples_1, mu_1_model, sigma_1_model)+(1-p_1_model)*norm.pdf(samples_1, mu_2_model, sigma_2_model))-(np.log(1-((p_1_model*norm.pdf(samples_1, mu_1_model, sigma_1_model))/(p_1_model*norm.pdf(samples_1, mu_1_model, sigma_1_model)+(1-p_1_model)*norm.pdf(samples_1, mu_2_model, sigma_2_model)))/(1-(p_1_process*norm.pdf(samples_1, mu_1_process, sigma_1_process))/(p_1_process*norm.pdf(samples_1, mu_1_process, sigma_1_process)+(1-p_1_process)*norm.pdf(samples_1, mu_2_process, sigma_2_process))))+(p_1_model*norm.pdf(samples_1, mu_1_model, sigma_1_model))/(p_1_model*norm.pdf(samples_1, mu_1_model, sigma_1_model)+(1-p_1_model)*norm.pdf(samples_1, mu_2_model, sigma_2_model))*np.log(((p_1_model*norm.pdf(samples_1, mu_1_model, sigma_1_model))/(p_1_model*norm.pdf(samples_1, mu_1_model, sigma_1_model)+(1-p_1_model)*norm.pdf(samples_1, mu_2_model, sigma_2_model))*(1-(p_1_process*norm.pdf(samples_1, mu_1_process, sigma_1_process))/(p_1_process*norm.pdf(samples_1, mu_1_process, sigma_1_process)+(1-p_1_process)*norm.pdf(samples_1, mu_2_process, sigma_2_process))))/((p_1_process*norm.pdf(samples_1, mu_1_process, sigma_1_process))/(p_1_process*norm.pdf(samples_1, mu_1_process, sigma_1_process)+(1-p_1_process)*norm.pdf(samples_1, mu_2_process, sigma_2_process))*(1-(p_1_model*norm.pdf(samples_1, mu_1_model, sigma_1_model))/(p_1_model*norm.pdf(samples_1, mu_1_model, sigma_1_model)+(1-p_1_model)*norm.pdf(samples_1, mu_2_model, sigma_2_model)))))+(p_1_model*norm.pdf(samples_1, mu_1_model, sigma_1_model))/(p_1_model*norm.pdf(samples_1, mu_1_model, sigma_1_model)+(1-p_1_model)*norm.pdf(samples_1, mu_2_model, sigma_2_model))), ax = ax16, fill = True)

    fig.canvas.draw_idle()
    #plt.show()
    display(fig)
    #fig.clear_output(wait=True)
######

interact(update_parameters, p_1_process = (0.0, 1.0), mu_1_process =(-5.5,5.5), mu_2_process = (-5.5,5.5), sigma_1_process = (1.0, 3.0), sigma_2_process =  (1.0, 3.0), p_1_model = (0.0, 1.0), mu_1_model =(-5.5,5.5), mu_2_model = (-5.5,5.5), sigma_1_model = (1.0, 3.0), sigma_2_model =  (1.0, 3.0));

