In [None]:
import os
import math
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

import pybamm
from pymcmcstat.MCMC import MCMC
from pymcmcstat.settings.DataStructure import DataStructure

## SPMe Model Definition

In [None]:
def current_function(c):
    
    def multiharmonic(t):
        current = c*(np.sin(2*math.pi*t/10) + np.sin(2*math.pi*t/1) + np.sin(2*math.pi*10*t) + np.sin(2*math.pi*100*t))
        return current
    
    return multiharmonic

In [None]:
soc_list = np.linspace(0, 1, num=11)

In [None]:
model = pybamm.lithium_ion.SPMe()

param = model.default_parameter_values
param['Negative electrode diffusivity [m2.s-1]'] = 3.54e-14
param['Positive electrode diffusivity [m2.s-1]'] = 1.01e-13
param['Electrolyte diffusivity [m2.s-1]'] = 2.8e-10
param['Cation transference number'] = 0.4

cur_amp = 0.002
param["Current function [A]"] =  current_function(cur_amp)

simulation = pybamm.Simulation(model, parameter_values=param)


t_max = 60
t_eval = np.linspace(0, t_max, num=(t_max//2))

In [None]:
fig = go.Figure()

for soc in soc_list:
    simulation.solve(t_eval, initial_soc = soc)
    solution = simulation.solution
    t = solution["Time [s]"]
    V = solution["Terminal voltage [V]"]
    fig.add_trace(go.Scatter(x = t_eval/60, y = V.entries, mode = 'lines', name = f"SOC"))
    
fig.update_xaxes(title='Time (min)')
fig.update_yaxes(title='Voltage (V)')

In [None]:
fig.show()

## Data Generation

In [None]:
# noise with σ = 3 mV
noise = np.random.normal(loc = 0, scale = 3e-3, size = V.entries.shape)

voltage_data = V.entries # + noise

## Parameter Estimation

In [None]:
def dfn_model(theta):
    
    Ds_n, Ds_p, De = theta
    
    param['Negative electrode diffusivity [m2.s-1]'] = np.exp(Ds_n)
    param['Positive electrode diffusivity [m2.s-1]'] = np.exp(Ds_p)
    param['Electrolyte diffusivity [m2.s-1]'] = np.exp(De)
    
    simulation = pybamm.Simulation(model, parameter_values=param)
    simulation.solve(t_eval, inputs={"Current function [A]": cur_app})
    
    return simulation.solution["Terminal voltage [V]"].entries   

In [None]:
def ss_error(theta, pass_arg):
    
    y_hat = dfn_model(theta)
    dif = voltage_data[:y_hat.size] - y_hat
    return np.dot(dif.T, dif)
    

## Setup MCMC Simulation

In [None]:
theta0 = {"D_sn": np.log(5e-12), "D_sp": np.log(5e-12), "D_e": np.log(5e-12)}

In [None]:
# Initialize MCMC object
mcstat = MCMC()

# Add data
mcstat.data.add_data_set(x=t_eval,
                         y=voltage_data)

# Define model parameters

mcstat.parameters.add_model_parameter(name='$D_s,n$',
                                      theta0=theta0['D_sn'], 
                                      minimum=np.log(1e-16),
                                      maximum = np.log(1e-8),
                                      sample=True)

mcstat.parameters.add_model_parameter(name='$D_s,p$',
                                      theta0=theta0['D_sp'],
                                      minimum= np.log(1e-16),
                                      maximum = np.log(1e-8),
                                      sample=True)

mcstat.parameters.add_model_parameter(name='$D_e$',
                                      theta0=theta0['D_e'],
                                      minimum= np.log(1e-16),
                                      maximum = np.log(1e-8),
                                      sample=True)

In [None]:
mcstat.model_settings.define_model_settings(sos_function = ss_error)

# Define simulation options
mcstat.simulation_options.define_simulation_options(
    nsimu=1e5,
    updatesigma=True)

In [None]:
# Run Simulation
mcstat.run_simulation(use_previous_results=False)

results = mcstat.simulation_results.results
names = results['names']
fullchain = results['chain']
fulls2chain = results['s2chain']
nsimu = results['nsimu']
burnin = int(nsimu/20)
chain = fullchain[burnin:, :]
s2chain = fulls2chain[burnin:, :]

mcstat.chainstats(chain, results)

# plot chain metrics
mcstat.mcmcplot.plot_chain_panel(chain, names, figsizeinches=(4, 4));
mcstat.mcmcplot.plot_density_panel(chain, names, figsizeinches=(4, 4));
mcstat.mcmcplot.plot_pairwise_correlation_panel(chain, names,
                                                figsizeinches=(4, 4));