## How to generate samples for MC using the JAX generator

In [3]:
import jax
import jax.random as jr
import jax.numpy as jnp
import numpy as np

from jax_src.sampler_for_mlmc import LASamplerForMC
from jax_src.generator import load_net, make_net

# First choose a PRNG key, which all randomness is derived from. If you use the same key, you will always get exactly the same results, no matter what environment you use, etc. Honestly, I think this is the best thing about JAX!
key = jr.key(17)

# Then create a sampler, we'll make one that uses James' method and one which uses a GAN
foster_sampler = LASamplerForMC(bm_dim=5, net=None, key=key, method="Foster")

# load a net for the GAN-based sampler (you'll have to change the path)
path = "/home/andy/PycharmProjects/Levy_CFGAN/numpy_nets/"
net = load_net(
    path=path,
    noise_size=3,
    hidden_size=16,
    num_layers=3,
    slope=0.01,
    use_multlayer=True,
    dtype=jnp.complex64,
)
# Or just create a new net
net2 = make_net(
    jr.key(1),
    noise_size=3,
    hidden_size=16,
    num_layers=3,
    slope=0.01,
    use_multlayer=True,
    dtype=jnp.complex64,
)

gan_sampler = LASamplerForMC(bm_dim=2, net=net, key=key, method="GAN")

### Generating samples
Note that each time you generate samples using the same LASamplerForMC object, you will get different results, because the PRNG key is updated each time.

In [4]:
# Now sample. This acts exactly the same for both methods. The output has the same shape as the original generate_MC_samples function.
samples = gan_sampler.generate_mc_samples(m=3, n=8, dt=0.1)
# These are a JAX array, so convert to numpy:
samples_np = np.array(samples)
print(samples.shape)
print(samples_np.shape)

(3, 8, 3)
(3, 8, 3)


### Sample and Chen combine

In [5]:
samples, chen_combined_samples = gan_sampler.generate_mc_samples(
    m=3, n=8, dt=0.1, use_chen=True
)
print(samples.shape)
print(chen_combined_samples.shape)

(3, 8, 3)
(3, 4, 3)


In [6]:
import time

start = time.time()
for i in range(1):
    samples = jax.block_until_ready(gan_sampler.generate_mc_samples(m=3, n=8, dt=0.1))
end = time.time()
print((end - start) / 1)

0.0012984275817871094
