# Reverse Mode Sensitivity Analysis 

* With the approach from [Chen et al](https://arxiv.org/abs/1806.07366): *Reverse-mode adjoint sensitivity analysis*, we can compute derivates of the trajectories : 

* Given is an ODE $\dot{\mathbf{x}} = f(\mathbf{x},t;\theta)$ that is integrated from $t_0$ to $t_1$ with parameters $\theta$ and a scalar loss function $L(\mathbf{x}(t_1))$ of the ODE solution that is supposed to be minimized by the training procedure 
* The parameters $\theta$ can include those of data-driven models like ANN
* To compute the gradient $$\frac{\partial\mathcal{L}}{\partial\theta},$$ the ODE is appended with the adjoint
$$\begin{align}
    \mathbf{a}(t) = \frac{\partial \mathcal{L}}{\partial \mathbf{x}(t)}.
\end{align}$$
* The adjoint follows the dynamics
$$\begin{align}
    \frac{d \mathbf{a}}{d t} = - \mathbf{a}^T(t) \frac{\partial f(\mathbf{x}(t),t,\theta)}{\partial \mathbf{x}(t)}
\end{align}$$
and tracks how the gradient of the loss depends on the trajectory. This is needed to compute the desired $\frac{\partial\mathcal{L}}{\partial\theta}$ with another adjoint
\begin{align}
    \mathbf{a}_\theta &= \frac{\partial\mathcal{L}}{\partial\theta(t)} \\
    \frac{d \mathbf{a}_\theta}{d t}& = - \mathbf{a}^T(t) \frac{\partial f(\mathbf{x}(t),t,\theta)}{\partial \theta}
\end{align}

## Proof
The proof of these dynamics can be seen as a continuous backpropagation. Similar to the traditional, discrete backprogation the chain rule is applied with
\begin{align}
 \frac{d\mathcal{L}}{d\mathbf{x}(t)}&=\frac{d\mathcal{L}}{d\mathbf{x}(t+\epsilon)}\frac{d\mathbf{x}(t+\epsilon)}{d \mathbf{x}(t)}\label{eq:cont-L}
\end{align}
by inserting a state of the trajectory evolved by an incremental time step $\epsilon$. Evolving the trajectory can be approximated with 
\begin{align}
     \mathbf{x}(t+\epsilon) &= \int_t^{t+\epsilon} f(\mathbf{x}(t),t;\theta)dt + \mathbf{x}(t) = T_\epsilon(\mathbf{x}(t),t;\theta) \overset{\epsilon\rightarrow 0}{\approx} \epsilon f(\mathbf{x}(t),t;\theta) + \mathbf{x}(t)
\end{align}
to rewrite the chain rule equation above as 
\begin{align}
\mathbf{a}(t)&=\mathbf{a}(t+\epsilon)\frac{\partial T_\epsilon(\mathbf{x}(t),t)}{\partial \mathbf{x}(t)}.
\end{align}
These results can be used to get the dynamics of $\mathbf{a}(t)$ by inserting them into the definition of its derivative:  
\begin{align}
      \frac{d\mathbf{a}(t)}{dt} &= \lim_{\epsilon\rightarrow 0} \frac{1}{\epsilon}(\mathbf{a}(t+\epsilon) - \mathbf{a}(t)) \\
    &= \lim_{\epsilon\rightarrow 0} \frac{1}{\epsilon}(\mathbf{a}(t+\epsilon) - \mathbf{a}(t+\epsilon)\frac{\partial}{\partial\mathbf{x}(t)}(\mathbf{x}(t)+\epsilon f(\mathbf{x}(t),t;\theta))\\
    &= \lim_{\epsilon\rightarrow 0} - \mathbf{a}(t+\epsilon)\frac{\partial f(\mathbf{x}(t),t;\theta)}{\partial\mathbf{x}(t)}=- \mathbf{a}(t) \frac{\partial f(\mathbf{x}(t),t,\theta)}{\partial \mathbf{x}(t)} .
\end{align}
The dynamics of $\mathbf{a}_\theta$ can be derived analogously. 


## Reverse-mode Adjoint Sensitivity Problem 

Similar to how a traditional backpropagation traverses the chain of the ANN from the output back to the input, the appended ODE with the adjoints needs to be solved backwards in time as the initial values of the augmented dynamics are only known at the end point of the integration $t_1$. To compute $\frac{\partial\mathcal{L}}{\partial\theta}$, we thus need to solve the appended ODE
\begin{align} 
    \begin{pmatrix}
        \frac{d \mathbf{x}}{d t}\\
        \frac{d \mathbf{a}}{d t}\\ 
        \frac{d \mathbf{a}_\theta}{d t}\\
    \end{pmatrix} =
     \begin{pmatrix}
        f(\mathbf{x},t,\theta)\\
        - \mathbf{a}^T(t) \frac{\partial f(\mathbf{x}(t),t,\theta)}{\partial \mathbf{x}(t)} \\ 
        - \mathbf{a}^T(t) \frac{\partial f(\mathbf{x}(t),t,\theta)}{\partial \theta}\label{eq:node-train}
    \end{pmatrix} 
\end{align}
backwards in time from $t_1$ to $t_0$ with initial conditions $[\mathbf{x}(t_1); \frac{\partial\mathcal{L}}{\partial\mathbf{x}(t_1)}; \mathbf{0}]$ to eventually get $\frac{\partial\mathcal{L}}{\partial\theta} = \mathbf{a}_\theta(t_0)$. The partial derivatives are computed using AD. 

## Other Sensitivity Algorithms 

There are also many other sensitivty algorithms, several of them are implement in `DiffEqSensitivity`. In most cases these are: 

* `QuadratureAdjoint`, `InterpolatingAdjoint` and `BacksolveAdjoint`: [Kim et al](https://arxiv.org/abs/2103.15341) outline these algorithm, espacially with regards to their properties for stiff differential equation problems
* For ergotic, chaotic system, we can also use [least square shadowing methods](https://arxiv.org/abs/1204.0159), with `AdjointLSS` or `NILSAS` 