In [None]:
using SciMLSensitivity
using DifferentialEquations
using SciMLSensitivity   # or DiffEqSensitivity if you prefer
using Zygote
using Optimisers         # for optimizer & update
using LinearAlgebra
using DifferentialEquations
using Flux
using Plots
using Optimization
using OptimizationOptimisers
using Zygote
using DataFrames

using Random
Random.seed!(1234)
println("All the nessecessary packages have been imported")

In [None]:

# Hodgkin-Huxley Model Parameters (Global Constants)


# Physical Constants
const Cm = 1.0        # ŒºF/cm^2
const g_Na = 120.0    # mS/cm^2
const g_K = 36.0      # mS/cm^2
const g_L = 0.3       # mS/cm^2
const E_Na = 50.0     # mV
const E_K = -77.0     # mV
const E_L = -54.387   # mV

In [None]:
# --- Cell 2: Known Physics & Stimulus ---

# Voltage-gated ion channel kinetics
Œ±_n(V) = 0.01 * (V + 55) / (1 - exp(-(V + 55) / 10))
Œ≤_n(V) = 0.125 * exp(-(V + 65) / 80)
Œ±_m(V) = 0.1 * (V + 40) / (1 - exp(-(V + 40) / 10))
Œ≤_m(V) = 4.0 * exp(-(V + 65) / 18)
Œ±_h(V) = 0.07 * exp(-(V + 65) / 20)
Œ≤_h(V) = 1 / (1 + exp(-(V + 35) / 10))

# Steady-state & time-constant functions for the 2D model
m_inf(V) = Œ±_m(V) / (Œ±_m(V) + Œ≤_m(V))
h_inf(V) = Œ±_h(V) / (Œ±_h(V) + Œ≤_h(V))
n_inf(V) = Œ±_n(V) / (Œ±_n(V) + Œ≤_n(V))
tau_n(V) = 1 / (Œ±_n(V) + Œ≤_n(V))

println("Physics of neural dynamics has been defined")



In [None]:
function Stimulus(t)
    # A 1ms pulse starting at 10ms
    return(t>-10.0 && t<11.0) ? 20 : 0.0
end

println(" An extra current form neighbour to generate a pulse")

In [None]:
# --- Cell 3: Data Generation ---

# 2D Hodgkin-Huxley reduced model engine
function hodgkin_huxley_reduced!(du, u, p, t)
    V, n = u
    I_ext = Stimulus(t)


    # Known 2D current
    I_Na = g_Na * m_inf(V)^3 * h_inf(V) * (V - E_Na)
    I_K  = g_K  * n^4 * (V - E_K)
    I_L  = g_L * (V - E_L)
    du[1] = (I_ext - I_Na - I_K - I_L) / Cm
    du[2] = (n_inf(V) - n) / tau_n(V)
end

# Generate Data
u0_true = [-65.0, n_inf(-65.0)]
tspan = (0.0, 50.0)
prob_true = ODEProblem(hodgkin_huxley_reduced!, u0_true, tspan)
sol_true = solve(prob_true, Rodas5P(), saveat=0.5)

# Extract and structure the training data
data_V = sol_true[1, :]
t_train = sol_true.t

# (Optional) Verify data shape and content
df = DataFrame(t=t_train, V=data_V )
println("Generated Training Data:")
display(first(df, 5))

In [None]:
U = Chain(
    Dense(1,15, tanh,init = Flux.glorot_uniform),
    # Dense(15,30,tanh,init = Flux.glorot_uniform),
    Dense(15,1,init=Flux.glorot_uniform)
) 

In [None]:
# Extract the trainable parameters (p_nn) and the re-structuring function (re)
p_nn, re = Flux.destructure(U)
println("Recruit Constructed. Parameters: ", length(p_nn))


## The hybrid UDE

In [None]:


# Define the UDE function with the embedded neural network
function ude_dynamics!(du, u, p, t)
    V, n = u
    # p --> p_nn neural network parameters
    # Neural network component to learn the unknown current
    # we will normalize V roughly ( divide by 100) to keep inputs clean for the NN

    nn_input = V / 100.0
    # We divide V by 100.0 to keep inputs small for the Neural Network
    # Example: -65mV becomes -0.65

    nn_I_Na = re(p)([nn_input])[1]
    # Known physics components
    I_ext = Stimulus(t)
    I_K  = g_K  * n^4 * (V - E_K)
    I_L  = g_L * (V - E_L)
    
    # The hybrid dynamics equation
    du[1] = (I_ext + nn_I_Na - I_K - I_L) / Cm
    du[2] = (n_inf(V) - n) / tau_n(V)
end
println("Hybrid Engine Assembled.")

In [None]:
# ---- Stable predict function using BacksolveAdjoint and Float64 inputs ----



prob_nn = ODEProblem(ude_dynamics!,u0_true,tspan , p_nn)
function predict_ude(p)
    # build problem with the current flattened NN params


    _prob=remake(prob_nn,p=p)
    
    
    solve(_prob, Rodas5P(), saveat=t_train, 
          sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP()))
end


# ---- Loss function (keep as Float64) ----
function loss(p)
    pred = predict_ude(p)
    if pred.retcode != :Success
        return 1e6
    end
    pred_V = pred[1, :]
    loss_val = sum(abs2, pred_V .- data_V)
    return loss_val
end

println("Objective Functions Defined.")

In [None]:

using JLD2
const CHECKPOINT_FILE = "neuron_mission_log.jld" # Corrected the typo from CKECKPOINT_FILE


#### Creating a `callback()`
a callback() function to measure the loss , store the loss and continue from the stoppez 

In [None]:
# Container to hold all losses in memory for plotting later
all_losses = Float64[]

# We create a robust callback generator
function create_callback(phase_name)
    return function (p, l)
        # 1. Check for failure (Exploding Gradients)
        if isnan(l)
            @warn "!!! ABORT MISSION: Loss is NaN in $phase_name !!!"
            return true # halting the optimization
        end
        
        # 2. Record the loss
        push!(all_losses, l)
        
        # 3. SAVE TO DISK (The Checkpoint)
        # We save the current parameters 'p' and the history 'all_losses'
        # Saving every iteration on a small model (46 params) is fast and safe.
        jldsave(CHECKPOINT_FILE; params=p, loss_history=all_losses)

        # 4. Status Report (Every 50 steps)
        current_iter = length(all_losses)
        if current_iter % 50 == 0
            println("[$phase_name] Iter: $current_iter | Loss: $l")
        end
        
        return false # Continue training
    end
end

# Resume Function: Checks if we have previous intel
function load_checkpoint_if_exists(initial_params)
    if isfile(CHECKPOINT_FILE)
        println("üìÇ INTEL FOUND: Loading previous checkpoint...")
        data = load(CHECKPOINT_FILE)
        # Restore global history
        global all_losses = data["loss_history"]
        # Return saved parameters
        println("   -> Resuming from Iteration $(length(all_losses)) with Loss $(last(all_losses))")
        return data["params"]
    else
        println("üåü NO INTEL: Starting fresh recruit training.")
        return initial_params
    end
end

In [None]:
function loss(p)
    pred = predict_ude(p)
    if pred.retcode != :Success
        return 1e6
    end
    pred_V = pred[1, :]
    loss_val = sum(abs2, pred_V .- data_V)
    return loss_val
end

println("Objective Functions Defined.")

# Define the OptimizationFunction
optf = Optimization.OptimizationFunction((p, adtype) -> loss(p), Optimization.AutoZygote())

Objective Functions Defined.


OptimizationFunction{true, AutoZygote, var"#27#28", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}(var"#27#28"(), AutoZygote(), nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED_NO_TIME, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing)

In [None]:
# --- LOAD STARTING STATE ---
# Check if we crashed before. If so, resume from there.
# If not, start with fresh recruit parameters (p_nn).
current_params = load_checkpoint_if_exists(p_nn)

# ==========================================
# STAGE 1: ROUGH TRAINING (Adam lr=0.05)
# Goal: Quickly find the general shape of the valley
# ==========================================
if length(all_losses) < 1000
    println("\n--- STAGE 1: ROUGH MANEUVERS (Adam 0.05) ---")
    optprob = Optimization.OptimizationProblem(optf, current_params)
    
    res1 = Optimization.solve(optprob, OptimizationOptimisers.Adam(0.05), 
                              callback=create_callback("Stage1"), 
                              maxiters=1000 - length(all_losses)) # Only run what's left
    
    global current_params = res1.u # Update our best params
else
    println("‚úÖ Stage 1 completed previously.")
end

# ==========================================
# STAGE 2: FINE TUNING (Adam lr=0.01)
# Goal: Settle into the stable orbit (limit cycle/spike)
# ==========================================
if length(all_losses) < 2000
    println("\n--- STAGE 2: FINE TUNING (Adam 0.01) ---")
    optprob2 = Optimization.OptimizationProblem(optf, current_params)
    
    # Calculate remaining iterations needed to reach 2000
    remaining_iters = 2000 - length(all_losses)
    
    if remaining_iters > 0
        res2 = Optimization.solve(optprob2, OptimizationOptimisers.Adam(0.01), 
                                  callback=create_callback("Stage2"), 
                                  maxiters=remaining_iters)
        global current_params = res2.u
    end
else
    println("‚úÖ Stage 2 completed previously.")
end

# ==========================================
# STAGE 3: FINAL POLISH (L-BFGS)
# Goal: Mathematical precision. 
# Note: L-BFGS requires `OptimizationOptimJL` and can result in NaNs if unstable.
# ==========================================
println("\n--- STAGE 3: THE POLISH (L-BFGS) ---")

# Checkpointing L-BFGS is harder because it keeps internal memory (Hessian).
# We essentially restart L-BFGS from our best point every time we run this block.
optprob3 = Optimization.OptimizationProblem(optf, current_params)

try
    res3 = Optimization.solve(optprob3, OptimizationOptimJL.LBFGS(), 
                              callback=create_callback("Stage3_LBFGS"), 
                              maxiters=500) # Run for 500 steps of polishing
    global current_params = res3.u
catch e
    println("‚ö†Ô∏è L-BFGS encountered instability or interrupt. Parameters saved safely.")
    println("Error: ", e)
end

println("\nüéâ MISSION COMPLETE. Total Iterations: $(length(all_losses))")
println("   Final Saved Loss: $(last(all_losses))")

üåü NO INTEL: Starting fresh recruit training.

--- STAGE 1: ROUGH MANEUVERS (Adam 0.05) ---


‚îÇ   The input will be converted, but any earlier layers may be very slow.
‚îÇ   layer = Dense(1 => 15, tanh)
‚îÇ   summary(x) = 1-element Vector{Float64}
‚îî @ Flux C:\Users\nirbh\.julia\packages\Flux\uRn8o\src\layers\stateless.jl:60


[Stage1] Iter: 50 | Loss: 2604.732604184652
[Stage1] Iter: 100 | Loss: 1761.3410719238284
[Stage1] Iter: 150 | Loss: 1557.4089777287425
[Stage1] Iter: 200 | Loss: 1314.7096864810308
[Stage1] Iter: 250 | Loss: 991.2153812671663
[Stage1] Iter: 300 | Loss: 754.2350997330925
[Stage1] Iter: 350 | Loss: 680.12886148631
[Stage1] Iter: 400 | Loss: 668.3740903523076
[Stage1] Iter: 450 | Loss: 636.6761430952537
[Stage1] Iter: 500 | Loss: 613.9090528627605
[Stage1] Iter: 550 | Loss: 593.9281586054842
[Stage1] Iter: 600 | Loss: 576.4011893335945
[Stage1] Iter: 650 | Loss: 561.1952538639416
[Stage1] Iter: 700 | Loss: 550.9160434889776
[Stage1] Iter: 750 | Loss: 534.6316304026655
[Stage1] Iter: 800 | Loss: 516.4121136681152
[Stage1] Iter: 850 | Loss: 499.91727665554146
[Stage1] Iter: 900 | Loss: 506.9220221127258
[Stage1] Iter: 950 | Loss: 427.91116827945206
[Stage1] Iter: 1000 | Loss: 263.0439166032882

--- STAGE 2: FINE TUNING (Adam 0.01) ---
[Stage2] Iter: 1050 | Loss: 187.11952826689017
[Stage2]

In [None]:
plot(losses,
     xlabel="Iteration",
     ylabel="Loss",
     title="Training Loss (Linear Scale)",
     label="Loss",
     lw=2)


UndefVarError: UndefVarError: `losses` not defined in `Main`
Suggestion: check for spelling errors or missing imports.

In [None]:
# 2. Visualizing the Recruit vs The Master
# Run a prediction with the TRAINED parameters (res.u)
final_sol = predict_ude(res.u)

p2 = plot(t_train, data_V, label="Ground Truth", lw=3, alpha=0.5, color=:green)
plot!(p2, final_sol.t, final_sol[1,:], label="UDE Prediction", lw=2, color=:red, linestyle=:dash)
title!(p2, "Neural Network Performance")
xlabel!("Time (ms)")
ylabel!("Voltage (mV)")
display(p2)

UndefVarError: UndefVarError: `res` not defined in `Main`
Suggestion: check for spelling errors or missing imports.

In [None]:
# Combined Plotting
total_iterations = 1:(length(losses_adam) + length(losses_lbfgs))

# Setup the canvas
p_combined = plot(title="Dual-Phase Training (Adam -> L-BFGS)", 
                  xlabel="Iteration", ylabel="Loss (Log Scale)", yaxis=:log)

# Plot Phase 1: Adam (Blue)
plot!(p_combined, 1:length(losses_adam), losses_adam, 
      label="Phase I: Adam (Coarse)", color=:blue, lw=2)

# Plot Phase 2: L-BFGS (Red)
# We shift the x-axis so it starts exactly where Adam ended
range_phase2 = (length(losses_adam)+1):length(total_iterations)
plot!(p_combined, range_phase2, losses_lbfgs, 
      label="Phase II: L-BFGS (Fine)", color=:red, lw=2)

# Add a vertical line to mark the hand-off
vline!(p_combined, [length(losses_adam)], label="Optimizer Switch", color=:black, linestyle=:dash)

display(p_combined)

UndefVarError: UndefVarError: `losses_adam` not defined in `Main`
Suggestion: check for spelling errors or missing imports.