<a href="https://colab.research.google.com/github/DanielWarfield1/MLWritingAndResearch/blob/main/BERTFromScratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Reference
- BERT reference: https://colab.research.google.com/drive/13FjI_uXaw8JJGjzjVX3qKSLyW9p3b6OV?usp=sharing#scrollTo=AhX8b1ydtrVf from https://neptune.ai/blog/how-to-code-bert-using-pytorch-tutorial
    - which is inspired by https://github.com/graykode
- Tokenization reference: https://tinkerd.net/blog/machine-learning/bert-tokenization/

In [None]:
!pip install datasets
!pip install nltk

Collecting datasets
  Downloading datasets-2.20.0-py3-none-any.whl.metadata (19 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.5.0,>=2023.1.0 (from fsspec[http]<=2024.5.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.5.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-2.20.0-py3-none-any.whl (547 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m25.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

# Setting Up pre-training dataset
Using wikipedia data, divided into paragrpahs then divided into sentences. What we end up with is a list of paragraphs which are, themselves, a list of sentences. Each paragraph consists of sequential charecters, but contiguous paragraphs may not be related.

In [None]:
from datasets import load_dataset
#the dataset is big, to make things easier we're going to be streaming a subset
dataset = load_dataset("wikipedia", "20220301.en", trust_remote_code=True, streaming=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading builder script:   0%|          | 0.00/36.7k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/16.0k [00:00<?, ?B/s]

In [None]:
dataset_iter = iter(dataset['train'])

In [None]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [None]:
import itertools

num_articles = 10000
#getting n articles
articles = list(itertools.islice(dataset_iter, num_articles))

#getting paragraphs
paragraphs = []
for article in articles:
    paragraphs.extend(article['text'].splitlines())

#filtering paragraphs so they're hopefully actually paragraphs
paragraps = [p for p in paragraphs if len(p)>50]

#dividing paragraphs into sentences
divided_paragraphs = []
for p in paragraphs:
    divided_paragraphs.append(nltk.sent_tokenize(p))

#only using paragraphs with 3 or more sentences
divided_paragraphs = [pls for pls in divided_paragraphs if len(pls)>=3]
divided_paragraphs

[['Anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy.',
  'Anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful.',
  'As a historically left-wing movement, placed on the farthest left of the political spectrum, it is usually described alongside communalism and libertarian Marxism as the libertarian wing (libertarian socialism) of the socialist movement, and has a strong historical association with anti-capitalism and socialism.'],
 ['Humans lived in societies without formal hierarchies long before the establishment of formal states, realms, or empires.',
  'With the rise of organised hierarchical bodies, scepticism toward authority also rose.',
  'Although traces of anarchist thought are found throughout history, modern anarchism emerged from the Enlightenment.',
  "During the latter half of the 19th and the first decades of the 20th century, the 

In [None]:
len(divided_paragraphs)

249990

# Constructing positive pairs and negative pairs
This isn't the most effecient way to do this, but for the sake of simplicity I'll just be keeping two lists; one of positive pairs and one of negative pairs. This will mean I'm storing duplicates of data inefficiently, but whatever.

In [None]:
import random

positive_pairs = []
negative_pairs = []

num_paragraphs = len(divided_paragraphs)

for i, paragraph in enumerate(divided_paragraphs):
    for j in range(len(paragraph)-1):
        positive_pairs.append((paragraph[j], paragraph[j+1]))
        rand_par = i
        while rand_par == i:
            rand_par = random.randint(0, num_paragraphs-1)
        rand_sent = random.randint(0, len(divided_paragraphs[rand_par])-1)
        negative_pairs.append((paragraph[j], divided_paragraphs[rand_par][rand_sent]))

In [None]:
print(len(positive_pairs))
positive_pairs[:10]

936768


[('Anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy.',
  'Anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful.'),
 ('Anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful.',
  'As a historically left-wing movement, placed on the farthest left of the political spectrum, it is usually described alongside communalism and libertarian Marxism as the libertarian wing (libertarian socialism) of the socialist movement, and has a strong historical association with anti-capitalism and socialism.'),
 ('Humans lived in societies without formal hierarchies long before the establishment of formal states, realms, or empires.',
  'With the rise of organised hierarchical bodies, scepticism toward authority also rose.'),
 ('With the rise of organised hierarchical bodies, scepticism toward authority also rose.',

In [None]:
negative_pairs[:10]

[('Anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy.',
  'Wycliffite teachings on the Eucharist were declared heresy at the Blackfriars Council of 1382.'),
 ('Anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful.',
  'While Erdoğan declared being against antisemitism, he has been accused of invoking antisemitic stereotypes in public statements.'),
 ('Humans lived in societies without formal hierarchies long before the establishment of formal states, realms, or empires.',
  "In 1939, DeMille's Union Pacific was successful through DeMille's collaboration with the Union Pacific Railroad."),
 ('With the rise of organised hierarchical bodies, scepticism toward authority also rose.',
  "that father and son each bore the same double name, or that Abiathar officiated during his father's lifetime and in his father's stead—have been supported by great name

# Setting up tokenization

In [None]:
from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("google-bert/bert-base-uncased")

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [None]:
"""Playing around with the tokenizer
"""
sentence = "Here's a weird word: Withoutadoubticus."
print(f'original sentence: "{sentence}"')
demo_tokens = tokenizer([sentence])
print(f"token IDs: {demo_tokens['input_ids']}")
tokens = tokenizer.convert_ids_to_tokens(demo_tokens['input_ids'][0])
print(f'token values: {tokens}')

original sentence: "Here's a weird word: Withoutadoubticus."
token IDs: [[101, 2182, 1005, 1055, 1037, 6881, 2773, 1024, 2302, 9365, 12083, 29587, 1012, 102]]
token values: ['[CLS]', 'here', "'", 's', 'a', 'weird', 'word', ':', 'without', '##ado', '##ub', '##ticus', '.', '[SEP]']


In [None]:
print(tokenizer(["Here's a weird word: Withoutadoubticus"]))


{'input_ids': [[101, 2182, 1005, 1055, 1037, 6881, 2773, 1024, 2302, 9365, 12083, 29587, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}


In [None]:
decoded_string = tokenizer.decode(demo_tokens['input_ids'][0], skip_special_tokens=False)
decoded_string

"[CLS] here's a weird word : withoutadoubticus [SEP]"

In [None]:
tokenizer.vocab_size

30522

# Defining Pad
At the end of the day our batch needs to be a squar matrix, with a bunch of similarly sized examples of size `[batch_size, sequence_length, model_dim]`. We're dealing with multiple sequences of various lengths, so we need to use padding to even out the values.

There are a lot of ways we could do this. I'm doing it the simplest way I can imagine: I'm just cutting down sequences to fit within the max cumulative input size, and I'm padding to get to that point.

This has to be done carefully, as we have special utiltiy tokens that exist within the data. We'll cross that bridge in the batch creation implementation.
For now, I'm printing out the tokenizer to confirm that `0` is indeed padding.

In [None]:
tokenizer

BertTokenizerFast(name_or_path='google-bert/bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [None]:
tokenizer.vocab_size

30522

In [None]:
tokenizer.added_tokens_decoder.keys()

dict_keys([0, 100, 101, 102, 103])

# Batch Creation
This might not be exactly what the OG BERT paper does, but it's the spirit.

- For batch size n:
    - get n random sentences
    - for half of those, get the next sentence.
        - assuming sentences in the dataset that follow are related. This may not always be true, but it's the assumption we'll use in training.
    - for half of those grab some random sentence.
        - assuming that random sentence is not the next sentence. The likelihood of this happening is negligable.
    - mask out 15% of input tokens that are not `[cls]` or `[sep]`
        - 80% of that time we'll replace with `[mask]`
        - 10% of the time we'll replace with a random word
        - 10% of the time we'll keep it the same.

We'll also keep track of masked indexes within the total input sequence.

In [None]:
# Shuffling positive and negative pairs
from random import shuffle
shuffle(positive_pairs)
shuffle(negative_pairs)

In [None]:
"""Parallelized
"""

from tqdm import tqdm
import torch
from multiprocessing import Pool, cpu_count

#defining the device the data ends up living on
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#number of examples in the batch
batch_size = 128  # should be divisible by 2
#sequence length of model
max_input_length = 64

#defining parallelizable function for processign batches
def process_batch(batch_index):
    #establishing bounds of the batch
    start_index = batch_index * batch_size
    end_index = start_index + batch_size

    if end_index > len(positive_pairs):
        return None, None, None

    #getting the sentence pairs of the batch, and if they're pos or neg
    sentence_pairs = []
    is_positives = []

    # Creating positive pairs
    sentence_pairs.extend(positive_pairs[start_index:start_index + int(batch_size / 2)])
    is_positives.extend([1] * int(batch_size / 2))

    # Creating negative pairs
    sentence_pairs.extend(negative_pairs[start_index + int(batch_size / 2):end_index])
    is_positives.extend([0] * int(batch_size / 2))

    # Defining outputs
    # At the end of the day we need to know three things:
    #   - the tokens for the sequences in a batch
    #   - which sentence the tokens belong to, for positional encoding
    #   - if the examples in the batch are positive or negative
    # these keep track of the first two
    batch_sentence_location_tokens = []
    batch_sequence_tokens = []

    # Tokenizing pairs
    for sentence_pair in sentence_pairs:
        sentence1 = sentence_pair[0]
        sentence2 = sentence_pair[1]

        # Tokenizing both sentences
        tokens = tokenizer([sentence1, sentence2])
        sentence1_tokens = tokens['input_ids'][0]
        sentence2_tokens = tokens['input_ids'][1]

        # Trimming down tokens
        if len(sentence1_tokens) + len(sentence2_tokens) > max_input_length:
            sentence1_tokens = [101] + sentence1_tokens[-int(max_input_length / 2) + 1:]
            sentence2_tokens = sentence2_tokens[:int(max_input_length / 2) - 1] + [102]

        # Creating sentence tokens
        sentence_tokens = [0] * len(sentence1_tokens) + [1] * len(sentence2_tokens)

        # Combining and padding
        pad_num = max_input_length - (len(sentence1_tokens) + len(sentence2_tokens))
        sequence_tokens = sentence1_tokens + sentence2_tokens + [0] * pad_num
        sentence_location_tokens = sentence_tokens + [1] * pad_num

        # Adding to batch
        batch_sequence_tokens.append(sequence_tokens)
        batch_sentence_location_tokens.append(sentence_location_tokens)

    return torch.tensor(batch_sentence_location_tokens), torch.tensor(batch_sequence_tokens), torch.tensor(is_positives)

# Determine the number of batches
num_batches = len(positive_pairs) // batch_size

# Use a Pool of workers equal to the number of CPU cores
with Pool(processes=cpu_count()) as pool:
    results = list(tqdm(pool.imap(process_batch, range(num_batches)), total=num_batches))

# Filter out None results from the process_batch function
results = [result for result in results if result[0] is not None]

# Unpack results into batches
sentence_location_batches, sequence_tokens_batches, is_positives_batches = zip(*results)

# Stack tensors into final batches
sentence_location_batches = torch.stack(sentence_location_batches).to(device)
sequence_tokens_batches = torch.stack(sequence_tokens_batches).to(device)
is_positives_batches = torch.stack(is_positives_batches).to(device)

  self.pid = os.fork()
 86%|████████▋ | 6324/7318 [01:32<00:13, 72.58it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (514 > 512). Running this sequence through the model will result in indexing errors
100%|██████████| 7318/7318 [01:46<00:00, 68.61it/s]


In [None]:
sentence_location_batches.shape

torch.Size([7318, 128, 64])

# Setting Up Masking
I decided to seperate masking into it's own thing.

- mask out 15% of input tokens that are not already special tokens
    - 80% of that time we'll replace with `[mask]`
    - 10% of the time we'll replace with a random word
    - 10% of the time we'll keep it the same.

Because masked values can be random word tokens we can't rely on implicitly knowing which words are masked and which are not. So, we need to keep and record a mask vector so we know what we messed with.

Currently the plan is to run this before exposing the model to a batch in the training loop.


In [None]:
#listing out vocab for random token masking
vocab = tokenizer.get_vocab()
valid_token_ids = list(vocab.values())

def mask_batch(batch_tokens, clone=True):
    if clone:
        batch_tokens = torch.clone(batch_tokens)

    # Define the percentage of tokens to potentially mask
    replace_percentage = 0.15

    # Define tokens that should not be replaced
    excluded_tokens = {0, 100, 101, 102, 103}

    # Create a mask to identify tokens that are eligible for replacement
    eligible_mask = ~torch.isin(batch_tokens, torch.tensor(list(excluded_tokens)).to(device))

    # Count the number of eligible tokens
    num_eligible_tokens = eligible_mask.sum().item()

    # Calculate the number of tokens to potentially mask
    num_tokens_to_mask = int(num_eligible_tokens * replace_percentage)

    # Create a random permutation of eligible token indices
    eligible_indices = eligible_mask.nonzero(as_tuple=True)
    random_indices = torch.randperm(num_eligible_tokens)[:num_tokens_to_mask]

    # Create a probability distribution for replacement
    replacement_probs = torch.tensor([0.8, 0.1, 0.1])  # Probabilities for [103, random token, leave unchanged]
    replacement_choices = torch.multinomial(replacement_probs, num_tokens_to_mask, replacement=True)

    # Vector to store if a token was masked (0: not masked, 1: masked)
    masked_indicator = torch.zeros_like(batch_tokens, dtype=torch.int32)

    # Apply replacements based on sampled choices
    for i, idx in enumerate(random_indices):
        row = eligible_indices[0][idx]
        col = eligible_indices[1][idx]

        if replacement_choices[i] == 0:
            batch_tokens[row, col] = 103
            masked_indicator[row, col] = 1
        elif replacement_choices[i] == 1:
            batch_tokens[row, col] = random.choice(valid_token_ids)
            masked_indicator[row, col] = 1
        elif replacement_choices[i] == 2:
            masked_indicator[row, col] = 1

    return batch_tokens, masked_indicator

batch_tokens, masked_indicator = mask_batch(sequence_tokens_batches[0])
batch_tokens

tensor([[ 101, 5366, 2018,  ...,    0,    0,    0],
        [ 101, 2122, 6702,  ...,    0,    0,    0],
        [ 101, 1999,  103,  ...,    0,    0,    0],
        ...,
        [ 101, 2008, 2087,  ...,    0,    0,    0],
        [ 101,  101, 1999,  ...,    0,    0,    0],
        [ 101, 2019,  103,  ...,    0,    0,    0]], device='cuda:0')

# Setting up embedding
This is the first component of the model which converts tokens into vectors. These vectors are learned throughout the training process, where there's esssentially a lookup table for each word.

In [None]:
import torch.nn as nn
import torch

vocab_size = tokenizer.vocab_size
d_model = 256
n_segments = 2

class Embedding(nn.Module):
    def __init__(self):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)  # token embedding
        self.pos_embed = nn.Embedding(max_input_length, d_model)  # position embedding
        self.seg_embed = nn.Embedding(n_segments, d_model)  # segment(token type) embedding
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, seg):
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long).to(device)
        pos = pos.unsqueeze(0).expand_as(x)  # (seq_len,) -> (batch_size, seq_len)
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)

e = Embedding()
e.to(device)

Embedding(
  (tok_embed): Embedding(30522, 256)
  (pos_embed): Embedding(64, 256)
  (seg_embed): Embedding(2, 256)
  (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
)

In [None]:
#in order for the input to work all sequences need to be padded to be equivilent length
dummy_embedding = e(sequence_tokens_batches[0], sentence_location_batches[0])
print(dummy_embedding.shape)
print(dummy_embedding)

torch.Size([128, 64, 256])
tensor([[[-0.4109,  0.1544,  0.3778,  ..., -1.9995,  1.3578,  0.3117],
         [-0.5452, -0.7935, -0.6296,  ...,  1.0046, -0.1871, -0.3125],
         [-2.2820,  0.4665, -1.1026,  ..., -0.5876,  1.4205, -1.5876],
         ...,
         [ 1.2866,  0.9395,  0.7138,  ...,  0.4223,  0.3374,  0.6935],
         [-0.3787,  1.4489, -0.7226,  ...,  0.3139,  0.3640,  0.4926],
         [ 1.1291,  1.4248, -0.2899,  ...,  0.8080,  0.7977,  1.4257]],

        [[-0.4109,  0.1544,  0.3778,  ..., -1.9995,  1.3578,  0.3117],
         [-0.9470, -0.4977, -1.0789,  ...,  0.5366,  0.5290, -1.7874],
         [-1.5527, -0.2966, -0.3398,  ..., -0.5468,  1.3547, -0.6128],
         ...,
         [ 1.2866,  0.9395,  0.7138,  ...,  0.4223,  0.3374,  0.6935],
         [-0.3787,  1.4489, -0.7226,  ...,  0.3139,  0.3640,  0.4926],
         [ 1.1291,  1.4248, -0.2899,  ...,  0.8080,  0.7977,  1.4257]],

        [[-0.4109,  0.1544,  0.3778,  ..., -1.9995,  1.3578,  0.3117],
         [-0.9972,

# Defining Model
Now that we have tokenization, and we've shown that we can build an embedding model that can work with that tokenization, we can build the model itself.

The models consists of
- embeddding
- encoder block
    - multi headed self attention
    - feed forward
- a fully connected layer for outputting the prediction
- a fully connected layer for turning masked vectors (logits) to token
predictions. This predicts across all masked tokens in a batch in training.

Naturally the whole point of this is pre-training, so we'll encapsulate the core components and the head for training as seperate objects, allowing us to somewhat easily port the core components of BERT into fine tuning.



In [None]:
"""no masking
"""
import numpy as np

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V):
        #Q, K, V of size [batch x sequence_length x dim]
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(Q.shape[1])
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context, attn

#sanity checking
q = torch.tensor([[[1.1,1.3],[0.9,0.8]]]).to(device)
k = torch.tensor([[[0.9,1],[0.2,2.1]]]).to(device)
v = torch.tensor([[[1.1,1.3],[0.9,0.8]]]).to(device)
sample = ScaledDotProductAttention().to(device)
sample(q,k,v)

(tensor([[[0.9771, 0.9927],
          [0.9912, 1.0280]]], device='cuda:0'),
 tensor([[[0.3854, 0.6146],
          [0.4559, 0.5441]]], device='cuda:0'))

## Understanding Shape Transformations
in order for MHSA to work there's some pretty complex shape transformations
that have to go right. This explores just those operations.

In [None]:
#defining sample value matrix
#[batch_size x sequence_len x (query_key_dim * n_heads)]
#in this matrix, [0,1,2,3] represents the values for 2 heads across a single word vector
samp_val = torch.tensor([[[0,1,2,3],[4,5,6,7]],[[0,-1,-2,-3],[-4,-5,-6,-7]]])

#dividing into two heads
#[batch_size x sequence_len x query_key_dim x n_heads]
samp_val = samp_val.view(2,2,2,2)

#moving the head dimension next to the batch dimension
#[batch_size x n_heads x sequence_len x query_key_dim]
samp_val = samp_val.permute(0, 3, 1, 2)

#combining batch and head dimension
#[batch_size*n_heads x sequence_len x query_key_dim]
samp_val = samp_val.reshape(-1, 2, 2)

#that would be the input into mhsa, which would give back the same shape output
#now I want to unpack the mhsa back into the original shape
#[batch_size x sequence_len x (query_key_dim * n_heads)]
#if I do this right, the values should be exactly identical

#seperating heads
#[batch_size x n_heads x sequence_len x query_key_dim]
samp_val = samp_val.reshape(2,2,2,2)

#moving the head dimension to the end
#[batch_size x sequence_len x query_key_dim x n_heads]
samp_val = samp_val.permute(0, 2, 3, 1)

#combining the last dim to effectively concatonate the result of the heads
#[batch_size x sequence_len x query_key_dim*n_heads]
samp_val = samp_val.reshape(2, 2, -1)
samp_val

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7]],

        [[ 0, -1, -2, -3],
         [-4, -5, -6, -7]]])

## implementing mhsa
now that I understand what shape transformations are necessary I can get this rolling

In [None]:
import torch
import torch.nn as nn

# Define constants
n_heads = 3
query_key_dim = 64
value_dim = 64

class MultiHeadSelfAttention(nn.Module):
    def __init__(self):
        super(MultiHeadSelfAttention, self).__init__()
        # Defining the linear layers that construct the query, key, and value
        self.W_Q = nn.Linear(d_model, query_key_dim * n_heads)   # Projects input to [batch x sequence x (q/k_dim*num_heads)]
        self.W_K = nn.Linear(d_model, query_key_dim * n_heads)   # Projects input to [batch x sequence x (q/k_dim*num_heads)]
        self.W_V = nn.Linear(d_model, value_dim * n_heads)       # Projects input to [batch x sequence x (v_dim*num_heads)]
        self.dot_prod_attn = ScaledDotProductAttention()         # Parameterless system that calculates attention
        self.proj_back = nn.Linear(value_dim * n_heads, d_model) # Projects final output of mhsa back into model dimension

    def forward(self, embedding):

        # passing embedding through dense networks
        qs = self.W_Q(embedding)  # [batch_size x sequence_len x (query_key_dim * n_heads)]
        ks = self.W_K(embedding)  # [batch_size x sequence_len x (query_key_dim * n_heads)]
        vs = self.W_V(embedding)  # [batch_size x sequence_len x (value_dim * n_heads)]

        #dividing out heads
        #[batch_size, sequence_len, q/k/v_dim, n_heads]
        qs = qs.view(batch_size, max_input_length, query_key_dim, n_heads)
        ks = ks.view(batch_size, max_input_length, query_key_dim, n_heads)
        vs = vs.view(batch_size, max_input_length, value_dim, n_heads)

        #moving the head dimension next to the batch dimension
        #[batch_size x n_heads x sequence_len x q/k/v_dim]
        qs = qs.permute(0, 3, 1, 2)
        ks = ks.permute(0, 3, 1, 2)
        vs = vs.permute(0, 3, 1, 2)

        #combining batch and head dimension
        #[batch_size*n_heads x sequence_len x q/k/v_dim]
        qs = qs.reshape(-1, max_input_length, query_key_dim)
        ks = ks.reshape(-1, max_input_length, query_key_dim)
        vs = vs.reshape(-1, max_input_length, value_dim)

        #passing batches/heads of self attention through attn
        #[batch_size*n_heads x sequence_len x q/k/v_dim]
        head_results, _ = self.dot_prod_attn(qs,ks,vs)

        #seperating heads
        #[batch_size x n_heads x sequence_len x v_dim]
        head_results = head_results.reshape(batch_size,n_heads,max_input_length,value_dim)

        #moving the head dimension to the end
        #[batch_size x sequence_len x query_key_dim x n_heads]
        head_results = head_results.permute(0, 2, 3, 1)

        #combining the last dim to effectively concatonate the result of the heads
        #[batch_size x sequence_len x query_key_dim*n_heads]
        head_results = head_results.reshape(batch_size, max_input_length, -1)

        #projecting result of head back into model dimension
        return self.proj_back(head_results)

# Example usage
sample_embeddings = torch.tensor([[[1.1] * d_model] * max_input_length] * batch_size).to(device)
print("Sample embeddings shape:", sample_embeddings.shape)

sample = MultiHeadSelfAttention().to(device)
output = sample(sample_embeddings)
print('Output shape of mhsa:', output.shape)

Sample embeddings shape: torch.Size([128, 64, 256])
Output shape of mhsa: torch.Size([128, 64, 256])


## Point wise feed forward
the transformer uses point wise feed forward, allowing the model to filter each word vector independently.

In the OG transformer this is critical as it preservs independence so masking stays relvent. We're not doing masking, though, so the only benfit we're getting is that it's highly parallelizable. It's possible BERT might, in some way, benifit allowing tokens to interact with one another, I'm not sure.

We'll stick with the classic approach.

In [None]:
d_ff = 4*d_model

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        # (batch_size, len_seq, d_model) -> (batch_size, len_seq, d_ff) -> (batch_size, len_seq, d_model)
        return self.fc2(torch.nn.functional.gelu(self.fc1(x)))

sample = PoswiseFeedForwardNet().to(device)
sample_embeddings = torch.tensor([[[1.1] * d_model] * max_input_length] * batch_size).to(device)
sample(sample_embeddings).shape

torch.Size([128, 64, 256])

# Implementing an encoder block
Ok, we have mhsa and point wise feed forward. Now we need to combine them together to create a transformer block.

I'll be doing it this way
- input -> mhsa
- mhsa result + input -> skip connection
- skip connection result -> point wise feed forward
- pwff + skip conn result -> final skip connected output

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self):
        super(EncoderBlock, self).__init__()
        self.mhsa = MultiHeadSelfAttention()
        self.pwff = PoswiseFeedForwardNet()

    def forward(self, x):
        mhsa_output = self.mhsa(x)
        skip1 = mhsa_output + x
        pwff_output = self.pwff(skip1)
        skip2 = skip1+pwff_output
        return skip2

sample = EncoderBlock().to(device)
sample_embeddings = torch.tensor([[[1.1] * d_model] * max_input_length] * batch_size).to(device)
sample(sample_embeddings).shape

torch.Size([128, 64, 256])

# Building BERT
Ok we have all the core components of the encoder blocks. Now we just need to build BERT.

BERT consists of the following:
- An embedding model
- A list of encoder blocks
- A dense network for classifying if sequences are a positive pair
- A dense network for projecting word vectors into probabilities

In [None]:
n_layers = 1

class BERT(nn.Module):
    def __init__(self):
        super(BERT, self).__init__()
        #for converting tokens into vector embeddings
        self.embedding = Embedding()
        #encoder blocks
        self.encoder_blocks = nn.ModuleList([EncoderBlock() for _ in range(n_layers)])
        #for decoding a word vector (or tensor of them) into token predictions
        self.decoder = nn.Linear(d_model, tokenizer.vocab_size, bias=False)
        #for converting the first output token into a binary classification
        self.classifier = nn.Linear(d_model, 1, bias=False)

    def forward(self, x, seg, masked_token_locations):

        #x of shape [batch x seq_len x model_dim]
        embeddings = self.embedding(x, seg)
        x = embeddings
        for block in self.encoder_blocks:
            x = block(x)

        #passing first token through classifier
        clsf_logits = self.classifier(x[:,0,:])

        #passing masked tokens through decoder
        masked_token_embeddings = embeddings[masked_token_locations.bool()]
        token_logits = self.decoder(masked_token_embeddings)

        return clsf_logits, token_logits

sample = BERT().to(device)
masked_tokens, masked_token_locations = mask_batch(sequence_tokens_batches[0])
clsf_logits, token_logits = sample(masked_tokens,sentence_location_batches[0], masked_token_locations)
clsf_logits

tensor([[-0.8603],
        [-0.9484],
        [-0.8859],
        [-0.8585],
        [-0.9178],
        [-0.8825],
        [-0.9472],
        [-0.8773],
        [-0.9120],
        [-0.9046],
        [-0.8691],
        [-0.9286],
        [-0.8952],
        [-0.9534],
        [-0.8909],
        [-0.8469],
        [-0.9012],
        [-0.8312],
        [-0.8889],
        [-0.9511],
        [-0.9250],
        [-0.8991],
        [-0.8798],
        [-0.8663],
        [-0.8555],
        [-0.9122],
        [-0.9293],
        [-1.0068],
        [-0.8960],
        [-0.8803],
        [-0.9099],
        [-0.8648],
        [-0.8907],
        [-0.9916],
        [-0.9520],
        [-0.8931],
        [-0.8725],
        [-0.8994],
        [-0.9053],
        [-0.9553],
        [-0.9201],
        [-0.9063],
        [-0.9101],
        [-1.0045],
        [-0.8851],
        [-0.9236],
        [-0.8502],
        [-0.8939],
        [-0.8638],
        [-0.9233],
        [-0.8765],
        [-0.8729],
        [-0.

# PreTraining
Ok, finaly.
We have a model that's set up and outputting predictions for masked sentences and next sentence classification. Now we should be able to pass data through our model, generate some inferences, and update the model based on how wrong it was.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

# Assume the BERT class and necessary preprocessing functions are already defined

model = BERT().to(device)
token_criterion = nn.CrossEntropyLoss()  # Expect indices, not one-hot vectors
classification_criterion = nn.BCEWithLogitsLoss()  # For logits directly
optimizer = optim.Adam(model.parameters(), lr=0.001)

losses = [[]*1]

#these epochs can take a while, keeping it at a fairly small number
for epoch in range(4):
    for sequence_batch, location_batch, classtarg_batch in tqdm(zip(sequence_tokens_batches, sentence_location_batches, is_positives_batches)):
        # Zeroing out gradients from last iteration
        optimizer.zero_grad()

        # Masking the tokens in the input sequence
        masked_tokens, masked_token_locations = mask_batch(sequence_batch)

        # Generating class and masked token predictions
        clsf_logits, token_logits = model(masked_tokens, location_batch, masked_token_locations)

        # Setting up target for masked token prediction
        masked_token_targets = sequence_batch[masked_token_locations.bool()]

        # Calculating loss for next sentence classification
        loss_clsf = classification_criterion(clsf_logits.squeeze(), classtarg_batch.float())

        # Calculating loss for masked language modeling
        loss_mlm = token_criterion(token_logits, masked_token_targets)

        # Combining losses
        loss = loss_mlm + loss_clsf

        #keeping track of loss across the current epoch
        losses[-1].append(float(loss))

        # Backpropagation
        loss.backward()
        optimizer.step()

    print(f'=======Epoch {epoch} Completed=======')
    print(f'average loss in epoch: {np.mean(losses[-1])}')
    losses.append([])

7318it [13:50,  8.81it/s]


average loss in epoch: 7.652484150003233


7318it [13:50,  8.81it/s]


average loss in epoch: 7.468071281339797


7318it [13:50,  8.81it/s]


average loss in epoch: 7.4392927800674


7318it [13:49,  8.82it/s]

average loss in epoch: 7.4234244145145505





# Looking at some details about the model

In [None]:
model

BERT(
  (embedding): Embedding(
    (tok_embed): Embedding(30522, 256)
    (pos_embed): Embedding(64, 256)
    (seg_embed): Embedding(2, 256)
    (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  )
  (encoder_blocks): ModuleList(
    (0): EncoderBlock(
      (mhsa): MultiHeadSelfAttention(
        (W_Q): Linear(in_features=256, out_features=192, bias=True)
        (W_K): Linear(in_features=256, out_features=192, bias=True)
        (W_V): Linear(in_features=256, out_features=192, bias=True)
        (dot_prod_attn): ScaledDotProductAttention()
        (proj_back): Linear(in_features=192, out_features=256, bias=True)
      )
      (pwff): PoswiseFeedForwardNet(
        (fc1): Linear(in_features=256, out_features=1024, bias=True)
        (fc2): Linear(in_features=1024, out_features=256, bias=True)
      )
    )
  )
  (decoder): Linear(in_features=256, out_features=30522, bias=False)
  (classifier): Linear(in_features=256, out_features=1, bias=False)
)

In [None]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

16367936

# Fine Tuning on Sentiment Analysis
now that we have a model that hopefully knows a thing or two about language, we can fine tune the model for sentiment analysis.

First we need to get the dataset.
The [fancyzhx/amazon_polarity](https://huggingface.co/datasets/fancyzhx/amazon_polarity) was released with [this paper](https://arxiv.org/pdf/1509.01626) and is also referenced on the [registry of open data on AWS](https://registry.opendata.aws/fast-ai-nlp/). On Huggingface it's lisenced under [apache 2.0](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md)


## Getting Data

In [None]:
fine_tune_ds = load_dataset("fancyzhx/amazon_polarity")

In [None]:
for elem in fine_tune_ds['train']:
    print(elem)
    break

{'label': 1, 'title': 'Stuning even for the non-gamer', 'content': 'This sound track was beautiful! It paints the senery in your mind so well I would recomend it even to people who hate vid. game music! I have played the game Chrono Cross but out of all of the games I have ever played it has the best music! It backs away from crude keyboarding and takes a fresher step with grate guitars and soulful orchestras. It would impress anyone who cares to listen! ^_^'}


## Constructing Train and Test Data.
This dataset happens to have two sequences (the title and the review conten), which actually plays nicely with the original preprocessing stuff we did for pre training. I just need to get the data in a similar format, and everything should be pretty easy to get rolling.

In [None]:
def preprocess_data(data, max_num = 100000):
    data_tokens = []
    data_positional = []
    data_targets = []

    #unpacking data
    for i, elem in enumerate(data):

        #tokenizing the title and content
        sentence1 = elem['title']
        sentence2 = elem['content']
        tokens = tokenizer([sentence1, sentence2])
        sentence1_tokens = tokens['input_ids'][0]
        sentence2_tokens = tokens['input_ids'][1]

        # Trimming down tokens
        if len(sentence1_tokens) + len(sentence2_tokens) > max_input_length:
            sentence1_tokens = [101] + sentence1_tokens[-int(max_input_length / 2) + 1:]
            sentence2_tokens = sentence2_tokens[:int(max_input_length / 2) - 1] + [102]

        # Creating sentence tokens
        sentence_tokens = [0] * len(sentence1_tokens) + [1] * len(sentence2_tokens)

        # Combining and padding
        pad_num = max_input_length - (len(sentence1_tokens) + len(sentence2_tokens))
        sequence_tokens = sentence1_tokens + sentence2_tokens + [0] * pad_num
        sentence_location_tokens = sentence_tokens + [1] * pad_num

        data_tokens.append(sequence_tokens)
        data_positional.append(sentence_location_tokens)
        data_targets.append(elem['label'])

        if i > max_num: break

    return torch.tensor(data_positional), torch.tensor(data_tokens), torch.tensor(data_targets)

#processing data into modeling data
train_pos, train_tok, train_targ = preprocess_data(fine_tune_ds['train'])
test_pos, test_tok, test_targ = preprocess_data(fine_tune_ds['test'])

#moving to training
train_pos = train_pos.to(device)
train_tok = train_tok.to(device)
train_targ = train_targ.to(device)

#moving to testing
test_pos = test_pos.to(device)
test_tok = test_tok.to(device)
test_targ = test_targ.to(device)

print(train_tok)

tensor([[  101,   101, 24646,  ...,     0,     0,     0],
        [  101,   101,  1996,  ...,     0,     0,     0],
        [  101,   101,  6429,  ...,     0,     0,     0],
        ...,
        [  101,   101,  2438,  ...,     0,     0,     0],
        [  101,   101,  2307,  ...,     0,     0,     0],
        [  101,   101,  9458,  ...,     0,     0,     0]], device='cuda:0')


# Fine Tuning
Iterating over batches of the training data to fine tune

In [None]:
# Replacing classification head with a new head
# the new training objective is still binary classification,
# except these parameters will be used to decide if a
# review was positive or negative
model.classifier = nn.Linear(d_model, 1, bias=False).to(device)

# resetting the optimizer to have access to the parameters of the new head
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
ft_losses = [[]*1]
ft_test_acc = []

for epoch in range(5):
    for i in tqdm(range(0, train_pos.shape[0], batch_size)):

        if i+batch_size>=train_pos.shape[0]:
            break

        #getting batch
        train_pos_batch = train_pos[i:i+batch_size]
        train_tok_batch = train_tok[i:i+batch_size]
        train_targ_batch = train_targ[i:i+batch_size]

        # Zeroing out gradients from last iteration
        optimizer.zero_grad()

        # Masking the tokens in the input sequence
        masked_tokens, masked_token_locations = mask_batch(train_tok_batch)

        # Generating class and masked token predictions
        clsf_logits, token_logits = model(masked_tokens, train_pos_batch, masked_token_locations)

        # Setting up target for masked token prediction
        masked_token_targets = sequence_batch[masked_token_locations.bool()]

        # Calculating loss for next sentence classification
        loss_clsf = classification_criterion(clsf_logits.squeeze(), train_targ_batch.float())

        # Calculating loss for masked language modeling
        loss_mlm = token_criterion(token_logits, masked_token_targets)

        # Combining losses
        loss = loss_mlm + loss_clsf

        ft_losses[-1].append(float(loss))

        # Backpropagation
        loss.backward()
        optimizer.step()

    print(f'=======Epoch {epoch} Completed=======')
    print(f'average loss in epoch: {np.mean(ft_losses[-1])}')
    losses.append([])



100%|█████████▉| 781/782 [01:13<00:00, 10.60it/s]


average loss in epoch: 5.904687740433384


100%|█████████▉| 781/782 [01:13<00:00, 10.65it/s]


average loss in epoch: 5.426271478894731


100%|█████████▉| 781/782 [01:13<00:00, 10.57it/s]


average loss in epoch: 5.234847757687082


100%|█████████▉| 781/782 [01:13<00:00, 10.60it/s]


average loss in epoch: 5.128661873245972


100%|█████████▉| 781/782 [01:13<00:00, 10.61it/s]

average loss in epoch: 5.058861758614319





In [None]:
is_correct = []
predicted_class = []
original_class = []

for i in tqdm(range(0, test_pos.shape[0], batch_size)):

    if i+batch_size>=test_pos.shape[0]:
            break

    #getting batch
    test_pos_batch = test_pos[i:i+batch_size]
    test_tok_batch = test_tok[i:i+batch_size]
    test_targ_batch = test_targ[i:i+batch_size]

    #making prediction, not masking anything
    clsf_logits, _ = model(test_tok_batch, test_pos_batch, torch.zeros(test_pos_batch.shape))

    #converting logits to probabilities then rounding to classifications
    res = torch.sigmoid(clsf_logits).round().squeeze()

    #keeping track of the original class (positive or negative) and if the model was correct
    original_class.extend(np.array(test_targ_batch.to('cpu')))
    is_correct.extend(np.array((res == test_targ_batch).to('cpu')))
    predicted_class.extend(np.array(res.detach().to('cpu')))

100%|█████████▉| 781/782 [00:02<00:00, 379.67it/s]


In [None]:
#accuracy
sum(list(is_correct))/len(is_correct)

0.7686259603072984

In [None]:
from sklearn.metrics import confusion_matrix, classification_report
print(classification_report(original_class, predicted_class))

              precision    recall  f1-score   support

           0       0.76      0.77      0.77     49405
           1       0.77      0.77      0.77     50563

    accuracy                           0.77     99968
   macro avg       0.77      0.77      0.77     99968
weighted avg       0.77      0.77      0.77     99968

