<a href="https://colab.research.google.com/github/JAZ201107/DL-Experiments/blob/main/BERT_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# BERT from scratch
In this project, I will build a simple BERT from scratch using **IMDB reviews data** with ~72k words. And using then using the tried model to do a **senmentiment classification**


In [None]:
!pip install kaggle

In [7]:
import os
import time
from datetime import datetime
from collections import Counter
import random
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchtext.vocab import vocab
from torchtext.data.utils import get_tokenizer

# import tensorboard
from torch.utils.tensorboard import SummaryWriter

# Dataset

In [5]:
# !pip install kaggle

os.environ['KAGGLE_USERNAME'] = "x"
os.environ['KAGGLE_KEY'] = "x"

# download datasets
!kaggle datasets download -d lakshmi25npathi/imdb-dataset-of-50k-movie-reviews
!unzip imdb-dataset-of-50k-movie-reviews.zip


Archive:  imdb-dataset-of-50k-movie-reviews.zip
  inflating: IMDB Dataset.csv        


In [6]:
df = pd.read_csv('/content/IMDB Dataset.csv')
df.head()

Unnamed: 0,review,sentiment
0,One of the other reviewers has mentioned that ...,positive
1,A wonderful little production. <br /><br />The...,positive
2,I thought this was a wonderful way to spend ti...,positive
3,Basically there's a family where a little boy ...,negative
4,"Petter Mattei's ""Love in the Time of Money"" is...",positive


In [10]:
class PARAMS:
    # Dataset
    CLS = '[CLS]'
    PAD = '[PAD]'
    SEP = '[SEP]'
    MASK = '[MASK]'
    UNK = '[UNK]'

    MASK_PERCENTAGE = 0.15
    MASKED_INDICES_COLUMN = 'masked_indices'
    TARGET_COLUMN = 'indices'
    NSP_TARGET_COLUMN = 'is_next'
    TOKEN_MASK_COLUMN = 'token_mask'

    OPTIMAL_LENGTH_PERCENTILE = 70

    device = torch.device('cuda' if torch.cuda.is_available() else "cpu")


    # Joint Embeeding
    vocab_size =
    size =

    # Attention
    dim_inp =
    dim_out =
    num_heads = 4
    dropout = 0,1


In [11]:
PARAMS.device

device(type='cpu')

In [None]:
class IMBDBertDataset(Dataset):
    def __init__(
            self,
            path,
            ds_from = None,
            ds_to = None,
            should_include_text = False,
            params
    ):
        self.ds = pd.read_csv(path)['review']

        if ds_from is not None or ds_to is not None:
            self.ds = self.ds[ds_from:ds_to]

        self.tokenizer = get_tokenizer('basic_english')
        self.counter = Counter()
        self.vocab = None
        self.optimal_sentence_length = None
        self.should_include_text = should_include_text

        if should_include_text:
            if should_include_text:
            self.columns = ['masked_sentence', self.MASKED_INDICES_COLUMN, 'sentence', self.TARGET_COLUMN,
                            self.TOKEN_MASK_COLUMN,
                            self.NSP_TARGET_COLUMN]
        else:
            self.columns = [self.MASKED_INDICES_COLUMN, self.TARGET_COLUMN, self.TOKEN_MASK_COLUMN,
                            self.NSP_TARGET_COLUMN]

        self.df = self.prepare_dataset()

    def __getitem__(self, idx):

# Define Bert Model

In [12]:
class JointEmbedding(nn.Module):
    def __init__(self, params):
        super().__init__()

        self.size = params.vocab_size

        self.token_emb = nn.Embedding(params.vocab_size, params.size)
        self.segment_emb = nn.Embedding(params.vocab_size, params.size)

        self.norm = nn.LayerNorm(params.size)

    def forward(self, input_tensor, params):

        sentence_size = input_tensor.size(-1)
        pos_tensor = self.attention_position(self.size, input_tensor, params)

        segment_tensor = torch.zeros_like(input_tensor).to(params.device)
        segment_tensor[:, sentence_size // 2 + 1:] = 1

        output = self.token_emb(input_tensor) + self.segment_emb(segment_tensor) + pos_tensor
        return self.norm(output)

    def attention_position(self, dim, input_tensor, params):
        batch_size = input_tensor.size(0)
        sentence_size = input_tensor.size(-1)

        pos = torch.arange(sentence_size, dtype=torch.long).to(params.device)
        d = torch.arange(dim, dtype=torch.long).to(params.device)
        d = (2 * d / dim)

        pos = pos.unsqueeze(1)
        pos = pos / (1e4 ** d)

        pos[:, ::2] = torch.sin(pos[:, ::2])
        pos[:, 1::2] = torch.cos(pos[:, 1::2])

        return pos.expand(batch_size, *pos.size())

    def numeric_position(self, dim, input_tensor, params):
        pos_tensor = torch.arange(dim, dtype=torch.long).to(params.device)
        return pos_tensor.expand_as(input_tensor)



In [None]:
class AttentionHead(nn.Module):
    def __init__(
            self,
            params
    ):
        super().__init__()

        self.dim_inp = params.dim_inp

        self.q = nn.Linear(params.dim_inp, params.dim_out)
        self.k = nn.Linear(params.dim_inp, params.dim_out)
        self.v = nn.Linear(params.dim_inp, params.dim_out)

    def forward(self, input_tensor, attention_mask):
        query, key, value = self.q(input_tensor), self.k(input_tensor), self.v(input_tensor)

        scale = query.size(1) ** 0.5
        scores = torch.bmm(query, key.transpose(1, 2)) / scale

        scores = scores.masked_fill_(attention_mask, -1e9)
        attn = F.softmax(scores, dim=-1)
        context = torch.bmm(attn, value)

        return context

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, params):
        super().__init__()
        self.heads = nn.ModuleList([
            AttentionHead(params) for _ in range(params.num_heads)
        ])
        self.linear = nn.Linear(params.dim_out * params.num_heads, params.dim_inp)
        self.norm = nn.LayerNorm(params.dim_inp)


    def forward(self, input_tensor, attention_mask):
        s = [head(input_tensor, attention_mask) for head in self.heads]
        scores = torch.cat(s, dim=-1)
        scores = self.linear(scores)
        return self.norm(scores)

In [None]:
class Encoder(nn.Module):
    def __init__(self, params):
        super().__init__()

        self.attention = MultiHeadAttention(params)  # batch_size x sentence size x dim_inp
        self.feed_forward = nn.Sequential(
            nn.Linear(params.dim_inp, params.dim_out),
            nn.Dropout(params.dropout),
            nn.GELU(),
            nn.Linear(params.dim_out, params.dim_inp),
            nn.Dropout(params.dropout)
        )
        self.norm = nn.LayerNorm(params.dim_inp)

    def forward(self, input_tensor, attention_mask):
        context = self.attention(input_tensor, attention_mask)
        res = self.feed_forward(context)
        return self.norm(res)


In [None]:
class BERT(nn.Module):

    def __init__(self, params):
        super().__init__()

        self.embedding = JointEmbedding(params)
        self.encoder = Encoder(params)

        self.token_prediction_layer = nn.Linear(params.dim_inp, params.vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)
        self.classification_layer = nn.Linear(params.dim_inp, 2)

    def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor):
        embedded = self.embedding(input_tensor)
        encoded = self.encoder(embedded, attention_mask)

        token_predictions = self.token_prediction_layer(encoded)

        first_word = encoded[:, 0, :]
        return self.softmax(token_predictions), self.classification_layer(first_word)

# Helper functions to train model

In [14]:
def percentage(batch_size, max_index, current_index):
    """
    Calculcate epoch progress percentage
    """
    batched_max = max_index // batch_size
    return round(current_index / batched_max * 100, 2)

def nsp_accuracy(result, target):
    """
    Calculate NSP accuracy between two tensors
    """
    s = (result.argmax(1) == target.argmax(1)).sum()
    return round(float(s / result.size(0)), 2)

def token_accuracy(result, target, inverse_token_mask):
    """
    Calculate MLM accuracy between ONLY masked words
    """
    r = result.argmax(-1).masked_select(~inverse_token_mask)
    t = target.masked_select(~inverse_token_mask)
    s = (r == t).sum()

    return round(float(s / (result.size(0) * result.size(1))), 2)



In [None]:
class BertTrainer:

    def __init__(self,
                 model: BERT,
                 dataset: IMDBBertDataset,
                 log_dir: Path,
                 checkpoint_dir: Path = None,
                 print_progress_every: int = 10,
                 print_accuracy_every: int = 50,
                 batch_size: int = 24,
                 learning_rate: float = 0.005,
                 epochs: int = 5,
                 ):
        self.model = model
        self.dataset = dataset

        self.batch_size = batch_size
        self.epochs = epochs
        self.current_epoch = 0

        self.loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True)

        self.writer = SummaryWriter(str(log_dir))
        self.checkpoint_dir = checkpoint_dir

        self.criterion = nn.BCEWithLogitsLoss().to(device)
        self.ml_criterion = nn.NLLLoss(ignore_index=0).to(device)
        self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.015)

        self._splitter_size = 35

        self._ds_len = len(self.dataset)
        self._batched_len = self._ds_len // self.batch_size

        self._print_every = print_progress_every
        self._accuracy_every = print_accuracy_every

    def print_summary(self):
        ds_len = len(self.dataset)

        print("Model Summary\n")
        print('=' * self._splitter_size)
        print(f"Device: {device}")
        print(f"Training dataset len: {ds_len}")
        print(f"Max / Optimal sentence len: {self.dataset.optimal_sentence_length}")
        print(f"Vocab size: {len(self.dataset.vocab)}")
        print(f"Batch size: {self.batch_size}")
        print(f"Batched dataset len: {self._batched_len}")
        print('=' * self._splitter_size)
        print()

    def __call__(self):
        for self.current_epoch in range(self.current_epoch, self.epochs):
            loss = self.train(self.current_epoch)
            self.save_checkpoint(self.current_epoch, step=-1, loss=loss)

    def train(self, epoch: int):
        print(f"Begin epoch {epoch}")

        prev = time.time()
        average_nsp_loss = 0
        average_mlm_loss = 0
        for i, value in enumerate(self.loader):
            index = i + 1
            inp, mask, inverse_token_mask, token_target, nsp_target = value
            self.optimizer.zero_grad()

            token, nsp = self.model(inp, mask)

            tm = inverse_token_mask.unsqueeze(-1).expand_as(token)
            token = token.masked_fill(tm, 0)

            loss_token = self.ml_criterion(token.transpose(1, 2), token_target)  # 1D tensor as target is required
            loss_nsp = self.criterion(nsp, nsp_target)

            loss = loss_token + loss_nsp
            average_nsp_loss += loss_nsp
            average_mlm_loss += loss_token

            loss.backward()
            self.optimizer.step()

            if index % self._print_every == 0:
                elapsed = time.gmtime(time.time() - prev)
                s = self.training_summary(elapsed, index, average_nsp_loss, average_mlm_loss)

                if index % self._accuracy_every == 0:
                    s += self.accuracy_summary(index, token, nsp, token_target, nsp_target, inverse_token_mask)

                print(s)

                average_nsp_loss = 0
                average_mlm_loss = 0
        return loss

    def training_summary(self, elapsed, index, average_nsp_loss, average_mlm_loss):
        passed = percentage(self.batch_size, self._ds_len, index)
        global_step = self.current_epoch * len(self.loader) + index

        print_nsp_loss = average_nsp_loss / self._print_every
        print_mlm_loss = average_mlm_loss / self._print_every

        s = f"{time.strftime('%H:%M:%S', elapsed)}"
        s += f" | Epoch {self.current_epoch + 1} | {index} / {self._batched_len} ({passed}%) | " \
             f"NSP loss {print_nsp_loss:6.2f} | MLM loss {print_mlm_loss:6.2f}"

        self.writer.add_scalar("NSP loss", print_nsp_loss, global_step=global_step)
        self.writer.add_scalar("MLM loss", print_mlm_loss, global_step=global_step)
        return s

    def accuracy_summary(self, index, token, nsp, token_target, nsp_target, inverse_token_mask):
        global_step = self.current_epoch * len(self.loader) + index
        nsp_acc = nsp_accuracy(nsp, nsp_target)
        token_acc = token_accuracy(token, token_target, inverse_token_mask)

        self.writer.add_scalar("NSP train accuracy", nsp_acc, global_step=global_step)
        self.writer.add_scalar("Token train accuracy", token_acc, global_step=global_step)

        return f" | NSP accuracy {nsp_acc} | Token accuracy {token_acc}"

    def save_checkpoint(self, epoch, step, loss):
        if not self.checkpoint_dir:
            return

        prev = time.time()
        name = f"bert_epoch{epoch}_step{step}_{datetime.utcnow().timestamp():.0f}.pt"

        torch.save({
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'loss': loss,
        }, self.checkpoint_dir.joinpath(name))

        print()
        print('=' * self._splitter_size)
        print(f"Model saved as '{name}' for {time.time() - prev:.2f}s")
        print('=' * self._splitter_size)
        print()

    def load_checkpoint(self, path: Path):
        print('=' * self._splitter_size)
        print(f"Restoring model {path}")
        checkpoint = torch.load(path)
        self.current_epoch = checkpoint['epoch']
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print("Model is restored.")
        print('=' * self._splitter_size)

# Train Model

In [None]:
ds = IMDBBertDataset(BASE_DIR.joinpath('data/imdb.csv'), ds_from=0, ds_to=1000)

bert_model = BERT(PARAMS).to(PARAMS.device)
trainer = BertTrainer(
        model=bert,
        dataset=ds,
        log_dir=LOG_DIR,
        checkpoint_dir=CHECKPOINT_DIR,
        print_progress_every=20,
        print_accuracy_every=200,
        batch_size=BATCH_SIZE,
        learning_rate=0.00007,
        epochs=15
    )

trainer.print_summary()
trainer()