In [27]:
from typing import Callable

import jax
import jax.numpy as jnp

In [28]:
f = Callable[[int, int], int]

def test(x: int, y: int) -> int:
    return x + y

def test2(x: int, y: int) -> int:
    print('test')
    return x + y

# Define the function you want to pass separately
@jax.jit
def add(x: int, y: int) -> int:
    return x + y

# Define the main function that uses the JIT-compiled function
def test_func(x: int, y: int, func: Callable[[int, int], int]) -> int:
    return func(x, y)

# JIT-compile the main function
test_func = jax.jit(test_func)

test_jax = jax.jit(test)
test_jax2 = jax.jit(test2)

In [29]:
test_jax(5,3)

Array(8, dtype=int32, weak_type=True)

In [30]:
test_jax2(5,3)

test


Array(8, dtype=int32, weak_type=True)

In [32]:
test_func(5,3, add)

TypeError: Cannot interpret value of type <class 'jaxlib.xla_extension.PjitFunction'> as an abstract array; it does not have a dtype attribute

In [35]:
# Define the operations you want to perform as separate functions
def add(x: int, y: int) -> int:
    return x + y

def subtract(x: int, y: int) -> int:
    return x - y

# Define the main function with control flow to select the operation
# Define the main function with control flow based on integer flags
@jax.jit
def test_func(x: int, y: int, operation: int) -> int:
    if operation == 0:
        return add(x, y)
    elif operation == 1:
        return subtract(x, y)
    else:
        raise ValueError("Unsupported operation")

# Call the main function with the desired operation
result_add = test_func(5, 3, 0)
print(result_add)  # Output should be 8

result_subtract = test_func(5, 3, 1)
print(result_subtract)  # Output should be 2

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function test_func at /var/folders/f1/8mlz_8fd39bdxjwt47vnbpx80000gn/T/ipykernel_5026/1191508438.py:10 for jit. This concrete value was not available in Python because it depends on the value of the argument operation.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

In [36]:
def create_func(op):
    if op == "add":
        return jax.jit(lambda x, y: x + y)
    elif op == "mul":
        return jax.jit(lambda x, y: x * y)
    # Add more operations as needed


In [38]:
jix_add = create_func("add")
jix_mul = create_func("mul")

jix_add(3,4)
jix_mul(3,4)

Array(12, dtype=int32, weak_type=True)

In [43]:
from methods.potential_energy import PotentialEnergy
from jaxtyping import Array


def U(q: Array) -> float:
    return 0.5

# Create a factory function that returns a JIT-compiled version of test_f
def create_test_f(energy_func: Callable[[jnp.ndarray], float]):
    @jax.jit
    def test_f(x: float) -> float:
        return energy_func(jnp.array(x))
    return test_f

# Instantiate PotentialEnergy with the potential function U
energy = PotentialEnergy(U)

# Create the JIT-compiled test function using the factory
jit_test_f = create_test_f(energy.__call__)

# Now you can use jit_test_f directly
test_energy = jit_test_f(1)

print(test_energy)

0.5
