In [1]:
import random
import typing
import torch
from collections import Counter
from transformers import BertTokenizer
from datasets import load_dataset
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
from tqdm import tqdm

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class FitnessChatBertDataset(Dataset):
    CLS = '[CLS]'
    PAD = '[PAD]'
    SEP = '[SEP]'
    MASK = '[MASK]'
    UNK = '[UNK]'

    MASK_PERCENTAGE = 0.15

    MASKED_INDICES_COLUMN = 'masked_indices'
    TARGET_COLUMN = 'indices'
    NSP_TARGET_COLUMN = 'is_next'
    TOKEN_MASK_COLUMN = 'token_mask'

    OPTIMAL_LENGTH_PERCENTILE = 70

    def __init__(self, dataset_name: str, should_include_text=False, ds_from=None, ds_to=None):
        self.optimal_sentence_length = 128
        self.ds = load_dataset(dataset_name, split='train')  


        dataset_size = len(self.ds)
        print(f"Dataset size: {dataset_size}")  

        if ds_from is not None and ds_to is not None:

            ds_from = min(ds_from, dataset_size)
            ds_to = min(ds_to, dataset_size)

            print(f"Slicing dataset from {ds_from} to {ds_to}")

            self.ds = self.ds.select(range(ds_from, ds_to))  
        else:
            print("No slicing applied, using the entire dataset.")

        print("First few samples in the dataset:", self.ds[:3])  

        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.counter = Counter()
        self.vocab = self.tokenizer.get_vocab()

        self.optimal_sentence_length = None
        self.should_include_text = should_include_text

        if should_include_text:
            self.columns = ['masked_prompt', self.MASKED_INDICES_COLUMN, 'prompt', self.TARGET_COLUMN,
                            self.TOKEN_MASK_COLUMN, self.NSP_TARGET_COLUMN]
        else:
            self.columns = [self.MASKED_INDICES_COLUMN, self.TARGET_COLUMN, self.TOKEN_MASK_COLUMN,
                            self.NSP_TARGET_COLUMN]

        self.df = self.prepare_dataset()

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        item = self.df.iloc[idx]

        inp = torch.Tensor(item[self.MASKED_INDICES_COLUMN]).long()
        token_mask = torch.Tensor(item[self.TOKEN_MASK_COLUMN]).bool()

        mask_target = torch.Tensor(item[self.TARGET_COLUMN]).long()

        if len(mask_target) != len(token_mask):

            if len(mask_target) > len(token_mask):
                mask_target = mask_target[:len(token_mask)]
 
            elif len(mask_target) < len(token_mask):
                padding_length = len(token_mask) - len(mask_target)
                mask_target = torch.cat([mask_target, torch.zeros(padding_length).long()])

        mask_target = mask_target.masked_fill_(token_mask, 0)

        attention_mask = (inp == self.vocab[self.PAD]).unsqueeze(0)

        if item[self.NSP_TARGET_COLUMN] == 0:
            t = [1, 0]
        else:
            t = [0, 1]

        nsp_target = torch.Tensor(t)

        return (
            inp.to(device),
            attention_mask.to(device),
            token_mask.to(device),
            mask_target.to(device),
            nsp_target.to(device)
        )


    def prepare_dataset(self) -> pd.DataFrame:
        sentences = []
        nsp = []
        sentence_lens = []

        for entry in self.ds:

            prompt = entry['instruction']  
            completion = entry['output']   

            # Tokenize the prompt and completion
            prompt_tokens = self.tokenizer.tokenize(prompt)
            completion_tokens = self.tokenizer.tokenize(completion)

            sentences.append((prompt_tokens, completion_tokens))
            sentence_lens.append(len(prompt_tokens) + len(completion_tokens))

        self.optimal_sentence_length = self._find_optimal_sentence_length(sentence_lens)

        print("Preprocessing dataset")
        for prompt_tokens, completion_tokens in tqdm(sentences):
            # True NSP item
            nsp.append(self._create_item(prompt_tokens, completion_tokens, 1))

            # False NSP item
            first, second = self._select_false_nsp_sentences(sentences)
            nsp.append(self._create_item(first, second, 0))

        df = pd.DataFrame(nsp, columns=self.columns)
        print(f"Processed dataset with {len(df)} samples.")
        return df

    def _find_optimal_sentence_length(self, lengths: typing.List[int]):
        arr = np.array(lengths)
        optimal_length = int(np.percentile(arr, self.OPTIMAL_LENGTH_PERCENTILE))
        print(f"Optimal sentence length: {optimal_length}")
        return optimal_length

    def _create_item(self, prompt: typing.List[str], completion: typing.List[str], target: int = 1):

        updated_prompt, prompt_mask = self._preprocess_sentence(prompt)
        updated_completion, completion_mask = self._preprocess_sentence(completion)

        nsp_sentence = updated_prompt + [self.SEP] + updated_completion

        nsp_sentence_flat = [token for sublist in nsp_sentence for token in (sublist if isinstance(sublist, list) else [sublist])]

        nsp_indices = self.tokenizer.convert_tokens_to_ids(nsp_sentence_flat)

        inverse_token_mask = prompt_mask + [True] + completion_mask

        prompt, _ = self._preprocess_sentence(prompt, should_mask=False)
        completion, _ = self._preprocess_sentence(completion, should_mask=False)
        original_nsp_sentence = prompt + [self.SEP] + completion
        original_nsp_indices = self.tokenizer.convert_tokens_to_ids(original_nsp_sentence)

        if self.should_include_text:

            return (
                nsp_sentence,               
                nsp_indices,                 
                original_nsp_sentence,      
                original_nsp_indices,       
                inverse_token_mask,         
                target                      
            )
        else:

            return (
                nsp_indices,                 
                inverse_token_mask,          
                original_nsp_indices,        
                target                      
            )

    def _preprocess_sentence(self, sentence: typing.List[str], should_mask: bool = True):

        inverse_token_mask = [] if not should_mask else None

        if should_mask:
            sentence, inverse_token_mask = self._mask_sentence(sentence)

        sentence = list(sentence) if isinstance(sentence, tuple) else sentence
 
        inverse_token_mask = list(inverse_token_mask) if isinstance(inverse_token_mask, tuple) else inverse_token_mask
        inverse_token_mask = inverse_token_mask or []  
 
        sentence, inverse_token_mask = self._pad_sentence([self.CLS] + sentence, inverse_token_mask)

        return sentence, inverse_token_mask

    def _pad_sentence(self, sentence: typing.List[str], inverse_token_mask: typing.List[bool]):
        padding_length = self.optimal_sentence_length - len(sentence)
        if padding_length > 0:
            sentence += [self.PAD] * padding_length
            inverse_token_mask += [False] * padding_length

        return sentence, inverse_token_mask

    def _mask_sentence(self, sentence: typing.List[str]):
        """Mask words in the sentence randomly"""
        masked_sentence = sentence[:]
        token_mask = [False] * len(sentence)
        num_tokens_to_mask = int(len(sentence) * self.MASK_PERCENTAGE)

        masked_indices = random.sample(range(len(sentence)), num_tokens_to_mask)
        
        for idx in masked_indices:
            masked_sentence[idx] = self.MASK
            token_mask[idx] = True
        
        print(f"Masked sentence: {masked_sentence}")
        print(f"Token mask: {token_mask}")
        
        return masked_sentence, token_mask

    def _select_false_nsp_sentences(self, sentences: typing.List[typing.Tuple[list, list]]):
        """Select two sentences that are not consecutive for the false NSP task."""

        index1 = random.randint(0, len(sentences) - 1)
        index2 = random.randint(0, len(sentences) - 1)

        while abs(index1 - index2) == 1 or index1 == index2:
            index2 = random.randint(0, len(sentences) - 1)

        return sentences[index1] 


ds = FitnessChatBertDataset('chibbss/fitness-chat-prompt-completion-dataset')
print(ds[0])

Dataset size: 245
No slicing applied, using the entire dataset.
First few samples in the dataset: {'output': ['1. Develop a consistent exercise routine – Exercise is essential for physical and mental health. Aim for at least 30 minutes of physical activity five days a week.\n\n2. Follow a healthy diet – Incorporate more fruits, vegetables, and whole grains into your diet while avoiding processed and fast foods.\n\n3. Get enough sleep – Give your body time to rest and repair by getting the recommended seven to nine hours of sleep every night.\n\n4. Practice relaxation techniques – Take a break to practice mindfulness, deep breathing, and other forms of relaxation to reduce stress and maintain emotional balance.\n\n5. Talk', 'A balanced diet is one that includes all the essential nutrients that your body needs to function properly. It should include an adequate amount of protein, carbohydrates, fat, vitamins, minerals, and water. It should also include a variety of whole grains, fruits, 



Optimal sentence length: 137
Preprocessing dataset


100%|███████████████████████████████████████| 245/245 [00:00<00:00, 3792.50it/s]


Masked sentence: ['what', 'are', 'some', 'practical', 'steps', '[MASK]', 'can', 'take', 'to', 'improve', 'my', 'overall', 'health', 'and', 'well', '[MASK]', 'being', '?']
Token mask: [False, False, False, False, False, True, False, False, False, False, False, False, False, False, False, True, False, False]
Masked sentence: ['1', '.', 'develop', 'a', 'consistent', 'exercise', 'routine', '–', 'exercise', 'is', 'essential', 'for', 'physical', 'and', 'mental', '[MASK]', '.', 'aim', '[MASK]', '[MASK]', 'least', '30', 'minutes', 'of', 'physical', 'activity', 'five', '[MASK]', 'a', 'week', '.', '2', '[MASK]', 'follow', 'a', 'healthy', '[MASK]', '–', 'incorporate', 'more', 'fruits', ',', 'vegetables', ',', 'and', 'whole', 'grains', 'into', 'your', 'diet', 'while', 'avoiding', 'processed', 'and', 'fast', '[MASK]', '.', '3', '[MASK]', 'get', 'enough', 'sleep', '–', 'give', 'your', 'body', 'time', 'to', '[MASK]', 'and', 'repair', 'by', 'getting', 'the', 'recommended', '[MASK]', 'to', 'nine', 'hou

In [4]:
import torch

from torch import nn
import torch.nn.functional as f


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class JointEmbedding(nn.Module):

    def __init__(self, vocab_size, size):
        super(JointEmbedding, self).__init__()

        self.size = size

        self.token_emb = nn.Embedding(vocab_size, size)
        self.segment_emb = nn.Embedding(vocab_size, size)

        self.norm = nn.LayerNorm(size)

    def forward(self, input_tensor):
        sentence_size = input_tensor.size(-1)
        pos_tensor = self.attention_position(self.size, input_tensor)

        segment_tensor = torch.zeros_like(input_tensor).to(device)
        segment_tensor[:, sentence_size // 2 + 1:] = 1

        output = self.token_emb(input_tensor) + self.segment_emb(segment_tensor) + pos_tensor
        return self.norm(output)

    def attention_position(self, dim, input_tensor):
        sentence_size = input_tensor.size(-1)
    
        pos = torch.arange(sentence_size, dtype=torch.float32, device=device).unsqueeze(1)
        d = torch.arange(dim, dtype=torch.float32, device=device) / dim
        div_term = 1e4 ** d
    
        pos_encoding = pos / div_term
        pos_encoding[:, ::2] = torch.sin(pos_encoding[:, ::2])  
        pos_encoding[:, 1::2] = torch.cos(pos_encoding[:, 1::2])  
    
        return pos_encoding.unsqueeze(0).expand(input_tensor.size(0), -1, -1)  


    def numeric_position(self, dim, input_tensor):
        pos_tensor = torch.arange(dim, dtype=torch.long).to(device)
        return pos_tensor.expand_as(input_tensor)


class AttentionHead(nn.Module):

    def __init__(self, dim_inp, dim_out):
        super(AttentionHead, self).__init__()

        self.dim_inp = dim_inp

        self.q = nn.Linear(dim_inp, dim_out)
        self.k = nn.Linear(dim_inp, dim_out)
        self.v = nn.Linear(dim_inp, dim_out)

    def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor = None):
        query, key, value = self.q(input_tensor), self.k(input_tensor), self.v(input_tensor)
    
        scale = query.size(-1) ** 0.5  
        scores = torch.bmm(query, key.transpose(1, 2)) / scale

        if attention_mask is not None:
            scores = scores.masked_fill(attention_mask.bool(), -1e9)  
    
        attn = f.softmax(scores, dim=-1)
        context = torch.bmm(attn, value)
        return context



class MultiHeadAttention(nn.Module):

    def __init__(self, num_heads, dim_inp, dim_out):
        super(MultiHeadAttention, self).__init__()

        self.heads = nn.ModuleList([
            AttentionHead(dim_inp, dim_out) for _ in range(num_heads)
        ])
        self.linear = nn.Linear(dim_out * num_heads, dim_inp)
        self.norm = nn.LayerNorm(dim_inp)

    def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor):
        s = [head(input_tensor, attention_mask) for head in self.heads]
        scores = torch.cat(s, dim=-1)
        scores = self.linear(scores)
        return self.norm(scores)


class Encoder(nn.Module):

    def __init__(self, dim_inp, dim_out, attention_heads=4, dropout=0.1):
        super(Encoder, self).__init__()

        self.attention = MultiHeadAttention(attention_heads, dim_inp, dim_out)  
        self.feed_forward = nn.Sequential(
            nn.Linear(dim_inp, dim_out),
            nn.Dropout(dropout),
            nn.GELU(),
            nn.Linear(dim_out, dim_inp),
            nn.Dropout(dropout)
        )
        self.norm = nn.LayerNorm(dim_inp)

    def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor):
        context = self.attention(input_tensor, attention_mask)
        res = self.feed_forward(context)
        return self.norm(res)


class BERT(nn.Module):

    def __init__(self, vocab_size, dim_inp, dim_out, attention_heads=4):
        super(BERT, self).__init__()

        self.embedding = JointEmbedding(vocab_size, dim_inp)
        self.encoder = Encoder(dim_inp, dim_out, attention_heads)

        self.token_prediction_layer = nn.Linear(dim_inp, vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)
        self.classification_layer = nn.Linear(dim_inp, 2)

    def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor):
        embedded = self.embedding(input_tensor)
        encoded = self.encoder(embedded, attention_mask)

        token_predictions = self.token_prediction_layer(encoded)

        first_word = encoded[:, 0, :]
        return self.softmax(token_predictions), self.classification_layer(first_word)

In [5]:
import time
from datetime import datetime
from pathlib import Path

import torch

from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def percentage(batch_size: int, max_index: int, current_index: int):
    
    batched_max = max_index // batch_size
    return round(current_index / batched_max * 100, 2)


def nsp_accuracy(result: torch.Tensor, target: torch.Tensor):
    
    s = (result.argmax(1) == target.argmax(1)).sum()
    return round(float(s / result.size(0)), 2)


def token_accuracy(result: torch.Tensor, target: torch.Tensor, inverse_token_mask: torch.Tensor):
    
    r = result.argmax(-1).masked_select(~inverse_token_mask)
    t = target.masked_select(~inverse_token_mask)
    s = (r == t).sum()
    return round(float(s / (result.size(0) * result.size(1))), 2)


class BertTrainer:

    def __init__(self,
                 model: BERT,
                 dataset: FitnessChatBertDataset,
                 log_dir: Path = None,
                 checkpoint_dir: Path = None,
                 print_progress_every: int = 10,
                 print_accuracy_every: int = 50,
                 batch_size: int = 24,
                 learning_rate: float = 0.005,
                 epochs: int = 5,
                 ):
        self.model = model
        self.dataset = dataset

        self.batch_size = batch_size
        self.epochs = epochs
        self.current_epoch = 0

        self.loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True)

        self.writer = SummaryWriter(str(log_dir))
        self.checkpoint_dir = checkpoint_dir

        self.criterion = nn.BCEWithLogitsLoss().to(device)
        self.ml_criterion = nn.NLLLoss(ignore_index=0).to(device)
        self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.015)

        self._splitter_size = 35

        self._ds_len = len(self.dataset)
        self._batched_len = self._ds_len // self.batch_size

        self._print_every = print_progress_every
        self._accuracy_every = print_accuracy_every

    def print_summary(self):
        ds_len = len(self.dataset)

        print("Model Summary\n")
        print('=' * self._splitter_size)
        print(f"Device: {device}")
        print(f"Training dataset len: {ds_len}")
        print(f"Max / Optimal sentence len: {self.dataset.optimal_sentence_length}")
        print(f"Vocab size: {len(self.dataset.vocab)}")
        print(f"Batch size: {self.batch_size}")
        print(f"Batched dataset len: {self._batched_len}")
        print('=' * self._splitter_size)
        print()

    def __call__(self):
        for self.current_epoch in range(self.current_epoch, self.epochs):
            loss = self.train(self.current_epoch)
            self.save_checkpoint(self.current_epoch, step=-1, loss=loss)

    def train(self, epoch: int):
        print(f"Begin epoch {epoch}")

        prev = time.time()
        average_nsp_loss = 0
        average_mlm_loss = 0
        for i, value in enumerate(self.loader):
            index = i + 1
            inp, mask, inverse_token_mask, token_target, nsp_target = value
            self.optimizer.zero_grad()

            token, nsp = self.model(inp, mask)

            tm = inverse_token_mask.unsqueeze(-1).expand_as(token)
            token = token.masked_fill(tm, 0)

            loss_token = self.ml_criterion(token.transpose(1, 2), token_target)  
            loss_nsp = self.criterion(nsp, nsp_target)

            loss = loss_token + loss_nsp
            average_nsp_loss += loss_nsp
            average_mlm_loss += loss_token

            loss.backward()
            self.optimizer.step()

            if index % self._print_every == 0:
                elapsed = time.gmtime(time.time() - prev)
                s = self.training_summary(elapsed, index, average_nsp_loss, average_mlm_loss)

                if index % self._accuracy_every == 0:
                    s += self.accuracy_summary(index, token, nsp, token_target, nsp_target)

                print(s)

                average_nsp_loss = 0
                average_mlm_loss = 0
        return loss

    def training_summary(self, elapsed, index, average_nsp_loss, average_mlm_loss):
        passed = percentage(self.batch_size, self._ds_len, index)
        global_step = self.current_epoch * len(self.loader) + index

        print_nsp_loss = average_nsp_loss / self._print_every
        print_mlm_loss = average_mlm_loss / self._print_every

        s = f"{time.strftime('%H:%M:%S', elapsed)}"
        s += f" | Epoch {self.current_epoch + 1} | {index} / {self._batched_len} ({passed}%) | " \
             f"NSP loss {print_nsp_loss:6.2f} | MLM loss {print_mlm_loss:6.2f}"

        self.writer.add_scalar("NSP loss", print_nsp_loss, global_step=global_step)
        self.writer.add_scalar("MLM loss", print_mlm_loss, global_step=global_step)
        return s

    def accuracy_summary(self, index, token, nsp, token_target, nsp_target, inverse_token_mask):
        global_step = self.current_epoch * len(self.loader) + index
        nsp_acc = nsp_accuracy(nsp, nsp_target)
        token_acc = token_accuracy(token, token_target, inverse_token_mask)

        self.writer.add_scalar("NSP train accuracy", nsp_acc, global_step=global_step)
        self.writer.add_scalar("Token train accuracy", token_acc, global_step=global_step)

        return f" | NSP accuracy {nsp_acc} | Token accuracy {token_acc}"

    def save_checkpoint(self, epoch, step, loss):
        if self.checkpoint_dir:
            return

        prev = time.time()
        name = f"bert_epoch{epoch}_step{step}_{datetime.datetime.utcnow().timestamp():.0f}.pt"

        torch.save({
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'loss': loss,
        }, self.checkpoint_dir.joinpath(name))

        print()
        print('=' * self._splitter_size)
        print(f"Model saved as '{name}' for {time.time() - prev:.2f}s")
        print('=' * self._splitter_size)
        print()

    def load_checkpoint(self, path: Path):
        print('=' * self._splitter_size)
        print(f"Restoring model {path}")
        checkpoint = torch.load(path)
        self.current_epoch = checkpoint['epoch']
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print("Model is restored.")
        print('=' * self._splitter_size)

In [6]:
import datetime

import torch

from pathlib import Path


EMB_SIZE = 64
HIDDEN_SIZE = 36
EPOCHS = 4
BATCH_SIZE = 12
NUM_HEADS = 4

CHECKPOINT_DIR = 'Desktop/Temp/bert_checkpoints'

timestamp = datetime.datetime.utcnow().timestamp()
LOG_DIR = 'Desktop/Temp/logs/bert_experiment_{timestamp}'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.cuda.empty_cache()

if __name__ == '__main__':
    print("Prepare dataset")
    ds = FitnessChatBertDataset('chibbss/fitness-chat-prompt-completion-dataset')

    bert = BERT(len(ds.vocab), EMB_SIZE, HIDDEN_SIZE, NUM_HEADS).to(device)
    trainer = BertTrainer(
        model=bert,
        dataset=ds,
        log_dir=LOG_DIR,
        checkpoint_dir=CHECKPOINT_DIR,
        print_progress_every=20,
        print_accuracy_every=200,
        batch_size=BATCH_SIZE,
        learning_rate=0.00007,
        epochs=15
    )

    trainer.print_summary()
    trainer()

Prepare dataset
Dataset size: 245
No slicing applied, using the entire dataset.
First few samples in the dataset: {'output': ['1. Develop a consistent exercise routine – Exercise is essential for physical and mental health. Aim for at least 30 minutes of physical activity five days a week.\n\n2. Follow a healthy diet – Incorporate more fruits, vegetables, and whole grains into your diet while avoiding processed and fast foods.\n\n3. Get enough sleep – Give your body time to rest and repair by getting the recommended seven to nine hours of sleep every night.\n\n4. Practice relaxation techniques – Take a break to practice mindfulness, deep breathing, and other forms of relaxation to reduce stress and maintain emotional balance.\n\n5. Talk', 'A balanced diet is one that includes all the essential nutrients that your body needs to function properly. It should include an adequate amount of protein, carbohydrates, fat, vitamins, minerals, and water. It should also include a variety of whole 

100%|███████████████████████████████████████| 245/245 [00:00<00:00, 3819.28it/s]

Masked sentence: ['what', 'are', '[MASK]', 'practical', 'steps', 'i', 'can', 'take', 'to', 'improve', '[MASK]', 'overall', 'health', 'and', 'well', '-', 'being', '?']
Token mask: [False, False, True, False, False, False, False, False, False, False, True, False, False, False, False, False, False, False]
Masked sentence: ['1', '.', 'develop', 'a', 'consistent', 'exercise', 'routine', '[MASK]', 'exercise', 'is', 'essential', 'for', '[MASK]', 'and', 'mental', 'health', '.', 'aim', 'for', 'at', 'least', '[MASK]', 'minutes', 'of', 'physical', 'activity', 'five', 'days', 'a', 'week', '.', '[MASK]', '.', 'follow', 'a', 'healthy', 'diet', '–', 'incorporate', 'more', 'fruits', ',', 'vegetables', '[MASK]', 'and', 'whole', '[MASK]', 'into', 'your', '[MASK]', 'while', 'avoiding', 'processed', 'and', 'fast', 'foods', '[MASK]', '[MASK]', '.', 'get', 'enough', 'sleep', '–', 'give', 'your', 'body', 'time', 'to', 'rest', 'and', 'repair', 'by', '[MASK]', 'the', 'recommended', 'seven', 'to', 'nine', '[MAS




Model Summary

Device: cpu
Training dataset len: 490
Max / Optimal sentence len: 137
Vocab size: 30522
Batch size: 12
Batched dataset len: 40

Begin epoch 0
00:00:07 | Epoch 1 | 20 / 40 (50.0%) | NSP loss   0.73 | MLM loss   8.27
00:00:15 | Epoch 1 | 40 / 40 (100.0%) | NSP loss   0.72 | MLM loss   7.42
Begin epoch 1
00:00:07 | Epoch 2 | 20 / 40 (50.0%) | NSP loss   0.71 | MLM loss   7.09
00:00:14 | Epoch 2 | 40 / 40 (100.0%) | NSP loss   0.70 | MLM loss   6.78
Begin epoch 2
00:00:07 | Epoch 3 | 20 / 40 (50.0%) | NSP loss   0.72 | MLM loss   6.55
00:00:15 | Epoch 3 | 40 / 40 (100.0%) | NSP loss   0.73 | MLM loss   6.33
Begin epoch 3
00:00:07 | Epoch 4 | 20 / 40 (50.0%) | NSP loss   0.72 | MLM loss   6.17
00:00:15 | Epoch 4 | 40 / 40 (100.0%) | NSP loss   0.71 | MLM loss   5.99
Begin epoch 4
00:00:07 | Epoch 5 | 20 / 40 (50.0%) | NSP loss   0.71 | MLM loss   5.81
00:00:15 | Epoch 5 | 40 / 40 (100.0%) | NSP loss   0.70 | MLM loss   5.62
Begin epoch 5
00:00:07 | Epoch 6 | 20 / 40 (50.0%) |

In [7]:
# Save the entire trained model
model_save_path = '/Users/likeshkoya/code/NLP_Project/bert_model.pt'

torch.save(trainer.model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")


Model saved to /Users/likeshkoya/code/NLP_Project/bert_model.pt


In [8]:
# Load the saved model
bert = BERT(len(ds.vocab), EMB_SIZE, HIDDEN_SIZE, NUM_HEADS).to(device)
bert.load_state_dict(torch.load(model_save_path, map_location=device))
bert.eval()  # Set model to evaluation mode
print("Model loaded and ready for testing.")


Model loaded and ready for testing.


In [9]:
from transformers import BertTokenizer

def test_model(model, tokenizer, vocab, optimal_length):

    model.eval()  
    
    while True:
        print("\nEnter two sentences for NSP prediction (or type 'exit' to quit):")
        first_sentence = input("First sentence: ")
        if first_sentence.lower() == "exit":
            break
        second_sentence = input("Second sentence: ")
        if second_sentence.lower() == "exit":
            break

        first_tokens = tokenizer.tokenize(first_sentence)
        second_tokens = tokenizer.tokenize(second_sentence)
        input_tokens = [tokenizer.cls_token] + first_tokens + [tokenizer.sep_token] + second_tokens + [tokenizer.sep_token]

        input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
        attention_mask = [1] * len(input_ids)
        
        padding_length = max(0, optimal_length - len(input_ids))
        input_ids += [vocab['[PAD]']] * padding_length
        attention_mask += [0] * padding_length

        input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)
        attention_mask = torch.tensor(attention_mask).unsqueeze(0).to(device)

        with torch.no_grad():
            token_output, nsp_output = model(input_ids, attention_mask)

        nsp_prediction = nsp_output.argmax(-1).item()
        nsp_result = "Next Sentence" if nsp_prediction == 1 else "Not Next Sentence"

        print(f"NSP Result: {nsp_result}")

        if '[MASK]' in input_tokens:
            predicted_tokens = tokenizer.convert_ids_to_tokens(token_output[0].argmax(-1).tolist())
            masked_sentence = " ".join(predicted_tokens).replace("[PAD]", "").strip()
            print(f"MLM Prediction: {masked_sentence}")


In [10]:
# Load tokenizer, vocab, and trained model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
vocab = tokenizer.get_vocab()

# Test with user input
test_model(bert, tokenizer, vocab, ds.optimal_sentence_length)





Enter two sentences for NSP prediction (or type 'exit' to quit):


First sentence:  What are some practical steps I can take to improve my overall health and well-being?
Second sentence:  What are some effective strategies for incorporating regular exercise into my daily routine?


NSP Result: Next Sentence

Enter two sentences for NSP prediction (or type 'exit' to quit):


First sentence:  exit
