Trying to understand scaling laws by [chinchilla]() using [nanogpt (scaling_laws.ipynb)]()

In [238]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch import functional as F
import pandas as pd
import scipy
%matplotlib inline

In [3]:
def gpt_params(seq_len, vocab_size, d_model, num_heads, num_layers):
    """ Given a GPT calculate total number of parameters """
    ffw_size = 4*d_model # Represent the intermediate layer size in MLP. in GPT the number of intermediate features is always 4*d_model.
    # token and position embeddings
    embeddings = d_model * vocab_size + d_model * seq_len
    # transformer blocks
    attention = 3*d_model**2 + 3*d_model # weights and biases
    attproj = d_model**2 + d_model
    ffw = d_model*(ffw_size) + ffw_size
    ffwproj = ffw_size*d_model + d_model
    layernorms = 2*2*d_model
    # dense
    ln_f = 2*d_model
    dense = d_model*vocab_size # note: no bias here
    # note: embeddings are not included in the param count!
    total_params = num_layers*(attention + attproj + ffw + ffwproj + layernorms) + ln_f + dense
    return total_params

# Explanation

A model's parameter size can simply be derived by running `.parameters()` to recieve the parameter count. However, intuitively let's understand, how the GPT parameters is calculated.

`embedding = d_model * vocab_size + d_model * seq_len`
embedding represents the total number of parameters used in the embedding layer. The embedding layer is made of two parts _token embedding_ and the _positional embedding_. The size of the weights used in the token embedding is `(number of input channels (d_model), vocab_size)`

since 
```
embedding layer    =  Train Data     @ Transpose(Weight)
(B, T, vocab_size) = (B, vocab_size) @ (T, vocab_size)  
```

As you can see, after the training data is hot encoded it has a size of `B x vocab_size`. So based on this the Weight dimensions is `Size(vocab_size, T)`. where `T` represents the number of channels (aka. embedding output dimension). So the total number of parameters in the weight is `vocab_size x T`.

Similarly in the positional embedding, the weight dimension is dictated by the training data.

```
embedding layer    =  Train Data  @ Transpose(Weight)
(B, T, vocab_size) = (B, seq_len) @ (T, seq_len)  
```

Positional embedding is simply done by numbering the input within a fixed context window. This context window size is the `seq_len`(or sequence length).
Based on this, the weight that embeds this data is of the `Size(seq_len, T)`. So making the total embedding of the positional embedding `seq_len x T`.

The total number of parameters used in the embedding later is the sum of these two layers since the values are summed before passing through the next layer.

```
attention = 3*d_model**2 + 3*d_model
```

You've come across __Quadratic Scaling__ or __Squared attention__. Its the thought that when double the length of the input sequence, the computational cost associated with the attention mechanism increases by a factor of four because computational cost of attention in a transformer model scales quadratically with the sequence length.

How does attention work?
```python
k = []
for i in keys:
    ki = []
    for j in values:
        for k in j:
            ki[k] += j


def attention(query, keys, values):
    # Q, K, V  ->  Size(T, 3*C)
    # Initialize attention weights
    attention_weights = [] # (T, T)
    
    # Calculate attention scores
    for key in keys:
        score = dot_product(query, key) # (T, T)
        attention_weights.append(score)
    
    # Normalize attention weights using softmax
    attention_weights = softmax(attention_weights) # (T, T)
    
    # Calculate weighted sum of values
    context_vector = [0] * len(values[0]) #(C,)
    for i in range(len(values)):
        for j in range(len(values[0])):
            context_vector[j] += attention_weights[i] * values[i][j]
    
    return context_vector
```


In [269]:
B, T, C = 1, 4, 2

In [273]:
# first implementation
np.random.seed(10)
x = np.linspace(1, 8, 8).reshape(B, T, C) 
# x /= B*T*C
sol_1 = np.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        x_prev = x[b, :t+1] #(t, C)
        sol_1[b, t] = np.mean(x_prev, 0)
        

x, sol_1

(array([[[1., 2.],
         [3., 4.],
         [5., 6.],
         [7., 8.]]]),
 array([[[1., 2.],
         [2., 3.],
         [3., 4.],
         [4., 5.]]]))

In [277]:
# second implementation

wei = torch.tril(torch.ones((T, T))) # ohh weight do not include the batch dimensions (B)
print(wei,)
wei = wei /wei.sum(1, keepdims=True)
sol_2 = wei @ x

# assert (sol_1 == sol_2).all()
sol_2

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])


tensor([[[1.0000, 2.0000],
         [2.0000, 3.0000],
         [3.0000, 4.0000],
         [4.0000, 5.0000]]], dtype=torch.float64)

In [120]:
# third implementation
softmax = lambda v: np.exp(v) / np.sum(np.exp(v), )

c = np.tril(np.ones((B, T, C)))
print(c)
masked_weight = np.ma.masked_array(c, mask=(c==0))
filled_weight = np.ma.filled(masked_weight, fill_value=-np.inf)
wei = softmax(filled_weight) # NOTE: WHY 10???
print(wei)
sol_3 = wei @ x
# assert (sol_1 == sol_2).all()
sol_3


[[[1. 0. 0. 0.]
  [1. 1. 0. 0.]
  [1. 1. 1. 0.]
  [1. 1. 1. 1.]]]
[[[0.1 0.  0.  0. ]
  [0.1 0.1 0.  0. ]
  [0.1 0.1 0.1 0. ]
  [0.1 0.1 0.1 0.1]]]


array([[[0.00625, 0.0125 , 0.01875, 0.025  ],
        [0.0375 , 0.05   , 0.0625 , 0.075  ],
        [0.09375, 0.1125 , 0.13125, 0.15   ],
        [0.175  , 0.2    , 0.225  , 0.25   ]]])

In [None]:
# third implementation
# softmax = lambda x: np.exp(x) / np.sum(np.exp(x))
k = x #(T, C)
q = x
v = x
d_n = 16
attn = q @ np.transpose(k, (0, 2, 1))  # (T, T)
# attn /= 4
attn_e = np.exp(attn)
attn_sum = np.sum(attn_e)
softmax = attn_e / attn_sum # using this instead of the above because there is no backward pass
print(softmax.shape)
attn /= attn_sum
print(x.shape, attn.shape)
sol_3 = attn @ v # (T, C)
print(sol_3)



In [None]:
x = np.linspace(1, 8, 8).reshape(4, 2, 1)
x.shape, x.T.shape, np.transpose(x, (0, 2, 1)).shape