In [1]:
import jax
import jax.numpy as jnp
import jax.random as jrand
from StatTests.MWTest import wilcoxon_test
from StatTests.TTest import t_test

### NOTE

I couldn't find JAX implementations of the T Test or Wilcox Test so I had OpenAI's new `o3-mini-high` make quick implementations of them. I haven't checked the absolute accuracy of those implementations, but the numbers lined up with the numbers in Very Normal's (Christian) video so I'm just going with them. 

The main point of this was to show the difference in speed between JAX and R.

In [2]:
# Recreation of Christian's code but calling JAX functions

def runReplication(key, test, shift, dist):
    key1, key2 = jrand.split(key)
    if dist == "Normal":
        A = jrand.normal(key1, (27,)) * 1200 + 10500
        B = jrand.normal(key2, (27,)) * 1200 + 10500 + shift
    elif dist == "Cauchy":
        A = jrand.cauchy(key1, (27,)) * 1200 + 10500
        B = jrand.cauchy(key2, (27,)) * 1200 + 10500 + shift

    if test == "Student":
        _, p_value = t_test(A, B, equal_var=True)
    elif test == "Welch":
        _, p_value = t_test(A, B, equal_var=False)
    elif test == "Wilcox":
        _, p_value = wilcoxon_test(A, B)

    return p_value < 0.05

Here we define the options we want to test and call the function.

`jax.vmap` is similar to the `pwalk` function in the video (a parallel map)

`jrand.PRNGKey` initializes the RNG seed. It takes any integer (I chose 0)

`jrand.split` creates `number_of_simulations` new keys(seeds) so that each call to `runReplication` has a unique seed.

Here I've nested for loops over the `tests` and `dists` then nested `jax.vmap`s over the `shifts` and `number_of_simulations`.

In [3]:
tests = ["Student", "Welch", "Wilcox"]
dists = ["Normal", "Cauchy"]
shifts = jnp.array([300, 600, 900])
number_of_simulations = 1_000_000

results = {}
for test in tests:
    results[test] = {}
    for dist in dists:
        results[test][dist] = jax.vmap(
            lambda k: jax.vmap(lambda s: runReplication(k, test, s, dist))(
                shifts
            )
        )(jrand.split(jrand.PRNGKey(0), number_of_simulations))

Here I just format the results for printing out the average results for each test.

In [4]:
import jax.tree_util as jtu

powers = jtu.tree_map(lambda v: jnp.mean(v, axis=0), results)

for test in powers.keys():
    print(f"Power of {test} test:")
    for dist in powers[test].keys():
        print(f"with {dist} distribution:")
        for i, shift in enumerate(shifts):
            print(f"Shift: {shift}, Power: {powers[test][dist][i]:.3f}")

Power of Student test:
with Cauchy distribution:
Shift: 300, Power: 0.024
Shift: 600, Power: 0.034
Shift: 900, Power: 0.051
with Normal distribution:
Shift: 300, Power: 0.147
Shift: 600, Power: 0.438
Shift: 900, Power: 0.772
Power of Welch test:
with Cauchy distribution:
Shift: 300, Power: 0.023
Shift: 600, Power: 0.033
Shift: 900, Power: 0.049
with Normal distribution:
Shift: 300, Power: 0.146
Shift: 600, Power: 0.437
Shift: 900, Power: 0.771
Power of Wilcox test:
with Cauchy distribution:
Shift: 300, Power: 0.076
Shift: 600, Power: 0.162
Shift: 900, Power: 0.298
with Normal distribution:
Shift: 300, Power: 0.138
Shift: 600, Power: 0.414
Shift: 900, Power: 0.745
