In [None]:
import matplotlib.pyplot as plt
import pybamm
import time
import dask
from distributed import Client

In [None]:
def generate_plots(discharge, t, capacity, current, voltage):

    def styleplot(ax):
        ax.legend(loc='best')
        ax.grid(color='0.9')
        ax.set_frame_on(False)
        ax.tick_params(color='0.9')

    _, ax = plt.subplots(tight_layout=True)
    for i in range(len(discharge)):
        ax.plot(t[i], current[i], label=f'{discharge[i]} A')
    ax.set_xlabel('Time [s]')
    ax.set_ylabel('Current [A]')
    styleplot(ax)

    _, ax = plt.subplots(tight_layout=True)
    for i in range(len(discharge)):
        ax.plot(t[i], voltage[i], label=f'{discharge[i]} A')
    ax.set_xlabel('Time [s]')
    ax.set_ylabel('Terminal voltage [V]')
    styleplot(ax)

    _, ax = plt.subplots(tight_layout=True)
    for i in range(len(discharge)):
        ax.plot(capacity[i], voltage[i], label=f'{discharge[i]} A')
    ax.set_xlabel('Discharge capacity [Ah]')
    ax.set_ylabel('Terminal voltage [V]')
    styleplot(ax)

    plt.show()

def run_simulation(dis, t_eval):

    model = pybamm.lithium_ion.SPMe()

    param = model.default_parameter_values
    param['Current function [A]'] = '[input]'

    sim = pybamm.Simulation(model, parameter_values=param)
    sim.solve(t_eval, inputs={'Current function [A]': dis})

    return sim.solution

In [None]:
tic = time.perf_counter()

discharge = [4, 3.5, 3, 2.5, 2, 1.8, 1.5, 1]  # discharge currents [A]
t_eval = [0, 4000]                            # evaluation time [s]

# No Dask
# ------------------------------------------------------------------------

label = 'no Dask'

sols = []
for dis in discharge:
    sol = run_simulation(dis, t_eval)
    sols.append(sol)

# Dask
# ------------------------------------------------------------------------

# label = 'Dask'

# lazy_sols = []
# for dis in discharge:
#     sol = dask.delayed(run_simulation)(dis, t_eval)
#     lazy_sols.append(sol)

# sols = dask.compute(*lazy_sols)

# ------------------------------------------------------------------------

t = []
capacity = []
current = []
voltage = []

for sol in sols:
    t.append(sol['Time [s]'].entries)
    capacity.append(sol['Discharge capacity [A.h]'].entries)
    current.append(sol['Current [A]'].entries)
    voltage.append(sol["Terminal voltage [V]"].entries)

toc = time.perf_counter()
print(f'Elapsed time ({label}) = {toc - tic:.2f} s')

generate_plots(discharge, t, capacity, current, voltage)