In [1]:
import math
import random
import time
from typing import List, Literal

import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import torch
from torch import Tensor
import torch.nn.functional as F

%matplotlib inline

In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('Using CUDA')
else:
    device = torch.device('cpu')
    print('Using CPU')

Using CUDA


In [3]:
RANDOM_SEED = 42
TORCH_GENERATOR_SEED = 2147483647

random.seed(RANDOM_SEED)
g = torch.Generator(device=device).manual_seed(TORCH_GENERATOR_SEED)

In [4]:
words = open('names.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [5]:
len(words)

32033

In [6]:
chars = sorted(list(set(str().join(words))))
stoi = {s: i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i: s for s, i in stoi.items()}
print(itos)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [7]:
vocab_size = len(itos)
vocab_size

27

In [8]:
def build_dataset(words, block_size, device):
    X_data, Y_data = [], []
    for word in words:
        context = [0 for _ in range(block_size)]
        for ch in word + '.':
            ix = stoi[ch]

            X_data.append(context)
            Y_data.append(ix)

            context = context[1:] + [ix]

    X = torch.tensor(X_data, device=device)
    Y = torch.tensor(Y_data, device=device)
    return X, Y

In [9]:
block_size = 3

random.shuffle(words)

n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))

X_tr, Y_tr = build_dataset(words[:n1], block_size, device)
X_dev, Y_dev = build_dataset(words[n1:n2], block_size, device)
X_te, Y_te = build_dataset(words[n2:], block_size, device)

print(f'{X_tr.shape=}, {Y_tr.shape=}\n{X_dev.shape=}, {Y_dev.shape=}\n{X_te.shape=}, {Y_te.shape=}')

X_tr.shape=torch.Size([182625, 3]), Y_tr.shape=torch.Size([182625])
X_dev.shape=torch.Size([22655, 3]), Y_dev.shape=torch.Size([22655])
X_te.shape=torch.Size([22866, 3]), Y_te.shape=torch.Size([22866])


In [10]:
def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [11]:
n_embd = 10
n_hidden = 64

C = torch.randn((vocab_size, n_embd),             generator=g, device=device)
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g, device=device) * (5.0/3.0) / (math.sqrt(n_embd * block_size))
b1 = torch.randn(n_hidden,                        generator=g, device=device) * 0.1
W2 = torch.randn((n_hidden, vocab_size),          generator=g, device=device) * 0.1
b2 = torch.randn(vocab_size,                      generator=g, device=device) * 0.1

bngain = torch.randn((1, n_hidden), device=device) * 0.1 + 1.0
bnbias = torch.randn((1, n_hidden), device=device) * 0.1

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
for p in parameters:
    p.requires_grad = True

print(f'Number of parameters: {sum(p.nelement() for p in parameters)}')

Number of parameters: 4137


In [12]:
batch_size = 32
ix = torch.randint(0, X_tr.shape[0], (batch_size,), generator=g, device=device)
Xb, Yb = X_tr[ix], Y_tr[ix]

In [13]:
emb = C[Xb]
embcat = emb.view(emb.shape[0], -1)

hprebn = embcat @ W1 + b1

bnmeani = 1.0 / batch_size * hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
bnvar = 1 / (batch_size-1) * bndiff2.sum(0, keepdim=True)
bnvar_inv = (bnvar + 1e-5)**-0.5
bnraw = bndiff * bnvar_inv

hpreact = bngain * bnraw + bnbias
h = torch.tanh(hpreact)

logits = h @ W2 + b2

logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes               # Subtract by max for numerical safety

counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdim=True)
counts_sum_inv = counts_sum**-1
probs = counts * counts_sum_inv                 # Multiplying by inverse instead of dividing
logprobs = probs.log()
loss = -logprobs[range(batch_size), Yb].mean()

for p in parameters:
    p.grad = None

for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, norm_logits, logit_maxes, logits, h, hpreact, bnraw, bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani, embcat, emb]:
    t.retain_grad()
loss.backward()
loss

tensor(3.3605, device='cuda:0', grad_fn=<NegBackward0>)

In [14]:
logprobs[range(batch_size),Yb].shape

torch.Size([32])

### Derivative of `loss` with respect to `logprobs`.

The `logprobs` tensor has dimension $32 \times 27$.

In [15]:
logprobs.shape

torch.Size([32, 27])

Let $m = 32$ be the batch size, $n = 27$ be the vocabulary size, and $\mathbf{LP}$ be the `logprobs` tensor.

$$
\mathbf{LP} = \begin{bmatrix}
lp_{1 \, 1} & lp_{1 \, 2} & \dots & lp_{1 \, n}\\
lp_{2 \, 1} & lp_{2 \, 2} & \dots & lp_{2 \, n}\\
\vdots   & \vdots   & \ddots & \vdots\\
lp_{m \, 1} & lp_{m \, 2} & \dots & lp_{m \, n}\\
\end{bmatrix}
$$

Let $\mathscr{L}$ be the loss. We want to find $\frac{\partial \mathscr{L}}{\partial \mathbf{LP}}$.

We calculate loss as

```python
loss = -logprobs[range(batch_size), Yb].mean()
```

So let's first find an expression for `-logprobs[range(batch_size), Yb]`. We call this tensor $\mathbf{LP_b}$. We index the rows with `range(batch_size)`, which means we get all rows $0, 1, \dots, m$. And for each row we index that row with `Yb`. This means the size of `Yb` must be the batch size. Indeed it is.

In [16]:
print(f'{len(range(batch_size))=}, {Yb.shape=}')

len(range(batch_size))=32, Yb.shape=torch.Size([32])


The vector $\mathbf{LP}_b$ can be expressed as

$$
\mathbf{LP}_b = \begin{bmatrix}
    \mathbf{LP}_{1 \; {\mathbf{y}_b}_1}\\
    \mathbf{LP}_{2 \; {\mathbf{y}_b}_2}\\
    \vdots\\
    \mathbf{LP}_{m \; {\mathbf{y}_b}_m}\\
\end{bmatrix}
$$

And the loss is the negative mean of the elements in this vector.

$$
\mathscr{L} = - \frac{{\mathbf{y}_b}_1 + {\mathbf{y}_b}_2 + \dots + {\mathbf{y}_b}_m}{m}
$$

For the elements of $\mathbf{LP}$ that are part of $\mathbf{LP}_b$, the loss $\frac{\partial \mathscr{L}}{\partial \mathbf{LP}_{i \, j}}$ is

$$
\begin{align*}
    \frac{\partial \mathscr{L}}{\partial \mathbf{LP}_{i \, j}} &= \frac{\partial}{\partial \mathbf{LP}_{i j}} \left( - \frac{{\mathbf{y}_b}_1 + {\mathbf{y}_b}_2 + \dots + {\mathbf{y}_b}_m}{m} \right)\\
    &= 0 + 0 + \dots + \frac{\partial}{\partial \mathbf{LP}_{i j}} \left( - \frac{\mathbf{LP}_{i \, j}}{m} \right) + \dots + 0 + 0\\
    \frac{\partial \mathscr{L}}{\partial \mathbf{LP}_{i \, j}} &= -\frac{1}{m}
\end{align*}
$$

For the elements of $\mathbf{LP}$ that are *not* part of $\mathbf{LP}_b$, the loss $\frac{\partial \mathscr{L}}{\partial \mathbf{LP}_{i \, j}}$ is $0$.

In [17]:
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(batch_size), Yb] = -1 / batch_size
cmp('logprobs', dlogprobs, logprobs)

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0


### Derivative of `loss` with respect to `probs`.

We want to find $\frac{\partial \mathscr{L}}{\partial \mathbf{P}}$, where $P$ is the tensor `probs`.

Note that `logprobs` is defined as

```python
logprobs = probs.log()
```

So, instead of calculating $\frac{\partial \mathscr{L}}{\partial \mathbf{P}}$ directly, we can use the chain rule.

$$
\frac{\partial \mathscr{L}}{\partial \mathbf{P}} = \frac{\partial \mathscr{L}}{\partial \mathbf{LP}} \frac{\partial \mathbf{LP}}{\partial \mathbf{P}}
$$

We already know $\frac{\partial \mathscr{L}}{\partial \mathbf{LP}}$, so we just need to calculate $\frac{\partial \mathbf{LP}}{\partial \mathbf{P}}$.

If $\mathbf{LP}$ is defined as

$$
\mathbf{LP} = \begin{bmatrix}
lp_{1 \, 1} & lp_{1 \, 2} & \dots & lp_{1 \, n}\\
lp_{2 \, 1} & lp_{2 \, 2} & \dots & lp_{2 \, n}\\
\vdots   & \vdots   & \ddots & \vdots\\
lp_{m \, 1} & lp_{m \, 2} & \dots & lp_{m \, n}\\
\end{bmatrix}
$$

and $\mathbf{P}$ is defined as

$$
\mathbf{P} = \begin{bmatrix}
p_{1 \, 1} & p_{1 \, 2} & \dots & p_{1 \, n}\\
p_{2 \, 1} & p_{2 \, 2} & \dots & p_{2 \, n}\\
\vdots   & \vdots   & \ddots & \vdots\\
p_{m \, 1} & p_{m \, 2} & \dots & p_{m \, n}\\
\end{bmatrix}
$$

then $\mathbf{LP}$ can also be defined as

$$
\mathbf{LP} = \log(\mathbf{P}) = \log\left(\begin{bmatrix}
p_{1 \, 1} & p_{1 \, 2} & \dots & p_{1 \, n}\\
p_{2 \, 1} & p_{2 \, 2} & \dots & p_{2 \, n}\\
\vdots   & \vdots   & \ddots & \vdots\\
p_{m \, 1} & p_{m \, 2} & \dots & p_{m \, n}\\
\end{bmatrix}\right) = \begin{bmatrix}
\log(p_{1 \, 1}) & \log(p_{1 \, 2}) & \dots & \log(p_{1 \, n})\\
\log(p_{2 \, 1}) & \log(p_{2 \, 2}) & \dots & \log(p_{2 \, n})\\
\vdots   & \vdots   & \ddots & \vdots\\
\log(p_{m \, 1}) & \log(p_{m \, 2}) & \dots & \log(p_{m \, n})\\
\end{bmatrix}
$$

So we know that $\mathbf{LP}_{i \, j} = \log(\mathbf{P}_{i \, j})$. This means

$$
\frac{\partial \mathbf{LP}_{i\,j}}{\partial \mathbf{P}_{i\,j}} = \frac{\partial}{\partial \mathbf{P}_{i\,j}} \left( \log(\mathbf{P}_{i \, j}) \right) = \frac{1}{\mathbf{P}_{i \, j}}
$$

$$
\frac{\partial \mathbf{LP}}{\partial \mathbf{P}} = \begin{bmatrix}
\frac{1}{p_{1 \, 1}} & \frac{1}{p_{1 \, 2}} & \dots & \frac{1}{p_{1 \, n}}\\
\frac{1}{p_{2 \, 1}} & \frac{1}{p_{2 \, 2}} & \dots & \frac{1}{p_{2 \, n}}\\
\vdots   & \vdots   & \ddots & \vdots\\
\frac{1}{p_{m \, 1}} & \frac{1}{p_{m \, 2}} & \dots & \frac{1}{p_{m \, n}}\\
\end{bmatrix}
$$

Then we can do element-wise multiplication between $\frac{\partial \mathscr{L}}{\partial \mathbf{LP}}$ and $\frac{\partial \mathbf{LP}}{\partial \mathbf{P}}$ to get $\frac{\partial \mathscr{L}}{\partial \mathbf{P}}$.

In [18]:
dprobs = dlogprobs * probs.pow(-1)
cmp('probs', dprobs, probs)

probs           | exact: True  | approximate: True  | maxdiff: 0.0


### Derivative of `loss` with respect to `counts_sum_inv`

Note that `probs` is defined as

```python
probs = counts * counts_sum_inv
```

But the sizes of `counts` and `counts_sum_inv` are different.

In [19]:
print(f'{counts.shape=}, {counts_sum_inv.shape=}')

counts.shape=torch.Size([32, 27]), counts_sum_inv.shape=torch.Size([32, 1])


This means PyTorch broadcasts the multiplication like so:

$$
\mathbf{C} \circ \mathbf{CSI} = \begin{bmatrix}
c_{1 \: 1} & c_{1 \: 2} & \dots & c_{1 \: 27}\\
c_{2 \: 1} & c_{2 \: 2} & \dots & c_{2 \: 27}\\
\vdots   & \vdots   & \ddots & \vdots\\
c_{32 \: 1} & c_{32 \: 2} & \dots & c_{32 \: 27}\\
\end{bmatrix} \circ \begin{bmatrix}
csi_{1}\\
csi_{2}\\
\vdots\\
csi_{32}\\
\end{bmatrix} = \begin{bmatrix}
c_{1 \: 1} & c_{1 \: 2} & \dots & c_{1 \: 27}\\
c_{2 \: 1} & c_{2 \: 2} & \dots & c_{2 \: 27}\\
\vdots   & \vdots   & \ddots & \vdots\\
c_{32 \: 1} & c_{32 \: 2} & \dots & c_{32 \: 27}\\
\end{bmatrix} \circ \begin{bmatrix}
csi_{1} & csi_{1} & \dots & csi_{1}\\
csi_{2} & csi_{2} & \dots & csi_{2}\\
\vdots  & \vdots & \ddots & \vdots\\
csi_{32} & csi_{32} & \dots & csi_{32}\\
\end{bmatrix} = \begin{bmatrix}
c_{1 \: 1} \: csi_{1} & c_{1 \: 2} \: csi_{1} & \dots & c_{1 \: 27} \: csi_{1}\\
c_{2 \: 1} \: csi_{2} & c_{2 \: 2} \: csi_{2} & \dots & c_{2 \: 27} \: csi_{2}\\
\vdots  & \vdots & \ddots & \vdots\\
c_{32 \: 1} \: csi_{32} & c_{32 \: 2} \: csi_{32} & \dots & c_{32 \: 27} \: csi_{32}\\
\end{bmatrix}
$$

First let's call this broadcasted tensor $\mathbf{CSI}'$ and calculate its partial derivative.

$$
\mathbf{CSI}' = \begin{bmatrix}
csi_{1} & csi_{1} & \dots & csi_{1}\\
csi_{2} & csi_{2} & \dots & csi_{2}\\
\vdots  & \vdots & \ddots & \vdots\\
csi_{32} & csi_{32} & \dots & csi_{32}\\
\end{bmatrix}
$$

Using the chain rule

$$
\frac{\partial \mathscr{L}}{\partial \mathbf{CSI}'} = \frac{\partial \mathscr{L}}{\partial \mathbf{P}} \frac{\partial \mathbf{P}}{\partial \mathbf{CSI}'}
$$

Since $\mathbf{P} = \mathbf{C} \circ \mathbf{CSI}'$, $\frac{\partial \mathbf{P}}{\partial \mathbf{CSI}'}$ can be expressed as

$$
\frac{\partial \mathbf{P}}{\partial \mathbf{CSI}'} = \frac{\partial}{\partial \mathbf{CSI}'} \left( \mathbf{C} \circ \mathbf{CSI}' \right) = \mathbf{C}
$$

Since the gradient of the vector $\mathbf{CSI}$ is used multiple times (each column in the broadcast), we must sum each use of the vector.

In [20]:
dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)

counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0


### Derivative of `loss` with respect to `counts_sum`

Note that `counts_sum_inv` is defined as

```python
counts_sum_inv = counts_sum**-1
```

Denoting `counts_sum_inv` as $\mathbf{CSI}$ and `counts_sum` as $\mathbf{CS}$

$$
\mathbf{CSI} = \mathbf{CS}^{-1}
$$

Using the chain rule

$$
\frac{\partial \mathscr{L}}{\partial \mathbf{CS}} = \frac{\partial \mathscr{L}}{\partial \mathbf{CSI}} \frac{\partial \mathbf{CSI}}{\partial \mathbf{CS}}
$$

All we need to compute is $\frac{\partial \mathbf{CSI}}{\partial \mathbf{CS}}$.

$$
\frac{\partial \mathbf{CSI}}{\partial \mathbf{CS}} = \frac{\partial}{\partial \mathbf{CS}} \left( \mathbf{CS}^{-1} \right) = - \mathbf{CS}^{-2}
$$

In [21]:
dcounts_sum = dcounts_sum_inv * - counts_sum.pow(-2)
cmp('counts_sum', dcounts_sum, counts_sum)

counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0


### Derivative of `loss` with respect to `counts`

Note that `probs` is defined as

```python
probs = counts * counts_sum_inv
```

and `counts_sum` is defined as

```python
counts_sum = counts.sum(1, keepdim=True)
```

Representing `probs` as $\mathbf{P}$, `counts` as $\mathbf{C}$, `counts_sum` as $\mathbf{CS}$, and `counts_sum_inv` as $\mathbf{CSI}$, we get two equations involving $\mathbf{C}$.

$$
\begin{align*}
    \mathbf{P} = \mathbf{C} * \mathbf{CSI}\\
    \mathbf{CS} = \sum_{i} \mathbf{C}_i
\end{align*}
$$

To calculate $\frac{\partial \mathscr{L}}{\partial \mathbf{C}}$, we must calculate the gradient for both of the above equations and add their contributions to the gradient of $\mathbf{C}$.

1. Remember, the sizes of `counts` and `counts_sum_inv` are different, causing `counts_sum_inv` to be broadcasted.

    Using the chain rule

    $$
    \frac{\partial \mathscr{L}}{\partial \mathbf{C}} = \frac{\partial \mathscr{L}}{\partial \mathbf{P}} \frac{\partial \mathbf{P}}{\partial \mathbf{C}}
    $$

    All we need to compute is $\frac{\partial \mathbf{P}}{\partial \mathbf{C}}$, denoting the broadcasted tensor of $\mathbf{CSI}$ as $\mathbf{CSI}'$.

    $$
    \frac{\partial \mathbf{P}}{\partial \mathbf{C}} = \frac{\partial}{\partial \mathbf{C}} \left( \mathbf{C} \circ \mathbf{CSI}' \right) = \mathbf{CSI}'
    $$

2. Defining $\mathbf{C}$ as

    $$
    \mathbf{C} = \begin{bmatrix}
    c_{1 \: 1} & c_{1 \: 2} & \dots & c_{1 \: 27}\\
    c_{2 \: 1} & c_{2 \: 2} & \dots & c_{2 \: 27}\\
    \vdots   & \vdots   & \ddots & \vdots\\
    c_{32 \: 1} & c_{32 \: 2} & \dots & c_{32 \: 27}\\
    \end{bmatrix}
    $$

    we know that $\mathbf{CS}$ is defined as

    $$
    \mathbf{CS} = \begin{bmatrix}
    c_{1 \: 1} + c_{1 \: 2} + \dots + c_{1 \: 27}\\
    c_{2 \: 1} + c_{2 \: 2} + \dots + c_{2 \: 27}\\
    \vdots\\
    c_{32 \: 1} + c_{32 \: 2} + \dots + c_{32 \: 27}\\
    \end{bmatrix}
    $$

    Using the chain rule

    $$
    \frac{\partial \mathscr{L}}{\partial \mathbf{C}} = \frac{\partial \mathscr{L}}{\partial \mathbf{CS}} \frac{\partial \mathbf{CS}}{\partial \mathbf{C}}
    $$

    We only need to calculate $\frac{\partial \mathbf{CS}}{\partial \mathbf{C}}$.

    Since each element $c_{i j}$ of $\mathbf{C}$ participates in a sum in some element of $\mathbf{CS}$, the partial derivative $\frac{\partial \mathbf{CS}_i}{\partial \mathbf{C}_{i j}}$ is

    $$
    \frac{\partial \mathbf{CS}_i}{\partial \mathbf{C}_{i j}} = \frac{\partial}{\partial \mathbf{C}_{i j}} \left( c_{i \, 1} + c_{i \, 2} + \dots + c_{i j} + \dots c_{i \, 27} \right) = 1
    $$

    So, the partial derivative $\frac{\partial \mathbf{CS}}{\partial \mathbf{C}}$ is a matrix the shape of $\mathbf{C}$ with all elements being $1$.

In [22]:
dcounts = (counts_sum_inv * dprobs) + (torch.ones_like(counts) * dcounts_sum)
cmp('counts', dcounts, counts)

counts          | exact: True  | approximate: True  | maxdiff: 0.0


### Derivative of `loss` with respect to `norm_logits`

Note that `counts` is defined as

```python
counts = norm_logits.exp()
```

Representing `counts` as $\mathbf{C}$ and `norm_logits` as $\mathbf{NL}$, we have:

$$
\mathbf{NL} = \begin{bmatrix}
nl_{1 \: 1} & nl_{1 \: 2} & \dots & nl_{1 \: 27}\\
nl_{2 \: 1} & nl_{2 \: 2} & \dots & nl_{2 \: 27}\\
\vdots   & \vdots   & \ddots & \vdots\\
nl_{32 \: 1} & nl_{32 \: 2} & \dots & nl_{32 \: 27}\\
\end{bmatrix}
$$

This way, $\mathbf{C}$ can be expressed as

$$
\mathbf{C} = \exp(\mathbf{NL}) = \begin{bmatrix}
e^{nl_{1 \: 1}} & e^{nl_{1 \: 2}} & \dots & e^{nl_{1 \: 27}}\\
e^{nl_{2 \: 1}} & e^{nl_{2 \: 2}} & \dots & e^{nl_{2 \: 27}}\\
\vdots   & \vdots   & \ddots & \vdots\\
e^{nl_{32 \: 1}} & e^{nl_{32 \: 2}} & \dots & e^{nl_{32 \: 27}}\\
\end{bmatrix}
$$

Using the chain rule

$$
\frac{\partial \mathscr{L}}{\partial \mathbf{NL}} = \frac{\partial \mathscr{L}}{\partial \mathbf{C}} \frac{\partial \mathbf{C}}{\partial \mathbf{NL}}
$$

For each partial derivative $\frac{\partial \mathbf{C}_{ij}}{\partial \mathbf{NL}_{ij}}$, we have

$$
\frac{\partial \mathbf{C}_{ij}}{\partial \mathbf{NL}_{ij}} = \frac{\partial}{\partial \mathbf{NL}_{ij}}\left( e^{nl_{ij}} \right) = e^{nl_{ij}}
$$

Therefore, $\frac{\partial \mathbf{C}}{\partial \mathbf{NL}} = \mathbf{C}$.

In [23]:
dnorm_logits = counts * dcounts
cmp('norm_logits', dnorm_logits, norm_logits)

norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0


### Derivative of `loss` with respect to `logits_maxes`

Note that `norm_logits` is defined as

```python
norm_logits = logits - logits_maxes
```

Since `logits` and `logits_maxes` are not the same shape, the operation will be broadcasted.

In [24]:
print(f'{logits.shape=}, {logit_maxes.shape=}')

logits.shape=torch.Size([32, 27]), logit_maxes.shape=torch.Size([32, 1])


This means the vector `logit_maxes` will be repeated $27$ times. So `norm_logits` is defined as

$$
\mathbf{NL} = \begin{bmatrix}
l_{1 \: 1} - lm_1 & l_{1 \: 2} - lm_1 & \dots & l_{1 \: 27} - lm_1\\
l_{2 \: 1} - lm_2 & l_{2 \: 2} - lm_2 & \dots & l_{2 \: 27} - lm_2\\
\vdots   & \vdots   & \ddots & \vdots\\
l_{32 \: 1} - lm_{32} & l_{32 \: 2} - lm_{32} & \dots & l_{32 \: 27} - lm_{32}\\
\end{bmatrix}
$$

The gradient $\frac{\partial \mathbf{NL}}{\partial \mathbf{LM}}$ is

$$
\frac{\partial \mathbf{NL}}{\partial \mathbf{LM}} = \begin{bmatrix}
-1\\
-1\\
\vdots\\
-1\\
\end{bmatrix}
$$

Again using the chain rule, $\frac{\partial \mathscr{L}}{\partial \mathbf{LM}} = \frac{\partial \mathscr{L}}{\partial \mathbf{NL}} \frac{\partial \mathbf{NL}}{\partial \mathbf{LM}}$. Again, due to broadcasting, $\frac{\partial \mathbf{NL}}{\partial \mathbf{LM}}$ is repeated. Since all its entries are $-1$, we can sum across $\frac{\partial \mathscr{L}}{\partial \mathbf{NL}}$ and negate it.

In [25]:
dlogit_maxes = - dnorm_logits.sum(1, keepdim=True)
cmp('logit_maxes', dlogit_maxes, logit_maxes)

logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0


Note that the logit normalization is performed to prevent floating point errors when values are too large. This means it should not have an effect on the loss, which means the gradient of `dlogit_maxes` should be approximately $0$.

In [26]:
dlogit_maxes

tensor([[ 9.3132e-10],
        [-5.5879e-09],
        [-1.8626e-09],
        [ 1.8626e-09],
        [-9.3132e-10],
        [-9.3132e-10],
        [ 2.7940e-09],
        [-0.0000e+00],
        [-9.3132e-10],
        [-3.7253e-09],
        [ 9.3132e-10],
        [ 9.3132e-10],
        [-0.0000e+00],
        [-9.3132e-10],
        [ 1.8626e-09],
        [-0.0000e+00],
        [ 3.7253e-09],
        [ 1.8626e-09],
        [-4.6566e-09],
        [-5.5879e-09],
        [-3.7253e-09],
        [ 9.3132e-10],
        [ 3.7253e-09],
        [ 2.7940e-09],
        [-9.3132e-10],
        [ 4.6566e-09],
        [-3.7253e-09],
        [ 9.3132e-10],
        [-3.7253e-09],
        [-0.0000e+00],
        [-0.0000e+00],
        [-9.3132e-10]], device='cuda:0', grad_fn=<NegBackward0>)

### Derivative of `loss` with respect to `logits`

The tensor `logits` is used twice, so its gradient is a sum of the two separate gradients.

```python
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes
```

1. The `logits` tensor $\mathbf{L}$ is defined as

    $$
    \mathbf{L} = \begin{bmatrix}
    l_{1 \: 1} & l_{1 \: 2} & \dots & l_{1 \: 27}\\
    l_{2 \: 1} & l_{2 \: 2} & \dots & l_{2 \: 27}\\
    \vdots   & \vdots   & \ddots & \vdots\\
    l_{32 \: 1} & l_{32 \: 2} & \dots & l_{32 \: 27}\\
    \end{bmatrix}
    $$

    The `logit_maxes` tensor is defined as

    $$
    \mathbf{LM} = \begin{bmatrix}
    lm_1\\
    lm_2\\
    \vdots\\
    lm_{32}\\
    \end{bmatrix} = \begin{bmatrix}
    \max(l_{1 \: 1}, l_{1 \: 2}, \dots, l_{1 \: 27})\\
    \max(l_{2 \: 1}, l_{2 \: 2}, \dots, l_{2 \: 27})\\
    \vdots\\
    \max(l_{32 \: 1}, l_{32 \: 2}, \dots, l_{32 \: 27})\\
    \end{bmatrix}
    $$

    The partial derivative $\frac{\partial \mathbf{LM}_{i}}{\partial \mathbf{L}_{i j}}$ is $1$ if $l_{i j}$ is the maximum of the row $l_{i \, 1}, l_{i \, 2}, \dots l_{i \, j}, \dots, l_{i \, 27}$, and is $0$ otherwise. From here we can use the chain rule $\frac{\partial \mathscr{L}}{\partial \mathbf{L}} = \frac{\partial \mathscr{L}}{\partial \mathbf{LM}} \frac{\partial \mathbf{LM}}{\partial \mathbf{L}}$

2. The `norm_logits` tensor $\mathbf{NL}$ is defined as

    $$
    \mathbf{NL} = \begin{bmatrix}
    l_{1 \: 1} - lm_1 & l_{1 \: 2} - lm_1 & \dots & l_{1 \: 27} - lm_1\\
    l_{2 \: 1} - lm_2 & l_{2 \: 2} - lm_2 & \dots & l_{2 \: 27} - lm_2\\
    \vdots   & \vdots   & \ddots & \vdots\\
    l_{32 \: 1} - lm_{32} & l_{32 \: 2} - lm_{32} & \dots & l_{32 \: 27} - lm_{32}\\
    \end{bmatrix}
    $$

    The partial derivative $\frac{\partial \mathbf{NL}_{i j}}{\partial \mathbf{L}_{i j}}$ is $1$ because all $l_{i j}$ participate in the $\mathbf{NL}$ matrix with coefficient $1$. This means

    $$
    \frac{\partial \mathbf{NL}}{\partial \mathbf{L}} = \begin{bmatrix}
    1 & 1 & \dots & 1\\
    1 & 1 & \dots & 1\\
    \vdots   & \vdots   & \ddots & \vdots\\
    1 & 1 & \dots & 1\\
    \end{bmatrix}
    $$

    From here we can use the chain rule $\frac{\partial \mathscr{L}}{\partial \mathbf{L}} = \frac{\partial \mathscr{L}}{\partial \mathbf{NL}} \frac{\partial \mathbf{NL}}{\partial \mathbf{L}}$.

In [27]:
dlogits = (F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes) + (dnorm_logits)
cmp('logits', dlogits, logits)

logits          | exact: True  | approximate: True  | maxdiff: 0.0


### Derivative of `loss` with respect to `h`, `W2`, and `b2`

The `logits` tensor is defined as

```
logits = h @ W2 + b2
```

Defining the $\mathbf{H}$, $\mathbf{W}_2$, and $\mathbf{b}_2$ tensors as

$$
\mathbf{H} = \begin{bmatrix}
h_{1 \: 1} & h_{1 \: 2} & \dots & h_{1 \: 64}\\
h_{2 \: 1} & h_{2 \: 2} & \dots & h_{2 \: 64}\\
\vdots   & \vdots   & \ddots & \vdots\\
h_{32 \: 1} & h_{32 \: 2} & \dots & h_{32 \: 64}\\
\end{bmatrix} \quad\quad \mathbf{W}_2 = \begin{bmatrix}
w_{1 \: 1} & w_{1 \: 2} & \dots & w_{1 \: 27}\\
w_{2 \: 1} & w_{2 \: 2} & \dots & w_{2 \: 27}\\
\vdots   & \vdots   & \ddots & \vdots\\
w_{64 \: 1} & w_{64 \: 2} & \dots & w_{64 \: 27}\\
\end{bmatrix} \quad\quad \mathbf{b}_2 = \begin{bmatrix}
b_{1} & b_{2} & \dots & b_{27}
\end{bmatrix}
$$

The product $\mathbf{H} \mathbf{W}_2$ is performed as

$$
\mathbf{H} \mathbf{W}_2 = \begin{bmatrix}
h_{1 \: 1} \, w_{1 \: 1} + h_{1 \: 2} \, w_{2 \: 1} + \dots + h_{1 \: 64} \, w_{64 \: 1} & h_{1 \: 1} \, w_{1 \: 2} + h_{1 \: 2} \, w_{2 \: 2} + \dots + h_{1 \: 64} \, w_{64 \: 2} & \dots & h_{1 \: 1} \, w_{1 \: 27} + h_{1 \: 2} \, w_{2 \: 27} + \dots + h_{1 \: 64} \, w_{64 \: 27}\\
h_{2 \: 1} \, w_{1 \: 1} + h_{2 \: 2} \, w_{2 \: 1} + \dots + h_{2 \: 64} \, w_{64 \: 1} & h_{2 \: 1} \, w_{1 \: 2} + h_{2 \: 2} \, w_{2 \: 2} + \dots + h_{2 \: 64} \, w_{64 \: 2} & \dots & h_{2 \: 1} \, w_{1 \: 27} + h_{2 \: 2} \, w_{2 \: 27} + \dots + h_{2 \: 64} \, w_{64 \: 27}\\
\vdots   & \vdots   & \ddots & \vdots\\
h_{32 \: 1} \, w_{1 \: 1} + h_{32 \: 2} \, w_{2 \: 1} + \dots + h_{32 \: 64} \, w_{64 \: 1} & h_{32 \: 1} \, w_{1 \: 2} + h_{32 \: 2} \, w_{2 \: 2} + \dots + h_{32 \: 64} \, w_{64 \: 2} & \dots & h_{32 \: 1} \, w_{1 \: 27} + h_{32 \: 2} \, w_{2 \: 27} + \dots + h_{32 \: 64} \, w_{64 \: 27}\\
\end{bmatrix}
$$

Then the full expression $\mathbf{L} = \mathbf{H} \mathbf{W}_2 + \mathbf{b}_2$, which causes $\mathbf{b}_2$ to be broadcasted, is

$$
\mathbf{L} = \mathbf{H} \mathbf{W}_2 + \mathbf{b}_2 = \begin{bmatrix}
h_{1 \: 1} \, w_{1 \: 1} + h_{1 \: 2} \, w_{2 \: 1} + \dots + h_{1 \: 64} \, w_{64 \: 1} + b_1 & h_{1 \: 1} \, w_{1 \: 2} + h_{1 \: 2} \, w_{2 \: 2} + \dots + h_{1 \: 64} \, w_{64 \: 2} + b_2 & \dots & h_{1 \: 1} \, w_{1 \: 27} + h_{1 \: 2} \, w_{2 \: 27} + \dots + h_{1 \: 64} \, w_{64 \: 27} + b_{27}\\
h_{2 \: 1} \, w_{1 \: 1} + h_{2 \: 2} \, w_{2 \: 1} + \dots + h_{2 \: 64} \, w_{64 \: 1} + b_1 & h_{2 \: 1} \, w_{1 \: 2} + h_{2 \: 2} \, w_{2 \: 2} + \dots + h_{2 \: 64} \, w_{64 \: 2} + b_2 & \dots & h_{2 \: 1} \, w_{1 \: 27} + h_{2 \: 2} \, w_{2 \: 27} + \dots + h_{2 \: 64} \, w_{64 \: 27} + b_{27}\\
\vdots   & \vdots   & \ddots & \vdots\\
h_{32 \: 1} \, w_{1 \: 1} + h_{32 \: 2} \, w_{2 \: 1} + \dots + h_{32 \: 64} \, w_{64 \: 1} + b_{1} & h_{32 \: 1} \, w_{1 \: 2} + h_{32 \: 2} \, w_{2 \: 2} + \dots + h_{32 \: 64} \, w_{64 \: 2} + b_{2} & \dots & h_{32 \: 1} \, w_{1 \: 27} + h_{32 \: 2} \, w_{2 \: 27} + \dots + h_{32 \: 64} \, w_{64 \: 27} + b_{27}\\
\end{bmatrix}
$$

For $\frac{\partial \mathbf{L}}{\partial \mathbf{H}}$, the partial derivative for entry $i,j$, $\frac{\partial \mathbf{L}_{i j}}{\partial \mathbf{H}_{i j}}$, is

$$
\frac{\partial \mathbf{L}_{i j}}{\partial \mathbf{H}_{i j}} = \frac{\partial}{\partial h_{i j}} \left( h_{i \: 1} \, w_{1 \: i} + h_{i \: 2} \, w_{2 \: i} + \dots + h_{i \: j} \, w_{j \: i} + \dots + h_{i \: 64} \, w_{64 \: i} + b_j \right) = w_{j \: i}
$$

which means the partial derivative $\frac{\partial \mathbf{L}}{\partial \mathbf{H}}$ is

$$
\frac{\partial \mathbf{L}}{\partial \mathbf{H}} = {\mathbf{W}_2}^\mathrm{T}
$$

By similar logic, the partial derivative $\frac{\partial \mathbf{L}}{\partial {\mathbf{W}_2}}$, is

$$
\frac{\partial \mathbf{L}}{\partial {\mathbf{W}_2}}
$$

Lastly, for $\mathbf{b}_2$, the partial derivative $\frac{\partial \mathbf{L}}{\partial \mathbf{b}_2}$ for entry $i,j$, $\frac{\partial \mathbf{L}_{i j}}{\partial {\mathbf{b}_2}_{i j}}$, is

$$
\frac{\partial \mathbf{L}_{i j}}{\partial {\mathbf{b}_2}_{i j}} = \frac{\partial}{\partial {b}_{j}} \left( h_{i \: 1} \, w_{1 \: i} + h_{i \: 2} \, w_{2 \: i} + \dots + h_{i \: j} \, w_{j \: i} + \dots + h_{i \: 64} \, w_{64 \: i} + b_j \right) = 1
$$

Since $\mathbf{b}_2$ is broadcasted, there are $64$ contributions to the gradient $\frac{\partial \mathbf{L}}{\partial {\mathbf{b}_2}}$ (one for each row).

From here, we can use the chain rule to calculate the derivatives of $\mathbf{L}$ with respect to $\mathbf{H}$, $\mathbf{W}_2$, and $\mathbf{b}_2$.

$$
\begin{align*}
    \frac{\partial \mathscr{L}}{\partial \mathbf{H}} &= \frac{\partial \mathscr{L}}{\partial \mathbf{L}} \frac{\partial \mathbf{L}}{\partial \mathbf{H}}\\[10pt]
    \frac{\partial \mathscr{L}}{\partial \mathbf{W}_2} &= \frac{\partial \mathscr{L}}{\partial \mathbf{L}} \frac{\partial \mathbf{L}}{\partial \mathbf{W}_2}\\[10pt]
    \frac{\partial \mathscr{L}}{\partial \mathbf{b}_2} &= \frac{\partial \mathscr{L}}{\partial \mathbf{L}} \frac{\partial \mathbf{L}}{\partial \mathbf{b}_2}
\end{align*}
$$

In [28]:
print('Knowing the shapes of the tensors involved in the matrix multiplication can help you figure out the correct way to multiply the gradients.')
print(f'{logits.shape=}, {W2.shape=}, {h.shape=}, {b2.shape=}')

Knowing the shapes of the tensors involved in the matrix multiplication can help you figure out the correct way to multiply the gradients.
logits.shape=torch.Size([32, 27]), W2.shape=torch.Size([64, 27]), h.shape=torch.Size([32, 64]), b2.shape=torch.Size([27])


In [29]:
dh = dlogits @ W2.T
dW2 = h.T @ dlogits
db2 = dlogits.sum(0, keepdim=True)
cmp('h', dh, h)
cmp('W2', dW2, W2)
cmp('b2', db2, b2)

h               | exact: True  | approximate: True  | maxdiff: 0.0
W2              | exact: True  | approximate: True  | maxdiff: 0.0
b2              | exact: True  | approximate: True  | maxdiff: 0.0


### Derivative of `loss` with respect to `hpreact`

The tensor `h` is defined as

```python
h = torch.tanh(hpreact)
```

Note the derivative of $\mathrm{tanh}$

$$
y = \mathrm{tanh}(x) \Rightarrow \frac{dy}{dx} = 1 - y^2
$$

Since $\mathbf{H} = \mathrm{tanh}(\mathbf{H}_{pre \, activation})$, we know that

$$
\frac{\partial \mathbf{H}}{\partial \mathbf{H}_{pre \,activation}} = 1 - \mathbf{H}^2
$$

From here we can use the chain rule

$$
\frac{\partial \mathscr{L}}{\partial \mathbf{H}_{pre \, activation}} = \frac{\partial \mathscr{L}}{\partial \mathbf{H}} \frac{\partial \mathbf{H}}{\partial \mathbf{H}_{pre \, activation}}
$$

In [30]:
dhpreact = dh * (1.0 - h.pow(2))
cmp('hpreact', dhpreact, hpreact)

hpreact         | exact: False | approximate: True  | maxdiff: 4.656612873077393e-10


### Derivative of `loss` with respect to `bngain`, `bnraw`, and `bnbias`

The tensor `hpreact` is defined as

```python
hpreact = bngain * bnraw + bnbias
```

Note the shapes of each tensor.

In [38]:
print(f'{hpreact.shape=}, {bngain.shape=}, {bnraw.shape=}, {bnbias.shape=}')

hpreact.shape=torch.Size([32, 64]), bngain.shape=torch.Size([1, 64]), bnraw.shape=torch.Size([32, 64]), bnbias.shape=torch.Size([1, 64])


Defining the ${\mathbf{b}_{n}}_{gain}$, ${\mathbf{B}_{n}}_{raw}$, and ${\mathbf{b}_{n}}_{bias}$ tensors as

$$
{\mathbf{b}_{n}}_{gain} = \begin{bmatrix}
{{b_n}_{gain}}_{1} & {{b_n}_{gain}}_{2} & \dots & {{b_n}_{gain}}_{64}
\end{bmatrix} \quad\quad {\mathbf{B}_{n}}_{raw} = \begin{bmatrix}
{{b_n}_{raw}}_{1 \: 1} & {{b_n}_{raw}}_{1 \: 2} & \dots & {{b_n}_{raw}}_{1 \: 64}\\
{{b_n}_{raw}}_{2 \: 1} & {{b_n}_{raw}}_{2 \: 2} & \dots & {{b_n}_{raw}}_{2 \: 64}\\
\vdots   & \vdots   & \ddots & \vdots\\
{{b_n}_{raw}}_{32 \: 1} & {{b_n}_{raw}}_{32 \: 2} & \dots & {{b_n}_{raw}}_{32 \: 64}\\
\end{bmatrix} \quad\quad {\mathbf{b}_{n}}_{bias} = \begin{bmatrix}
{{b_n}_{bias}}_{1} & {{b_n}_{bias}}_{2} & \dots & {{b_n}_{bias}}_{64}
\end{bmatrix}
$$

The full operation $\mathbf{H}_{pre \, activation} = {\mathbf{b}_{n}}_{gain} \circ {\mathbf{B}_{n}}_{raw} + {\mathbf{b}_{n}}_{bias}$, including broadcasting, is

$$
\mathbf{H}_{pre \, activation} = {\mathbf{b}_{n}}_{gain} \circ {\mathbf{B}_{n}}_{raw} + {\mathbf{b}_{n}}_{bias} = \begin{bmatrix}
{{b_n}_{gain}}_{1} \: {{b_n}_{raw}}_{1 \: 1} + {{b_n}_{bias}}_{1} & {{b_n}_{gain}}_{2} \: {{b_n}_{raw}}_{1 \: 2} + {{b_n}_{bias}}_{2} & \dots & {{b_n}_{gain}}_{64} \: {{b_n}_{raw}}_{1 \: 64} + {{b_n}_{bias}}_{64}\\
{{b_n}_{gain}}_{1} \: {{b_n}_{raw}}_{2 \: 1} + {{b_n}_{bias}}_{1} & {{b_n}_{gain}}_{2} \: {{b_n}_{raw}}_{2 \: 2} + {{b_n}_{bias}}_{2} & \dots & {{b_n}_{gain}}_{64} \: {{b_n}_{raw}}_{2 \: 64} + {{b_n}_{bias}}_{64}\\
\vdots   & \vdots   & \ddots & \vdots\\
{{b_n}_{gain}}_{1} \: {{b_n}_{raw}}_{32 \: 1} + {{b_n}_{bias}}_{1} & {{b_n}_{gain}}_{2} \: {{b_n}_{raw}}_{32 \: 2} + {{b_n}_{bias}}_{2} & \dots & {{b_n}_{gain}}_{64} \: {{b_n}_{raw}}_{32 \: 64} + {{b_n}_{bias}}_{64}\\
\end{bmatrix}
$$

Calculating the derivative $\frac{\partial {\mathbf{H}_{pre \, activation}}}{\partial {{\mathbf{b}_{n}}_{gain}}}$

$$
\begin{align*}
\frac{\partial {\mathbf{H}_{pre \, activation}}_{i j}}{\partial {{\mathbf{b}_{n}}_{gain}}_{i j}} = \frac{\partial}{\partial {{{b}_{n}}_{gain}}_{j}} \left( {{b_n}_{gain}}_{j} \: {{b_n}_{raw}}_{i \: j} + {{b_n}_{bias}}_{j} \right) = {{b_n}_{raw}}_{i \: j}
\end{align*}
$$

$$
\begin{align*}
\frac{\partial {\mathbf{H}_{pre \, activation}}}{\partial {{\mathbf{b}_{n}}_{gain}}} = {\mathbf{B}_{n}}_{raw}
\end{align*}
$$

Calculating the derivative $\frac{\partial {\mathbf{H}_{pre \, activation}}}{\partial {{\mathbf{B}_{n}}_{raw}}}$

$$
\begin{align*}
\frac{\partial {\mathbf{H}_{pre \, activation}}_{i j}}{\partial {{\mathbf{B}_{n}}_{raw}}_{i j}} = \frac{\partial}{\partial {{{b}_{n}}_{gain}}_{j}} \left( {{b_n}_{gain}}_{j} \: {{b_n}_{raw}}_{i \: j} + {{b_n}_{bias}}_{j} \right) = {{b_n}_{gain}}_{j}
\end{align*}
$$

$$
\begin{align*}
\frac{\partial {\mathbf{H}_{pre \, activation}}}{\partial {{\mathbf{B}_{n}}_{raw}}} = {\mathbf{b}_{n}}_{gain}
\end{align*}
$$

Calculating the derivative $\frac{\partial {\mathbf{H}_{pre \, activation}}}{\partial {{\mathbf{b}_{n}}_{bias}}}$

$$
\begin{align*}
\frac{\partial {\mathbf{H}_{pre \, activation}}_{i j}}{\partial {{\mathbf{b}_{n}}_{bias}}_{i j}} = \frac{\partial}{\partial {{{b}_{n}}_{gain}}_{j}} \left( {{b_n}_{gain}}_{j} \: {{b_n}_{raw}}_{i \: j} + {{b_n}_{bias}}_{j} \right) = 1
\end{align*}
$$

$$
\begin{align*}
\frac{\partial {\mathbf{H}_{pre \, activation}}}{\partial {{\mathbf{b}_{n}}_{bias}}} = 1
\end{align*}
$$

From here we can use the chain rule

$$
\begin{align*}
\frac{\partial \mathscr{L}}{\partial {\mathbf{b}_{n}}_{gain}} &= \frac{\partial \mathscr{L}}{\partial \mathbf{H}_{pre \, activation}} \frac{\partial \mathbf{H}_{pre \, activation}}{\partial {\mathbf{b}_{n}}_{gain}}\\
\frac{\partial \mathscr{L}}{\partial {\mathbf{B}_{n}}_{raw}} &= \frac{\partial \mathscr{L}}{\partial \mathbf{H}_{pre \, activation}} \frac{\partial \mathbf{H}_{pre \, activation}}{\partial {\mathbf{B}_{n}}_{raw}}\\
\frac{\partial \mathscr{L}}{\partial {\mathbf{b}_{n}}_{bias}} &= \frac{\partial \mathscr{L}}{\partial \mathbf{H}_{pre \, activation}} \frac{\partial \mathbf{H}_{pre \, activation}}{\partial {\mathbf{b}_{n}}_{bias}}
\end{align*}
$$

In [42]:
dbngain = (dhpreact * bnraw).sum(0, keepdim=True)
dbnraw = dhpreact * bngain
dbnbias = dhpreact.sum(0, keepdim=True)
cmp('bngain', dbngain, bngain)
cmp('bnraw', dbnraw, bnraw)
cmp('bnbias', dbnbias, bnbias)

bngain          | exact: False | approximate: True  | maxdiff: 1.862645149230957e-09
bnraw           | exact: False | approximate: True  | maxdiff: 9.313225746154785e-10
bnbias          | exact: False | approximate: True  | maxdiff: 3.725290298461914e-09


### Derivative of `loss` with respect to `dbnvar_inv`

The tensor `bnraw` is defined as

```python
bnraw = bndiff * bnvar_inv
```

Note the shapes are different

In [46]:
print(f'{bnraw.shape=}, {bndiff.shape=}, {bnvar_inv.shape=}')

bnraw.shape=torch.Size([32, 64]), bndiff.shape=torch.Size([32, 64]), bnvar_inv.shape=torch.Size([1, 64])


So `bnvar_inv` will be broadcasted.

We know how to do back-propagation for element-wise multiplication: we just multiply ${\mathbf{B}_n}_{diff}$ by $\frac{\partial \mathscr{L}}{\partial {\mathbf{B}_n}_{raw}}$ and sum across the rows (because of broadcasting).

In [47]:
dbnvar_inv = (dbnraw * bndiff).sum(0, keepdim=True)
cmp('bnvar_inv', dbnvar_inv, bnvar_inv)

bnvar_inv       | exact: False | approximate: True  | maxdiff: 1.862645149230957e-09


### Derivative of `loss` with respect to `bnvar`

The tensor `bnraw` is defined as

```python
bnvar_inv = (bnvar + 1e-5)**-0.5
```

Taking the derivative 

$$
\frac{\partial {{\mathbf{B}_n}_{var}}_{inv}}{\partial {\mathbf{B}_n}_{var}} = \frac{\partial}{\partial {\mathbf{B}_n}_{var}} \left( \left( {\mathbf{B}_n}_{var} + 1 \times 10^{-5} \right)^{-\frac{1}{2}} \right) = - \frac{1}{2} \left( {\mathbf{B}_n}_{var} + 1 \times 10^{-5} \right)^{-\frac{3}{2}}
$$

From here we can use the chain rule

$$
\frac{\partial \mathscr{L}}{\partial {\mathbf{B}_n}_{var}} = \frac{\partial \mathscr{L}}{\partial {{\mathbf{B}_n}_{var}}_{inv}} \frac{\partial {{\mathbf{B}_n}_{var}}_{inv}}{\partial {\mathbf{B}_n}_{var}}
$$

In [51]:
dbnvar = dbnvar_inv * (-(1.0/2.0) * (bnvar + 1e-5)**(-3.0/2.0))
cmp('bnvar', dbnvar, bnvar)

bnvar           | exact: False | approximate: True  | maxdiff: 9.313225746154785e-10


In [49]:
print(f'{bnvar.shape=}, {bnvar_inv.shape=}')

bnvar.shape=torch.Size([1, 64]), bnvar_inv.shape=torch.Size([1, 64])


In [31]:
# emb = C[Xb]
# embcat = emb.view(emb.shape[0], -1)

# hprebn = embcat @ W1 + b1

# bnmeani = 1.0 / batch_size * hprebn.sum(0, keepdim=True)
# bndiff = hprebn - bnmeani
# bndiff2 = bndiff**2
# bnvar = 1 / (batch_size-1) * bndiff2.sum(0, keepdim=True)
# bnvar_inv = (bnvar + 1e-5)**-0.5
# bnraw = bndiff * bnvar_inv

# hpreact = bngain + bnraw + bnbias

In [32]:
W2.shape

torch.Size([64, 27])

In [33]:
logits.shape

torch.Size([32, 27])

In [34]:
dlogit_maxes.shape

torch.Size([32, 1])

In [35]:
dcounts_sum.shape

torch.Size([32, 1])