<h2><b>BIVARIATE NORMAL DISTRIBUTION ANIMATION</b></h2>
<img src="https://upload.wikimedia.org/wikipedia/commons/8/8e/MultivariateNormal.png" width="300" height="200"></img>

*   <p>Reproduce the above figure showing samples from bivariate normal with 
marginal PDFs from scratch using JAX and matplotlib.</p>
*   <p>Add interactivity to the figure by adding sliders with ipywidgets. You should be able to vary the parameters of bivariate normal distribution (mean and covariance matrix) using ipywidgets.</p>

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
import ipywidgets as w

# function to run the program
def f(a, b, c, d, e, f):

  plt.rcParams['figure.figsize'] = 14, 6
  key = jax.random.PRNGKey(0)   # creating a pseudo-random number generator key

  cov = jnp.array([[a, b], [c, d]])   # creating the covariance matrix
  mean = jnp.array([e, f])    # creating the mean array
  
  # sampling the multivariate normal random values with given key, mean, covariance & result batch shape
  x1,x2 = jax.random.multivariate_normal(key, mean, cov, (5000,)).T

  mu_x, mu_y = 0, 0
  variance_x, variance_y = 3, 15
  X, Y = np.meshgrid(x1,x2)

  pos = np.empty(X.shape + (2,))
  pos[:, :, 0] = X
  pos[:, :, 1] = Y
  rv = multivariate_normal([mu_x, mu_y], [[variance_x, 0], [0, variance_y]])

  #Make a 3D plot
  fig = plt.figure()
  ax = fig.gca(projection='3d')
  ax.plot_surface(X, Y, rv.pdf(pos),cmap='viridis',linewidth=0)
  ax.set_xlabel('X')
  ax.set_ylabel('Y')
  ax.set_zlabel('(X)d')
  plt.show()

  #Create grid and multivariate normal
  x = jnp.linspace(-10,10,500)
  y = x
  X, Y = np.meshgrid(x,y)
  pos = np.empty(X.shape + (2,))
  pos[:, :, 0] = X; pos[:, :, 1] = Y
  rv = multivariate_normal([mu_x, mu_y], [[variance_x, 0], [0, variance_y]])

  #Make a 3D plot
  fig = plt.figure()
  ax = fig.gca(projection='3d')
  ax.plot_surface(X, Y, rv.pdf(pos),cmap='viridis',linewidth=0)
  ax.set_xlabel('p(X)')
  ax.set_ylabel('p(Y)')
  ax.set_zlabel('X')
  plt.show()

# Utilising pywidgets library to create interactive figure
_ = w.interact(f, a=1, b=0.6, c=0.6, d=2, e=0.5, f=0.5)

interactive(children=(IntSlider(value=1, description='a', max=3, min=-1), FloatSlider(value=0.6, description='…