In [2]:
import jax
import matplotlib.pyplot as plt
import jax.numpy as jnp
key = jax.random.PRNGKey(0)
jnp.set_printoptions(formatter={'float': lambda x: "{0:0.2f}".format(x)})




Box Muller Transoform is used to transform Uniform Random Variable to Univariate Gaussian Random Variable.

We first start with two random samples of equal length, u_1 and u_2, drawn from the uniform distribution U(0,1). Then, we generate from them two normally-distributed random variables z_1 and z_2. Their values are:
\begin{equation*}
Z_1 = \sqrt{-2ln(u_1)}cos(2πu_2)
\end{equation*}
\begin{equation*}
Z_2 = \sqrt{-2ln(u_1)}sin(2πu_2)
\end{equation*}
source: [Uniform to Normal distribution](https://www.baeldung.com/cs/uniform-to-normal-distribution)

Samples can be drawn form Normal distribution using Central limit theorm too.

In [3]:
def box_muller(shape):
    global key,subkey
    u1, u2 = 0.0, 0.0
    epsilon = 0.1**5
    u1 = jax.random.uniform(key = key,shape = shape,minval=0.0001)
    key, subkey = jax.random.split(key)
    u2 = jax.random.uniform(key = key,shape = shape)
    key, subkey = jax.random.split(key)
    n1 = jnp.sqrt(-2 * jnp.log(u1)) * jnp.cos(2 * jnp.pi * u2)
    n2 = jnp.sqrt(-2 * jnp.log(u1)) * jnp.sin(2 * jnp.pi * u2)
    return n1

Gaussian_random_variable_sample = box_muller(shape=(5000,))

I have used cholesky decomposition to convert Univariate Gaussian to Multivariate guassian.

Let the given mean is **m** and covariance matrix is Σ.
Let the cholesky decomposition of Σ yield \$L\$ such that \$Σ = LL^T\$.
Let \$u\$ have Normal distribution. 
ie \$ u ∼ N(0,1) \$

then if we compute 
\begin{equation*}
x = m + Lu
\end{equation*}

Then \$x\$ follows multivariate normal distribution with mean \$m\$ and covariance \$Σ\$.
ie. 

\begin{equation*}
x∼ N(m,Σ)
\end{equation*}
source: https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Drawing_values_from_the_distribution

In [4]:
def multivariate_guassian(mean,cov, size ):
  d = mean.shape[0]
  mean = mean.reshape(d,1)
  epsilon = 0.0001
  cov = cov
  L = jnp.linalg.cholesky(cov)
  shape = (size*d,)
  u = box_muller(shape =(size*d,))
  u = u.reshape(d,size)
  X = jnp.dot(L,u)
  X = X+mean
  return X.T

dimension = 10 # change the dimension here.
mean = jax.random.uniform(key,shape=(dimension,),maxval=30)
rand = jax.random.uniform(key+11,shape= (dimension,dimension),minval=1,maxval=10)
cov = jnp.dot(rand,rand.T)


X = multivariate_guassian(mean,cov,(5000))

In [5]:
sample_mean = jnp.mean(X,axis=0)
print(f"given mean={mean}\nsample mean = {sample_mean}")

sample_cov = jnp.cov(X.T)
print(f"given cov = \n{cov}\nsample cov=\n{sample_cov}")

given mean=[25.48 4.00 24.89 7.23 16.90 7.12 5.95 25.94 6.67 13.66]
sample mean = [25.61 4.16 24.96 7.34 17.06 7.14 6.03 26.03 6.79 13.71]
given cov = 
[[364.74 240.74 344.65 277.72 291.31 266.32 433.64 329.20 292.56 324.72]
 [240.74 306.49 332.71 272.43 247.39 192.59 346.58 340.15 229.14 251.10]
 [344.65 332.71 430.01 339.18 311.42 262.40 470.22 425.77 299.51 357.02]
 [277.72 272.43 339.18 360.56 256.41 268.85 399.48 345.23 286.78 300.44]
 [291.31 247.39 311.42 256.41 292.65 213.05 370.94 308.74 288.45 293.70]
 [266.32 192.59 262.40 268.85 213.05 283.99 354.83 285.54 261.14 254.69]
 [433.64 346.58 470.22 399.48 370.94 354.83 606.74 485.66 403.06 458.58]
 [329.20 340.15 425.77 345.23 308.74 285.54 485.66 502.69 317.94 386.12]
 [292.56 229.14 299.51 286.78 288.45 261.14 403.06 317.94 346.56 284.97]
 [324.72 251.10 357.02 300.44 293.70 254.69 458.58 386.12 284.97 418.48]]
sample cov=
[[360.87 240.82 342.15 273.46 286.73 264.75 432.69 329.54 287.62 323.93]
 [240.82 314.50 336.64 274.24 25

In order to check whether our samples have the same distribution as we wanted, we will check the L2 norm of the difference between the given mean and sample mean.


In [6]:
error = jnp.linalg.norm(sample_mean-mean)
print("error in mean is")
print(error/dimension)

error in mean is
0.034104403


In [7]:
errorcov = jnp.linalg.norm(sample_cov-cov)
print("error in cov is")
print(errorcov/(dimension**2))

error in cov is
0.26023322


Since the values are very low, we can say that the samples have the same distribution as we wanted.
