In [2]:
import pybamm
import time
import numpy as np
import jax
import jax.numpy as jnp

In [3]:
# We will want to differentiate our model, so let's define two input parameters
inputs = {
    "Current function [A]": 0.222,
    "Separator porosity": 0.3,
}

# Set-up the model
model = pybamm.lithium_ion.DFN()
geometry = model.default_geometry
param = model.default_parameter_values
param.update({key: "[input]" for key in inputs.keys()})
param.process_geometry(geometry)
param.process_model(model)
var = pybamm.standard_spatial_vars
var_pts = {var.x_n: 20, var.x_s: 20, var.x_p: 20, var.r_n: 10, var.r_p: 10}
mesh = pybamm.Mesh(geometry, model.default_submesh_types, var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)

# Use a short time-vector for this example, and declare which variables to track
t_eval = np.linspace(0, 360, 10)
output_variables = [
    "Voltage [V]",
    "Current [A]",
    "Time [min]",
]

# Create the IDAKLU Solver object
idaklu_solver = pybamm.IDAKLUSolver(
    rtol=1e-6,
    atol=1e-6,
    output_variables=output_variables,
)

In [4]:
# This is how we would normally perform a solve using IDAKLU
sim = idaklu_solver.solve(
    model,
    t_eval,
    inputs=inputs,
    calculate_sensitivities=True,
)

# Instead, we Jaxify the IDAKLU solver using similar arguments...
jax_solver = idaklu_solver.jaxify(
    model,
    t_eval,
)

# ... and then obtain a JAX expression for the solve
f = jax_solver.get_jaxpr()
print(f"JAX expression: {f}")

JAX expression: <function IDAKLUJax._jaxify.<locals>.f at 0x7f8d6ed0bd90>


In [5]:
# This is how we would normally perform a solve using IDAKLU
sim = idaklu_solver.solve(
    model,
    t_eval,
    inputs=inputs,
    calculate_sensitivities=True,
)

# Instead, we Jaxify the IDAKLU solver using similar arguments...
jax_solver = idaklu_solver.jaxify(
    model,
    t_eval,
)

# ... and then obtain a JAX expression for the solve
f = jax_solver.get_jaxpr()
print(f"JAX expression: {f}")

JAX expression: <function IDAKLUJax._jaxify.<locals>.f at 0x7f8ec42024d0>


In [6]:
# Print all output variables, evaluated over a given time vector
data = f(t_eval, inputs)
print(data)

'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)


[[3.81933939e+000 2.22000000e-001 1.15484130e-311]
 [3.81349276e+000 2.22000000e-001 6.66666667e-001]
 [3.81083261e+000 2.22000000e-001 1.33333333e+000]
 [3.80888692e+000 2.22000000e-001 2.00000000e+000]
 [3.80717743e+000 2.22000000e-001 2.66666667e+000]
 [3.80555556e+000 2.22000000e-001 3.33333333e+000]
 [3.80397075e+000 2.22000000e-001 4.00000000e+000]
 [3.80240520e+000 2.22000000e-001 4.66666667e+000]
 [3.80085126e+000 2.22000000e-001 5.33333333e+000]
 [3.79930646e+000 2.22000000e-001 6.00000000e+000]]
