# Animate bivariate normal distribution using jax

The density function describes the relative likelihood of a random variable X     at a given sample. If the value is high around a given sample, that means that the random variable will most probably take on that value when sampled at random. Responsible for its characteristic “bell shape”, the density function of a given bivariate Gaussian random variable X     is mathematically defined as:

f_X(x) = \frac{1}{{ \sqrt {2\pi|\Sigma| }}} exp\begin{pmatrix}\frac{-(x-\mu)^T \Sigma^{-1}(x-\mu)}{2} \end{pmatrix}   

Where x        is any input vector \in \mathbb{R^2}         while the symbols \mu        and \Sigma        have their usual meaning.

# jax.random.multivariate_normal

In [None]:
jax.random.multivariate_normal(key, mean, cov, shape=None, dtype=<class 'numpy.float64'>, method='cholesky')[source]

Parameters
key (Union[Any, PRNGKeyArray]) – a PRNG key used as the random key.

mean (Any) – a mean vector of shape (..., n).

cov (Any) – a positive definite covariance matrix of shape (..., n, n). The batch shape ... must be broadcast-compatible with that of mean.

shape (Optional[Sequence[int]]) – optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last axis. Must be broadcast-compatible with mean.shape[:-1] and cov.shape[:-2]. The default (None) produces a result batch shape by broadcasting together the batch shapes of mean and cov.

dtype (Any) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

method (str) – optional, a method to compute the factor of cov. Must be one of ‘svd’, eigh, and ‘cholesky’. Default ‘cholesky’.

Return type
ndarray

Returns
A random array with the specified dtype and shape given by shape + mean.shape[-1:] if shape is not None, or else broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:].

In [None]:
#Importing the necessary modules
import numpy as np
import matplotlib.pyplot as plt
from jax import random
key = random.PRNGKey(0)
random.multivariate_normal(key=key, mean=mean, cov=cov, shape=None, method='cholesky')


plt.style.use('seaborn-dark')
plt.rcParams['figure.figsize']=14,6
fig = plt.figure()

# Initializing the random seed
random_seed=1000

# List containing the variance
# covariance values
cov_val = [-0.8, 0, 0.8]

# Setting mean of the distributino
# to be at (0,0)
mean = np.array([0,0])

# Storing density function values for
# further analysis
pdf_list = []

# Iterating over different covariance values
for idx, val in enumerate(cov_val):

# Initializing the covariance matrix
    cov = np.array([[1, val], [val, 1]])
    # Generating a meshgrid complacent with
    # the 3-sigma boundary
    mean_1, mean_2 = mean[0], mean[1]
    sigma_1, sigma_2 = cov[0,0], cov[1,1]

    x = np.linspace(-3*sigma_1, 3*sigma_1, num=100)
    y = np.linspace(-3*sigma_2, 3*sigma_2, num=100)
    X, Y = np.meshgrid(x,y)
    
    # Generating the density function
    # for each point in the meshgrid
    pdf = np.zeros(X.shape)
    for i in range(X.shape[0]):
        for j in range(X.shape[1]):
            pdf[i,j] = distr.pdf([X[i,j], Y[i,j]])

# Plotting the density function values
    key = 131+idx
    ax = fig.add_subplot(key, projection = '3d')
    ax.plot_surface(X, Y, pdf, cmap = 'viridis')
    plt.xlabel("x1")
    plt.ylabel("x2")
    plt.title(f'Covariance between x1 and x2 = {val}')
    pdf_list.append(pdf)
    ax.axes.zaxis.set_ticks([])

plt.tight_layout()
plt.show()



# Conclusion
We understood the various intricacies behind the Gaussian bivariate distribution through a series of plots and verified the theoretical results with the practical findings using Python. The reader is encouraged to play around with the code snippets for gaining a much more profound intuition about this magical distribution!

# References
For Studying Bivariate Normal Distribution : https://en.wikipedia.org/wiki/Multivariate_normal_distribution

For Studying Jax Syntax for animating Bivariate Normal Distribution: https://jax.readthedocs.io/en/latest/_autosummary/jax.random.multivariate_normal.html#jax.random.multivariate_normal