# Continual Learning and Infinite width limits

We are intrested in uderstanding the origin and the behaviour of Catastrophic forgetting in the Continual Learning framework.

To obtain better insights on it, the infinite width limit of the model gives us deterministic dynamics, due to concetration of measures given by the Central Limit Theorem.

The main deterministic quantity to observe and study is the so called Neural Tangent Kernel, which can be defined as:

#### $$K_t^{NTK}(x,x') = \sum_{\theta}\frac{\partial f(x,\theta_t)}{\partial \theta} \cdot \frac{\partial f(x',\theta_t)}{\partial \theta} \in \mathcal{R}^{P \times P}$$

where:
* P are the number of samples in the batch
* $\theta$ the network's parameter
* $f(x,\theta)$ is the network's output given input $x$ and the configuration $\theta$ of the parameters.

## Parametrisations
We can define two ways to initialise the network, given a factor $1/\gamma$ applied to the output layer, and the learning rate:
* Mean Field or Maximum Update Parametrisation ($\mu P$): $\gamma = \mathcal{O}(\sqrt{N}), \eta = \mathcal{O}(\gamma^2)$
* Neural Tangent Parametrisation (NTP): $\gamma = \mathcal{O}(1), \eta = \mathcal{O}(1)$
with N the network widths.

From basic assumpions on the grandient dynamics and the use of L2 loss, we can get the evolution of the residuals $\Delta = y - f(x,\theta)$:
####  $$ \frac{\partial \Delta(t)}{\partial t} = - K^{NTK}(t)\Delta(t) $$

In the NTP, the NTK is fixed at initialisation for a infinitely wide network, and this allows us to obtain a closed form solution for the residuals evolution:
####  $$ \frac{\partial \Delta(t)}{\partial t} = - K^{NTK}\Delta(t) \quad \rightarrow \quad \Delta(t) = \Delta_0 \exp(- {K^{NTK}t})$$ 

In the $\mu P$ the NTK is free to evolve and thus move from initialisation. To keep track of its evolution a Dynamical Mean Field Theory approach is essential. We create 2 fields: 
* a forward field $h_\mu^\ell(t)$ that represents population of the hidden representations of input $x_\mu$ in each layer $\ell$, that progates from the input layer to the output one
* a backward one $g_\mu^\ell(t)$, that represents the populations of the gradients updates, relative to the input $x_\mu$, at each layer $\ell$,, that starts from the output and finishes at input.

Thank to these 2 fields we can build the NTK as follows:
#### $$ K ^{NTK}_{\mu \alpha}(t,t)=\sum_{l=0}^{L} G_{\mu \alpha}^{l+1}(t, t) \Phi_{\mu \alpha}^{l}(t, t)  $$
where $\Phi_{\mu \alpha}^\ell(t, t) = \langle \phi( h_\mu^\ell(t)) \cdot \phi(h_\alpha^\ell(t) )\rangle $ and $G_{\mu \alpha}^\ell(t,t) = \langle g_\mu^\ell(t) \cdot g_\alpha^\ell(t) \rangle $

## Continual Learning and Parametrisations

### NTP
Applying the same derivations to the Continual Learning framework, thus looking at the evolution of residuals of Task 1 while we are training Task 2, we can obtain something of the same fashion of above for the NTP case:
####  $$ \frac{\partial \Delta_{\mu_1}(t)}{\partial t} = - K^{NTK}_{\mu_1 \alpha_2} \Delta_{\alpha_2}(t) $$  
So the new object $K^{NTK}_{\mu_1 \alpha_2}$, called NTK Across Tasks, gives crucial information about the residuals evolution, thus it allows us to obtain the loss evolution during other tasks' training

<p align="center">
<img src="lazy_regime/residuals_ntk.png" alt="ntp_residuals" width="1000"/>
<p>

### $\mu P$

In this case we need the updates equations for the fields $h_\mu^\ell(t)$ and $g_\mu^\ell(t)$ to obtain the evolution of the NTK and thus the residuals evolution. This can be made applying the gradient descent dynamics in the recursive relation that defined both the forward and backward field, obtaining the desired equations:

$$ \frac{d}{dt}\boldsymbol{h}_{\mathcal{T}_1}^1(t) =  \gamma_0 \left[ \Theta(t_1-t) \boldsymbol{\Delta}_{\mathcal{T}_1}(t) \boldsymbol{g}_{{\mathcal{T}_1}}^{1}(t) K^x_{\mathcal{T}_1 \mathcal{T}_1}  + \Theta (t-t_1)\boldsymbol{\Delta}_{{\mathcal{T}_2}}(t) \boldsymbol{g}_{\mathcal{T}_2}^{1}(t) K^x_{\mathcal{T}_1 \mathcal{T}_2} \right]  $$

$$ \frac{d}{dt}\boldsymbol{z}^1_{\mathcal{T}_1}(t) =  \gamma_0 \left[ \Theta(t_1-t) 
    \boldsymbol{\Delta}_{\mathcal{T}_1}(t) \phi(\boldsymbol{h}_{\mathcal{T}_1}^1(t))   + 
    \Theta (t-t_1)
    \boldsymbol{\Delta}_{\mathcal{T}_2}(t) \phi(\boldsymbol{h}_{\mathcal{T}_2}^1(t))\right] $$ 

$$ \frac{\partial}{\partial t} 
    \boldsymbol{\Delta}_{\mathcal{T}_1}(t)  = -[\Phi^1_{\mathcal{T}_1\mathcal{T}_2}(t) + G^1_{\mathcal{T}_1\mathcal{T}_2}(t) \boldsymbol{K}^x_{\mathcal{T}_1 \mathcal{T}_2} ] \boldsymbol{\Delta}_{\mathcal{T}_2}(t)  $$

We can thus obtain the residuals evolution and the internal representation evolution of both fields:

<p align="center">
<img src="rich_regime/h_dist_mup_1.png" alt=h_dist" width="1.8*350" height="350"/>
<img src="rich_regime/z_dist_mup_1.png" alt="z_dist" width="1.8*350" height="350"/>
<p>

<p align="center">
<img src="rich_regime/phi_mup_1.png" alt="phi_evol" width="1.8*300" height="270"/>
<img src="rich_regime/g_mup_1.png" alt="g_evol" width="1.8*300" height="270"/>
<p>

<p align="center">
<img src="rich_regime/loss_mup_1.png" alt="loss_evol" width="1.8*350" height="400"/>
<p>

## Forgetting

Once we know which are the quantities involved in the evolution of the all tasks losses, we can try to address the source of forgetting and identify what we can do to cope with that. 
We have defined forgetting on task 1 as negative backward transfer, that is defined as:
#### $$ BWT(t) = \mathcal{L}(t_1) - \mathcal{L}(t) $$

If we look at time evolution of the quantity above, and impose it equal to 0, we get that the features of the forward and backward fields of different tasks must be orhogonal to each others:
#### $$ \langle \phi(h^\ell_{\mu_1}(t)) \cdot \phi(h^\ell_{\alpha_2}(t)) \rangle = 0, \quad \langle g^\ell_{\mu_1}(t) \cdot g^\ell_{\alpha_2}(t) \rangle = 0 $$
We are still working to have a looser characterization of the condition on the non-negative backward transfer...

## Similarity
We also imposed a couple of assumptions on the input covariance matrix of tasks and also between tasks: the $K^x$ of each task will be the identity, $K^x_{ii} = \mathbb{I}$ while the $K^x$ across tasks will be $K^x= \epsilon \mathbb{I}$.

We are trying to characterize how $\epsilon$ and $\gamma_0$ impact together the forgetting. We have simulation based on the DMFT equations that are perfectly aligned with the empirical results, leading some concave shape in the forgetting vs similarity plots

## Forgetting
Top curve: Linear

Bottom curve: Relu
<p align="center">
<img src="forgetting/forg_vs_eps.png" alt="forg_evol_relu"  height="450"/>
<p>

## Normalized Forgetting
<p align="center">
<img src="forgetting/forg_vs_eps_normalized.png" alt="forg_evol_relu"  height="450"/>
<p>

## Forgetting
<p align="center">
<img src="forgetting/gamma0_vs_eps.png" alt="forg_evol_relu"  height="550"/>
<p>

## Normalized Forgetting
<p align="center">
<img src="forgetting/gamma0_vs_eps_normalized.png" alt="forg_evol_relu"  height="550"/>
<p>