# **TASK 2**
Implement from scratch a sampling method to draw samples from a multivariate Normal (MVN) distribution in JAX
- set the number of
dimensions (random variables of MVN) to 10 for this task
- You are only allowed to use `jax.random.uniform`. You are especially not allowed to use
`jax.random.normal`
- You should randomly create the mean and covariance matrix to fully specify an MVN
distribution

### **Attempting to create random means and covariance**

In [1]:
import jax.numpy as np
import jax
import random
from random import uniform
import math
from sklearn.datasets import make_spd_matrix

k = 10   # how many classes
N = 20   # how many samples
x_min = -20
x_max = 20
y_min = -20
y_max = 20



# w = np.random.random(k)    # generate k random numbers


from jax import random
key = random.PRNGKey(42)
w = random.uniform(key, shape=(10,))
w /= w.sum() # divide by their sum
# now these sum to 1
#print(w)  # actual weight vector

number_of_divisions_x = number_of_divisions_y = math.ceil(math.sqrt(k))
x_div = (x_max - x_min)/number_of_divisions_x
y_div = (y_max - y_min)/number_of_divisions_y

# divide a grid
# x = x_min to x_max
# y = y_min to y_max

# now form the grid and put the means

points = []  # the list of actual means

# now generate means
for i in range(number_of_divisions_x):
    for j in range(number_of_divisions_y):
        if len(points) == k:   # after generating k points, break the loop
            break
        val = x_min + i * x_div + y_min + j * y_div
        if val > 20:
          val -= 20
        else:
          val += 20
        point = val  # array (not tuple) is needed for the function
        
        points.append(point)

    if len(points) == k:  # after generating k points, break the loop
        break  

samples = []  # list containing all the samples
actual_sample = [None] * k   # how many samples in each class actually. Stored for later weight correction

covs = []    # actual covariances

for i in range(k):

    number_of_samples = math.ceil(N * w[i])  # number of samples from class i, generate at least one sample
    actual_sample[i] = number_of_samples  # for later weight correction

    if i == k-1:
        number_of_samples = N - sum(actual_sample[:k-1])   # for the last class, take all the remaining samples
        actual_sample[i] = number_of_samples

    mean = points[i]

    # generate covariance matrix

    cov = make_spd_matrix(10)   # Generate a random symmetric, positive-definite matrix, whose size is 2x2

    covs.append(cov)

In [2]:
cov.shape

(10, 10)

In [3]:
points

[-20.0, -10.0, 0.0, 10.0, -10.0, 0.0, 10.0, 20.0, 0.0, 10.0]

In [5]:
cov

array([[ 0.91403572,  0.27073634, -0.37070834,  0.20190801, -0.49408699,
         0.56851996,  0.46955073, -0.22527234,  1.05803919, -0.22501412],
       [ 0.27073634,  0.81706591, -0.2600977 ,  0.07777133, -0.25370266,
         0.41010781,  0.24963775, -0.23113938,  0.64841425, -0.17429734],
       [-0.37070834, -0.2600977 ,  1.14271222, -0.33178492,  0.66371463,
        -0.95918116, -0.76241614,  0.51013315, -1.58803968,  0.40677391],
       [ 0.20190801,  0.07777133, -0.33178492,  0.97573738, -0.51836549,
         0.54883474,  0.57466159, -0.26358343,  0.99273913, -0.30451835],
       [-0.49408699, -0.25370266,  0.66371463, -0.51836549,  1.63435328,
        -1.28627407, -0.77375221,  0.52707412, -2.06895041,  0.48501229],
       [ 0.56851996,  0.41010781, -0.95918116,  0.54883474, -1.28627407,
         2.61417108,  1.36848501, -0.85397541,  2.7808604 , -0.58402555],
       [ 0.46955073,  0.24963775, -0.76241614,  0.57466159, -0.77375221,
         1.36848501,  1.64051378, -0.64518508

In [6]:
import numpy
points = numpy.array(points)
points = points.reshape(10,1)

### **Creating a function multivariate_normal_sampler**

In [8]:
#Importing jax.numpy and jax.scipy api
import jax.numpy as np
import jax.scipy.linalg as spla

key = random.PRNGKey(42)
def multivariate_normal_sampler(mean, covariance, n_samples=1):
  L = spla.cholesky(covariance)

#cholesky decomposition can improve numerical stability and we can also use it to draw samples from a multivariate normal
 
  Z = random.uniform(key,shape=(n_samples,covariance.shape[0])) #Generating some white gausian noise
  
  return Z.dot(L) + mean 
#returns random sample that are distributed from a multivariate normal distribution with correct mean and caovariance

In [9]:
X = multivariate_normal_sampler(mean, cov, n_samples=1000)

In [10]:
X

DeviceArray([[10.381706 , 10.9372225,  9.913212 , ..., 10.503998 ,
              11.04511  ,  9.85002  ],
             [10.375615 , 10.59613  , 10.599334 , ..., 10.057558 ,
              10.056023 , 10.222786 ],
             [10.313506 , 10.192428 , 10.029142 , ..., 10.059924 ,
              10.745003 , 10.32871  ],
             ...,
             [10.261748 , 10.769014 , 10.14813  , ...,  9.683733 ,
              12.20946  , 10.226092 ],
             [10.893305 , 10.442641 , 10.173006 , ..., 10.535437 ,
              11.616586 , 10.292549 ],
             [10.9359455, 10.476151 , 10.35238  , ..., 10.549779 ,
              10.754279 , 10.522897 ]], dtype=float32)

In [11]:
X.mean(axis=0)

DeviceArray([10.490202, 10.562163, 10.193903, 10.459148, 10.242573,
             10.550854, 10.78474 , 10.261264, 11.037973, 10.229966],            dtype=float32)

In [12]:
np.cov(X.T)

DeviceArray([[ 0.07697921,  0.0200586 , -0.03137587,  0.01664387,
              -0.03977153,  0.04651411,  0.03703811, -0.01427077,
               0.08667421, -0.01580602],
             [ 0.0200586 ,  0.06650877, -0.02030928,  0.00785543,
              -0.0187023 ,  0.03336751,  0.01668165, -0.01564905,
               0.05020088, -0.0125584 ],
             [-0.03137587, -0.02030928,  0.09128047, -0.02915336,
               0.05317736, -0.07689525, -0.0600033 ,  0.03940117,
              -0.1265764 ,  0.02962657],
             [ 0.01664387,  0.00785543, -0.02915336,  0.08517961,
              -0.04976802,  0.05415992,  0.0516171 , -0.02413062,
               0.0945849 , -0.02380342],
             [-0.03977153, -0.0187023 ,  0.05317736, -0.04976802,
               0.14494088, -0.11050644, -0.06604397,  0.04438507,
              -0.17813413,  0.03754091],
             [ 0.04651411,  0.03336751, -0.07689525,  0.05415992,
              -0.11050644,  0.22196814,  0.11483034, -0.07325884,
   

## **Using builtin random.multivariate_normal**

In [28]:
import jax
import jax.numpy as jnp
import numpy as np

key = random.PRNGKey(0)
cov = np.array([[1.2, 0.4], [0.4, 1.0]])
mean = np.array([3,-1])
x1,x2 = jax.random.multivariate_normal(key, mean, cov, (5000,)).T

### Usually building the function `multivariate_normal_sampler` is faster than built in  sampling routine for mltivariate normal `random.multivariate_normal`

In [38]:
import numpy as np
import scipy.linalg as spla
from scipy.spatial.distance import cdist

# key = random.PRNGKey(0)
def multivariate_normal_sampler(mean, covariance, n_samples=1):
  L = spla.cholesky(covariance)
  Z = np.random.normal(size=(n_samples,covariance.shape[0]))
  return Z.dot(L) + mean

X = np.random.normal(size=(500,1))
K = np.exp(-cdist(X,X,"sqeuclidean")) + 1e-6 + np.eye(X.shape[0])
mean = np.zeros((X.shape[0],))

In [39]:
import time

In [41]:
#USING BUILTIN FUNCTION
time.sleep(1.)
start_time = time.time()
samples = np.random.multivariate_normal(mean, k, size=(10000,))
print("Time elapsed: {:.4f}".format(time.time() - start_time))

Time elapsed: 0.5319


In [43]:
#FUNCTION WE MADE
time.sleep(1.)
start_time = time.time()
samples = multivariate_normal_sampler(mean, k, 10000)
print("Time elapsed: {:.4f}".format(time.time() - start_time))

Time elapsed: 0.4205
