# Prototype of NORTA Algorithm
The goal is to take the existing methods in our `bigsimr` package and translate them to python code using the numpy library. The code here is written with the intention of later using `jax.numpy` for JIT compilation and GPU acceleration. Jax works best with small composable functions, so every part of the `rvec` function should be broken up into its simplest pieces.

In [7]:
import numpy as np
from scipy import stats

In [58]:
def normal2marginal(x, margin):
    return margin[0].ppf(stats.norm.cdf(x), *margin[1:])

In [59]:
def rvec(n, rho, margins):
    
    d = rho.shape[0]
    
    # Generate multivariate normal data
    z = np.random.multivariate_normal(mean=np.zeros(d),
                                      cov=rho,
                                      size=n)
    
    # For each margin, apply `normal2marginal()`
    u = np.empty_like(z)
    for i, c in enumerate(z.T):
        u[:,i] = normal2marginal(c, my_margins[i])
    
    
    return u

In [60]:
my_margins = [(stats.nbinom, 12, 0.75),
              (stats.nbinom, 20, 0.50),
              (stats.nbinom, 5, 0.35),
              (stats.nbinom, 12, 0.25)]
x = rvec(10, np.eye(4), my_margins)

In [61]:
%timeit rvec(2000, np.eye(4), my_margins)

41.1 ms ± 128 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [62]:
x

array([[ 4., 18., 12., 51.],
       [ 2., 27.,  3., 27.],
       [ 8., 22., 14., 32.],
       [ 4., 13., 17., 24.],
       [ 2., 15., 16., 20.],
       [ 2., 15., 15., 42.],
       [ 3., 16., 15., 22.],
       [ 7., 12.,  8., 35.],
       [ 2., 24., 14., 46.],
       [ 2., 26., 11., 33.]])