<a href="https://colab.research.google.com/github/Anirud-Mohan/Machine-Learnning/blob/main/Solution_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

***Sampling Method To Draw Samples From a Multivariate Normal Distribution(MVN) In JAX***

***MULTIVARIATE NORMAL DISTRIBUTION***

In probability theory and statistics, the multivariate normal distribution, multivariate Gaussian distribution, or joint normal distribution is a generalization of the one-dimensional (univariate) normal distribution to higher dimensions. One definition is that a random vector is said to be k-variate normally distributed if every linear combination of its k components has a univariate normal distribution. Its importance derives mainly from the multivariate central limit theorem. The multivariate normal distribution is often used to describe, at least approximately, any set of (possibly) correlated real-valued random variables each of which clusters around a mean value.




In [None]:
import jax.numpy as jnp
import jax.random as random
key = random.PRNGKey(23)

In [None]:
#Setting Parameters
# Define dimension. 
d = 10
# Set mean vector. 
m = jnp.array([1,2,3,4,5,6,7,8,9,10])
# Set covariance function.
K_0 = jnp.array([[1,0,0,0,0,0,0,0,0,0],
       [0,1,0,0,0,0,0,0,0,0],
       [0,0,1,0,0,0,0,0,0,0],
       [0,0,0,1,0,0,0,0,0,0],
       [0,0,0,0,1,0,0,0,0,0],
       [0,0,0,0,0,1,0,0,0,0],
       [0,0,0,0,0,0,1,0,0,0],
       [0,0,0,0,0,0,0,1,0,0],
       [0,0,0,0,0,0,0,0,1,0],
       [0,0,0,0,0,0,0,0,0,1]])

K_0 , m.reshape(10,1)

 **Sampling Process
Step 1: Compute the Cholesky Decomposition**

In [None]:
# Define epsilon.
epsilon = 0.0001

# Add small pertturbation. 
K = K_0 + epsilon*jnp.identity(d)

#  Cholesky decomposition.
L = jnp.linalg.cholesky(K)
L

In [None]:
#Verifying Desired Properties
jnp.dot(L, jnp.transpose(L))

**Step 2: Generate Independent Samples u ∼ N( 0, I )** 

In [None]:
n = 10000
u = random.uniform(key,shape=(d, n),minval=-3 , maxval=3)
u

**Step 3: Compute x = m + Lu**

In [None]:
x = m + jnp.dot(L, u).T
x

In [None]:
#Using JAX Sampler
key = random.PRNGKey(67)
cov = jnp.array([[1.2, 0.4], [0.4, 1.0]])
mean = jnp.array([3,-1])
x1 = random.multivariate_normal(key, mean, cov, (10000,)).T
x1