## Before you begin

If you're new to calculus or need a refresher of the concepts and formulas, check out the Khan Academy course on YouTube. 

LINK : https://youtube.com/playlist?list=PL19E79A0638C8D449

Specially for Derivatives: https://youtu.be/rAof9Ld5sOg 

## Simple differentiation

Say we have this function: $y = x^{4} - 5x^{3} + 2x^{2} +4x - 1$

And we want to differentiate it in terms of $x$ or find its derivative which is $\frac{dy}{dx}$

In [19]:
import jax as J

In [20]:
def y(x):
    return x**4 - 5 * x**3 + 2 * x**2 + 4 * x - 1

Why function? Good question. Jax was designed with Functional Programming in mind hence this function stuff. You can also use lambdas. For example this can also be written as 

```python
y = lambda x: x**4 - 5 * x**3 + 2 * x**2 + 4 * x - 1
```

In [21]:
dy_dx = J.grad(y)
dy_dx

<function __main__.y(x)>

Where's the derivative? Huh Charlston! Where's the derivative? 

Well you didn't tell me for which value of x you need a derivative of y here!

In [22]:
dy_dx(2.) # make sure to pass in floats, that's how JAX likes it!

DeviceArray(-16., dtype=float32, weak_type=True)

Let's verify that, shall we? 

$\frac{dy}{dx} = 4x^{3} - 15x^{2} + 4x + 4$

So for $x = 2.0$, or say $x = 2$

$\frac{dy}{dx} = 4x^{3} - 15x^{2} + 4x + 4 = 4 \times 2^{3} - 15 \times 2^{2} + 4 \times 2 + 4 = -16$


## Higher order differentiation

What we saw above was the first order derivative of  y in terms of x. So a very simple definition would be that when you differentiate once, it's a first order derivative, twice - second order derivative and so on. Anything above first order derivative is called a higher order derivative. 

Let's try the second and third order derivative:

$\frac{d}{dx}(\frac{dy}{dx}) = \frac{d}{dx}(4x^{3} - 15x^{2} + 4x + 4) = 12x^{2} - 30x + 4$

This second order derivative or, $\frac{d}{dx}(\frac{dy}{dx})$ can also be written as $\frac{d^{2}x}{dx^{2}} $ or simply ${y}''$. Guess what the first order derivative is then? ${y}'$. 

So, 

${y}''' = 24x - 30$

Fun fact, for a fourth order derivative you don't have to add like 4 ' signs. You can replace ' with numbers as well, and just write something like $y_{4}$ or $y^{4}$ or $y^{iv}$.

In [23]:
y2 = J.grad(dy_dx)
y2(2.)

DeviceArray(-8., dtype=float32, weak_type=True)

In [24]:
y3 = J.grad(y2)
y3(1.)

DeviceArray(-6., dtype=float32, weak_type=True)

In [25]:
y4 = J.grad(y3)
y4(0.)

DeviceArray(24., dtype=float32, weak_type=True)

In [26]:
y5 = J.grad(y4)
y5(1.)

DeviceArray(0., dtype=float32, weak_type=True)

Why's this 0? Like all those books out there, this has been left as an exercise for the reader. 

## Partial Derivatives And Multvariate Differentiation

So far the derivatives have been in terms of only one variable, $x$. What if you have multiple variables? For example:
$f(x, y, z) = 3x + 2z + 3xy + 5zx$

--------
**Note:**  
$ y = f(x) = .....$
y was defined above (in the code) as a function and had only one parameter x. By default JAX takes derivatives in term of the first such parameter. But you can get around that by using argnums.

------------------------------
The partial derivatives of $f$ here can be:

$\frac{\delta f}{\delta x}$, $\frac{\delta f}{\delta y}$, $\frac{\delta f}{\delta z}$

In [27]:
# define the function f
def f(x, y, z):
    return 3*x + 2*z + 3*x* y + 5*z*x

In [28]:
f1_y = J.grad(f, argnums=1) # in terms of the second param, y
f1_y(1., 2., 3.)

DeviceArray(3., dtype=float32, weak_type=True)

In [29]:
f1_z = J.grad(f, argnums=2) # in terms of z
f1_z(1., 2., 3.)

DeviceArray(7., dtype=float32, weak_type=True)

What about $x$ ? Do we leave it hanging there? Okay.....

In [30]:
# by default argnums is 0 so even if you don't 
# mention anything it will use x

f1_x = J.grad(f, argnums=0)
f1_x(1., 2., 3.)

DeviceArray(24., dtype=float32, weak_type=True)

## Chain Rule

So there can be multiple variables. What if they're linked? Like one depends on the other? Like a chain? 

![./images/godzilla.jpg](./images/godzilla.jpg)


Actually it's not that difficult. You just have to understand the how some variables can be dependent on each other. 

Let's think in terms of functions, since JAX looooves functions. 

However, I want to take a different approach, enough with x y and z's. Let's say I want to cook some noodles. Nothing fancy, just some good noodles boiled up, tossed into a pan with diced potatoes, onions and fried eggs. Smells good. 


So first you have to $boil()$ the noodle. Then you need to prepare the pan with oil and $fry()$ eggs, then you need to pray to Satan and throw in all the noodle on the pan. 

So in a functional way, you can say it like this: `fry(boil(noodles), eggs, offerings_to_satan)`

But wait, that's a function inside a function! Oh no! But anyway. These kind of functions which can hold other functions are called composite functions ([more on this here](https://www.khanacademy.org/math/ap-calculus-ab/ab-differentiation-2-new/ab-3-1a/a/chain-rule-review)). And when you try to differentiate them, you need the chain rule. 

![why](./images/giphy.gif)

Differentiation has a simple goal : to find out how much effect can the tiniest of changes make to something. Making cookies? How much change in sugar can change the taste? Baking potatoes? How slight change of temperature can ruin them for you? Actually real life problems regarding this are far more intricate but let's not get into those. 

So when you get composite functions which depend on each other, like you can't fry without boiling the noodles first, you need to use the chain rule to differentiate and observe the effects. Enough chit chat and bad examples, we'll dive into code now.

In [37]:
def some_composite_function(x):
    return jnp.log(
        jnp.sin(
            jnp.cos(x)
        )
    )

# from a naive point of view, if you call composites nested functions, you may not be altogether wrong!

In [38]:
d_composite = J.grad(some_composite_function)
d_composite(45.)

DeviceArray(-1.4679605, dtype=float32, weak_type=True)

Seems complicated? Let's break them up and then see how it goes.

In [39]:
def a(x):
    return jnp.log(x)

def b(x):
    return jnp.sin(x)

def c(x):
    return jnp.cos(x)

def comp(x):
    return a(b(c(x)))

So what's happening here is that a relies on b, b relies on output from c. And when you run grad,

In [40]:
d_c = J.grad(comp)
d_c(45.)

DeviceArray(-1.4679605, dtype=float32, weak_type=True)

Where's the effect though? I understand that this isn't enough to give you feeling of where the effect lands actually, but if you keep your patience until we check backpropagation on a learning algorithm(stop calling these glorious learning functions AI, they're not!), you'll see it for yourself!