<a href="https://colab.research.google.com/github/nevencaplar/JaxPeriodDrwFit/blob/main/JaxAstronomy_save_compile_and_pad.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Jax imports
import jax
import jax.numpy as jnp
from jax import jit
from jax import export
from jax._src import compilation_cache as cc
from jax._src.lib import xla_client

import numpy as np
import timeit
import pickle
from pathlib import Path
import logging

# print if you are using cpu or gpu
print("cpu/gpu: " + str(jax.default_backend()))

# cache options
cc.set_cache_dir("./cache_min_example")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_explain_cache_misses", True)

# logging options
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logging.debug("test")

jax.__version__

DEBUG:root:test


cpu/gpu: cpu


In [None]:
# has to be True for be able to cache
getattr(xla_client.Client, "supports_executable_serialization", True)

**Simple Test function**

In [None]:
# test function, just for fun
# repeat 20 times so that it takes some time
@jax.jit
def test_fun(z):
  def fn(y):
    y -= jax.nn.logsumexp(y,0)
    y -= jax.nn.logsumexp(y,1)
    return y
  for _ in range(20): z = fn(z)
  return z

In [None]:
# https://jax.readthedocs.io/en/latest/_autosummary/jax.random.gumbel.html
# run on randomly generated data
a = jax.random.gumbel(jax.random.PRNGKey(0),(100,100))

In [None]:
%%time
# first time slow (due to compile)
_ = test_fun(a)

%%time
# second time fast (due to compile)
_ = test_fun(a)

**More complex**

In [None]:

def complex_kernel(arr):
    multiplier = 2.5
    # Element-wise multiplication
    arr = arr * multiplier

    # Apply sine transformation
    arr = jnp.sin(arr)

    # Normalize the array to mean 0 and standard deviation 1
    mean = jnp.mean(arr)
    std_dev = jnp.std(arr)

    for _ in range(100):
      _ = (arr - mean) / std_dev

    arr = (arr - mean) / std_dev

    return arr

# Test the complex kernel on GPU with varying input lengths
def test_complex_kernel():
    # Different lengths to test
    array_lengths = [10, 100, 1000, 10000]
    evaluations = 10  # Number of evaluations for each array length

    for length in array_lengths:
        # Generate a random 1D array of the given length
        arr = jnp.array(np.random.randn(length), dtype=jnp.float32)

        out_file_name = f'test_array_size_{length}.pkl'

        if Path(out_file_name).exists():
          with open(out_file_name, 'rb') as f:
            serialized = pickle.load(f)
        else:
          exported = export.export(jax.jit(complex_kernel))(
              jax.ShapeDtypeStruct((), jax.ShapedArray(np.float32, shape=length)))
          serialized = exported.serialize()

          if not Path(out_file_name).exists():
            with open(out_file_name, 'wb') as f:
                pickle.dump(serialized, f)

        rehydrated_exp = export.deserialize(serialized)

        # Define a wrapper for the timeit function
        def kernel_evaluation():
            # return complex_kernel(arr).block_until_ready()  # Ensure GPU sync for timing
            return rehydrated_exp.call(arr).block_until_ready()  # Ensure GPU sync for timing

        first_iteration_time = timeit.timeit(kernel_evaluation, number=1)

        remaining_time = timeit.timeit(kernel_evaluation, number=evaluations - 1)
        avg_remaining_time = remaining_time / (evaluations - 1)
        total_time = first_iteration_time + remaining_time

        # Print timing results
        print(f"Array Length: {length}")
        print(f"Total time for {evaluations} evaluations: {total_time:f} seconds")
        print(f"Time for first compilation: {first_iteration_time}")
        print(f"Average time for remaining evaluation: {avg_remaining_time:f} seconds\n")

# Run the test case
test_complex_kernel()
