In [4]:
import pybamm
import time
import numpy as np
import jax
import jax.numpy as jnp

In [5]:
# 1) pip install "pybamm[iree,jax]"  
# 2) then install after "pip install jax[cuda12]" - this upgrades jax to support CUDA12
print("Available devices:", jax.devices())

Available devices: [cuda(id=0)]


In [6]:
# We will want to differentiate our model, so let's define two input parameters
inputs = {
    "Current function [A]": 0.222,
    "Separator porosity": 0.3,
    
}

# Set-up the model
options = {"cell geometry": "arbitrary", "thermal": "lumped"}
model = pybamm.lithium_ion.DFN(options=options)
geometry = model.default_geometry
param = model.default_parameter_values
param.update({key: "[input]" for key in inputs.keys()})
param.process_geometry(geometry)
param.process_model(model)
var = pybamm.standard_spatial_vars
var_pts = {var.x_n: 20, var.x_s: 20, var.x_p: 20, var.r_n: 10, var.r_p: 10}
mesh = pybamm.Mesh(geometry, model.default_submesh_types, var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)

# Use a short time-vector for this example, and declare which variables to track
t_eval = np.linspace(0, 360, 10)
output_variables = [
    "Voltage [V]",
    "Current [A]",
    "Time [min]",
]

# Create the IDAKLU Solver object
idaklu_solver = pybamm.IDAKLUSolver(
    rtol=1e-6,
    atol=1e-6,
    output_variables=output_variables,
)

In [7]:
# This is how we would normally perform a solve using IDAKLU
sim = idaklu_solver.solve(
    model,
    t_eval,
    inputs=inputs,
    calculate_sensitivities=True,
)

# Instead, we Jaxify the IDAKLU solver using similar arguments...
jax_solver = idaklu_solver.jaxify(
    model,
    t_eval,
)

# ... and then obtain a JAX expression for the solve
f = jax_solver.get_jaxpr()
print(f"JAX expression: {f}")

JAX expression: <function IDAKLUJax._jaxify.<locals>.f at 0x7f5c45007370>


In [8]:
# This is how we would normally perform a solve using IDAKLU
sim = idaklu_solver.solve(
    model,
    t_eval,
    inputs=inputs,
    calculate_sensitivities=True,
)

# Instead, we Jaxify the IDAKLU solver using similar arguments...
jax_solver = idaklu_solver.jaxify(
    model,
    t_eval,
)

# ... and then obtain a JAX expression for the solve
f = jax_solver.get_jaxpr()
print(f"JAX expression: {f}")

JAX expression: <function IDAKLUJax._jaxify.<locals>.f at 0x7f5c450077f0>


In [9]:
# Print all output variables, evaluated over a given time vector
data = f(t_eval, inputs)
print(data)

'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)


[[3.81933939 0.222      0.        ]
 [3.81351212 0.222      0.66666667]
 [3.81085748 0.222      1.33333333]
 [3.80891269 0.222      2.        ]
 [3.80720408 0.222      2.66666667]
 [3.80558248 0.222      3.33333333]
 [3.80399826 0.222      4.        ]
 [3.80243274 0.222      4.66666667]
 [3.80087893 0.222      5.33333333]
 [3.79933422 0.222      6.        ]]


In [10]:
# Isolate a single variables
data = jax_solver.get_var("Voltage [V]")(t_eval, inputs)
print(f"Isolating a single variable returns an array of shape {data.shape}")
print(data)

# Isolate two variables from the solver
data = jax_solver.get_vars(
    [
        "Voltage [V]",
        "Current [A]",
    ],
)(t_eval, inputs)
print(f"\nIsolating two variables returns an array of shape {data.shape}")
print(data)

Isolating a single variable returns an array of shape (10,)
[3.81933939 3.81351212 3.81085748 3.80891269 3.80720408 3.80558248
 3.80399826 3.80243274 3.80087893 3.79933422]

Isolating two variables returns an array of shape (10, 2)
[[3.81933939 0.222     ]
 [3.81351212 0.222     ]
 [3.81085748 0.222     ]
 [3.80891269 0.222     ]
 [3.80720408 0.222     ]
 [3.80558248 0.222     ]
 [3.80399826 0.222     ]
 [3.80243274 0.222     ]
 [3.80087893 0.222     ]
 [3.79933422 0.222     ]]


'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring f

In [11]:
# Calculate the Jacobian matrix (via forward autodiff)
t_start = time.time()
out = jax.jacfwd(f, argnums=1)(t_eval, inputs)
print(f"Jacobian forward method ran in {time.time()-t_start:0.3} secs")
print(out)

# Calculate Jacobian matrix (via backward autodiff)
t_start = time.time()
out = jax.jacrev(f, argnums=1)(t_eval, inputs)
print(f"\nJacobian reverse method ran in {time.time()-t_start:0.3} secs")
print(out)

'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring f

Jacobian forward method ran in 0.55 secs
{'Current function [A]': Array([[-0.13629603,  1.        ,  0.        ],
       [-0.13630352,  1.        ,  0.        ],
       [-0.13631101,  1.        ,  0.        ],
       [-0.13632595,  1.        ,  0.        ],
       [-0.13635572,  1.        ,  0.        ],
       [-0.13641497,  1.        ,  0.        ],
       [-0.13647369,  1.        ,  0.        ],
       [-0.13653216,  1.        ,  0.        ],
       [-0.13664787,  1.        ,  0.        ],
       [-0.1367626 ,  1.        ,  0.        ]], dtype=float64), 'Separator porosity': Array([[0.00579554, 0.        , 0.        ],
       [0.00579554, 0.        , 0.        ],
       [0.00579554, 0.        , 0.        ],
       [0.00579555, 0.        , 0.        ],
       [0.0057956 , 0.        , 0.        ],
       [0.00579579, 0.        , 0.        ],
       [0.0057961 , 0.        , 0.        ],
       [0.00579647, 0.        , 0.        ],
       [0.0057974 , 0.        , 0.        ],
       [0.

'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring f


Jacobian reverse method ran in 1.23 secs
{'Current function [A]': Array([[-0.13629603,  1.        ,  0.        ],
       [-0.13630352,  1.        ,  0.        ],
       [-0.13631101,  1.        ,  0.        ],
       [-0.13632595,  1.        ,  0.        ],
       [-0.13635572,  1.        ,  0.        ],
       [-0.13641497,  1.        ,  0.        ],
       [-0.13647369,  1.        ,  0.        ],
       [-0.13653216,  1.        ,  0.        ],
       [-0.13664787,  1.        ,  0.        ],
       [-0.1367626 ,  1.        ,  0.        ]],      dtype=float64, weak_type=True), 'Separator porosity': Array([[0.00579554, 0.        , 0.        ],
       [0.00579554, 0.        , 0.        ],
       [0.00579554, 0.        , 0.        ],
       [0.00579555, 0.        , 0.        ],
       [0.0057956 , 0.        , 0.        ],
       [0.00579579, 0.        , 0.        ],
       [0.0057961 , 0.        , 0.        ],
       [0.00579647, 0.        , 0.        ],
       [0.0057974 , 0.        , 0

In [12]:
# Isolate the derivate of Voltage with respect to the Current function:
out = jax.jacfwd(f, argnums=1)(t_eval, inputs)
data = jax_solver.get_var(out["Current function [A]"], "Voltage [V]")
print(data)

[-0.13629603 -0.13630352 -0.13631101 -0.13632595 -0.13635572 -0.13641497
 -0.13647369 -0.13653216 -0.13664787 -0.1367626 ]


In [13]:
# Example evaluation using the `grad` function
t_start = time.time()
data = jax.vmap(
    jax.grad(
        jax_solver.get_var("Voltage [V]"),
        argnums=1,  # take derivative with respect to `inputs`
    ),
    in_axes=(0, None),  # map time over the 0th dimension and do not map inputs
)(t_eval, inputs)
print(f"Gradient method ran in {time.time()-t_start:0.3} secs")
print(data)

'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring f

Gradient method ran in 0.524 secs
{'Current function [A]': Array([-0.13629603, -0.13630352, -0.13631101, -0.13632595, -0.13635572,
       -0.13641497, -0.13647369, -0.13653216, -0.13664787, -0.1367626 ],      dtype=float64), 'Separator porosity': Array([0.00579554, 0.00579554, 0.00579554, 0.00579555, 0.0057956 ,
       0.00579579, 0.0057961 , 0.00579647, 0.0057974 , 0.00579851],      dtype=float64)}


'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)


# use case example

As a use-case example, consider a fitting procedure where we want to compare simulation data against some experimental data. We achieve this by computing the sum-of-squared error (SEE) between the two. Many fitting procedures will converge more quickly (with fewer iterations) if both the value and gradient of the SSE function are provided. By making use of JAX-expressions we can derive these effortlessly.

Note: We do not need to map over time when calling value_and_grad in this example as the sse function returns a scalar (despite taking vector inputs).

In [14]:
# Simulate some experimental data using our original parameter settings
data = sim["Voltage [V]"](t_eval)


# Sum-of-squared errors
def sse(t, inputs):
    modelled = jax_solver.get_var("Voltage [V]")(t_eval, inputs)
    return jnp.sum((modelled - data) ** 2)


# Provide some predicted model inputs (these could come from a fitting procedure)
inputs_pred = {
    "Current function [A]": 0.150,
    "Separator porosity": 0.333,
}

# Get the value and gradient of the SSE function
t_start = time.time()
value, gradient = jax.value_and_grad(sse, argnums=1)(t_eval, inputs_pred)
print(f"Value and gradient computed in {time.time()-t_start:0.3} secs")
print("SSE value: ", value)
print("SSE gradient (wrt each input): ", gradient)

'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring feature)
'+ptx86' is not a recognized feature for this target (ignoring f

Value and gradient computed in 0.577 secs
SSE value:  0.002077049099434896
SSE gradient (wrt each input):  {'Current function [A]': array(-0.04099523), 'Separator porosity': array(0.00087822)}


# getting a parameter, x-avg cell volt

this is more so for my own understanding on how the battery simulator api should work when being called to solve for a simulation. right now i dont think it uses output variables but rather simulates all with casadi then filters based on what the user wants... need to double check this.
if so i would need to see if i can call output_variables before using any solver so the results you get from the api call are agnostic to unique solvers

# DFN simulation

In [15]:
import pybamm
import numpy as np
import jax

devices = jax.devices()
device_type = devices[0].device_kind
print(f"Simulation is running on: {device_type}")

# Define inputs
inputs = {
    "Current function [A]": 0.222,
    "Separator porosity": 0.3,
}

# Set-up the model
options = {"cell geometry": "arbitrary", "thermal": "lumped"}
model = pybamm.lithium_ion.DFN(options=options)
geometry = model.default_geometry
param = model.default_parameter_values
param.update({key: "[input]" for key in inputs.keys()})
param.process_geometry(geometry)
param.process_model(model)
var = pybamm.standard_spatial_vars
var_pts = {var.x_n: 20, var.x_s: 20, var.x_p: 20, var.r_n: 10, var.r_p: 10}
mesh = pybamm.Mesh(geometry, model.default_submesh_types, var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)

# Use a short time-vector for this example
t_eval = np.linspace(0, 360, 10)

# Include the variable of interest in output_variables
output_variables = [
    "Voltage [V]",
    "Current [A]",
    "Time [min]",
    "X-averaged cell temperature [C]" # want to retrieve this and print the output
]

# Create the IDAKLU Solver object
idaklu_solver = pybamm.IDAKLUSolver(rtol=1e-6, atol=1e-6, output_variables=output_variables)

start_time = time.time()

# Perform the simulation
sim = idaklu_solver.solve(
    model,
    t_eval,
    inputs=inputs,
    calculate_sensitivities=True,
)

end_time = time.time()
print(f"Simulation took {end_time - start_time:.2f} seconds.")

# Variable name to extract
var_name = "X-averaged cell temperature [C]"

# Check if the variable is present in the simulation results
if var_name not in sim._variables:
    print(f"Variable '{var_name}' not found in the simulation results.")
else:
    # Extract the variable from the simulation result
    try:
        # Evaluate the variable using the simulation result
        temp_data = sim[var_name](t_eval)
        print("X-averaged cell temperature [C]:")
        print(temp_data)
    except Exception as e:
        print(f"An error occurred while evaluating the variable: {e}")

Simulation is running on: Quadro T2000
Simulation took 0.98 seconds.
X-averaged cell temperature [C]:
[25.         25.01242858 25.0152713  25.01595535 25.0161316  25.0161862
 25.01620831 25.01622085 25.01622946 25.0162357 ]


# SPM simulation

In [16]:
import pybamm
import numpy as np
import jax

devices = jax.devices()
device_type = devices[0].device_kind
print(f"Simulation is running on: {device_type}")

# Define inputs
inputs = {
    "Current function [A]": 0.222,
    "Separator porosity": 0.3,
}

# Set-up the model
options = {"cell geometry": "arbitrary", "thermal": "lumped"}
model = pybamm.lithium_ion.SPM(options=options)
geometry = model.default_geometry
param = model.default_parameter_values
param.update({key: "[input]" for key in inputs.keys()})
param.process_geometry(geometry)
param.process_model(model)
var = pybamm.standard_spatial_vars
var_pts = {var.x_n: 20, var.x_s: 20, var.x_p: 20, var.r_n: 10, var.r_p: 10}
mesh = pybamm.Mesh(geometry, model.default_submesh_types, var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)

# Use a short time-vector for this example
t_eval = np.linspace(0, 360, 10)

# Include the variable of interest in output_variables
output_variables = [
    "Voltage [V]",
    "Current [A]",
    "Time [min]",
    "X-averaged cell temperature [C]" # want to retrieve this and print the output
]

# Create the IDAKLU Solver object
idaklu_solver = pybamm.IDAKLUSolver(rtol=1e-6, atol=1e-6, output_variables=output_variables)

start_time = time.time()

# Perform the simulation
sim = idaklu_solver.solve(
    model,
    t_eval,
    inputs=inputs,
    calculate_sensitivities=True,
)

end_time = time.time()
print(f"Simulation took {end_time - start_time:.2f} seconds.")

# Variable name to extract
var_name = "X-averaged cell temperature [C]"

# Check if the variable is present in the simulation results
if var_name not in sim._variables:
    print(f"Variable '{var_name}' not found in the simulation results.")
else:
    # Extract the variable from the simulation result
    try:
        # Evaluate the variable using the simulation result
        temp_data = sim[var_name](t_eval)
        print("X-averaged cell temperature [C]:")
        print(temp_data)
    except Exception as e:
        print(f"An error occurred while evaluating the variable: {e}")

Simulation is running on: Quadro T2000
Simulation took 0.17 seconds.
X-averaged cell temperature [C]:
[25.         25.01072713 25.01289727 25.01334668 25.01344455 25.01347274
 25.01348601 25.0134948  25.01350122 25.0135058 ]


# SPMe Simulation

In [17]:
import pybamm
import numpy as np
import jax

devices = jax.devices()
device_type = devices[0].device_kind
print(f"Simulation is running on: {device_type}")

# Define inputs
inputs = {
    "Current function [A]": 0.222,
    "Separator porosity": 0.3,
}

# Set-up the model
options = {"cell geometry": "arbitrary", "thermal": "lumped"}
model = pybamm.lithium_ion.SPMe(options=options)
geometry = model.default_geometry
param = model.default_parameter_values
param.update({key: "[input]" for key in inputs.keys()})
param.process_geometry(geometry)
param.process_model(model)
var = pybamm.standard_spatial_vars
var_pts = {var.x_n: 20, var.x_s: 20, var.x_p: 20, var.r_n: 10, var.r_p: 10}
mesh = pybamm.Mesh(geometry, model.default_submesh_types, var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)

# Use a short time-vector for this example
t_eval = np.linspace(0, 360, 10)

# Include the variable of interest in output_variables
output_variables = [
    "Voltage [V]",
    "Current [A]",
    "Time [min]",
    "X-averaged cell temperature [C]" # want to retrieve this and print the output
]

# Create the IDAKLU Solver object
idaklu_solver = pybamm.IDAKLUSolver(rtol=1e-6, atol=1e-6, output_variables=output_variables)

start_time = time.time()

# Perform the simulation
sim = idaklu_solver.solve(
    model,
    t_eval,
    inputs=inputs,
    calculate_sensitivities=True,
)

end_time = time.time()
print(f"Simulation took {end_time - start_time:.2f} seconds.")

# Variable name to extract
var_name = "X-averaged cell temperature [C]"

# Check if the variable is present in the simulation results
if var_name not in sim._variables:
    print(f"Variable '{var_name}' not found in the simulation results.")
else:
    # Extract the variable from the simulation result
    try:
        # Evaluate the variable using the simulation result
        temp_data = sim[var_name](t_eval)
        print("X-averaged cell temperature [C]:")
        print(temp_data)
    except Exception as e:
        print(f"An error occurred while evaluating the variable: {e}")

Simulation is running on: Quadro T2000
Simulation took 0.68 seconds.
X-averaged cell temperature [C]:
[25.         25.01253857 25.01545068 25.01616518 25.0163572  25.01641667
 25.01643949 25.0164513  25.01645871 25.0164637 ]
