# How to use AutoGrad by Jax.

In [1]:
import jax.numpy as jnp
from jax import random, tree, grad, jit
from plugins.minitorch.utils import softmax, cross_entropy_loss, one_hot

key = random.PRNGKey(42)

How to use grad decides how we manage our trainable parameters.   
Fortunatly, Jax provide us a very convenient way calculate grad——by pass a dict of parameters:

In [2]:
params = {
    'fc:0': {
        'w': jnp.ones((2, 3)),
        'b': jnp.ones((3,)),
    },
    'fc:1': {
        'w': jnp.ones((3, 4)),
        'b': jnp.ones((4,)),
    },
}

x_train = random.normal(key, (1000, 2))
y_train = random.randint(key, (1000, 1), 0, 3)
y_train = one_hot(y_train, 4)
print(f' * shape of x_train: {x_train.shape}')
print(f' * shape of y_train: {y_train.shape}')

params_shape = tree.map(lambda x: str(x.shape), params)

import json
print(' * shape of params dict')
print(json.dumps(params_shape, indent=4))

 * shape of x_train: (1000, 2)
 * shape of y_train: (1000, 4)
 * shape of params dict
{
    "fc:0": {
        "b": "(3,)",
        "w": "(2, 3)"
    },
    "fc:1": {
        "b": "(4,)",
        "w": "(3, 4)"
    }
}


Here is a small change, that is the position of key 'w' & 'b' exchanged. The reason for this problem is: Jax tools always use multi-thread technique to optimize the calculation of dict, and it my not use lock to keep the order. So make you always use key to access value of key instead of use ways like: 

```python
w, b = params.values()
```

or this small bug will becomes a fatal threat for your code.

In [3]:
def forward(x, params):
    res = x @ params['fc:0']['w'] + params['fc:0']['b']
    res = jnp.maximum(0, res)
    res = res @ params['fc:1']['w'] + params['fc:1']['b']
    res = jnp.maximum(0, res)

    return softmax(res)

_loss = lambda params: cross_entropy_loss(y_train, forward(x_train, params))
_loss = jit(_loss)

In [5]:
import time

s = time.time()
grad_res = grad(_loss, argnums=0)(params)
print(f' * time cost: {time.time() - s} s')

grad_res_shape = tree.map(lambda x: str(x.shape), grad_res)
print(json.dumps(grad_res_shape, indent=4))

 * time cost: 0.003015279769897461 s
{
    "fc:0": {
        "b": "(3,)",
        "w": "(2, 3)"
    },
    "fc:1": {
        "b": "(4,)",
        "w": "(3, 4)"
    }
}


Here is a very simple MLP case. As you can see, we get a gradient dict of trainable parameters we inited before. And then you can apply this result to GD algorithms like SGD, Adam... easy right?  
But this is also not what we want. This kind of initalization and optimization is very complex. So we can apply Pipeline Pattern to make it more easy to manage this procedure for users:  

<p align="center">
  <img src="../assets/notebook_docs/minitorch.svg" alt="Overview of framework", width="50%">
</p>

<p align="center">
Overview of Framework
</p>

In the figure, red line represents for the process of init parameters.    
- Step 1, we get input parameters of users, such as 'input dim', 'output channel' ..., then use parameterizer in nn.layers to convert it into dict;
- Step 2, then we pass it into Initer to filter out trainable parameter dict;
- Step 3, create Jax array to contain the trainable parameters & return real trainable parameter dict just like what we do before.   