In [1]:
%%capture
%pip install --upgrade --user pip
%pip install --upgrade --user tensorflow tensorflow_probability
%pip install git+git://github.com/deepmind/optax.git
%pip install --upgrade git+https://github.com/google/flax.git
%pip install git+git://github.com/blackjax-devs/blackjax.git
%pip install git+git://github.com/deepmind/distrax.git
%pip install superimport  einops arviz
%pip install jaxlib
%pip install latex
%pip install git+https://github.com/probml/probml-utils.git


In [2]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
import latex

from probml_utils import savefig, latexify
from jax.numpy.linalg import cholesky
from jax.scipy.linalg import inv
from scipy.stats import multivariate_normal

In [3]:
#set environment variables
import os
os.environ["LATEXIFY"] =str(1)
os.environ['FIG_DIR']="/content/Folder"

In [4]:
latexify(width_scale_factor=2,fig_height=2)

In [5]:
import matplotlib as mpl
mpl.rcParams.update(mpl.rcParamsDefault) #to make latexify work in compatibility with matplotlib

In [6]:
def gaussSample(mu, sigma, n,key):
    A = cholesky(sigma)
    Z = jax.random.normal(key,shape=(len(mu), n))
    return jnp.dot(A, Z).T + mu

In [7]:
mtrue = {}
prior = {}
post = {} 
ns=[10]
key = jax.random.PRNGKey(5)
muTrue = jnp.array([0.5,0.5])
Ctrue = 0.1 * jnp.array([[2, 1], [1, 1]])
mtrue["mu"] = muTrue
mtrue["Sigma"] = Ctrue
X = gaussSample(mtrue["mu"], mtrue["Sigma"], ns[-1],key)
xyrange = jnp.array([[-1, 1], [-1, 1]])
npoints = 100j
prior["mu"] = jnp.array([0, 0])
prior["Sigma"] = 0.1 * jnp.eye(2)
out = jnp.mgrid[xyrange[0, 0] : xyrange[0, 1] : npoints, xyrange[1, 0] : xyrange[1, 1] : npoints]
X1, X2 = out[0], out[1]
nr = X1.shape[0]
nc = X2.shape[0]
points = jnp.vstack([jnp.ravel(X1), jnp.ravel(X2)]).T
prior_pdf = multivariate_normal.pdf(points, mean=prior["mu"], cov=prior["Sigma"]).reshape(nr, nc)
data = X[: ns[0], :]
n = ns[0]
S0 = prior["Sigma"]
S0inv = inv(S0)
S = Ctrue
Sinv = inv(S)
Sn = inv(S0inv + n * Sinv)
mu0 = prior["mu"]
xbar = jnp.mean(data, 0)
muN =jnp.dot(Sn, (jnp.dot(n, jnp.dot(Sinv, xbar)) + jnp.dot(S0inv, mu0)))
post["mu"] = muN
post["Sigma"] = Sn
post_pdf = multivariate_normal.pdf(points, mean=post["mu"], cov=post["Sigma"]).reshape(nr, nc)




In [8]:
def make_graph(X,muTrue,savename,title,fig=None,ax=None):
  if fig is None:
        fig, ax = plt.subplots()
  ax.plot(X[:, 0], X[:, 1], "o", markersize=8, markerfacecolor="b")
  ax.set_ylim([-1, 1])
  ax.set_xlim([-1, 1])
  ax.set_title(title)
  ax.plot(muTrue[0], muTrue[1], "x", linewidth=5, markersize=20, color="k")
  sns.despine()
  if len(savename) > 0:
        savefig(savename)
  return fig, ax

In [9]:
def make_graph2(X1,X2,p,savename,title,fig=None,ax=None):
  if fig is None:
        fig, ax = plt.subplots()
  ax.contour(X1,X2,p)
  ax.set_ylim([-1, 1])
  ax.set_xlim([-1, 1])
  ax.set_title(title)
  sns.despine()
  if len(savename) > 0:
        savefig(savename)
  return fig, ax

In [10]:
_,_ = make_graph(X,muTrue,'gauss_infer_2d_(a)_latexified','data')
_,_=make_graph2(X1,X2,prior_pdf,'gauss_infer_2d_(b)_latexified','prior')
_,_=make_graph2(X1,X2,post_pdf,'gauss_infer_2d_(c)_latexified','posterior after 10 points')

saving image to /content/Folder/gauss_infer_2d_(a)_latexified
Figure size: [6.4 4.8]
saving image to /content/Folder/gauss_infer_2d_(b)_latexified
Figure size: [6.4 4.8]
saving image to /content/Folder/gauss_infer_2d_(c)_latexified
Figure size: [6.4 4.8]


In [12]:
from ipywidgets import interact


@interact(random_state=(1, 10),n_=(1,20),range=(1,5))
def generate_random(random_state,n_=10,range=1):
    key = jax.random.PRNGKey(random_state)
    ns = [n_]
    key = jax.random.PRNGKey(random_state)
    X = gaussSample(mtrue["mu"], mtrue["Sigma"], ns[-1],key)
    xyrange = jnp.array([[-1*range, range], [-1*range,range]])
    fig, ax = make_graph(X, muTrue, "",'data')
    out = jnp.mgrid[xyrange[0, 0] : xyrange[0, 1] : npoints, xyrange[1, 0] : xyrange[1, 1] : npoints]
    X1, X2 = out[0], out[1]
    nr = X1.shape[0]
    nc = X2.shape[0]
    points = jnp.vstack([jnp.ravel(X1), jnp.ravel(X2)]).T
    prior_pdf= multivariate_normal.pdf(points, mean=prior["mu"], cov=prior["Sigma"]).reshape(nr, nc)
    data = X[: n_, :]
    xbar = jnp.mean(data, 0)
    muN =jnp.dot(Sn, (jnp.dot(n, jnp.dot(Sinv, xbar)) + jnp.dot(S0inv, mu0)))
    post["mu"] = muN
    post["Sigma"] = Sn

    post_pdf= multivariate_normal.pdf(points, mean=post["mu"], cov=post["Sigma"]).reshape(nr, nc)
    fig,ax2 = make_graph2(X1,X2,prior_pdf,'','prior')
    fig,ax3 = make_graph2(X1,X2,post_pdf,'','posterior after 10 points')
    plt.show()


interactive(children=(IntSlider(value=5, description='random_state', max=10, min=1), IntSlider(value=10, descr…