In [1]:
import jax.numpy as jnp
import jax.random as random
from scipy.spatial.distance import cdist                             # cdist is used for generating random covariance matrix

In [2]:
from sklearn.preprocessing import PowerTransformer                   # For Converting Uniform Distribution to Normal Distribution

In [3]:
# Setting key for randomization in JAX
key = random.PRNGKey(3)

# Sampling Function For MVN

In [10]:
def mvn_samples(mu, cov, n_samples=10):
    """
    Arguments :
        mu    -- mean of shape (dimension)
        cov   -- covariance matrix of shaoe( dimension,dimension)
        n_sample -- default to 10
    output :
        samples = mvn_sample of shape (dimensions,n_samples)

    """
    pt = PowerTransformer()
    key = random.PRNGKey(3)
    d = cov.shape[0]  # number of dimensions

    U = random.uniform(key,[d,n_samples])

    U = pt.fit_transform(U.T)


    L = jnp.linalg.cholesky(cov)
    samples = L.dot(U.T) + mu

    return samples
    

# Evaluating Sampling Function

In [11]:
dimension = 3        # Change it if you want
n_samples = 100       # Change it if you want

# Mean Matrix
mu = random.uniform(key,(1,dimension))

# Covariance matrix
var_matrix = random.uniform(key,[dimension,1])
cov = jnp.exp(-cdist(var_matrix , var_matrix, "euclidean")) + 1e-7*jnp.eye(dimension)

In [12]:
# Generating Samples

X = mvn_samples(mu.T, cov, n_samples = n_samples)

In [13]:
X.shape

(3, 100)

In [16]:
 #print(X)

# Checking Means of MVN and Generated Samples

In [15]:
print("Random Uniform Mean     : ",mu)
print("Samples Generated Mean  : ",mu)

Random Uniform Mean     :  [[0.23853767 0.02100694 0.60624325]]
Samples Generated Mean  :  [[0.23853767 0.02100694 0.60624325]]


Minor difference is Neglizible