# 🥷 Makemore: becoming a backprop ninja

Hey there! 🙋‍♂️<br>
Here again we are. This time with **a serious challange**. We will get inside the backprop and will **learn by doing** it. 

**Some things to consider**:
- Here I will use Karpathy's code, but we will fill the code by our own *(of course)*
- There will be other explanations in the book to understand the backprop

Let's get started, looks promising 🤞

> **NOTE**: Since our makemore was *supposed to* handle total `28` different characters `A-Z`, `<`, `>`; and Karpathy's version supports total `27`, we will not change it and will **keep** it to `27`.  Because the main goal of this notebook is to ***understand*** the internals of backprop and not some tweak in the code.

<img src="./images/hero.png" height=500 width=1000>

# `1.` Loading & Creating the Dataset 

## `1.1` Creating the character mappings

In [2]:
import torch
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
# read in all the words
words = open('names.txt', 'r').read().splitlines()
print(len(words))
print(max(len(w) for w in words))
print(words[:8])

32033
15
['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']


In [4]:
### Build the vocabulary of characters and mappings to/from integers

# Total posible characters
chars = sorted(list(set(''.join(words))))

# 'a' -> 1
string_to_int = {s:i+1 for i,s in enumerate(chars)}

# add special character at 0
string_to_int['.'] = 0

# 1 -> 'a'
int_to_string = {i:s for s,i in string_to_int.items()}

vocab_size = len(int_to_string)
print(vocab_size)

27


In [5]:
int_to_string

{1: 'a',
 2: 'b',
 3: 'c',
 4: 'd',
 5: 'e',
 6: 'f',
 7: 'g',
 8: 'h',
 9: 'i',
 10: 'j',
 11: 'k',
 12: 'l',
 13: 'm',
 14: 'n',
 15: 'o',
 16: 'p',
 17: 'q',
 18: 'r',
 19: 's',
 20: 't',
 21: 'u',
 22: 'v',
 23: 'w',
 24: 'x',
 25: 'y',
 26: 'z',
 0: '.'}

## `1.2` Creating the dataset 

In [6]:
### Build the dataset
block_size = 3 # context length: how many characters do we take to predict the next one?

def build_dataset(words):
    X, Y = [], []

    for w in words:
        context = [0] * block_size
        for ch in w + '.':
            ix = string_to_int[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]  # crop and append

    X = torch.tensor(X)
    Y = torch.tensor(Y)
    print(X.shape, Y.shape)
    return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr,  Ytr  = build_dataset(words[:n1])     # 80%
Xdev, Ydev = build_dataset(words[n1:n2])   # 10%
Xte,  Yte  = build_dataset(words[n2:])     # 10%

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


## `1.3` Utility function to check the fail or pass

In [7]:
# utility function we will use later when comparing manual gradients to PyTorch gradients
def cmp(s, dt, t):
    '''
    s: Name of the process
    dt: The grad calculated by us
    t: The torch's grad
    
    This function is responsible to check whether:
    1. ALL grad of "ours" and torch's EXACTLY the same?
    2. Are they "close" enough to be called same?
    3. Check the maximum difference (simply our - torch)
    '''
    
    ex = torch.all(dt == t.grad).item()        # 1.
    app = torch.allclose(dt, t.grad)           # 2.
    maxdiff = (dt - t.grad).abs().max().item() # 3.
    print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

# `2.` Creating the network 

👉 The code is the same from the previous notebooks **but** it will change in the calculation **in the last layer**.

In [8]:
n_embd = 10 # the dimensionality of the character embedding vectors
n_hidden = 64 # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647)       # for reproducibility
C  = torch.randn((vocab_size, n_embd), generator=g) # embeddings

### Layer 1
W1 = torch.randn((n_embd * block_size, n_hidden), 
                 generator=g) * (5/3)/((n_embd * block_size)**0.5) # recall gain / sqrt(fan_in)
b1 = torch.randn(n_hidden, 
                 generator=g) * 0.1 # using b1 just for fun, it's useless because of BN


### Layer 2
W2 = torch.randn((n_hidden, vocab_size), 
                 generator=g) * 0.1
b2 = torch.randn(vocab_size,
                 generator=g) * 0.1

### BatchNorm parameters
bngain = torch.randn((1, n_hidden)) * 0.1 + 1.0
bnbias = torch.randn((1, n_hidden)) * 0.1

# Note: I am initializating many of these parameters in non-standard ways
# because sometimes initializating with e.g. all zeros could mask an incorrect
# implementation of the backward pass.

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
    p.requires_grad = True

4137


## 🧠 We need to "warp our head" around the NN 

<img src="./images/NN_1.png">

▶ Here we **just have** the nodes and "weights"... next up we will **add the biases**.

<img src="./images/NN_2.png">

▶ And finally we will do the **scaling** and **shifting**

<img src="./images/NN_3.png">

▶ Alright, now we know what is happening in the net, *at least visually*... we can have a bit better confidence moving forward.

# `3.` Prepare training variables

##  `3.1` Prepare training batch

In [9]:
batch_size = 32
n = batch_size # a shorter variable also, for convenience

# construct a minibatch
ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y

In [10]:
Xb.shape, Yb.shape

(torch.Size([32, 3]), torch.Size([32]))

## `3.2` Ignite the forward pass

This is the 🧊 **coolest** part of this notebook. 
Here we will *break* the steps down in the **manageable** chunks so that we can perform the backward pass.

I will *introduce* each part slowly to digest it properly. Let's begin.

### 👉 Have a look at the `original` code 

```python
'''
The following code is what we have used in the last notebook.
Should not make you freak out ;)
'''
losses = []
batch_size = 32 
epochs = 20_000

for i in range(epochs):
    # Data Sampling (FROM XTRAIN ONLY)
    sample_idx = torch.randint(0, Xtrain.shape[0], (batch_size,), generator=generator)
    
    # 1️⃣ Forward pass
    emb = embeddings[Xtrain[sample_idx]] 
    preact = emb.view(-1, block_size * emb_dim) @ W1 + b1 # seperated
    preact_mean = preact.mean(0, keepdim=True)   # preact mean
    preact_std = preact.std(0, keepdim=True)     # preact std
    preact = (preact - preact_mean) / preact_std # standardzation
    preact = (scaler * preact) + shifter         # scaling and shifting
    
    h = torch.tanh(preact)
    logits = h @ W2 + b2 
    loss = cross_entropy(logits, ytrain[sample_idx]) 
    
    # 2️⃣ Backward
    for p in parameters:
        p.grad = None
    loss.backward()
    
    # 3️⃣ Update - with decay
    learning_rate = 0.1 if i < 10_000 else 0.01
    for p in parameters:
        p.data += -learning_rate * p.grad
        
    losses.append(loss.item()) # for better visualization
    
plt.plot(losses, lw=0.1)
plt.title("Training log-loss");
```

### 👉 New `expanded` code

### `3.2.1` Transform the sampled data 

In [11]:
# forward pass, "chunkated" into smaller steps that are possible to backward one at a time

emb = C[Xb] # embed the characters into vectors
embcat = emb.view(emb.shape[0], -1) # concatenate the vectors

Thus now we have:

In [12]:
# 32 samples, each having 30 inputs (x1 to x30... keep the network in mind)
embcat.shape

torch.Size([32, 30])

### `3.2.2` Perform the matmul
> Now we are in the 1️⃣ code block of the "old" code.

In [13]:
# Linear layer 1
hprebn = embcat @ W1 + b1 # hidden layer pre-activation

### `3.2.3` Perform the batchnorm

☑ Find the constants

In [14]:
### BatchNorm layer ###

# 👉 Find the mean
bnmeani = 1 / n * hprebn.sum(0, keepdim=True)
# The above is equivalent to: `hprebn.sum(0) / n` ie. simply the mean :)

# 👉 Find the difference
bndiff = hprebn - bnmeani

# 👉 Squaring the difference to find variance below
bndiff2 = bndiff**2

# 👉 Finding the variance
bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)
# The above is equivalent to: `bndiff2.sum(0, keepdim=0) / (n-1)` ie. simply the variance :)

☑ Perform the standardization and so on

In [15]:
# 👉 Inverse the variance so that it can be "multiplied"
bnvar_inv = (bnvar + 1e-5)**-0.5
## Note: 1e-5 is a small constant to avoid division by zero

# 👉 Finally standardize the numbers
bnraw = bndiff * bnvar_inv

# 👉 Give the spice to it by gain and bias
hpreact = bngain * bnraw + bnbias

🎉 Yo! The batchnorm is done! **How simple that was!** <br>
Before we did that by:
```python
preact_mean = preact.mean(0, keepdim=True)   # preact mean
preact_std = preact.std(0, keepdim=True)     # preact std
preact = (preact - preact_mean) / preact_std # standardzation
preact = (scaler * preact) + shifter         # scaling and shifting
```

### `3.2.4` Apply an activation function

In [16]:
# Non-linearity
h = torch.tanh(hpreact) # hidden layer

### `3.2.5` Matmul for the final layer

In [17]:
### Linear layer 2
logits = h @ W2 + b2 # output layer

### `3.2.6` The cross entropy loss!
> We are back with the softmax!

***(Want some refresher? Go to `03 - Makemore MLP/Make more with MLP.ipynb` and scroll to topic: "🔥 Boom Boom 🔥")***

In [18]:
### cross entropy loss (same as F.cross_entropy(logits, Yb))

# 👉 Find the largest logit for each 32 samples
logit_maxes = logits.max(1, keepdim=True).values
## Now shape is (32, 1)

norm_logits = logits - logit_maxes # subtract max for numerical stability
counts = norm_logits.exp()

In the code above **we could have skipped the step** of finding the max logit and subtracting it from all logits. **We did this** to avoid numerical overflow. I think we have done/discussed this somewhere before but I can't recall where. 

> 💡 For now, just understand that "*subtracting the max logit from all logits will make the max logit `0` and others negative which will help `.exp()` not to introduce very large numbers*".


In [19]:
# Since we have the counts, we now can have the probability
counts_sum = counts.sum(1, keepdims=True)

# 👉 Dividing the sum of counts for each count
counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...

# 👉 But not a direct division, we will use inversion to match the grads better
probs = counts * counts_sum_inv

# 👉 Finally converting the probs back to log 
logprobs = probs.log()

# 👉 Findinf the loss by picking up the logs for each target character
loss = -logprobs[range(n), Yb].mean()

In [20]:
# Plain loss
loss

tensor(3.3426, grad_fn=<NegBackward0>)

## `3.3` Perform the backward pass! 🔥
> We are now in the 2️⃣ code block in the our training code.

In [21]:
# PyTorch backward pass
for p in parameters:
    p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way
          norm_logits, logit_maxes, logits, h, hpreact, bnraw,
          bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,
          embcat, emb]:
    t.retain_grad()

loss.backward()
loss

tensor(3.3426, grad_fn=<NegBackward0>)

# `4.` Backward by us 😎

The section header:

    # Exercise 1: backprop through the whole thing manually,
    backpropagating through exactly all of the variables
    as they are defined in the forward pass above, one by one

## `4.1` Derivative of `logprobs` wrt `loss`

> 💭 Which *"each of the values of `logprobs` impacting the `loss`?"*

We can see that the `loss` is simply the **mean** of the "plucked out" logs of the next token. Thus here only `32` tokens participate in calculating the loss.

Working for this line:
```python
loss = -logprobs[range(n), Yb].mean()
```

Which goes through the following process: 
- **Addition** of all logs (0 to 32)
- **Division** by 32

Thus, total `2` operations are being carried out. If you recall from the *micrograd* we were finding the gradient like below:

In [22]:
# Create a function which resambles the operation
def f(a, b, c):
    '''here taking only 3 numbers instead of 32'''
    return -(a+b+c)/3

# inputs
a = 2.0
b = 3.0
c = 10.0
# a small change
h_ = 0.0001

old = f(a, b, c)
a += h_
new = f(a, b, c)

# slope
(new - old) / h_

-0.33333333333551707

Which is around `-1/3`.

In [23]:
def f(a, b, c, d, e):
    '''here taking only 5 numbers instead of 32'''
    return -(a+b+c+d+e)/5

# inputs
a, b, c, d, e = 2.0, 3.0, 10.0, 55.0, 0.44
old = f(a, b, c, d, e)
a += h_
new = f(a, b, c, d, e)

# slope
(new - old) / h_

-0.20000000001019203

Which is around `-1/5`.

> 💭 **Which means** the slope is `-1/n`. And for our case the slope will be `-1/32`.

⚠ **REMEMBER**: We are only dealing with `1` logit from each `32` samples. Thus **DON'T MAKE** a mistake to assign the derivative of only selected logits to all other logits. Because currently in this state all other logits don't have ANY impact in the loss calculation and thus their grad will be `0`.

In [24]:
### We will use the convention `d__` for the variables which are calculateb by us.

dlogprobs = torch.zeros_like(logprobs)
dlogprobs.shape # the structure of probs then will be filled out

torch.Size([32, 27])

In [25]:
# the slope / grad calculation
d = -1/n
d

-0.03125

In [26]:
# assign in the structure!
dlogprobs[range(n), Yb] = d

In [27]:
# Check-1 ✅
cmp('logprobs', dlogprobs, logprobs)

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0


Yo! 🔥

## `4.2` Derivative of `probs` wrt `logprobs`

> 💭 Which is *"how performing `.log()` on `probs` impacted the `logprobs`?"*

Working for this line: 
```python
logprobs = probs.log()
```

In [28]:
# Create a function which resambles the operation
def f(a):
    '''here taking only 1 number because it works elementwise'''
    return np.log(a)

# inputs
a = 1.0
b = 2.0
c = 3.0

old = f(a)
a += h_
new = f(a)
d_a = (new - old) / h_

old = f(b)
b += h_
new = f(b)
d_b = (new - old) / h_

old = f(c)
c += h_
new = f(c)
d_c = (new - old) / h_

(d_a, d_b, d_c)

(0.9999500033329731, 0.49998750041746476, 0.3333277779016264)

> Oops! **Our this trick** only works for the simple addition, multiplication and subtraction. Here with log() which is non linear, we need to rely on something more to understand the numbers.

Searching on the internet gives $\frac{1}{x}$ as the derivative of `log`. <br>Thus...

In [29]:
1/a, 1/b, 1/c

(0.9999000099990001, 0.49997500124993743, 0.3333222225925802)

Are correct!

⚠ **REMEMBER**: Here  the ***"chain rule"*** will be applied! And also *we are safe* to use all elements because the `log` is applied on each `32 * 27` elements unlike only 32 elements in the exercise above 😊

In [30]:
#        the 1/x       chain rule (previous grads)
dprobs = (1 / probs) * dlogprobs

In [31]:
# Check-2 ✅
cmp('probs', dprobs, probs)

probs           | exact: True  | approximate: True  | maxdiff: 0.0


Yo! 🔥

### Insight 🧠
Andrej really has shared a beautiful insight at this point. See, `grads` show "how much impact" does the current element or node has on the loss *(because of the chain rule)*. And here we can see the calculation is: `(1 / probs) * dlogprobs`.

That means, if the character that has been picked up correctly, then **it will have the probability of 1.0** or near. That way the **grads will pass through** . But if the character which is picked is **wrong** the, it will have the probability low and hence (1 / probs) will introduce a bigger number and thus the gradients will be boosted!

Thus next time those values will have higher impact on loss and then the subsequent iterations will take care of them.

## `4.3` Derivative of `counts` wrt `probs`

> 💭 Which is *"what is effect of `counts` and `counts_sum_inv` on  `probs`?"*

Working for this line: 
```python
probs = counts * counts_sum_inv
```

Of course, here we are finding the `probs` by multiplying **and not** by dividing. And here we will need to find **both of their** gradients individually as they are both combined operations.

In [32]:
counts.shape

torch.Size([32, 27])

In [33]:
counts_sum_inv.shape

torch.Size([32, 1])

> 💭 Okay that means: "*for each 32 samples having 27 counts, we are multiplying with inverse sum of all 32 samples to get prob of each*."

We have the `*` sign!! 

Yo? 😊  <br>
-- or -- <br>
No Yo? 🙁 

**I think Yo!!**

In [34]:
def f(a, b):
    return a*b

In [35]:
a, b = 2.0, 3.0
old = f(a, b)
a += h_
new = f(a, b)

# slope
(new - old) / h_

3.000000000010772

In [36]:
a, b = 2.0, 3.0
old = f(a, b)
b += h_
new = f(a, b)

(new - old) / h_

2.0000000000042206

The same way `probs = counts * counts_sum_inv` are related. Thus...

In [37]:
### the line below uses .sum() function. Please view the "explainer" below
dcounts_sum_inv = (counts * dprobs).sum(1, keepdims=True)

# Check-3 ✅
cmp('count_sum_inv', dcounts_sum_inv, counts_sum_inv)

count_sum_inv   | exact: True  | approximate: True  | maxdiff: 0.0


Yo! 🔥

### 🔎 Explainer of what just happened.
Well, Andrej explains much clearly at this point, worth checking that one out [by this clip](https://youtube.com/clip/Ugkx_b4RbW-o9WG7uBN-MsMQK1v0G2pn2gfI).

<img src="./images/broadcast.png">

### 🙄 But, 
when we perform the `counts * dprobs` to get the derivative of the `count_sum_inv` we basically are multiplying:

- counts (32 x 27)
- dprobs (32 x 27)

Thus the result will be `32 x 27` and not the `32, 1` which is the shape of `count_sum_inv`.

### 🤔 What to do now?
If you recall... in the `01 - Micrograd/Micrograd Foundations.ipynb` book, we had one section **"🐞 Bug - 001"**. When having a single element which gets gradients from multiple other calculations... we used `+=` sign to solve that issue.

> ***Here we will do the same***. Just sum things up. And hence you saw "`(counts * dprobs).sum(1, keepdims=True)`". 

Hope that's clear 🤞

## `4.4` Derivative of `counts` wrt `probs`

Still, working for this line:
```python
probs = counts * counts_sum_inv
```

But now, it is the turn for the `counts` derivatives!

In [38]:
dcounts = counts_sum_inv * dprobs

In [39]:
# Check-4 ❌
cmp('counts', dcounts, counts)

counts          | exact: False | approximate: False | maxdiff: 0.006148857995867729


🤯 **Oops!** <br> Looks like a "small" max diff... but still it is problematic. But what is this?

**Have a look at the `lifeline of counts`...**
```python
counts = norm_logits.exp()

counts_sum = counts.sum(1, keepdims=True)

# USAGE - 1
counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...

# USAGE - 2
probs = counts * counts_sum_inv
```

Unlike `count_sum_inv` which is used in only single place in the calculating `probs`, `count` is used in **2** places. Thus, it's derivative has to `2` branches. <br>
**Which looks like:**

<img src="./images/counts_branch.png">

> 🐸 We can clearly see that the `counts` is associated with more than 1 operations, thus it gradients should not be solely calculated based on the single operation like we did above.

### 🤔 What to do then?
Here is the **exact** problem that we have to refer in **🐞 Bug - 001**. Here we will "sum" all the effects of count which come from:
1. `counts_sum`
2.`probs` *(which we just saw above and failed)*

So we will calculate for the `counts_sum` first!

## `4.5` Derivative of `counts_sum` wrt `count_sum_inv`

> 💭 How `counts_sum_inv` is affected by the `counts_sum` variable?

Working for this line:
```python
counts_sum_inv = counts_sum**-1
```

In [40]:
def f(a):
    return a**-1

a  = 2.0
old = f(a)
a += h_
new = f(a)

# slope
(new - old) / h_

-0.2499875006256591

In [41]:
b  = 10.0
old = f(b)
b += h_
new = f(b)

# slope
(new - old) / h_

-0.009999900000973172

Which translates to $-\frac{1}{x^2}$ <br>Thus...

In [42]:
# This confirms that we are correct
-(1/a**2), -(1/b**2)

(-0.24997500187487495, -0.009999800002999961)

In [43]:
dcounts_sum = -(1 / counts_sum**2) * dcounts_sum_inv

In [44]:
# Check-5 ✅
cmp('counts_sum', dcounts_sum, counts_sum)

counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0


Yo! 🔥

### And now we are ready to find for `count`

## `4.4` Derivative of `counts` wrt `probs` (again)

Still, working for this line:
```python
probs = counts * counts_sum_inv
```
But now we have enough data to calculate the gradient for the `counts`.

In [45]:
# Which is simply...

# USAGE-1
dcounts = 1.0 * dcounts_sum

# USAGE - 2 (NOTE: We used += below but explicitely)
dcounts = dcounts + (counts_sum_inv * dprobs)

In [46]:
# Check-4 ✅ (finally)
cmp('counts', dcounts, counts)

counts          | exact: True  | approximate: True  | maxdiff: 0.0


Yo! 🔥

## `4.6` Derivative of `norm_logits` wrt `counts`

> 💭 How `norm_logits` will affect the `counts` variable?

Working for this line:
```python
counts = norm_logits.exp()
```
This again is the elementwise operation...

In [47]:
def f(a):
    return np.exp(a)

a  = 2.0
old = f(a)
a += h_
new = f(a)

# slope
(new - old) / h_

7.389425564063856

In [48]:
b  = 10.0
old = f(b)
b += h_
new = f(b)

# slope
(new - old) / h_

22027.567154727876

Which translates to $e^x$  itself!<br>

> 💭 **That means** derivative of any element exponentiated, will have the derivative of the same magnitude. 

And also we will need to use the chain rule here. A small code snippet from **micrograd**:
```python
def exp(self):
    x = self.data
    out = Value(np.exp(x), (self, ), 'exp')
        
    def _backward():
        self.grad += out.data * out.grad # chain rule
    out._backward = _backward
    return out
```
Here we used `self.grad += out.data * out.grad` which confirms what I said above.

In [49]:
dnorm_logits = counts * dcounts

# Check-6 ✅
cmp('norm_logits', dnorm_logits, norm_logits)

norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0


Yo! 🔥

## `4.7` Derivative of `log_maxes` wrt `norm_logits`

> 💭 How `logit_maxes` will affect the `norm_logits` variable?

Working for this line:
```python
norm_logits = logits - logit_maxes # subtract max for numerical stability
```

Here again, we have *2* variables participating in the calculation of the `norm_logits`. Here, **luckily** we are ***experienced*** enough to know beforehand that the second variable `logits` will also be participating in 2 operations like `counts` did. 

But for now, `logit_maxes` will have an easy time.

In [50]:
# `-` has the derivative of -1.0
dlogit_maxes = (-1.0 * dnorm_logits).sum(1, keepdims=True)

# Check-7 ✅
cmp('logit_maxes', dlogit_maxes, logit_maxes)

logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0


Yo! 🔥

> ❓ Why `.sum`? Please revisit `4.3` 😊

### The `logit`'s life
```python
logits = h @ W2 + b2 # output layer

# USAGE - 1
logit_maxes = logits.max(1, keepdim=True).values

# USAGE - 2
norm_logits = logits - logit_maxes # subtract max for numerical stability
```

Which is...

<img src="./images/logits_branch.png">

## `4.8` Derivative of `logits` wrt `logit_maxes`

> 💭 How `logits` will affect the `logit_maxes` variable?

Working for this line:
```python
logit_maxes = logits.max(1, keepdim=True).values
```

Alright, so first of all we will calculate the `logits'` derivative for `logit_maxes` and then we will find another for `norm_logits` and will sum them all up!

### 🧠 An insight
Before continue though, we **need to keep in mind** that *in calculation of the `norm_logits` we are **only using** the max values per sample to normalize them*. 

Thus, the gradient of `logits` in this phase **will only be counted** towards those values which were max and were chosen.

### 🙄 But... what will be their gradient?
See, we are **only** picking out the elements, and **picking out** is **not** an operation so it doesn't have any *fancy* derivation. But still since we are picking and using them, they will get the `1` as the gradient and those which weren't used will get `0` gradient.

In [51]:
# recall these values had the indices
logits.max(1, keepdims=True).indices

tensor([[ 1],
        [ 2],
        [19],
        [15],
        [15],
        [25],
        [16],
        [ 3],
        [19],
        [ 8],
        [15],
        [ 3],
        [22],
        [18],
        [ 8],
        [ 5],
        [ 2],
        [ 1],
        [22],
        [ 6],
        [10],
        [19],
        [22],
        [22],
        [23],
        [ 5],
        [22],
        [20],
        [24],
        [ 8],
        [24],
        [13]])

These will be used to locate the elements in the `logits`. Thus, we will do something like below:

In [52]:
# I am following the old-school way :)
dlogits = torch.zeros_like(logits)
dlogits[range(logits.shape[0]), logits.max(1).indices] = 1

In [53]:
plt.imshow(dlogits);

In [54]:
### 🚨 THIS IS AN IMPORTANT STEP OTHERWISE THERE WILL BE A SLIGHT DIFFERENCE 🚨 ###

# because the gradients have to flow through
dlogits = (dlogits * dlogit_maxes)

## `4.9` Derivative of `logits` wrt `norm_logits`

> 💭 How `logits` will affect the `norm_logits` variable?

Again, Working for this line:
```python
norm_logits = logits - logit_maxes
```

Here we will find the **second part** of the logits.

In [55]:
# which is...
dlogits = dlogits + (1.0 * dnorm_logits)

In [56]:
# Check-8 ✅
cmp('logits', dlogits, logits)

Yo! 🔥

## 🤩 The crazy part of the backprop
Now we will be backproping through the `matmul` layer! Holy! 

Let's recap a little bit:
- We have been backing ourselves out from `loss`
- `loss` ← `logprobs` ← `probs` ← `counts` ← `norm_logits` ← `logits`

That's our **story till now**, with a little bit of extra things here and there, but now this is the *boss*! Let's tackle it now. <br>
Have a look at the following code as a refresher:
```python
# Layer 1
W2 = torch.randn((n_hidden, vocab_size),  generator=g) * 0.1
b2 = torch.randn(vocab_size, generator=g) * 0.1

### ...omitted code...

# Non-linearity
h = torch.tanh(hpreact) # hidden layer

### Linear layer 2
logits = h @ W2 + b2 # output layer
```

In [57]:
dlogits.shape

## `4.10` Derivative of `h` wrt `logits`

> 💭 How  our matrix multiplication `h` will affect the `logits` variable?

Again, Working for this line:
```python
logits = h @ W2 + b2 # output layer
```

Yo! This is getting crazy at this part.

<img src="./images/h_layer.png">

> See, here `h` is first matmuled and *then* whole calculation is added.

And inside...

In [58]:
# Creating a simple function to help us out here as well
# taking a 2x2, 2x2 metrics and doing the matmul manually

def f(a, b):
    """
    Internally the numpy `@` will perform the following:
    d1 = (x1*w1 + x2*w3)
    d2 = (x1*w2 + x2*w4)
    d3 = (x3*w1 + x4*w3)
    d4 = (x3*w2 + x4*w4)
    
    result = np.array([[d1, d2], [d3, d4]])
    """
    result = a @ b
    return result
    

x1, x2, x3, x4  = 1.0, 3.0, 5.0, 7.0
w1, w2, w3, w4  = 7.0, 9.0, 1.0, 2.0

a = np.array([[x1, x2], [x3, x4]])
b = np.array([[w1, w2], [w3, w4]])

old = f(a, b)
a[1, 0] += h_ # or changing x1 slightly
new = f(a, b)

# slope
((new - old) / h_).sum()

In [59]:
a = np.array([[x1, x2], [x3, x4]])
b = np.array([[w1, w2], [w3, w4]])
a @ b

That means the `X1` has the local gradient of `7 + 9 = 16`. *(as you can see, `X1` is related with `W1` and `W2` which take `7` and `9` values respectively.)*

In [60]:
# but suppose we have the gradient of the output of a@b (as we do with logits)
output_grad = np.array([[1., 2.], [3., 4.]])

In [61]:
# then the `a` layer will have the gradient
output_grad @ b.T

As discussed in the Andrej's video. But we can verify. 

Say we are finding the same gradient for `X1` as found about which has the local gradient of 16. **But** because of the chain rule we will need to consider the `d1` and `d2` gradients as well.

In [62]:
# d1 gradient (from `output_grad`)
d1_grad = output_grad[0, 0]
d2_grad = output_grad[0, 1]

# The gradient for `x1` since it is what affects d1 and d2
d1_grad * 7 + d2_grad * 9

🎉 That **confirms!!** the result!

> Here in the example I tried to find the slope of `x1`. Which is related with `w1` and `w3` thus it will have the slope `7` and `9` respectively and added up. 

You can change and play around with **any** value you like instead of `x1` and see how it affects.

Since all weights are **related with the `xn` values** with the `*` multiplication, we can see that how they are getting the values.

### 🤨 So what? Exactly?
I know it sound a little confusing and isn't properly understood yet. Let's take some additional steps to see what happens behind the scenes.

But, now. **Please pay attention**. A little miss would be dangerous.

<img src="./images/the_h_story.png">

## Holy! 🤯
That was a lot for a small stuff, but I had to, to make things **clear-er**. Now we can **confidently** use the *hacky* thing here.

> Remember, we still are in the "*`4.10` Derivative of `h` wrt `logits`*" section!

In [63]:
dh = dlogits @ W2.T

# Check-9 ✅
cmp('h', dh, h)

Yo! 🔥

## `4.11` Derivative of `W1` wrt `logits`

> 💭 How  our matrix multiplication `W1` will affect the `logits` variable?

Again, Working for this line:
```python
logits = h @ W2 + b2 # output layer
```

Now, its really easy, isn't it?

In [64]:
dW2 = h.T @ dlogits

# Check-10 ✅
cmp('W2', dW2, W2)

Yo! 🔥

## `4.12` Derivative of `b1` wrt `logits`

> 💭 How our  addition of `b1` will affect the `logits` variable?

Again, Working for this line:
```python
logits = h @ W2 + b2 # output layer
```

This is the easy part, we actually could've done this way before `h` and `W2`!

In [65]:
dlogits.shape

In [66]:
b2.shape

In [67]:
# Since `b2` is related with (h @ W1) with the `+` sign, it will simply
# "pass" the grads of `dlogits`.

db2 = dlogits.sum(0) # need to perform the sum becaise basically `b2` is broadcast

# Check-11 ✅
cmp('b2', db2, b2)

Yo! 🔥

## 📜 Now, towards the layer-1

## `4.13` Derivative of `hpreact` wrt `h`

> 💭 How the `tanh(hpreact)` will affect the `h` variable?

Working for this line:
```python
# Non-linearity
h = torch.tanh(hpreact) # hidden layer
```

This is simply back propogating through the `tanh` activation. Let's refer micrograd.

```python
def tanh(self):
    x = self.data
    t = (np.exp(2*x) - 1) / (np.exp(2*x) + 1)
    out = Value(t, (self,), "tanh")
    
    def _backward():
        self.grad = (1 - t ** 2) * out.grad # chain rule here too!
        out._backward = _backward
    return out
```

Let's apply this here too.

In [68]:
dhpreact = (1 - h**2) * dh

# Check-12 ✅
cmp('hpreact', dhpreact, hpreact)

Yo! 🔥

👉 Now we will be working with this portion of the forward pass

```python
# 👉 Inverse the variance so that it can be "multiplied"
bnvar_inv = (bnvar + 1e-5)**-0.5
## Note: 1e-5 is a small constant to avoid division by zero

# 👉 Finally standardize the numbers
bnraw = bndiff * bnvar_inv

# 👉 Give the spice to it by gain and bias
hpreact = bngain * bnraw + bnbias
```

## `4.14` Derivative of `bngain` wrt `hpreact`

> 💭 How the `bngain` will affect the `hpreact` variable?

Working for this line:
```python
hpreact = bngain * bnraw + bnbias
```
Once again **a simple step here**.

In [69]:
dbngain = (bnraw * dhpreact).sum(0, keepdims=True)

# Check-13 ✅
cmp('bngain', dbngain, bngain)

Yo! 🔥

## `4.15` Derivative of `bnraw` wrt `hpreact`

> 💭 How the `bnraw` will affect the `hpreact` variable?

**Again**, Working for this line:
```python
hpreact = bngain * bnraw + bnbias
```
Once, once again **a simple step here**.

In [70]:
dbnraw = (bngain * dhpreact)

# Check-14 ✅
cmp('bnraw', dbnraw, bnraw)

Yo! 🔥

## `4.16` Derivative of ` bnbias` wrt `hpreact`

> 💭 How the ` bnbias` will affect the `hpreact` variable?

**Again, again**, Working for this line:
```python
hpreact = bngain * bnraw + bnbias
```
Once, once, once again **a simple step here**.

In [71]:
dbnbias = dhpreact.sum(0, keepdims=True)

# Check-15 ✅
cmp('bnbias', dbnbias, bnbias)

Yo! 🔥

## `4.17` Derivative of `bnvar_inv` wrt `bnraw`

> 💭 How the `bnvar_inv` will affect the `bnraw` variable?

Working for this line:
```python
# 👉 Finally standardize the numbers
bnraw = bndiff * bnvar_inv
```

Here if you remember, we **are standardizing** the values of first layer so that it stays "well behaved". And on this `braw` are will be using the `bias` and `gain` to scale and shift it.

In [72]:
dbnvar_inv = (bndiff * dbnraw).sum(0, keepdims=True)

# Check-16 ✅
cmp('bnvar_inv', dbnvar_inv, bnvar_inv)

Yo! 🔥

## `4.18` Derivative of `bndiff` wrt `bnraw`

> 💭 How the `bndiff` will affect the `bnraw` variable?

**Again**, Working for this line:
```python
# 👉 Finally standardize the numbers
bnraw = bndiff * bnvar_inv
```

### 🤚 Hold on!
There is some **history** behind `bndiff`.

### `bndiff`'s life cycle

```python
### BatchNorm layer ###

# 👉 Find the mean
bnmeani = 1 / n * hprebn.sum(0, keepdim=True)
# The above is equivalent to: `hprebn.sum(0) / n` ie. simply the mean :)

# 👉 Find the difference (BIRTH) 🔴
bndiff = hprebn - bnmeani

# 👉 Squaring the difference to find variance below (USAGE-1) 🔴
bndiff2 = bndiff**2

# 👉 Finding the variance
bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)
# The above is equivalent to: `bndiff2.sum(0, keepdim=0) / (n-1)` ie. simply the variance :)

# 👉 Inverse the variance so that it can be "multiplied"
bnvar_inv = (bnvar + 1e-5)**-0.5
## Note: 1e-5 is a small constant to avoid division by zero

# 👉 Finally standardize the numbers  (USAGE-2) 🔴
bnraw = bndiff * bnvar_inv 

# 👉 Give the spice to it by gain and bias
hpreact = bngain * bnraw + bnbias
```

<img src="./images/bndiff_life.png">

### A hell lot of chains! ⛓
Let's start with finding the gradients for `bnvar` *(because we are already done with the `bnvar_inv`)*.

## `4.19` Derivative of `bnvar` wrt `bnvar_inv`

> 💭 How performing `+ 1e-5` and `**-0.5` on `bnvar` affects the `bnvar_inv`?

Working for this line:
```python
# 👉 Inverse the variance so that it can be "multiplied"
bnvar_inv = (bnvar + 1e-5)**-0.5
```

To help you recall this line a little, we are basically **inversing the variance** for the $\frac{(x - \mu)^2}{\sigma^2}$ so that it can be multiplied. 

Thus, we are in the standardization step.

> 📝 We are **also** adding a simple, small number `0.00005` in the variance aka `bnvar` so that we don't get the zero division error. *Thus, we need to take that into the account as well.*

In [73]:
# A sample function to guide us through
def f(a):
    return (a + 1e-5) **-0.5

In [74]:
a = 3.0

old = f(a)
a += h_
new = f(a)

(new - old) / h_

In [75]:
a = 5.0

old = f(a)
a += h_
new = f(a)

(new - old) / h_

**Looking out for the derivative** on the internet would give us:

$-\dfrac{1}{2\left(x+\frac{1}{100000}\right)^\frac{3}{2}}$

In [76]:
# Checks

### when a=3.0
print(
    -(1 / (2 * ((3.0 + 1e-5)**1.5)))
)


### when a=5.0
print(
    -(1 / (2 * ((5.0 + 1e-5)**1.5)))
)

**Approximate**, but it works. So let's use them.

#### 1️⃣ The way we found
$-\dfrac{1}{2\left(x+\frac{1}{100000}\right)^\frac{3}{2}}$

In [77]:
dbnvar = -(1 / (2 * ((bnvar + 1e-5)**1.5))) * dbnvar_inv

# Check-17 ✅
cmp('bnvar', dbnvar, bnvar)

👉 As said, **it is the approximate** and also this is how I found it online. Let's check *what Andrej has used* in the lecture.

#### 2️⃣ Andrej's way

In [78]:
dbnvar = -0.5*(bnvar + 1e-5)**-1.5 * dbnvar_inv

# Check-17 ✅
cmp('bnvar', dbnvar, bnvar)

👉 **Whoa!!** It works! Exact.

> 😊 We will keep this *Andrej's way* because it is giving `exact = True`. Though, we can take *our way*, but in the subsequent checks it will "retain" it's approximate effect and will give the "approximate" results all times.

Yo! 🔥

## `4.20` Derivative of `bndiff2` wrt `bnvar`

> 💭 How the calculation of the variance done via `bndiff2` will affect the `bnvar`?

Working for this line:
```python
# 👉 Finding the variance
bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)
# The above is equivalent to: `bndiff2.sum(0, keepdim=0) / (n-1)` ie. simply the variance :)
```

Now here, multiple things are going on together, but if we look in the graph once again, we can see that the things are still differenciable. 

Let's see how.

👀 Look that the `.sum()` is simply `a + b + c + d` and then it is multiplied with the constantv `1/(n - 1)`. Same thing that we have been doing till now, will also be applied here.

In [79]:
# the function at the rescue!
def f(a, b, c, d):
    sum_ = a + b + c + d
    constant = 1 / (4 - 1) # `4` is in place of `n`
    return constant * sum_

a, b, c, d = 1., 3., 5., 2.

old = f(a, b, c, d)
a += h_
new = f(a, b, c, d)

(new - old) / h_

Which equals to `1/3` as a slope. And having say `6` variables would give use `1/5` as the slope. Let's try.

In [80]:
# with 6 variables
def f(a, b, c, d, e, f_):
    sum_ = a + b + c + d  + e + f_
    constant = 1 / (6 - 1) # `6` is in place of `n`
    return constant * sum_

a, b, c, d, e, f_ = 1., 3., 5., 2., 3., 5.

old = f(a, b, c, d, e, f_)
a += h_
new = f(a, b, c, d, e, f_)

(new - old) / h_

Right? It is `1/5`.

> ⚠  Watchout the shape! Here we are dealing with the layer which had `32 x 64` before and after the operation it became `1 x 64`. <br> <br> In the situations which are the reverse of this, we do **sum** while backpropogating *(many time what we have done above)*, but herer we will do the reverse of it, we will have to "cast" the shape from `1 x 64` to `32 x 64` and the easy way to do this, is by using the `torch.ones()` as shown in the lecture 😉

In [81]:
propogated_shape = torch.ones_like(bndiff2)
dbndiff2 = (1/(n-1)) * propogated_shape * dbnvar

In [82]:
dbndiff2.shape, bndiff2.shape

In [83]:
# Check-18 ✅
cmp('bndiff2', dbndiff2, bndiff2)

Yo! 🔥

## `4.21` Derivative of `bndiff` wrt `bndiff2` (usage-1)

> 💭 How the `bndiff` will affect the `bndiff2` variable?

Working for this line:
```python
# 👉 Squaring the difference to find variance below (USAGE-1) 🔴
bndiff2 = bndiff**2
```

This is simple.

In [84]:
# The part one calculation
dbndiff = (2 * bndiff) * dbndiff2

## `4.18` Derivative of `bndiff` wrt `bnraw` (again & finally | Usage - 2)

> 💭 How the `bndiff` will affect the `bnraw` variable?

Working for this line:
```python
# 👉 Finally standardize the numbers  (USAGE-2) 🔴
bnraw = bndiff * bnvar_inv
```

In [85]:
# The second half :)
dbndiff += bnvar_inv * dbnraw

In [86]:
# Check-19  ✅
cmp('bndiff', dbndiff, bndiff)

Yo! 🔥

## `4.22` Derivative of `bnmeani` wrt `bndiff`

> 💭 How the `bnmeani` will affect the `bndiff` variable?

Working for this line:
```python
# 👉 Find the difference
bndiff = hprebn - bnmeani
```

Well, that is simple.

In [87]:
dbnmeani = (-1.0 * dbndiff).sum(0, keepdims=True)

# Check-20 ✅
cmp('bnmeani', dbnmeani, bnmeani)

Yo! 🔥

## `4.23` Derivative of `hprebn` wrt `bndiff`

> 💭 How the `hprebn` will affect the `bndiff` variable?

**Again**, Working for this line:
```python
# 👉 Find the difference
bndiff = hprebn - bnmeani
```

▶ Well, It has a little history. Let's see it's lifecycle.

```python
# Linear layer 1 (BIRTH) 🔴
hprebn = embcat @ W1 + b1 # hidden layer pre-activation


# 👉 Find the mean (USAGE - 1) 🔴
bnmeani = 1 / n * hprebn.sum(0, keepdim=True)
# The above is equivalent to: `hprebn.sum(0) / n` ie. simply the mean :)

# 👉 Find the difference (USAGE - 2) 🔴
bndiff = hprebn - bnmeani
```

<img src="./images/hpredn_life.png">

That's why firs of all we will need to find other derivations along the way, which is not crazy.

## `4.24` Derivative of `hprebn.sum()` wrt `bnmeani`

> 💭 How the `hprebn` will affect the `bndiff` variable?

Working for this line:
```python
# 👉 Find the mean (USAGE - 1) 🔴
bnmeani = 1 / n * hprebn.sum(0, keepdim=True)
# The above is equivalent to: `hprebn.sum(0) / n` ie. simply the mean :)
```

Here the `hprebn` is summed up and then multiplied with a constant `1/n`.

> 💭 That means, `1/n` is it's slope. After having enough experience with previous exercises, I think we don't need to uise our `f` function 😉.

In [88]:
# First half of `dhprebn`
dhprebn = torch.ones_like(hprebn) * (1/n) * dbnmeani
# again, did `ones_like` to make the shape proper

## `4.23` Derivative of `hprebn` wrt `bndiff` (continued, again)

> 💭 How the `hprebn` will affect the `bndiff` variable?

**Again**, Working for this line:
```python
# 👉 Find the difference (USAGE - 2) 🔴
bndiff = hprebn - bnmeani
```

Which is simple.

In [89]:
# Second half
dhprebn += 1.0 * dbndiff

In [90]:
# Check-21 ✅
cmp('hprebn', dhprebn, hprebn)

Yo! 🔥

## `4.25` Derivative of `embcat` wrt `hprebn`

> 💭 How the embeddings `embcat` will affect the `hprebn` variable?

Working for this line:
```python
# Linear layer 1
hprebn = embcat @ W1 + b1 # hidden layer pre-activation
```

This is again the simple case which we **"demystified"** in the `4.10` section, do visit it once again 🤗

In [91]:
dembcat = dhprebn @ W1.T

# Check-22 ✅
cmp('embcat', dembcat, embcat)

Yo! 🔥

## `4.26` Derivative of `W1` wrt `hprebn`

> 💭 How the weights `W1` will affect the `hprebn` variable?

**Again**,Working for this line:
```python
# Linear layer 1
hprebn = embcat @ W1 + b1 # hidden layer pre-activation
```

In [92]:
dW1 = embcat.T @ dhprebn

# Check-23 ✅
cmp('W1', dW1, W1)

Yo! 🔥

## `4.27` Derivative of `b1` wrt `hprebn`

> 💭 How the bias `b1` will affect the `hprebn` variable?

**Again**,Working for this line:
```python
# Linear layer 1
hprebn = embcat @ W1 + b1 # hidden layer pre-activation
```

In [93]:
db1 = (1.0 * dhprebn).sum(0, keepdims=True)

# Check-24 ✅
cmp('b1', db1, b1)

Yo! 🔥

And finally...

## `4.28` Derivative of `emb` wrt `embcat`

> 💭 How changing the shape of `emb` will affect the `embcat`?

Working for these lines:
```python
# forward pass, "chunkated" into smaller steps that are possible to backward one at a time

emb = C[Xb] # embed the characters into vectors
embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
```

In [94]:
demb = dembcat.view(emb.shape)

# Check-25 ✅
cmp('emb', demb, emb)

Yo! 🔥

## `4.29` Derivative of `C` wrt `emb`

> 💭 How  selecting only 32 examples in `C` affects the `emb`?

**Still**, Working for these lines:
```python
# forward pass, "chunked" into smaller steps that are possible to backward one at a time

emb = C[Xb] # embed the characters into vectors
embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
```

<img src="./images/final_flow.png">

In [95]:
# It was simpler than I expected

## Create a grad matrix 
dC = torch.zeros_like(C)

In [96]:
# For each sample that passed in...
for th, sample in enumerate(Xb):
    for idx, loc in enumerate(sample):
        dC[loc] += demb[th, idx]

## 🙄 Explainer
```python
👉 for th, sample in enumerate(Xb):
```

    It will take:
    Xb = [[1, 24, 5],
          [6, 4, 3],
          [2, 24, 5],
          [1, 4, 7],
          [5, 6, 5]]

    [1, 24, 5] in the first iteration.
    
```python
👉 for idx, loc in enumerate(sample):
```

    It will take from: [1, 24, 5]
    One-by-one so that it can be used in `C`
    
Then it will simply work. Have a look at the code, **it will make sense**.

In [97]:
# Check-26 ✅
cmp('C', dC, C)

Yo! 🔥

# 🔥🔥🔥 We have done this!!
**Exercise - 1** completed successfully!! 

# Yo! That was a LOT!
If you're not feeling something, you **should**! The job was incredible dude!

Now a simple note! *We won't be going to solve the next exercises*. 😢 <br>
The simple reason is that, I am **not a calculus guy** to be frank and I think whatever we have done above is more than enough for us to understand what is happening under the hood.

The main points the exercise #3 and #4 cover are:
- Of course all atomic pieces are not going to be a part of internal **loss.backward()** because in practice that will take a **lot** of time to compute.
- There has to be a simplified expression which can give the same result.
- That will increase the computation speed and is practical.

So for now, I am ending this notebook here and will see you at another exciting notebook with **Wavenet**!

___

PS: And yes we can still be proud and can say we have achieved `#loss.backward()` confidence! 😎

<img src="./images/dog_sunglasses.png">