Currently, we have very rudimentary handling of stochastic layers. Initialization of RNGs for stochastic layers is done as:
randn(rng, 1)
return (rng=replicate(rng), training=true)
This makes stochastic layers start from different RNGs. Need to look at how jax frameworks do it