
This doc is trying to find out why 768 is used and important in nlp transformers

In [1]:
!pip list


Package                  Version
------------------------ --------------------
anyio                    3.7.0
argon2-cffi              21.3.0
argon2-cffi-bindings     21.2.0
arrow                    1.2.3
asttokens                2.2.1
async-lru                2.0.2
attrs                    23.1.0
Babel                    2.12.1
backcall                 0.2.0
beautifulsoup4           4.12.2
bleach                   6.0.0
certifi                  2019.11.28
cffi                     1.15.1
chardet                  3.0.4
charset-normalizer       3.1.0
comm                     0.1.3
dbus-python              1.2.16
debugpy                  1.6.7
decorator                5.1.1
defusedxml               0.7.1
exceptiongroup           1.1.1
executing                1.2.0
fastjsonschema           2.17.1
fqdn                     1.5.1
idna                     2.8
ipykernel                6.23.3
ipython                  8.14.0
isoduration              20.11.0
jedi                     0.18.2
Jinja2

In [2]:

!ls


NGC-DL-CONTAINER-LICENSE  home	       media	      run	usr
bin			  jupyter.log  mnt	      sbin	var
boot			  lib	       opt	      srv	workspace
dev			  lib32        post_start.sh  start.sh
etc			  lib64        proc	      sys
get-pip.py		  libx32       root	      tmp


In [8]:

def is_prime(n):
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True


def sqrt(n: int):
    return int(n**0.5) + 1


def factorise(n: int):
    factors = []
    i = 2
    while n > 1:
        for i in range(2, sqrt(n)):
            if n % i == 0:
                factors.append(i)
                n = n // i
                print(n, i)
                break

        if is_prime(n):
            factors.append(n)
            break

    f = set(factors)

    return factors , { i: factors.count(i) for i in f }



f, exponents = factorise(768)
print( ' * '.join(f'{i}^{exp}' for i, exp in exponents.items()))


384 2
192 2
96 2
48 2
24 2
12 2
6 2
3 2
2^8 * 3^1



## BERT_base

### Layers
Sequential([
12 Encoders blocks
Linear layer only for the first token [CLS] (1,hidden)
Linear (hidden, vocab_size)
Softmax layer
])
Encoder block:
Sequential([
    MultiHeadAttention( 12 heads, 768 dim)
    Add & Norm
    ProjectionLayer( 768 , 3072 )
    Activation
    ProjectionLayer( 3072 , 768 )
    Add & Norm
])
## Explaination

The token are prepended with a special token [CLS] and appended with a special token [SEP]
The token are then passed through the embedding layer and the positional encoding layer
The output of the positional encoding layer is passed through the 12 encoder blocks
So the batch shape is (batch_size, seq_len, 768)
The output of the last encoder block is passed through a linear layer to get the logits
Note that 768 is the hidden size of the model and the size of the embedding token
## Training
The output of the last encoder block is clipped to the first token [CLS]
Vocab size is the number of classes
The encoder block are trained in pretraining, the linear layer is trained in fine-tuning by prepending the sequences with a  [CLS] and training the model to predict the class of the sequence


In [15]:

# Path: bert.ipynb
import torch
import torch.nn as nn

class BERTBlock(nn.Module):

    hidden_layer = 768
    def __init__(self, token_size):
        super().__init__()
        self.attn = nn.MultiheadAttention( token_size, num_heads=8, batch_first=True)
        self.qkv = nn.Linear(token_size, self.hidden_layer * 3)
        # self.Q = nn.Linear(input_size, self.hidden_size)
        # self.K = nn.Linear(input_size, self.hidden_size)
        # self.V = nn.Linear(input_size, self.hidden_size)

        self.proj_h = nn.Linear(self.hidden_layer, self.hidden_layer*4)
        self.proj_out = nn.Linear(self.hidden_layer*4, token_size)
        self.act = nn.Sigmoid()
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        # X.shape = (batch_size, seq_len, input_size)
        # q = self.Q(x) # (batch_size, seq_len, hidden_size)
        # k = self.K(x) # (batch_size, seq_len, hidden_size)
        # v = self.V(x) # (batch_size, seq_len, hidden_size)
        qkv = self.qkv(x)
        q, k, v = torch.chunk( qkv, 3, dim=-1)
        x ,_= self.attn(q, k, v)
        x = self.proj_h(x)
        x =self.proj_out(x)
        x = self.act(x)
        x = self.dropout(x)

        return x



In [None]:

class Bert(nn.Module):
    hid = 768

    def __init__(self, input_size, output_size):
        super().__init__()
        self.bert_blocks = nn.ModuleList(
            [BERTBlock(input_size) for _ in range(12)]
        )
        self.final = nn.Linear(1, self.hid)
        self.out = nn.Linear(self.hid, output_size)
        self.act = nn.Sigmoid()

        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        for block in self.bert_blocks:
            x = block(x)
        # x.shape = (batch_size, seq_len, hidden_size)

        x = self.final(x)

        return x


In [14]:

rand_qkv = torch.rand(1, 10, 768 * 3)
print(rand_qkv.shape)
q, k, v = torch.chunk(rand_qkv, 3, dim=-1)
q1, k1, v1 = torch.split(rand_qkv, 768, dim=-1)
print(q.shape, k.shape, v.shape)
print(q1.shape, k1.shape, v1.shape)

l2_norm = lambda x: torch.sqrt(torch.sum(x**2, dim=(-1, -2)))
print(l2_norm(q - q1), l2_norm(k - k1), l2_norm(v - v1))


torch.Size([1, 10, 2304])
torch.Size([1, 10, 768]) torch.Size([1, 10, 768]) torch.Size([1, 10, 768])
torch.Size([1, 10, 768]) torch.Size([1, 10, 768]) torch.Size([1, 10, 768])
tensor([0.]) tensor([0.]) tensor([0.])
