**Create some Fake Data with a Signal**

In [None]:
from prism.loading import generate_null_brain_map, fetch_mni152_2mm_mask_img
from prism.stats import t
from prism.inference import permutation_analysis_volumetric_dense
import numpy as np
from tqdm import tqdm
from nilearn.maskers import NiftiMasker
import os
import pandas as pd

# Create random maps
mask_img = fetch_mni152_2mm_mask_img()
n_maps = 100
imgs = []
img_paths = []
for i in tqdm(range(n_maps)):
    if i < 11:
        random_state = 42
    elif i > (n_maps - 11):
        random_state = 37
    else:
        random_state = i
    imgs.append(generate_null_brain_map(mask_img, random_state=random_state))
    imgs[i].to_filename(f'/Users/jt041/repos/pstn/testing_outputs/null_brain_maps/null_brain_map_{i+1}.nii.gz')
    img_paths.append(os.path.abspath(f'/Users/jt041/repos/pstn/testing_outputs/null_brain_maps/null_brain_map_{i+1}.nii.gz'))

# Save paths and concatenate images
np.random.seed(42)
masker = NiftiMasker(mask_img=mask_img)
data = masker.fit_transform(imgs)
concatenated_img = masker.inverse_transform(data)
concatenated_img.to_filename('/Users/jt041/repos/pstn/testing_outputs/null_brain_maps_concatenated.nii')
pd.DataFrame(img_paths).to_csv('/Users/jt041/repos/pstn/testing_outputs/null_brain_map_paths.csv', index=False, header=False)

# Define design matrix
design_col1 = np.hstack([np.ones(10)+0.75, np.random.randn(n_maps-10)])
design_col1 += np.random.randn(n_maps) * 0.01
design_col2 = np.hstack([np.random.randn(n_maps-10), np.ones(10)-1.75])
design_col2 += np.random.randn(n_maps) * 0.01
design_col3 = np.random.randn(n_maps) * 2
intercept = np.ones(n_maps)
design = np.vstack([design_col1, design_col2, design_col3, intercept]).T
design = design.astype(np.float16)

# Define T and F contrasts
contrast_r1 = np.atleast_2d(np.array([1, 0, 0, 0]))
contrast_r2 = np.atleast_2d(np.array([0, 1, 0, 0]))
contrast = np.vstack([contrast_r1, contrast_r2])
f_contrast_indices = [1,1]

# Define exchangeability blocks
exchangeability_blocks = np.hstack([np.ones(np.floor(n_maps / 2).astype(int)) * i+1 for i in range(2)])
if n_maps % 2 != 0:
    exchangeability_blocks = np.hstack([exchangeability_blocks, np.ones(n_maps % 2) * 2])


# Save design and contrast as csvs and npys
np.save('/Users/jt041/repos/pstn/testing_outputs/design.npy', design)
np.save('/Users/jt041/repos/pstn/testing_outputs/contrast.npy', contrast)
np.save('/Users/jt041/repos/pstn/testing_outputs/exchangeability_blocks.npy', exchangeability_blocks)
np.save('/Users/jt041/repos/pstn/testing_outputs/f_contrast_indices.npy', f_contrast_indices)
pd.DataFrame(design).to_csv('/Users/jt041/repos/pstn/testing_outputs/design.csv', index=False, header=False)
pd.DataFrame(contrast).to_csv('/Users/jt041/repos/pstn/testing_outputs/contrast.csv', index=False, header=False)
pd.DataFrame(exchangeability_blocks).to_csv('/Users/jt041/repos/pstn/testing_outputs/exchangeability_blocks.csv', index=False, header=False)
pd.DataFrame(np.atleast_2d(f_contrast_indices)).to_csv('/Users/jt041/repos/pstn/testing_outputs/f_contrast_indices.csv', index=False, header=False)

100%|██████████| 100/100 [00:20<00:00,  4.97it/s]


In [None]:
from scipy.stats import zscore as scipy_zscore
from jax import jit
from jax import numpy as jnp
import numpy as np
from tqdm import tqdm
import time
from prism.stats import zscore as jax_zscore

# Random matrix
arr = np.random.rand(3, 100000)
arr.shape


vec1 = arr[0, :]

n_iterations = 10000
time_start = time.time()
for i in tqdm(range(n_iterations)):
    vec2_z = jax_zscore(vec1)
timeend = time.time()
time_per_iteration_jax = (timeend - time_start) / n_iterations
print("Jax compiled approach takes", time_per_iteration_jax, "seconds per iteration")

# Now try the scipy zscore, using jax.vmap, but not compiling, to se if its better
start_time = time.time()
for i in tqdm(range(n_iterations)):
    vec2_z = scipy_zscore(vec1)
end_time = time.time()
time_per_iteration_scipy = (end_time - start_time) / n_iterations
print("Scipy zscore approach takes", time_per_iteration_scipy, "seconds per iteration")

# Calculate speedup ratio using jax
speedup_ratio = time_per_iteration_scipy / time_per_iteration_jax
print("Speedup ratio using jax:", speedup_ratio)

100%|██████████| 10000/10000 [00:00<00:00, 20133.72it/s]


Jax compiled approach takes 5.06742000579834e-05 seconds per iteration


100%|██████████| 10000/10000 [00:02<00:00, 3633.78it/s]

Scipy zscore approach takes 0.000275308084487915 seconds per iteration
Speedup ratio using jax: 5.432904400521306



