#QN1
*Unmodified code takes 1 min to run in free google collab*

The 'master_key' variable provides the required randomness for the entire code. Change 'master_key' to change output

In [1]:
# JAX is imported
import jax

# Randomness for the entire program
master_key = 20;
key = jax.random.PRNGKey(master_key)
subkey, key = jax.random.split(key)



Variable 'dimension' specifies the dimension of the multivariate gaussian.

Variable 'instances' specifies the number of sample instances that we take from the specified multivariate gaussian distribution

The Covariance and mean of our multivariate gaussian are randomly generated.


In [2]:
dimension=10 # dimension of multivariate gaussian
instances=1000 # Number of sampling instances 

# Creating random covariance matrix as AA' (to gaurantee positivity of covariance)
A = jax.random.uniform(subkey, [dimension, dimension])
subkey, key = jax.random.split(key)
True_cov = jax.numpy.matmul(A, jax.numpy.transpose(A)) # Cov=A*A'

# Creating random mean
True_mean = jax.random.uniform(subkey, [dimension,1])
subkey, key = jax.random.split(key)

Initialize variables to store the estimates of mean and covariance calculated from sampled data

In [3]:
Est_mean = jax.numpy.zeros([dimension, 1]) # to store mean estimate
Est_cov = jax.numpy.zeros([dimension, dimension]) # to store covariance estimate
Gaussian_samples_save = jax.numpy.zeros([dimension, instances]) # to store covariance estimate

Next we sample our specified multivariate gaussian. The approach is as follows,

*   First sample a uniform distribution of dimension 10
*   Pass the uniform random variable samples through the CDF of a standard multivariate gaussian of dimension 10. This is just the CDF to CDF mapping that converts the uniform random variable samples to gaussian samples. Qfunc inverse(gaussian CDF inverse) was not found in JAX, hence the alternative and similar error function is used to map the samples.
*   Multiply the obtained standard gaussian samples of dimension '10' by the matrix 'A' to convert the standard gaussian samples to gaussian samples of the specified covariance. Remember that the true covariance was calculated as AA'.
*   Add the mean to the obtained samples
*   Repeat this process for all instances

In [4]:
for ins in range(instances):
  Uniform_samples = jax.random.uniform(subkey, [dimension,1]) # Sampling the uniform distribution
  subkey, key = jax.random.split(key)

  Gaussian_samples = jax.numpy.zeros([dimension,1]) # Initializing the gaussian samples

  # This for loop converts the uniform samples to standard gaussian samples of given dimension
  # The mapping of uniform samples to gaussian samples is done through their CDFs
  for i in range(dimension):
    if(Uniform_samples[i,1] < 0.5):
      Gaussian_samples = Gaussian_samples.at[i,0].set(-jax.numpy.sqrt(2) * jax.scipy.special.erfinv(2*Uniform_samples[i,1]) )
    else:
      Gaussian_samples = Gaussian_samples.at[i,0].set(jax.numpy.sqrt(2) * jax.scipy.special.erfinv(2*Uniform_samples[i,1]-1) )


  # Convert the standard gaussian samples to samples of given mean and covariance
  Gaussian_samples = True_mean + jax.numpy.matmul(A, Gaussian_samples);

  # Save the samples for later to calculate the sample mean and covariance
  Gaussian_samples_save = Gaussian_samples_save.at[:, ins:ins+1].set(Gaussian_samples)

After sampling, the parameters are calculated back from the samples, $S_i$'s as ,

Sample mean, $E(s)=\frac{\sum_{i} S_i}{N}$

Sample covariance=$\frac{\sum_{i} \Big((S_i-E(s)) (S_i-E(s))^T\Big)}{N-1}$

In [5]:
# Sample mean is calculated here
Est_mean = jax.numpy.zeros([dimension,1])
Est_mean = Est_mean.at[:,0].set(jax.numpy.sum(Gaussian_samples_save, 1) / instances)

# Sample covariance is calculated here
for ins in range(instances):
  Est_cov = Est_cov + jax.numpy.matmul( Gaussian_samples_save[:,ins:ins+1]-Est_mean, jax.numpy.transpose(Gaussian_samples_save[:,ins:ins+1]-Est_mean) )

Est_cov = Est_cov / (instances-1)

Finally the true parameter values, estimated parameter values and their differences are displayed in same order

In [6]:
# Printing true values
print("True mean")
print(True_mean)
print("\nTrue Covariance")
print(True_cov)

# Printing estimated values
print("\n\nEst mean")
print(Est_mean)
print("\nEst Covariance")
print(Est_cov)

# Printing differences
print("\n\nTrue mean - Est mean")
print(True_mean - Est_mean)
print("\nTrue Covariance - Est Covariance")
print(True_cov - Est_cov)

True mean
[[0.05800748]
 [0.26793325]
 [0.998407  ]
 [0.69090223]
 [0.59423995]
 [0.3626442 ]
 [0.36861145]
 [0.99709284]
 [0.6963452 ]
 [0.44293797]]

True Covariance
[[1.8399405 1.3392221 1.564645  2.4566624 1.249694  2.107464  1.7751497
  1.8875052 1.6365731 2.4960427]
 [1.3392221 2.50501   1.5026025 2.3401546 2.2614415 2.1456811 1.5607216
  2.3542547 1.5120171 2.140727 ]
 [1.564645  1.5026025 2.7307885 2.4518714 2.0720088 3.019265  2.3068993
  2.5066159 2.4883673 2.7407084]
 [2.4566624 2.3401546 2.4518714 4.894763  2.6920557 4.011526  3.1403604
  3.2589016 2.9205432 3.9793098]
 [1.249694  2.2614415 2.0720088 2.6920557 3.5932405 3.127026  1.6811109
  2.6638625 2.4867241 2.620227 ]
 [2.107464  2.1456811 3.019265  4.011526  3.127026  4.9752345 2.6989992
  2.9629798 3.2046342 4.010004 ]
 [1.7751497 1.5607216 2.3068993 3.1403604 1.6811109 2.6989992 3.118404
  2.7301116 2.7383723 2.6447108]
 [1.8875052 2.3542547 2.5066159 3.2589016 2.6638625 2.9629798 2.7301116
  3.7282097 2.6469927 2.75