In [36]:
import pybamm
import time
import numpy as npq
import jax
import jax.numpy as jnp

In [37]:
# 1) pip install "pybamm[iree,jax]"  
# 2) then install after "pip install jax[cuda12]" - this upgrades jax to support CUDA12
print("Available devices:", jax.devices())

Available devices: [cuda(id=0)]


In [59]:
# We will want to differentiate our model, so let's define two input parameters
inputs = {
    "Current function [A]": 0.222,
    "Separator porosity": 0.3,
    
}

# Set-up the model
options = {"cell geometry": "arbitrary", "thermal": "lumped"}
model = pybamm.lithium_ion.DFN(options=options)
geometry = model.default_geometry
param = model.default_parameter_values
param.update({key: "[input]" for key in inputs.keys()})
param.process_geometry(geometry)
param.process_model(model)
var = pybamm.standard_spatial_vars
var_pts = {var.x_n: 20, var.x_s: 20, var.x_p: 20, var.r_n: 10, var.r_p: 10}
mesh = pybamm.Mesh(geometry, model.default_submesh_types, var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)

# Use a short time-vector for this example, and declare which variables to track
t_eval = np.linspace(0, 360, 10)
output_variables = [
    "Voltage [V]",
    "Current [A]",
    "Time [min]",
]

# Create the IDAKLU Solver object
idaklu_solver = pybamm.IDAKLUSolver(
    rtol=1e-6,
    atol=1e-6,
    output_variables=output_variables,
)

In [60]:
# This is how we would normally perform a solve using IDAKLU
sim = idaklu_solver.solve(
    model,
    t_eval,
    inputs=inputs,
    calculate_sensitivities=True,
)

# Instead, we Jaxify the IDAKLU solver using similar arguments...
jax_solver = idaklu_solver.jaxify(
    model,
    t_eval,
)

# ... and then obtain a JAX expression for the solve
f = jax_solver.get_jaxpr()
print(f"JAX expression: {f}")

JAX expression: <function IDAKLUJax._jaxify.<locals>.f at 0x7fb6d8c94430>


In [61]:
# This is how we would normally perform a solve using IDAKLU
sim = idaklu_solver.solve(
    model,
    t_eval,
    inputs=inputs,
    calculate_sensitivities=True,
)

# Instead, we Jaxify the IDAKLU solver using similar arguments...
jax_solver = idaklu_solver.jaxify(
    model,
    t_eval,
)

# ... and then obtain a JAX expression for the solve
f = jax_solver.get_jaxpr()
print(f"JAX expression: {f}")

JAX expression: <function IDAKLUJax._jaxify.<locals>.f at 0x7fb6d8c95990>


In [62]:
# Print all output variables, evaluated over a given time vector
data = f(t_eval, inputs)
print(data)

[[3.81933939e+000 2.22000000e-001 1.15635786e-311]
 [3.81351212e+000 2.22000000e-001 6.66666667e-001]
 [3.81085763e+000 2.22000000e-001 1.33333333e+000]
 [3.80891360e+000 2.22000000e-001 2.00000000e+000]
 [3.80720490e+000 2.22000000e-001 2.66666667e+000]
 [3.80558327e+000 2.22000000e-001 3.33333333e+000]
 [3.80399869e+000 2.22000000e-001 4.00000000e+000]
 [3.80243297e+000 2.22000000e-001 4.66666667e+000]
 [3.80087903e+000 2.22000000e-001 5.33333333e+000]
 [3.79933426e+000 2.22000000e-001 6.00000000e+000]]


In [63]:
# Isolate a single variables
data = jax_solver.get_var("Voltage [V]")(t_eval, inputs)
print(f"Isolating a single variable returns an array of shape {data.shape}")
print(data)

# Isolate two variables from the solver
data = jax_solver.get_vars(
    [
        "Voltage [V]",
        "Current [A]",
    ],
)(t_eval, inputs)
print(f"\nIsolating two variables returns an array of shape {data.shape}")
print(data)

Isolating a single variable returns an array of shape (10,)
[3.81933939 3.81351212 3.81085763 3.8089136  3.8072049  3.80558327
 3.80399869 3.80243297 3.80087903 3.79933426]

Isolating two variables returns an array of shape (10, 2)
[[3.81933939 0.222     ]
 [3.81351212 0.222     ]
 [3.81085763 0.222     ]
 [3.8089136  0.222     ]
 [3.8072049  0.222     ]
 [3.80558327 0.222     ]
 [3.80399869 0.222     ]
 [3.80243297 0.222     ]
 [3.80087903 0.222     ]
 [3.79933426 0.222     ]]


In [53]:
# Calculate the Jacobian matrix (via forward autodiff)
t_start = time.time()
out = jax.jacfwd(f, argnums=1)(t_eval, inputs)
print(f"Jacobian forward method ran in {time.time()-t_start:0.3} secs")
print(out)

# Calculate Jacobian matrix (via backward autodiff)
t_start = time.time()
out = jax.jacrev(f, argnums=1)(t_eval, inputs)
print(f"\nJacobian reverse method ran in {time.time()-t_start:0.3} secs")
print(out)

Jacobian forward method ran in 0.0959 secs
{'Current function [A]': Array([[-0.13629603,  1.        ,  0.        ],
       [-0.16386375,  1.        ,  0.        ],
       [-0.17615635,  1.        ,  0.        ],
       [-0.18494974,  1.        ,  0.        ],
       [-0.19258651,  1.        ,  0.        ],
       [-0.19978548,  1.        ,  0.        ],
       [-0.20678279,  1.        ,  0.        ],
       [-0.21365524,  1.        ,  0.        ],
       [-0.22043194,  1.        ,  0.        ],
       [-0.22711774,  1.        ,  0.        ]], dtype=float64), 'Separator porosity': Array([[0.00579553, 0.        , 0.        ],
       [0.0079704 , 0.        , 0.        ],
       [0.0095279 , 0.        , 0.        ],
       [0.01024855, 0.        , 0.        ],
       [0.01053721, 0.        , 0.        ],
       [0.01064576, 0.        , 0.        ],
       [0.0106863 , 0.        , 0.        ],
       [0.01070118, 0.        , 0.        ],
       [0.01070757, 0.        , 0.        ],
       [


Jacobian reverse method ran in 1.2 secs
{'Current function [A]': Array([[-0.13629603,  1.        ,  0.        ],
       [-0.16386375,  1.        ,  0.        ],
       [-0.17615635,  1.        ,  0.        ],
       [-0.18494974,  1.        ,  0.        ],
       [-0.19258651,  1.        ,  0.        ],
       [-0.19978548,  1.        ,  0.        ],
       [-0.20678279,  1.        ,  0.        ],
       [-0.21365524,  1.        ,  0.        ],
       [-0.22043194,  1.        ,  0.        ],
       [-0.22711774,  1.        ,  0.        ]],      dtype=float64, weak_type=True), 'Separator porosity': Array([[0.00579553, 0.        , 0.        ],
       [0.0079704 , 0.        , 0.        ],
       [0.0095279 , 0.        , 0.        ],
       [0.01024855, 0.        , 0.        ],
       [0.01053721, 0.        , 0.        ],
       [0.01064576, 0.        , 0.        ],
       [0.0106863 , 0.        , 0.        ],
       [0.01070118, 0.        , 0.        ],
       [0.01070757, 0.        , 0.

In [64]:
# Isolate the derivate of Voltage with respect to the Current function:
out = jax.jacfwd(f, argnums=1)(t_eval, inputs)
data = jax_solver.get_var(out["Current function [A]"], "Voltage [V]")
print(data)

[-0.13629603 -0.16363422 -0.17585803 -0.18462944 -0.19225735 -0.19945328
 -0.20644869 -0.21332169 -0.22009827 -0.22678395]


In [65]:
# Example evaluation using the `grad` function
t_start = time.time()
data = jax.vmap(
    jax.grad(
        jax_solver.get_var("Voltage [V]"),
        argnums=1,  # take derivative with respect to `inputs`
    ),
    in_axes=(0, None),  # map time over the 0th dimension and do not map inputs
)(t_eval, inputs)
print(f"Gradient method ran in {time.time()-t_start:0.3} secs")
print(data)

Gradient method ran in 0.536 secs
{'Current function [A]': Array([-0.13629603, -0.16363422, -0.17585803, -0.18462944, -0.19225735,
       -0.19945328, -0.20644869, -0.21332169, -0.22009827, -0.22678395],      dtype=float64), 'Separator porosity': Array([0.00579553, 0.00796344, 0.00951735, 0.01023533, 0.01052245,
       0.01063035, 0.01067032, 0.01068552, 0.01069203, 0.01069544],      dtype=float64)}


# use case example

As a use-case example, consider a fitting procedure where we want to compare simulation data against some experimental data. We achieve this by computing the sum-of-squared error (SEE) between the two. Many fitting procedures will converge more quickly (with fewer iterations) if both the value and gradient of the SSE function are provided. By making use of JAX-expressions we can derive these effortlessly.

Note: We do not need to map over time when calling value_and_grad in this example as the sse function returns a scalar (despite taking vector inputs).

In [66]:
# Simulate some experimental data using our original parameter settings
data = sim["Voltage [V]"](t_eval)


# Sum-of-squared errors
def sse(t, inputs):
    modelled = jax_solver.get_var("Voltage [V]")(t_eval, inputs)
    return jnp.sum((modelled - data) ** 2)


# Provide some predicted model inputs (these could come from a fitting procedure)
inputs_pred = {
    "Current function [A]": 0.150,
    "Separator porosity": 0.333,
}

# Get the value and gradient of the SSE function
t_start = time.time()
value, gradient = jax.value_and_grad(sse, argnums=1)(t_eval, inputs_pred)
print(f"Value and gradient computed in {time.time()-t_start:0.3} secs")
print("SSE value: ", value)
print("SSE gradient (wrt each input): ", gradient)

Value and gradient computed in 0.317 secs
SSE value:  0.0020770188284553034
SSE gradient (wrt each input):  {'Current function [A]': array(-0.05756411), 'Separator porosity': array(0.00146621)}
