## Backpropagation and automatic differentiation:

[My purpose: To dump all that I learnt. To make/express it the best way I can]
Few months ago I was trying to understand how learning actually happens and it struck me that I understood backpropagation in a very shallow way. So, I decided to understand it in depth and see how it is actually implemented in big neural networks. It will be a long post. Its not intended as quick read to understand backprop. We will be looking into it in depth and I hope you find the article useful. 

Notes:
1) the articles may seem quite math heavy but it isnt so. I have tried to use animation to ease the understanding.

### Intro to backprop:
Central to the concept of backprop is the idea of a cost function. Cost functions let us choose what we want to achieve. Either we want to maximize all the rewards or we want to minimize our errors. Cost function helps to formalize the uncertainty of our models[check Appendix A] and reduce it.

Let's say our model consists of parameters $\theta$ and our cost function is $\mathcal{J}(\theta)$. So, if we have the gradients $\frac{\partial \mathcal{J}(\theta)}{\partial \theta}$ then we can update the parameters $\theta$ using gradient descent[For more on gradients and gradient descent check Appendix B].

A forward pass in a simple Multilayer Preceptron looks like this:

![Forward pass](https://raw.githubusercontent.com/akashe/gifsandvids/main/forward_pass.gif)





The above model has these parameter $(W_1,b_1,W_2,b_2,W_3,b_3,W_4,b_4)$. A partial derivative of $W_1$ wrt to the cost function tells how simple changes in $W_1$ changes the value of $\mathcal{J}(\theta)$. But here lies a catch. $W_1$ doesn't directly affect the value of cost function. Value of $W_1$ affects the value of output of node 1. This output value later affects output of Node 3 and 4. Node 3 and 4 directly affect $\mathcal{J}(\theta)$. As this gif shows:



You can think of this as information paths. All the paths that connect a parameter to a cost function contribute in the total derivative of that parameter wrt cost function.In the above case. Since there are 2 paths that connect $W_1$ with $\mathcal{J}(\theta)$, $\frac{\partial \mathcal{J}(\theta)}{\partial W_1}$ has two terms:
\begin{aligned}\frac{\partial \mathcal{J}(\theta)}{\partial W_1} &= \text{derivative from path 1} + \text{derivative from path 2} \\
&= \frac{\partial \mathcal{J}(\theta)}{\partial Z_5}\frac{\partial Z_5}{\partial Z_3}\frac{\partial Z_3}{\partial Z_1}\frac{\partial Z_1}{\partial W_1} + \frac{\partial \mathcal{J}(\theta)}{\partial Z_5}\frac{\partial Z_5}{\partial Z_4}\frac{\partial Z_4}{\partial Z_1}\frac{\partial Z_1}{\partial W_1}
\end{aligned}

where $Z_i$ is the output of the ith node and considering Node O as Node 5. We got these expressions using Chain Rule[check Appendix C]

    start with the idea of cost function
    List of good references to understand basics of bakprop
    chain rule primer
0. Why we can't directly use what we discussed above.:
    show a gif of gradeint calculation; with long multiplication of jacobians and and how these patterns are repeated..discuss chain rule.
1. Modular structure of backprop:
    the three functions
    error itself as layer( since I am not showing error as a node in first forward pass)
2. Calculating gradients using jacobian products:
    introduce automatic differentiation here; a gif for how you can rewrite the same forward pass as steps in auto  diff
    introduce vjp function
3. Calculating gradients using automatic differentitation
    how it converts into sequntial 
    autodiff lib
4. Compuational graphs
    static vs dynamic
    
References:

Appendix:

### Appendix A: Cost Functions
To understand the need for a cost function we need to understand what learning is. In context of machine learning and deep learning algorithms, learning means ability to predict. In our models we want to reduce the uncertainity involved in our predictions or our learning of the data. We want our predictions to be spot on. We measure uncertainty using entropy.

The measure of uncertainty is called entropy. By definition

\begin{equation*} \mathcal{E} = - \sum \mathcal{P}(x) ln \mathcal{P}(x) \end{equation*}

Lets see how cost functions reduce entropy of predictions for different tasks.

#### For classification tasks:
Lets task a classification task of n classes using softmax. 
So, the probability that a particular sample $x$ is of class n is:

\begin{equation*} \mathcal{P}(x = n) = [\frac{e^{k_n}}{\sum_{i}^{n} e^{k_j}}]^{\Omega_n} \end{equation*}
where, $\Omega_n = \begin{cases}
1 &\text{if } \bar{y} \text{=1 i.e. label for the class n is 1 ,else} \\
0
\end{cases}
$

So, our likelihood for one data sample is:
\begin{equation*} \mathcal{P}(x) = \prod_{i=0}^{n}[\frac{e^{k_n}}{\sum_{j}^{n} e^{k_j}}]^{\Omega_i} \end{equation*}
 
So entropy of classification for one data sample becomes:

\begin{equation*} \mathcal{E} = - \sum_{i=0}^{n} [\frac{e^{k_n}}{\sum_{j}^{n} e^{k_j}}]^{\Omega_i} {\Omega_i} \ln [\frac{e^{k_n}}{\sum_{j}^{n} e^{k_j}}] \end{equation*}

The above expression may seem daunting but if you notice $\Omega_i$ is 0 for all non label classes and when $\Omega_i$ is 1 then the above expression gets it lowest value(0) when the model gives the probablility of 1 to the label class which will be a spot on prediction for the label class.

#### For regression tasks:
The first go to option for regression tasks is MeanSquarredError. But in MSE there is no notion of a $\mathcal{P}(y)$. So we assume that by predicting $y$ we are learning a distribution of the logits. A go-to choice is to learn a normal distribution of our logits. We assume that are logits are the mean of that distribution and we find the $\mathcal{P}(\bar{y})$ using the probability density function(PDF) of the distribution. The idea is then to maximize the probabilities $\mathcal{P}(\bar{y})$ which happens when $\bar{y} = y$.

Using PDF of a normal distribution,
\begin{equation*} \mathcal{P}(x) = \frac{e^{-(\bar{y} - y)^{2}/(2\sigma^{2}) }} {\sigma\sqrt{2\pi}}\end{equation*}

So our entropy becomes,
\begin{align*} \mathcal{E} &= - \frac{e^{-^{2}/(2\sigma^{2}) }} {\sigma\sqrt{2\pi}}(-(\bar{y} - y)^{2}/(2\sigma^{2})- \ln(\sigma\sqrt{2\pi})) \\
&\propto  \frac{e^{-^{2}/(2\sigma^{2}) }} {\sigma\sqrt{2\pi}}((\bar{y} - y)^{2})
\end{align*}

Again, a daunting expression but it minimizes when $(\bar{y} = y)$

So, as we saw in both cases, the point of a cost function is to measure the uncertainty of the model and reduce it.

### Appendix B: Gradients & Gradient Descent
A gradient tells how a function is changing at that particular point. Key points regarding gradients:
1. A gradient shows the direction in which the function is increasing.
2. When gradients are zero it means our function has reached a peak or a trough.

If we take a small step towards the direction mentioned by a gradient, we move to a point with function value greater than the previous position. If we take this step again and again we reach a 'maxima' point of that function. This is the basis of many optimization alogrithms. Find gradients wrt to cost function and keep updating parameters in the direction of their gradients.

#### Gradient Descent:
Gradient Descent is one such optimization algorithm. It has very simple rule:
\begin{equation}\theta_{i+1} = \theta_i + \alpha\frac{\partial J(\theta)}{\partial \theta}  \end{equation}
where $\alpha$ is the step size which regulates the amount of movement in the direction of gradient.
We use the above update rule when we have maximize our cost function. To minimize our cost function we have to move in the opposite direction of our gradients. So the update rule becomes:
\begin{equation}\theta_{i+1} = \theta_i - \alpha\frac{\partial J(\theta)}{\partial \theta}  \end{equation}

#### Stochastic Gradient Descent:
As simple as Gradient descent seems, implementing it for huge datasets has its own problem. Why? The problems remains in the gradient $\frac{\partial J(\theta)}{\partial \theta}$. Let me explain, for most applications the cost function is defined something like this: 
\begin{equation} J(\theta) = \sum_{i=0}^{n} f(\theta,i)\quad \text{where, n is the total length of the dataset}\end{equation}
If we had a dataset of length 1000, then we find $f(\theta,i)$ for each instance and then sum these to get $J(\theta)$. We then use $J(\theta)$ to find $\frac{\partial J(\theta)}{\partial \theta}$ and perform one optimization step for the entire dataset. Herein lies the problem, *Gradient Descent performs one optimization step for the entire length of dataset*. If we had a dataset of length 1 billion, we will have to iterate over those 1 billion instances before we can move once in the direction of gradients.

Stochastic Gradient Descent comes here for rescue. SGD basically says, we can estimate the true gradient using gradient at each instance and if we estimate for a large number of instances we can come pretty close to true gradients. This allows working with batches when we have huge datasets. 

Note: Gradient descent will take you directly to maxima or minima. SGD will wander here and there but if repeated correctly, over time it will take you to the same place.

### Appendix C: Chain Rule:
You can actutally think of the Cost function as a function of functions. Using notation $Z_i$ as the output of node i and considering output node as node 5, we can say:
\begin{aligned}
\mathcal{J}(\theta) &= f(\bar{y},Z_5)\\
&= f(\bar{y},g(Z_3,Z_4)) \\
&= f(\bar{y},g(h(Z_1,Z_2),h(Z_1,Z_2)))
\end{aligned}


here $Z_1 = ReLU(I*W_1 +b_1) and Z_2 = ReLU(I*W_2 +b_2)$


Chain Rule helps in finding gradients of composit functions such as $\mathcal{J}(\theta)$. Multivariate chain rule says,
\begin{equation} \frac{\partial}{\partial t}f(x(t),y(t))= \frac{\partial f}{\partial x}\frac{\partial x}{\partial t} + \frac{\partial f}{\partial y}\frac{\partial y}{\partial t}\end{equation}

Now we can use the chain rule to find $\frac{\partial \mathcal{J}(\theta)}{\partial W_1}$
\begin{aligned}
\frac{\partial \mathcal{J}(\theta)}{\partial W_1} &= \frac{\partial \mathcal{J}(\theta)}{\partial \bar{y}}\frac{\partial \bar{y}}{\partial W_1} + \frac{\partial \mathcal{J}(\theta)}{\partial Z_5}\frac{\partial Z_5}{\partial W_1} \quad \quad \text{expanding } \frac{\partial Z_5}{\partial W_1}\\
&= 0 + \frac{\partial \mathcal{J}(\theta)}{\partial Z_5} ( \frac{\partial Z_5}{\partial Z_3}\frac{\partial Z_3}{\partial W_1} + \frac{\partial Z_5}{\partial Z_4}\frac{\partial Z_4}{\partial W_1}) \quad \text{expanding } \frac{\partial Z_3}{\partial W_1} \text{ and } \frac{\partial Z_4}{\partial W_1}\\
&= \frac{\partial \mathcal{J}(\theta)}{\partial Z_5}(\frac{\partial Z_5}{\partial Z_3}(\frac{\partial Z_3}{\partial Z_1}\frac{\partial Z_1}{\partial W_1} + \frac{\partial Z_3}{\partial Z_2}\frac{\partial Z_2}{\partial W_1} ) +\frac{\partial Z_5}{\partial Z_4}(\frac{\partial Z_4}{\partial Z_1}\frac{\partial Z_1}{\partial W_1} + \frac{\partial Z_4}{\partial Z_2}\frac{\partial Z_2}{\partial W_1}) )\\
&= \frac{\partial \mathcal{J}(\theta)}{\partial Z_5}(\frac{\partial Z_5}{\partial Z_3}(\frac{\partial Z_3}{\partial Z_1}\frac{\partial Z_1}{\partial W_1} + 0 ) +\frac{\partial Z_5}{\partial Z_4}(\frac{\partial Z_4}{\partial Z_1}\frac{\partial Z_1}{\partial W_1} + 0) ) \quad \text{Since, } \frac{\partial Z_2}{\partial W_1}=0 \\
&= \frac{\partial \mathcal{J}(\theta)}{\partial Z_5}\frac{\partial Z_5}{\partial Z_3}\frac{\partial Z_3}{\partial Z_1}\frac{\partial Z_1}{\partial W_1} + \frac{\partial \mathcal{J}(\theta)}{\partial Z_5}\frac{\partial Z_5}{\partial Z_4}\frac{\partial Z_4}{\partial Z_1}\frac{\partial Z_1}{\partial W_1}
\end{aligned}