Import necessary packages

In [None]:
using Random
using CSV, DataFrames
using Plots
using OrdinaryDiffEq, DiffEqSensitivity
using Zygote, ForwardDiff
using Flux
using Flux.Optimise: update!
using Flux.Losses: mae, mse
using BSON: @save, @load
using ProgressBars, Printf
using LinearAlgebra
using Statistics

Load and visualize experimental data

In [None]:
#LOAD FILES 
file_path = raw"C:\Users\Arjun\Desktop\CRNN\Anode_from_scratch\dataset.csv"
data = CSV.read(file_path, DataFrame)

time_exp = data[:, 1];  # Column 1 is time in min
temp_exp = data[:, 2];  # Column 2 is temp of sample in °C
heat_flow = data[:, 3];  # Column 3 is heat flow in W/kg

#VISUALIZE DATA 
plot(temp_exp, heat_flow, xlabel="Temperature", ylabel="Heat Flow", title="Temperature vs. Heat Flow", lw=2, legend=false)


MODEL PARAMS

In [None]:
Random.seed!(1234);  # Allows reproducibility

n_epochs = 60000;  # Number of training epochs
n_plot = 100;  # Frequency of plotting
tstep = 0.4;  # Timestep chosen for simulations
atol = 1e-5;  # Absolute tolerance
rtol = 1e-2;  # Relative tolerance
maxiters = 10000;  # Maximum iterations for the ODE solver

grad_max = 1.e2;
lr_adam = 1.e-3;
w_decay = 1.e-7;

TRAINABLE KINETIC PARAMETERS SETUP

In [None]:
# TRAINABLE KINETIC PARAMETERS DEFINITION + INITIALIZATION
np = 27;  # 21 kinetic parameters + 1 slope parameter
p = randn(np) .* 1.e-2;  # No explicit Float64 type

# Lower and upper bounds with correct types
lb = zero(eltype(p)) + 1e-5;
ub = zero(eltype(p)) + 1e1;

# Initialize kinetic parameters
p[1:5] .+= 35;  # Pre-exponential factors
Ea_IC = [0.3, 0.6, 1.4, 1.2, 1];  # Initial activation energies
p[6:10] .+= Ea_IC;
p[11:15] .+= 1; #b values
p[16] += 0.2;  # Enthalpy adjustments
p[17] += 1;
p[18] += 0.8;
p[19] += 0.5;
p[20] += 0.3
p[21:25] .+= 1;  # Reaction orders
p[26] = 0.5;  # Alpha conversion factor
p[27] = 0.1;  # Slope


SETUP parameters with necessary clamps to prevent under/over-estimation

In [None]:
function p2vec(p)
    slope = p[end] .* 1.e1

    w_A = p[1:5] .* (slope * 20.0)
    w_A = clamp.(w_A, zero(eltype(p)), 50)

    w_in_order = clamp.(p[21:25], 0.01, 10)

    w_in_Ea = clamp.(abs.(p[6:10]), zero(eltype(p)), 3)

    w_in_b = p[11:15]

    w_delH = clamp.(abs.(p[16:20]) * 100, 10, 30000000)

    w_alpha = clamp.(p[26], zero(eltype(p)), one(eltype(p)))

    return w_in_Ea, w_in_b, w_delH, w_in_order, w_A, w_alpha
end


In [None]:
function display_p(p)
    w_in_Ea, w_in_b, w_delH, w_in_order, w_A, w_alpha = p2vec(p)
    println("species (column) reaction (row)")
    println("rxn ord")
    show(stdout, "text/plain", round.(w_in_order', digits=3))
    println("\nEa")
    show(stdout, "text/plain", round.(w_in_Ea', digits=3))
    println("\nb")
    show(stdout, "text/plain", round.(w_in_b', digits=3))
    println("\ndelH")
    show(stdout, "text/plain", round.(w_delH', digits=3))
    println("\nlnA")
    show(stdout, "text/plain", round.(w_A', digits=3))
end

Function to convert time points in the ODE to temperature values based on heating rate β  K/min

In [None]:
function getsampletemp(t, T0, beta)
    T = T0 .+ beta / 60 * t  # K/min to K/s
    return T
end

In [None]:
#Kinetic model
#State varaibles: [xSEI, xLiC1, xLiC2, xAnElIV, xAnElV]
#Dynamics. [dxSEI/dt, dxLiC1/dt, dxLiC2/dt, dxAnElIV/dt, dAnElV/dt]


In [None]:
function anode_crnn_factory(p)
    w_in_Ea, w_in_b, w_delH, w_in_order, w_A, w_alpha = p2vec(p)
    beta = one(eltype(p)) * 5
    T0 = one(eltype(p)) * (80 + 273.15)
    tseir = one(eltype(p)) * 0.033
    tsei0 = one(eltype(p)) * 0.05
    R = -one(eltype(p)) / 8.314
    lb_local = zero(eltype(p)) + 1e-5

    function anode_crnn!(du, u, p, t)
        
        xLiC1 = u[1]
        xLiC2 = u[2]
        
        xLiC = 1.0 - xLiC1 -xLiC2 #Unreacted lithiated graphite

        t_sei = tsei0 - xLiC1

        conc = [u[1],xLiC,xLiC,u[4],u[5]]

        logX = @. log(clamp(conc, lb_local, 10.0))
    
        T = getsampletemp(t, T0, beta)

        temp_term = reshape(hcat(log(T), R / T) * hcat(w_in_b, w_in_Ea * 1e5)', 5)
        rxn_ord_term = w_in_order .* logX

        pre_rxn_rates = temp_term + rxn_ord_term + w_A
        #pre_rxn_rates[2] -= t_sei / tseir
        

        pre_rxn_rates = @. exp(pre_rxn_rates)

        du[1] = pre_rxn_rates[1]
        du[2] = pre_rxn_rates[2]
        du[3] = pre_rxn_rates[3]
        du[4] = pre_rxn_rates[4]
        du[5] = pre_rxn_rates[5]
    end
    return anode_crnn!
end


In [None]:
function HRR_getter(times, u_outputs, p)
    w_in_Ea, w_in_b, w_delH, w_in_order, w_A, w_alpha = p2vec(p)
    beta = one(eltype(p)) * 5
    T0 = one(eltype(p)) * (80 + 273.15)
    tseir = one(eltype(p)) * 0.033
    tsei0 = one(eltype(p)) * 0.05
    R = -one(eltype(p)) / 8.314
    lb_local = zero(eltype(p)) + 1e-5

    num_times = length(times)
    u_out = u_outputs'

    conc_prof = zeros(eltype(p), num_times, 5)
    conc_prof[:, 1] .= u_out[:, 1]
    conc_prof[:, 2] .= 1 .- u_out[:, 2] .- u_out[:,3]
    conc_prof[:, 3] .= 1 .- u_out[:, 2] .- u_out[:,3]
    conc_prof[:, 4] .= u_out[:, 4]
    conc_prof[:, 5] .= u_out[:, 5]

    log_conc_prof = @. log(clamp(conc_prof, lb_local, 10.0))

    T = getsampletemp(times, T0, beta)
    log_T = log.(T)

    temp_term1 = log_T .* w_in_b'
    temp_term2 = (R ./ T) .* (w_in_Ea * 1e5)'
    temp_term = temp_term1 .+ temp_term2

    rxn_ord_term = log_conc_prof .* w_in_order'

    comb_term = temp_term .+ rxn_ord_term
    pre_rxn_rates = comb_term .+ w_A'

    t_sei = tsei0 .- u_out[:,1]
    
    ################pre_rxn_rates[:,2] .-= t_sei ./ tseir

    rxn_rates = @. exp(pre_rxn_rates)
    return rxn_rates
end

In [None]:
function pred_n_ode(p)
    ts = @view (data[:, 1])
    tspan = (ts[1], ts[end])

    u0 = [
        one(eltype(p)) * 1.0,
        one(eltype(p)) * 0.0,
        one(eltype(p)) * 0.0,
        one(eltype(p)) * 1.0,
        one(eltype(p)) * 1.0,
    ]

    anode_crnn! = anode_crnn_factory(p)
    prob = ODEProblem(anode_crnn!, u0, tspan)

    sol = solve(
        prob,
        AutoTsit5(Rosenbrock23()),
        p = p,
        saveat = ts,
        sensealg = ForwardSensitivity(autojacvec = true),
        maxiters = maxiters,
        abstol = atol,
        reltol = rtol,
    )

    w_in_Ea, w_in_b, w_delH, w_in_order, w_A, w_alpha = p2vec(p)

    heat_rel = HRR_getter(ts, sol[:, :], p) .* w_delH'

    if sol.retcode != :Success
        @warn "Solver failed with retcode: $(sol.retcode)"
    end

    return heat_rel, ts, sol
end


In [None]:
function loss_neuralode(p)
    pred = Array(pred_n_ode(p)[1])
    pred_total = sum(pred, dims=2)[:, 1]
    loss = mae(pred_total, @view(data[:, 3]))
    return loss
end


In [None]:
function plot_sol(HR1, HR2, HR3, HR4, HR5, exp_data, Tlist)
    beta = 5  # K/min
    T0 = 80 + 273.15

    sol = HR1 + HR2 + HR3 + HR4 + HR5

    plt = plot(
        Tlist,
        exp_data,
        seriestype = :scatter,
        label = "Exp",
    )

    plot!(
        plt,
        Tlist,
        HR1,
        lw = 3,
        legend = :left,
        label = "Peak 1: SEI growth",
    )
    plot!(
        plt,
        Tlist,
        HR2,
        lw = 3,
        legend = :left,
        label = "Peak 2: SEI Decomposition",
    )
    plot!(
        plt,
        Tlist,
        HR3,
        lw = 3,
        legend = :left,
        label = "Peak 3: Li-EC reaction",
    )
    plot!(
        plt,
        Tlist,
        HR4,
        lw = 3,
        legend = :left,
        label = "Peak 4: An-El IV reaction",
    )
    plot!(
        plt,
        Tlist,
        HR5,
        lw = 3,
        legend = :left,
        label = "Peak 4: An-El V reaction",
    )
    plot!(
        plt,
        Tlist,
        sol,
        lw = 3,
        legend = :left,
        label = "CRNN sum",
    )

    xlabel!(plt, "Time [min]")
    ylabel!(plt, "HRR")

    p2 = plot(Tlist, sol, lw = 2, legend = :right, label = "Heat release")
    xlabel!(p2, "Time [min]")
    ylabel!(p2, "W/g")

    plt = plot(plt, p2, framestyle = :box, layout = @layout [a; b])
    plot!(plt, size = (800, 800))
    return plt
end

function cbi(p)
    exp_data = data[:, 3]
    heat_rel, times, raw_sol = pred_n_ode(p)

    HRR = HRR_getter(times, raw_sol[:, :], p)
    w_in_Ea, w_in_b, w_delH, w_in_order, w_A, w_alpha = p2vec(p)
    HR1 = HRR[:, 1] * w_delH[1]
    HR2 = HRR[:, 2] * w_delH[2]
    HR3 = HRR[:, 3] * w_delH[3]
    HR4 = HRR[:, 4] * w_delH[4]
    HR5 = HRR[:,5] * w_delH[5]

    T0 = 80 + 273.15
    beta = 5  # K/min
    Tlist = getsampletemp(times, T0, beta)

    plt = plot_sol(HR1, HR2, HR3, HR4, HR5, exp_data, Tlist)
    png(plt, string("C:\\Users\\Arjun\\Desktop\\CRNN\\Anode_from_scratch\\figs", "/pred_exp.png"))

    return false
end


In [None]:
l_loss_train = []
list_grad = []
iter = 1

function plot_loss(l_loss_train; yscale = :log10)
    plt_loss = plot(l_loss_train, yscale = yscale, label = "train")
    plt_grad = plot(list_grad, yscale = yscale, label = "grad_norm")
    xlabel!(plt_loss, "Epoch")
    ylabel!(plt_loss, "Training Loss")
    xlabel!(plt_grad, "Epoch")
    ylabel!(plt_grad, "Gradient Norm")

    plt_all = plot([plt_loss, plt_grad]..., legend = :top, framestyle = :box)
    plot!(
        plt_all,
        size = (1000, 450),
        xtickfontsize = 11,
        ytickfontsize = 11,
        xguidefontsize = 12,
        yguidefontsize = 12,
    )
    png(plt_all, "C:\\Users\\Arjun\\Desktop\\CRNN\\Anode_from_scratch\\loss\\loss_grad.png")
end

cb = function (p, loss_train, g_norm)
    global l_loss_train, list_grad, iter

    if !isempty(l_loss_train)
        if loss_train < minimum(l_loss_train)
            global p_opt = deepcopy(p)
        end
    end
    push!(l_loss_train, loss_train)
    push!(list_grad, g_norm)

    if iter % n_plot == 0 || iter == 1
        display_p(p)
        if @isdefined p_opt
            @printf("Parameters of lowest yet loss:\n")
            display_p(p_opt)
        end

        @printf("Min Loss train: %.2e\n", minimum(l_loss_train))
        cbi(p)
        plot_loss(l_loss_train; yscale = :log10)
    end
    iter += 1
end


In [None]:
opt = ADAMW(lr_adam, (0.9, 0.999), w_decay)

epochs = ProgressBar(1:n_epochs)
loss_epoch = zeros(Float64, n_epochs)
grad_norm = zeros(Float64, n_epochs)

for epoch in epochs
    global p

    grad = ForwardDiff.gradient(x -> loss_neuralode(x), p)
    grad_norm[epoch] = norm(grad, 2)

    if grad_norm[epoch] > grad_max
        grad = grad ./ grad_norm[epoch] .* grad_max
    end

    update!(opt, p, grad)

    loss_epoch[epoch] = loss_neuralode(p)

    set_description(
        epochs,
        @sprintf(
            "Epoch %d: Loss: %.2e Grad Norm: %.2e",
            epoch,
            loss_epoch[epoch],
            grad_norm[epoch],
        ),
    )

    cb(p, loss_epoch[epoch], grad_norm[epoch])
end
