<a href="https://colab.research.google.com/github/RCortez25/Scientific-Machine-Learning/blob/main/Differential_equations/SIR(NODE)_(autonomous%2C_loss_only_wrapper).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction
---

# Code walkthrough
---

## Ground truth
---

In [None]:
# A
using ComponentArrays, Lux, DiffEqFlux, OrdinaryDiffEq, Optimization, OptimizationOptimJL,
    OptimizationOptimisers, Random, Plots, ModelingToolkit, Statistics

# B
random_number_generator = Random.default_rng()
Random.seed!(rng, 42)

# C
#------ Generate ground-truth data
# Autonomous system where t is the independent variable
@parameters β γ N
@independent_variables t
@variables S(t) I(t) R(t)
Dt = Differential(t)

eqs = [
    Dt(S) ~ -(β*S*I)/N,
    Dt(I) ~ ((β*S*I)/N) - γ*I,
    Dt(R) ~ γ*I
]

@named system = ODESystem(eqs, t, [S, I, R], [β, γ, N])
simplified = structural_simplify(system)

parameter_map = Dict(β => 0.3, γ => 0.1, N => 1000)
initial_conditions = Dict(I => 1, R => 0, S => 1000 - 1 - 0)
timespan = (0.0, 160.0) # in days

problem = ODEProblem(simplified,
                     merge(initial_conditions, parameter_map),
                     timespan)

solution = solve(problem, Tsit5(), saveat=1)
ground_truth = Array(solution)

**A** - Importing the packages
*   `ComponentArrays` is for packaging parameters into a single vector with names and structure
*   `DiffEqFlux` for using `NeuralODE`
*   `OptimizationOptimisers` adapter for `Optimisers.jl`, for using ADAM.
*   `Random` for seeding random number generators for reproducibility

**B** - Create a seeded random number generator for reproducibility

**C** - The rest of the code is used for solving the ODE system and generate ground-truth data for comparing with the NN results. `ground_truth` stores the results of the integrator. Note that `saveat=1` means that we're saving for each day, 1 day at a time.

## Neural ODE
---

In [None]:
# A
const input_dimension = 3 # S, I, R
const output_dimension = 3 # Derivatives of S, I, R

# B
layer_0 = Lux.Dense(input_dimension, 32, Lux.tanh)
layer_1 = Lux.Dense(32, 32, Lux.tanh)
layer_2 = Lux.Dense(32, output_dimension)

# C
NN = Lux.Chain(
    layer_0,
    layer_1,
    layer_2
)

# D
NN64 = Lux.f64(NN)

# E
parameters, state = Lux.setup(random_number_generator, NN64)

**A** - Defining the input and output dimensions. `const` locks them up in order to avoid accidental re-writing of the dimensions. In the case of NeuralODEs, as one is predicting the derivative, the inputs are $t,S,I,R$ in that order, and the output numbers represent the value of their derivatives $\dot{S},\dot{I},\dot{R}$.

**B** - Definition of the NN architecture, in this case, 3 layers:
*   `layer_0` is the input layer, recieves 4 inputs and outputs 32 numbers (this is arbitrary and can be changed) using `tanh` activation function.
*   `layer_1` First hidden layer with 32 inputs (from the previous layer) and 32 outputs using `tanh` activation layer.
*   `layer_2` is the output layer. It recieves 32 inputs (from the previous layer) and outputs 3 numbers, namely, the derivatives as stated before.

**C** - Creating of the NN using the defined layers

**D** - Makes the NN use `Float64` for better performance with the solvers

**E** - Initializing the network. It returns two objects:
*   `parameters` which is the set of all trainable parameters and biases to be optimized later during training.
*   `state` all non-trainable internal states some layers keep (BatchNorm running means, etc).

In [None]:
# A
neural_ODE_problem = NeuralODE(NN64, timespan, Tsit5(); saveat = 1)

# B
function neural_ode_predictions(parameters)
    u0 = Float64[
        initial_conditions[S],
        initial_conditions[I],
        initial_conditions[R],
    ]
    trajectory, new_state = neural_ODE_problem(u0, parameters, state)
    return Array(trajectory)
end

**A** - Builds the solver layer. It prepares everything so that when called, an ODE whose RHS is the neural network passed to it, is integrated. It

*   Indicates that the ODE's RHS (the ODE rule) is the neural network passed to it.
*   Stores how to solve it: the time span, the algorithm, at what time steps to solve it, and tolerances like `reltol`, `abstol`, not used in this example.

This is just the constructor, where everything is configured, but not run yet. The returned object is called as `neural_ODE_problem(initial_conditions, parameters, state)`.

**B** - Function for a NeuralODE forward pass. It accepts the trainable parameters of the NN and the state (BatchNorms, etc., in this case, the system is stateless, that is, just Dense + tanh, so no state is passed to it).

*   `u0` is the array of initial conditions. Earlier, we defined these using a dictionary, but we need to unpack them into a vector.
*   The `NeuralODE` is run, passing it the initial conditions, trainable parameters and the state. This integrates the ODE with the NN as the RHS using the solver, time span, etc, as defined above. This returns two objects:
    *   `trajectory`: The trajectory matrix, in this case, of shape roughly `(3,161)`, that is, 3 rows (for each S, I, and R), and 161 columns for each day the simulation was instructed to run (`saveat=1.0`).
    * `new_state`: The updated state (BatchNorms, etc), but in this case, since the system is stateless, there's no need for the use of this variable.

Changing `parameters` amounts to changing the NN weights, and this changes the predicted values and the integrated trajectory.

In the end, the function returns the `trajectory` as a Julia `Array`, ignoring `state` in this particular case.

In [None]:
# A
function loss_neuralode(parameters)
    predicted_trajectory = neural_ode_predictions(parameters)
    @assert size(predicted_trajectory) == size(ground_truth)
    loss = mean((ground_truth .- predicted_trajectory).^2)
    return loss, predicted_trajectory
end

# B
only_loss(x) = first(loss_neuralode(x))

**A** - Creation of the function for calculation the loss. It accepts the trainable parameters of the network.
*   `predicted_trajectory = neural_ode_predictions(parameters)` calls the function `neural_ode_predictions` that runs the `NeuralODE` problem and integrates it. It returns the integrated trajectory.
*   `@assert size(predicted_trajectory) == size(ground_truth)` ensures both arrays have the same shape. Without this, Julia can run without erroring out, making the wrong calculations.
*   `loss = mean((ground_truth .- predicted_trajectory).^2)` calculates the loss. recall that both `ground_truth` and `predicted_trajectory` are matrices, so the broadcast operator `.` makes the operations element-wise. Recall that `.-` is element-wise sustraction and `.^` is element-wise rise to a power. This ensures calculations are made comparing each day correctly, that ius, Day 1 with Day 1, Day 2 with Day 2, etc.

In the end, we return both the loss and the predicted trajectory by the solver.

**B** - Since the function returns both the loss and the predicted trajectory, in this case we want only the loss to pass it to the automatic differentiation later. This is more efficient in memory, in speed, and can avoid some hicups during training. Otherwise, we'll be passing both the scalar loss and the predicted trajectory to the automatic differentiation, whcih can cause unexpected behaviour, so we only pass it the loss.

In [None]:
# A
callback = function (current_parameters, current_loss; do_plot=true)
    println()
    if do_plot
        current_pred = neural_ode_predictions(current_parameters)
        # This takes only the first row, namely, the S variable
        plt = scatter(tsteps, ground_truth[1, :]; label = "Ground truth")
        scatter!(plt, tsteps, current_pred[1, :]; label = "Prediction")
        display(plot(plt))
    end
    return false
end

# B
tsteps = solution.t
initial_parameters = ComponentArray(parameters)

# C
callback(initial_parameters, only_loss(initial_parameters); do_plot = true)

**A** - Create a callback function that lets us do side task while the training is running. This function "calls back" into the code (that is, it calls some parts of the code) at every iteration of the training loop. Common reasons for defining this function are
*   **Monitor progress**: like printing the loss
*   **Plot and visualization**: for comparing prediction to ground truth as the training evolves
*    **Early stopping**: for stopping the training when loss plateaus or gets "good enough"
*   **Debugging and safety**: detecting NaNs or divergence
*   **Checkpointing**: save best parameters when there's an improvement
*   **Scheduling tweaks**: adjunst learning rate or tolerances during training

So this is a way to monitor the training loop. In this case, we only use it for plotting. We compare the ground truth to the predictions as the training evolves. Also, in this case, the parameters are

*   `current_parameters`: the current weights
*   `current_loss`: the current value of the loss
*   `do_plot=true`: keyword that enables plotting

These can vary for other cases. In the body of the function, `println()` only prints a blank line, and then the plotting starts. The line `current_pred = neural_ode_predictions(current_parameters)` calculates the predicted trajectory with the current parameters. Note that in this case one is taking only the first row of the ground truth data and the predicted trajectory, in this case this corresponds to the $S$ variable of the SIR model. If we want to plot $I$ we write `ground_truth[2, :]` and `current_pred[2, :]`. It is also important to note that the statement `return false` keeps the training loop running, whereas `return true` stops it.

**B** - In order to plot time on the $x$-axis, we need to retrieve the values for the variable $t$. Here we use the ground truth solution `solution` to retrieve them.

The initial parameters, that is, the initial weights, are transformed via `ComponentArray` into an optimizer-friendly, single vector, because they're originally nested arrays in NamedTuples as created by Lux. This is done because many optimizers want a vector of this form.

**C** - Call the callback once for sanity check using the current loss and predictions. Training does not start yet. This single call

*   Verifies shapes
*   Catches bugs early
*   Tests the overall callback
*   Makes a quick check of the baseline predicitons (no training) againts ground truth.
*   Allows one to get a feeling of the speed per iteration

This is done before launching a long optimization.

NOTE: For the case of not having the loss-only wrapper, we use the splat operator `...` in `loss_neuralode(initial_parameters)...` is for unpacking a collection or tuple into separate positional arguments. In this case, since `loss_neuralode` returns two objects, `loss` and `predicted_trajectory` so we have that

`callback(initial_parameters, loss_neuralode(initial_parameters)...; do_plot = true)`

becomes

`callback(initial_parameters, loss, predicted_trajectory; do_plot = true)`

In [None]:
# A
adtype = Optimization.AutoZygote()

# B
objective_function = Optimization.OptimizationFunction((x, p) -> only_loss(x), adtype)

# C
optimization_problem = Optimization.OptimizationProblem(objective_function, initial_parameters)

**A** - We tell `Optimization.jl` which automatic differentiation (AD) engine to use to get the gradients of the loss. `adtype` stands for *automatic differentiation type*. In this case we use the **Zygote reverse-mode AD backend**. Zygote is useful when the parameter vector (the weights) is large and one has a single scalar loss. Also, here one needs gradients through both the NN and the ODE integration, so Zygote is commonly used here.

**B** - Creates an objective function object. In this case it's MSE because we're passing 'loss_neuralode(x)' to the constructor. `(x,p) -> only_loss(x)` is an anonymous function taking a pair `(x,p)` and passing only $x$ to the loss function defined earlier. $x$ are the decision variables (the NN parameters in this case), whereas $p$ is a "problem data" bucket that stays fixed while the decision variables change, not used in this case. $p$ can include
*   Datasets: inputs/targets, batch indices, train/validation splits
*   Physical contraints or initial conditions: known parameters, time grids
*   Regularization knobs: $\lambda$ for L1 and L2, noise variance
*   Loss options: which norm to use, weights per state, masking
*   Control problems: reference trajectories, penalties, contraints

We also pass to the constructor how to obtain the gradients, in this case, using the engine we defined in **A** as `adtype`.

The constructor returns the scalar loss.

**C** - Defining the optimization problem to solve. It does not evaluate the loss and does not run any optimization, it just packs the problem to be solved. It accepts the objective function to be utilized and the starting point `initial_parameters`, the initial guess for the weights to be optimized. That is, this is an immutable problem description:

Minimize `objective_function` starting at `initial_parameters`.

In [None]:
# A
result_neuralode = Optimization.solve(optimization_problem,
                                      OptimizationOptimisers.Adam(0.01);
                                      callback = callback,
                                      maxiters = 300)

**A** - The training begins. The problem defined as `optimization_problem` is solved using ADAM as optimizer with a leraning rate of 0.01. Starts with `initial_parameters`, uses Zygote for the gradients, updates the parameters at each step. After each step, the `callback` is called, and the plot happens. The loop stops after 300 iterations (can be stopped earlier by returning `true` in the callback).

The object returned contains the best parameters stored in a variable `u`, loss, and other statistics.

Note that ADAM can be further refined, for example, by adding a scheduler `Adam(0.05, (β1=0.9, β2=0.999))`, the callback also can be refined, etc.

In [None]:
# A
refined_optimization_problem = remake(optimization_problem; u0 = result_neuralode.u)

# B
refined_result_neuralode = Optimization.solve(refined_optimization_problem,
                                       Optim.BFGS(; initial_stepnorm = 0.01);
                                       callback = callback,
                                       allow_f_increases = false)

**A** - Copies the characteristics of the `optimization_problem` defined above, but the new initial parameters are the results of the previous optimization loop. This is achieved via `u0 = result_neuralode.u`. `u0` are the initial parameters set to the learned weights by ADAM in the previous optimization loop.

This is done in order to start a new optimization loop using BFGS, with the same objective and AD settings, but with a new good starting point instead of a random guess. Moreover, ADAM is global and can be noisy, whereas BFGS is local and more precise. This is all a `refined_optimization_problem`.

**B** - The new training loop begins. This time using a quasi-Newton method BFGS. This uses gradients and local curvature for better approximating the minimum. ADAM does not use curvature. Now, for `Optim.BFGS(; initial_stepnorm = 0.01)`, the `;` indicates that we're not passing any arguments, only positional arguments, in this case `initial_stepnorm = 0.01` which is the first stepsize scale.

The callback is called again for plotting.

`allow_f_increases = false` tells the solver to reject steps that increase in the loss. This is considered a conservative line search.

The result is `refined_result_neuralode`, the new weights. Contains also the final loss and other statistics.

# Notes
---

1.  The `NeuralODE` constructor expects a function as one of its parameters. Specifically, a function of the form

$$f(u,p,t)⟶du/dt$$

(though one often only writes $f(u,p,t)⟶du$) because the solver, who lives inside `NeuralODE` (for example, `Tsit5()`), integrates that $du/dt$ and gives $u=[S,I,R]$. That is, the output of `NeuralODE` is the state `u` that can be called with initial conditions in order to obtain a whole trajectory.

2. When calculating the loss, `mean(ground_truth .- pred).^2` would first calculate the mean, then square that scalar. In the code, `mean(((ground_truth .- pred).^2))` was utilized, in order to first square the differences and then calculate the mean of the squares, as it should be.