In [1]:
!pip install transformers datasets tokenizers
!wget http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip
!unzip -qq cornell_movie_dialogs_corpus.zip
!rm cornell_movie_dialogs_corpus.zip
!mkdir datasets
!mv cornell\ movie-dialogs\ corpus/movie_conversations.txt ./datasets
!mv cornell\ movie-dialogs\ corpus/movie_lines.txt ./datasets

import os
from pathlib import Path
import torch
import re
import random
import transformers, datasets
from tokenizers import BertWordPieceTokenizer
from transformers import BertTokenizer
import tqdm
from torch.utils.data import Dataset, DataLoader
import itertools
import math
import torch.nn.functional as F
import numpy as np
from torch.optim import Adam

--2025-03-28 11:40:29--  http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip
Resolving www.cs.cornell.edu (www.cs.cornell.edu)... 132.236.207.53
Connecting to www.cs.cornell.edu (www.cs.cornell.edu)|132.236.207.53|:80... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip [following]
--2025-03-28 11:40:30--  https://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip
Connecting to www.cs.cornell.edu (www.cs.cornell.edu)|132.236.207.53|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9916637 (9.5M) [application/zip]
Saving to: ‘cornell_movie_dialogs_corpus.zip’


2025-03-28 11:40:33 (5.62 MB/s) - ‘cornell_movie_dialogs_corpus.zip’ saved [9916637/9916637]



In [2]:
from collections import defaultdict
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

In [3]:
MAX_LEN = 64

### loading all data into memory
corpus_movie_conv = './datasets/movie_conversations.txt'
corpus_movie_lines = './datasets/movie_lines.txt'
with open(corpus_movie_conv, 'r', encoding='iso-8859-1') as c:
    conv = c.readlines()
with open(corpus_movie_lines, 'r', encoding='iso-8859-1') as l:
    lines = l.readlines()

### splitting text using special lines
lines_dic = {}
for line in lines:
    objects = line.split(" +++$+++ ")
    lines_dic[objects[0]] = objects[-1]

### generate question answer pairs
pairs = []
for con in conv:
    ids = eval(con.split(" +++$+++ ")[-1])
    for i in range(len(ids)):
        qa_pairs = []

        if i == len(ids) - 1:
            break

        first = lines_dic[ids[i]].strip()
        second = lines_dic[ids[i+1]].strip()

        qa_pairs.append(' '.join(first.split()[:MAX_LEN]))
        qa_pairs.append(' '.join(second.split()[:MAX_LEN]))
        pairs.append(qa_pairs)

# sample
from sklearn.model_selection import train_test_split
train_pairs, val_pairs = train_test_split(pairs,test_size = 0.1, random_state = 42)
print(train_pairs[20])
print(val_pairs[20])

["What difference does it make if it's true? It's a <u>story</u>, and, it <u>breaks</u> they're gonna have to <u>run</u> with it -- How long've we got til it breaks?", 'Front page. Washington Post. Tomorrow.']
["Yeah. Everybody's talkin' about it. They're makin' a big deal out of it.", 'I know.']


In [4]:
os.mkdir('./data')
text_data = []
file_count = 0

for sample in tqdm.tqdm([x[0] for x in pairs]):
    text_data.append(sample)

    # once we hit the 10K mark, save to file
    if len(text_data) == 10000:
        with open(f'./data/text_{file_count}.txt', 'w', encoding='utf-8') as fp:
            fp.write('\n'.join(text_data))
        text_data = []
        file_count += 1

paths = [str(x) for x in Path('./data').glob('**/*.txt')]
print(len(paths))



100%|██████████| 221616/221616 [00:00<00:00, 1700881.39it/s]

22





In [5]:
### training own tokenizer
tokenizer = BertWordPieceTokenizer(
    clean_text=True,
    handle_chinese_chars=False,
    strip_accents=False,
    lowercase=True
)

tokenizer.train(
    files=paths,
    vocab_size=30_000,
    min_frequency=5,
    limit_alphabet=1000,
    wordpieces_prefix='##',
    special_tokens=['[PAD]', '[CLS]', '[SEP]', '[MASK]', '[UNK]']
    )

os.mkdir('./bert-it-1')
tokenizer.save_model('./bert-it-1', 'bert-it')
tokenizer = BertTokenizer.from_pretrained('./bert-it-1/bert-it-vocab.txt', local_files_only=True)
token_ids = tokenizer('I like surfboarding!')['input_ids']
print(token_ids)
print(tokenizer.convert_ids_to_tokens(token_ids))

[1, 48, 250, 4033, 3588, 154, 5, 2]
['[CLS]', 'i', 'like', 'surf', '##board', '##ing', '!', '[SEP]']




In [6]:
class BERTDataset(Dataset):
    def __init__(self, data_pair, tokenizer, seq_len=64):

        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.corpus_lines = len(data_pair)
        self.lines = data_pair

    def __len__(self):
        return self.corpus_lines

    def __getitem__(self, item):

        # Step 1: get random sentence pair, either negative or positive (saved as is_next_label)
        t1, t2, is_next_label = self.get_sent(item)

        # Step 2: replace random words in sentence with mask / random words
        t1_random, t1_label = self.random_word(t1)
        t2_random, t2_label = self.random_word(t2)

        # Step 3: Adding CLS and SEP tokens to the start and end of sentences
        # Adding PAD token for labels
        t1 = [self.tokenizer.vocab['[CLS]']] + t1_random + [self.tokenizer.vocab['[SEP]']]
        t2 = t2_random + [self.tokenizer.vocab['[SEP]']]
        t1_label = [self.tokenizer.vocab['[PAD]']] + t1_label + [self.tokenizer.vocab['[PAD]']]
        t2_label = t2_label + [self.tokenizer.vocab['[PAD]']]

        # Step 4: combine sentence 1 and 2 as one input
        # adding PAD tokens to make the sentence same length as seq_len
        segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
        bert_input = (t1 + t2)[:self.seq_len]
        bert_label = (t1_label + t2_label)[:self.seq_len]
        padding = [self.tokenizer.vocab['[PAD]'] for _ in range(self.seq_len - len(bert_input))]
        bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding)
        output = {"bert_input": bert_input,
                  "bert_label": bert_label,
                  "segment_label": segment_label,
                  "is_next": is_next_label}

        return {key: torch.tensor(value) for key, value in output.items()}

    def random_word(self, sentence):
        tokens = sentence.split()
        output_label = []
        output = []

        # 15% of the tokens would be replaced
        for i, token in enumerate(tokens):
            prob = random.random()

            # remove cls and sep token
            token_id = self.tokenizer(token)['input_ids'][1:-1]

            # 15% chance of altering token
            if prob < 0.15:
                prob /= 0.15

                # 80% chance change token to mask token
                if prob < 0.8:
                    for i in range(len(token_id)):
                        output.append(self.tokenizer.vocab['[MASK]'])

                # 10% chance change token to random token
                elif prob < 0.9:
                    for i in range(len(token_id)):
                        output.append(random.randrange(len(self.tokenizer.vocab)))

                else:
                    output.append(token_id)

                output_label.append(token_id)

            else:
                output.append(token_id)
                for i in range(len(token_id)):
                    output_label.append(0)

        # flattening
        output = list(itertools.chain(*[[x] if not isinstance(x, list) else x for x in output]))
        output_label = list(itertools.chain(*[[x] if not isinstance(x, list) else x for x in output_label]))
        assert len(output) == len(output_label)
        return output, output_label

    def get_sent(self, index):
        '''return random sentence pair'''
        t1, t2 = self.get_corpus_line(index)

        # negative or positive pair, for next sentence prediction
        if random.random() > 0.5:
            return t1, t2, 1
        else:
            return t1, self.get_random_line(), 0

    def get_corpus_line(self, item):
        '''return sentence pair'''
        return self.lines[item][0], self.lines[item][1]

    def get_random_line(self):
        '''return random single sentence'''
        return self.lines[random.randrange(len(self.lines))][1]

In [7]:
print("\n")
train_data = BERTDataset(pairs, seq_len=MAX_LEN, tokenizer=tokenizer)
val_data = BERTDataset(val_pairs, seq_len=MAX_LEN, tokenizer=tokenizer)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, pin_memory=True)
val_loader = DataLoader(train_data, batch_size=32, shuffle=True, pin_memory=True)
sample_data = next(iter(train_loader))
print('Batch Size', sample_data['bert_input'].size())

# 3 is MASK
result = train_data[random.randrange(len(train_data))]
result



Batch Size torch.Size([32, 64])


{'bert_input': tensor([    1,   192, 11389,  1041, 12089,    15,   211,     3,  4404,    15,
             3,   150,  8736,  6110, 13755,   408,  1847,  7603,    17,     2,
            48,   230,    11,    59,     3,   153,   432, 20042,    17,   934,
           870, 12758,   256,   173,    40,  6110,   465,     3,  3047,    15,
           464,  1669,   153, 14331,     3,     3,   179,   400,   162,    11,
            58,     3,   237,   422,    15,   422,    15,   422,    15,   248,
           150,   368,    17,     2]),
 'bert_label': tensor([   0,    0,  150,    0,    0,    0,    0, 1341,    0,    0,  179,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
         3219,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,  994,    0,    0,    0,    0,    0,    0,  146,   15,    0,    0,
            0,    0,    0,  538,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0]),
 'segment_label': t

In [8]:
class PositionalEmbedding(torch.nn.Module):
    def __init__(self, d_model, max_len=128):
        super().__init__()

        
        pe = torch.zeros(max_len, d_model).float()
        pe.requires_grad = False  

        for pos in range(max_len):  # Iterate over all positions
            for i in range(0, d_model, 2):  # Iterate over dimensions (step = 2)
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / d_model)))
                if i + 1 < d_model:
                    pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))

        self.pe = pe.unsqueeze(0)  # Shape becomes (1, max_len, d_model)

    def forward(self, x):
        # Slice the positional encodings to match input sequence length
        return self.pe[:, :x.size(1), :].to(x.device)

class BERTEmbedding(torch.nn.Module):
    def __init__(self, vocab_size, embed_size, seq_len=64, dropout=0.1):
        super().__init__()

        self.embed_size = embed_size

        # Token embedding: Converts each token ID into an embedding vector
        self.token = torch.nn.Embedding(vocab_size, embed_size, padding_idx=0)

        # Segment embedding: Indicates which sentence a token belongs to
        self.segment = torch.nn.Embedding(3, embed_size, padding_idx=0)

        # Positional embedding: Adds order information
        self.position = PositionalEmbedding(d_model=embed_size, max_len=seq_len)

        # Dropout layer to prevent overfitting
        self.dropout = torch.nn.Dropout(p=dropout)

    def forward(self, sequence, segment_label):
        """
        :param sequence: tokenized input (batch_size, seq_len)
        :param segment_label: segment IDs (batch_size, seq_len)
        :return: Embedded representation of shape (batch_size, seq_len, embed_size)
        """
        x = self.token(sequence)  # Token embeddings (batch_size, seq_len, embed_size)
        x += self.position(sequence)  # Add positional encodings
        x += self.segment(segment_label)  # Add segment embeddings
        return self.dropout(x)  # Apply dropout


### testing
embed_layer = BERTEmbedding(vocab_size=len(tokenizer.vocab), embed_size=768, seq_len=MAX_LEN)
embed_result = embed_layer(sample_data['bert_input'], sample_data['segment_label'])
print(embed_result.size())

torch.Size([32, 64, 768])


In [9]:
class MultiHeadedAttention(torch.nn.Module):

    def __init__(self, heads, d_model, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()

        assert d_model % heads == 0
        self.d_k = d_model // heads
        self.heads = heads
        self.dropout = torch.nn.Dropout(dropout)

        self.query = torch.nn.Linear(d_model, d_model)
        self.key = torch.nn.Linear(d_model, d_model)
        self.value = torch.nn.Linear(d_model, d_model)
        self.output_linear = torch.nn.Linear(d_model, d_model)
        
    def forward(self, query, key, value, mask):
        """
        query, key, value of shape: (batch_size, max_len, d_model)
        mask of shape: (batch_size, 1, 1, max_words)
        """
        # (batch_size, max_len, d_model)
        query = self.query(query)
        key = self.key(key)
        value = self.value(value)

        # (batch_size, max_len, d_model) --> (batch_size, max_len, h, d_k) --> (batch_size, h, max_len, d_k)
        query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
        key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
        value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)

        # (batch_size, h, max_len, d_k) matmul (batch_size, h, d_k, max_len) --> (batch_size, h, max_len, max_len)
        scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / math.sqrt(query.size(-1))

        # fill 0 mask with super small number so it wont affect the softmax weight
        # (batch_size, h, max_len, max_len)
        scores = scores.masked_fill(mask == 0, -1e9)

        # (batch_size, h, max_len, max_len)
        # softmax to put attention weight for all non-pad tokens
        # max_len X max_len matrix of attention
        weights = F.softmax(scores, dim=-1)
        weights = self.dropout(weights)

        # (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k)
        context = torch.matmul(weights, value)

        # (batch_size, h, max_len, d_k) --> (batch_size, max_len, h, d_k) --> (batch_size, max_len, d_model)
        context = context.permute(0, 2, 1, 3).contiguous().view(context.shape[0], -1, self.heads * self.d_k)

        # (batch_size, max_len, d_model)
        return self.output_linear(context)
        
class FeedForward(torch.nn.Module):
    "Implements FFN equation"

    def __init__(self, d_model, middle_dim=2048, dropout=0.1):
        super(FeedForward, self).__init__()

        self.fc1 = torch.nn.Linear(d_model, middle_dim)
        self.fc2 = torch.nn.Linear(middle_dim, d_model)
        self.dropout = torch.nn.Dropout(dropout)
        self.activation = torch.nn.GELU()

    def forward(self, x):
        out = self.activation(self.fc1(x))
        out = self.fc2(self.dropout(out))
        return out

class EncoderLayer(torch.nn.Module):
    def __init__(
        self,
        d_model=768,
        heads=12,
        feed_forward_hidden=768 * 4,
        dropout=0.1
        ):
        super(EncoderLayer, self).__init__()
        self.layernorm = torch.nn.LayerNorm(d_model)
        self.self_multihead = MultiHeadedAttention(heads, d_model)
        self.feed_forward = FeedForward(d_model, middle_dim=feed_forward_hidden)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, embeddings, mask):
        # embeddings: (batch_size, max_len, d_model)
        # encoder mask: (batch_size, 1, 1, max_len)
        # result: (batch_size, max_len, d_model)
        interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
        # residual layer
        interacted = self.layernorm(interacted + embeddings)
        # bottleneck
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        encoded = self.layernorm(feed_forward_out + interacted)
        return encoded

### testing
mask = (sample_data['bert_input'] > 0).unsqueeze(1).repeat(1, sample_data['bert_input'].size(1), 1).unsqueeze(1)
transformer_block = EncoderLayer()
transformer_result = transformer_block(embed_result, mask)
transformer_result.size()

torch.Size([32, 64, 768])

In [10]:
class BERT(torch.nn.Module):
    """
    BERT model : Bidirectional Encoder Representations from Transformers.
    """

    def __init__(self, vocab_size, d_model=768, n_layers=12, heads=12, dropout=0.1):
        """
        :param vocab_size: vocab_size of total words
        :param hidden: BERT model hidden size
        :param n_layers: numbers of Transformer blocks(layers)
        :param attn_heads: number of attention heads
        :param dropout: dropout rate
        """

        super().__init__()
        self.d_model = d_model
        self.n_layers = n_layers
        self.heads = heads

        # paper noted they used 4*hidden_size for ff_network_hidden_size
        self.feed_forward_hidden = d_model * 4

        # embedding for BERT, sum of positional, segment, token embeddings
        self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=d_model)

        # multi-layers transformer blocks, deep network
        self.encoder_blocks = torch.nn.ModuleList(
            [EncoderLayer(d_model, heads, d_model * 4, dropout) for _ in range(n_layers)])

    def forward(self, x, segment_info):
        # attention masking for padded token
        # (batch_size, 1, seq_len, seq_len)
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)

        # embedding the indexed sequence to sequence of vectors
        x = self.embedding(x, segment_info)

        # running over multiple transformer blocks
        for encoder in self.encoder_blocks:
            x = encoder.forward(x, mask)
        return x

class NextSentencePrediction(torch.nn.Module):
    """
    2-class classification model : is_next, is_not_next
    """

    def __init__(self, hidden):
        """
        :param hidden: BERT model output size
        """
        super().__init__()
        self.linear = torch.nn.Linear(hidden, 2)
        self.softmax = torch.nn.LogSoftmax(dim=-1)

    def forward(self, x):
        # use only the first token which is the [CLS]
        return self.softmax(self.linear(x[:, 0]))

class MaskedLanguageModel(torch.nn.Module):
    """
    predicting origin token from masked input sequence
    n-class classification problem, n-class = vocab_size
    """

    def __init__(self, hidden, vocab_size):
        """
        :param hidden: output size of BERT model
        :param vocab_size: total vocab size
        """
        super().__init__()
        self.linear = torch.nn.Linear(hidden, vocab_size)
        self.softmax = torch.nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x))

class BERTLM(torch.nn.Module):
    """
    BERT Language Model
    Next Sentence Prediction Model + Masked Language Model
    """

    def __init__(self, bert: BERT, vocab_size):
        """
        :param bert: BERT model which should be trained
        :param vocab_size: total vocab size for masked_lm
        """

        super().__init__()
        self.bert = bert
        self.next_sentence = NextSentencePrediction(self.bert.d_model)
        self.mask_lm = MaskedLanguageModel(self.bert.d_model, vocab_size)

    def forward(self, x, segment_label):
        x = self.bert(x, segment_label)
        return self.next_sentence(x), self.mask_lm(x)

### test
bert_model = BERT(len(tokenizer.vocab))
bert_result = bert_model(sample_data['bert_input'], sample_data['segment_label'])
print(bert_result.size())

bert_lm = BERTLM(bert_model, len(tokenizer.vocab))
final_result = bert_lm(sample_data['bert_input'], sample_data['segment_label'])
print(final_result[0].size(), final_result[1].size())

torch.Size([32, 64, 768])
torch.Size([32, 2]) torch.Size([32, 64, 21159])


In [11]:
class ScheduledOptim():
    '''A simple wrapper class for learning rate scheduling'''

    def __init__(self, optimizer, d_model, n_warmup_steps):
        self._optimizer = optimizer
        self.n_warmup_steps = n_warmup_steps
        self.n_current_steps = 0
        self.init_lr = np.power(d_model, -0.5)

    def step_and_update_lr(self):
        "Step with the inner optimizer"
        self._update_learning_rate()
        self._optimizer.step()

    def zero_grad(self):
        "Zero out the gradients by the inner optimizer"
        self._optimizer.zero_grad()

    def _get_lr_scale(self):
        return np.min([
            np.power(self.n_current_steps, -0.5),
            np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])

    def _update_learning_rate(self):
        ''' Learning rate scheduling per step '''

        self.n_current_steps += 1
        lr = self.init_lr * self._get_lr_scale()

        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr

In [12]:
class BERTTrainer:
    def __init__(
        self,
        model,
        train_dataloader,
        val_dataloader=None,
        lr= 1e-4,
        weight_decay=0.01,
        betas=(0.9, 0.999),
        warmup_steps=10000,
        log_freq=1000,
        device='cuda',
        save_path = "bert_pretrained.pt "
        ):

        self.device = device
        self.model = model.to(self.device)
        self.train_data = train_dataloader
        self.val_data = val_dataloader
        self.save_path = save_path

        # Setting the Adam optimizer with hyper-param
        self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
        self.optim_schedule = ScheduledOptim(
            self.optim, self.model.bert.d_model, n_warmup_steps=warmup_steps
            )

        # Using Negative Log Likelihood Loss function for predicting the masked_token
        self.criterion = torch.nn.NLLLoss(ignore_index=0)
        self.log_freq = log_freq
        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))

    def train(self, epoch):
        self.iteration(epoch, self.train_data)
        self.save_model(epoch)

    def test(self, epoch):
        self.iteration(epoch, self.test_data, train=False)

    def iteration(self, epoch, data_loader, train=True):

        avg_loss = 0.0
        total_correct = 0
        total_element = 0

        mode = "train" if train else "test"

        data_iter = tqdm.tqdm(
            enumerate(data_loader),
            desc=f"EP_{mode}:{epoch}",
            total=len(data_loader),
            bar_format="{l_bar}{r_bar}"
        )

        for i, data in data_iter:

            # 0. batch_data will be sent into the device(GPU or cpu)
            data = {key: value.to(self.device) for key, value in data.items()}

            # 1. forward the next_sentence_prediction and masked_lm model
            next_sent_output, mask_lm_output = self.model.forward(data["bert_input"].to(self.device), data["segment_label"].to(self.device))

            # 2-1. NLL(negative log likelihood) loss of is_next classification result
            next_loss = self.criterion(next_sent_output, data["is_next"])

            # 2-2. NLLLoss of predicting masked token word
            # transpose to (m, vocab_size, seq_len) vs (m, seq_len)
            # criterion(mask_lm_output.view(-1, mask_lm_output.size(-1)), data["bert_label"].view(-1))
            mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data["bert_label"])

            # 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure
            loss = next_loss + mask_loss

            if train:
                self.optim_schedule.zero_grad()
                loss.backward()
                self.optim_schedule.step_and_update_lr()

            # next sentence prediction accuracy
            correct = next_sent_output.argmax(dim=-1).eq(data["is_next"]).sum().item()
            avg_loss += loss.item()
            total_correct += correct
            total_element += data["is_next"].nelement()

            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": avg_loss / (i + 1),
                "avg_acc": total_correct / total_element * 100,
                "loss": loss.item()
            }

            if i % self.log_freq == 0:
                data_iter.write(str(post_fix))
        print(
            f"EP{epoch}, {mode}: \
            avg_loss={avg_loss / len(data_iter)}, \
            total_acc={total_correct * 100.0 / total_element}"
        )
    def save_model(self, epoch):
        """Save the model and optimizer state."""
        save_dict = {
            "epoch": epoch,
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optim.state_dict(),
        }
        torch.save(save_dict, self.save_path)
        print(f"Model saved at {self.save_path}")

### test
train_data = BERTDataset(train_pairs, seq_len=MAX_LEN, tokenizer=tokenizer)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, pin_memory=True)
bert_model = BERT(len(tokenizer.vocab))
bert_lm = BERTLM(bert_model, len(tokenizer.vocab))
device = "cuda" if torch.cuda.is_available() else "cpu"
bert_trainer = BERTTrainer(bert_lm, train_loader, device=device)
epochs = 10

for epoch in range(epochs):
    bert_trainer.train(epoch)


    

Total Parameters: 117561257


EP_train:0:   0%|| 1/6233 [00:00<1:19:50,  1.30it/s]

{'epoch': 0, 'iter': 0, 'avg_loss': 11.375048637390137, 'avg_acc': 34.375, 'loss': 11.375048637390137}


EP_train:0:  16%|| 1001/6233 [05:14<26:59,  3.23it/s]

{'epoch': 0, 'iter': 1000, 'avg_loss': 7.132251117374752, 'avg_acc': 50.10614385614386, 'loss': 5.945916175842285}


EP_train:0:  32%|| 2001/6233 [10:28<21:45,  3.24it/s]

{'epoch': 0, 'iter': 2000, 'avg_loss': 6.487824701893515, 'avg_acc': 50.0734007996002, 'loss': 5.502036094665527}


EP_train:0:  48%|| 3001/6233 [15:43<17:33,  3.07it/s]

{'epoch': 0, 'iter': 3000, 'avg_loss': 6.254598715590541, 'avg_acc': 50.09580139953349, 'loss': 6.026843547821045}


EP_train:0:  64%|| 4001/6233 [20:59<11:51,  3.14it/s]

{'epoch': 0, 'iter': 4000, 'avg_loss': 6.133127266274366, 'avg_acc': 50.051549612596844, 'loss': 6.039668083190918}


EP_train:0:  80%|| 5001/6233 [26:15<06:31,  3.15it/s]

{'epoch': 0, 'iter': 5000, 'avg_loss': 6.056759302817781, 'avg_acc': 50.040616876624675, 'loss': 5.612262725830078}


EP_train:0:  96%|| 6001/6233 [31:30<01:12,  3.20it/s]

{'epoch': 0, 'iter': 6000, 'avg_loss': 6.009902024265131, 'avg_acc': 50.06040659890019, 'loss': 5.717437267303467}


EP_train:0: 100%|| 6233/6233 [32:44<00:00,  3.17it/s]


EP0, train:             avg_loss=6.002341919796015,             total_acc=50.05515056103162
Model saved at bert_pretrained.pt 


EP_train:1:   0%|| 1/6233 [00:00<33:49,  3.07it/s]

{'epoch': 1, 'iter': 0, 'avg_loss': 5.655872821807861, 'avg_acc': 46.875, 'loss': 5.655872821807861}


EP_train:1:  16%|| 1001/6233 [05:14<26:56,  3.24it/s]

{'epoch': 1, 'iter': 1000, 'avg_loss': 5.801678434594885, 'avg_acc': 49.85014985014985, 'loss': 5.480400562286377}


EP_train:1:  32%|| 2001/6233 [10:27<21:51,  3.23it/s]

{'epoch': 1, 'iter': 2000, 'avg_loss': 5.7837699823889475, 'avg_acc': 49.77355072463768, 'loss': 5.523154258728027}


EP_train:1:  48%|| 3001/6233 [15:41<17:14,  3.13it/s]

{'epoch': 1, 'iter': 3000, 'avg_loss': 5.751955590221096, 'avg_acc': 49.777157614128626, 'loss': 5.872507572174072}


EP_train:1:  64%|| 4001/6233 [20:54<11:56,  3.11it/s]

{'epoch': 1, 'iter': 4000, 'avg_loss': 5.724056673300204, 'avg_acc': 49.810984753811546, 'loss': 6.324580669403076}


EP_train:1:  80%|| 5001/6233 [26:08<06:36,  3.11it/s]

{'epoch': 1, 'iter': 5000, 'avg_loss': 5.7010146210465855, 'avg_acc': 49.84753049390122, 'loss': 5.41267204284668}


EP_train:1:  96%|| 6001/6233 [31:23<01:12,  3.21it/s]

{'epoch': 1, 'iter': 6000, 'avg_loss': 5.677674365587938, 'avg_acc': 49.938551908015334, 'loss': 5.438734531402588}


EP_train:1: 100%|| 6233/6233 [32:36<00:00,  3.19it/s]


EP1, train:             avg_loss=5.674316647078393,             total_acc=49.96841376959098
Model saved at bert_pretrained.pt 


EP_train:2:   0%|| 1/6233 [00:00<35:41,  2.91it/s]

{'epoch': 2, 'iter': 0, 'avg_loss': 5.531171798706055, 'avg_acc': 34.375, 'loss': 5.531171798706055}


EP_train:2:  16%|| 1001/6233 [05:13<27:59,  3.11it/s]

{'epoch': 2, 'iter': 1000, 'avg_loss': 5.516070365905762, 'avg_acc': 49.93131868131868, 'loss': 5.493192672729492}


EP_train:2:  32%|| 2001/6233 [10:27<22:45,  3.10it/s]

{'epoch': 2, 'iter': 2000, 'avg_loss': 5.516763420000129, 'avg_acc': 49.85163668165917, 'loss': 5.749901294708252}


EP_train:2:  48%|| 3001/6233 [15:40<17:12,  3.13it/s]

{'epoch': 2, 'iter': 3000, 'avg_loss': 5.505030534776677, 'avg_acc': 49.862545818060646, 'loss': 5.478667259216309}


EP_train:2:  64%|| 4001/6233 [20:53<11:39,  3.19it/s]

{'epoch': 2, 'iter': 4000, 'avg_loss': 5.50099536902426, 'avg_acc': 50.00390527368158, 'loss': 5.8321685791015625}


EP_train:2:  80%|| 5001/6233 [26:07<06:28,  3.17it/s]

{'epoch': 2, 'iter': 5000, 'avg_loss': 5.492196642096294, 'avg_acc': 50.02062087582484, 'loss': 5.324150562286377}


EP_train:2:  96%|| 6001/6233 [31:20<01:15,  3.07it/s]

{'epoch': 2, 'iter': 6000, 'avg_loss': 5.48365729253623, 'avg_acc': 50.01614314280953, 'loss': 5.9090142250061035}


EP_train:2: 100%|| 6233/6233 [32:33<00:00,  3.19it/s]


EP2, train:             avg_loss=5.480744227845883,             total_acc=50.036599917775526
Model saved at bert_pretrained.pt 


EP_train:3:   0%|| 1/6233 [00:00<34:54,  2.98it/s]

{'epoch': 3, 'iter': 0, 'avg_loss': 5.4354729652404785, 'avg_acc': 40.625, 'loss': 5.4354729652404785}


EP_train:3:  16%|| 1001/6233 [05:14<26:47,  3.25it/s]

{'epoch': 3, 'iter': 1000, 'avg_loss': 5.399745056083748, 'avg_acc': 49.71278721278721, 'loss': 5.334134578704834}


EP_train:3:  32%|| 2001/6233 [10:28<22:08,  3.18it/s]

{'epoch': 3, 'iter': 2000, 'avg_loss': 5.405172184310753, 'avg_acc': 49.959395302348824, 'loss': 5.112679958343506}


EP_train:3:  48%|| 3001/6233 [15:42<16:40,  3.23it/s]

{'epoch': 3, 'iter': 3000, 'avg_loss': 5.404750120715275, 'avg_acc': 50.09580139953349, 'loss': 5.276096820831299}


EP_train:3:  64%|| 4001/6233 [20:56<11:27,  3.25it/s]

{'epoch': 3, 'iter': 4000, 'avg_loss': 5.396503540373718, 'avg_acc': 50.10231817045738, 'loss': 5.703319549560547}


EP_train:3:  80%|| 5001/6233 [26:09<06:22,  3.22it/s]

{'epoch': 3, 'iter': 5000, 'avg_loss': 5.388595895036843, 'avg_acc': 50.10122975404919, 'loss': 5.7265753746032715}


EP_train:3:  96%|| 6001/6233 [31:23<01:12,  3.21it/s]

{'epoch': 3, 'iter': 6000, 'avg_loss': 5.384401605161423, 'avg_acc': 50.08071571404766, 'loss': 5.649881839752197}


EP_train:3: 100%|| 6233/6233 [32:36<00:00,  3.19it/s]


EP3, train:             avg_loss=5.383055643344526,             total_acc=50.069188885657844
Model saved at bert_pretrained.pt 


EP_train:4:   0%|| 1/6233 [00:00<35:42,  2.91it/s]

{'epoch': 4, 'iter': 0, 'avg_loss': 5.146773338317871, 'avg_acc': 43.75, 'loss': 5.146773338317871}


EP_train:4:  16%|| 1001/6233 [05:14<26:59,  3.23it/s]

{'epoch': 4, 'iter': 1000, 'avg_loss': 5.345130040571764, 'avg_acc': 50.36525974025974, 'loss': 5.692777633666992}


EP_train:4:  32%|| 2001/6233 [10:27<22:10,  3.18it/s]

{'epoch': 4, 'iter': 2000, 'avg_loss': 5.343270956903979, 'avg_acc': 50.1358695652174, 'loss': 5.236797332763672}


EP_train:4:  48%|| 3001/6233 [15:40<16:43,  3.22it/s]

{'epoch': 4, 'iter': 3000, 'avg_loss': 5.340750195351651, 'avg_acc': 50.12704098633789, 'loss': 5.123149394989014}


EP_train:4:  64%|| 4001/6233 [20:54<11:35,  3.21it/s]

{'epoch': 4, 'iter': 4000, 'avg_loss': 5.336971686858292, 'avg_acc': 50.047644338915276, 'loss': 4.996855735778809}


EP_train:4:  80%|| 5001/6233 [26:07<06:35,  3.12it/s]

{'epoch': 4, 'iter': 5000, 'avg_loss': 5.330730498969328, 'avg_acc': 49.93001399720056, 'loss': 5.160671234130859}


EP_train:4:  96%|| 6001/6233 [31:20<01:13,  3.18it/s]

{'epoch': 4, 'iter': 6000, 'avg_loss': 5.3261842770569325, 'avg_acc': 49.98073237793701, 'loss': 5.162665367126465}


EP_train:4: 100%|| 6233/6233 [32:32<00:00,  3.19it/s]


EP4, train:             avg_loss=5.326214358796283,             total_acc=49.96139460727787
Model saved at bert_pretrained.pt 


EP_train:5:   0%|| 1/6233 [00:00<38:06,  2.73it/s]

{'epoch': 5, 'iter': 0, 'avg_loss': 5.02604866027832, 'avg_acc': 53.125, 'loss': 5.02604866027832}


EP_train:5:  16%|| 1001/6233 [05:13<28:12,  3.09it/s]

{'epoch': 5, 'iter': 1000, 'avg_loss': 5.299585878789484, 'avg_acc': 49.294455544455545, 'loss': 4.825554847717285}


EP_train:5:  32%|| 2001/6233 [10:26<22:20,  3.16it/s]

{'epoch': 5, 'iter': 2000, 'avg_loss': 5.295010340565267, 'avg_acc': 49.86256871564218, 'loss': 5.204926490783691}


EP_train:5:  48%|| 3001/6233 [15:40<17:00,  3.17it/s]

{'epoch': 5, 'iter': 3000, 'avg_loss': 5.293472433995898, 'avg_acc': 49.97500833055648, 'loss': 5.508352756500244}


EP_train:5:  64%|| 4001/6233 [20:53<11:41,  3.18it/s]

{'epoch': 5, 'iter': 4000, 'avg_loss': 5.292860332652051, 'avg_acc': 49.99609472631842, 'loss': 5.13086462020874}


EP_train:5:  80%|| 5001/6233 [26:07<06:22,  3.22it/s]

{'epoch': 5, 'iter': 5000, 'avg_loss': 5.292273519993114, 'avg_acc': 50.07186062787442, 'loss': 4.764576435089111}


EP_train:5:  96%|| 6001/6233 [31:20<01:12,  3.22it/s]

{'epoch': 5, 'iter': 6000, 'avg_loss': 5.28947074745838, 'avg_acc': 50.041138976837196, 'loss': 5.074993133544922}


EP_train:5: 100%|| 6233/6233 [32:33<00:00,  3.19it/s]


EP5, train:             avg_loss=5.288258502365937,             total_acc=50.06216972334473
Model saved at bert_pretrained.pt 


EP_train:6:   0%|| 1/6233 [00:00<34:43,  2.99it/s]

{'epoch': 6, 'iter': 0, 'avg_loss': 4.716089248657227, 'avg_acc': 56.25, 'loss': 4.716089248657227}


EP_train:6:  16%|| 1001/6233 [05:12<27:28,  3.17it/s]

{'epoch': 6, 'iter': 1000, 'avg_loss': 5.249583516325746, 'avg_acc': 50.134240759240754, 'loss': 5.20048713684082}


EP_train:6:  32%|| 2001/6233 [10:24<22:06,  3.19it/s]

{'epoch': 6, 'iter': 2000, 'avg_loss': 5.236154454401408, 'avg_acc': 50.14836331834083, 'loss': 5.5730085372924805}


EP_train:6:  48%|| 3001/6233 [15:38<17:05,  3.15it/s]

{'epoch': 6, 'iter': 3000, 'avg_loss': 5.233455294968485, 'avg_acc': 50.15099133622126, 'loss': 5.04054594039917}


EP_train:6:  64%|| 4001/6233 [20:52<11:43,  3.17it/s]

{'epoch': 6, 'iter': 4000, 'avg_loss': 5.236323529855813, 'avg_acc': 50.02499375156211, 'loss': 5.0901970863342285}


EP_train:6:  80%|| 5001/6233 [26:05<06:24,  3.21it/s]

{'epoch': 6, 'iter': 5000, 'avg_loss': 5.235412043634593, 'avg_acc': 50.031243751249754, 'loss': 5.813170433044434}


EP_train:6:  96%|| 6001/6233 [31:19<01:13,  3.15it/s]

{'epoch': 6, 'iter': 6000, 'avg_loss': 5.233428472241766, 'avg_acc': 49.99531328111981, 'loss': 4.633340358734131}


EP_train:6: 100%|| 6233/6233 [32:32<00:00,  3.19it/s]


EP6, train:             avg_loss=5.2322719446751735,             total_acc=50.00852326852307
Model saved at bert_pretrained.pt 


EP_train:7:   0%|| 1/6233 [00:00<33:55,  3.06it/s]

{'epoch': 7, 'iter': 0, 'avg_loss': 4.847684383392334, 'avg_acc': 59.375, 'loss': 4.847684383392334}


EP_train:7:  16%|| 1001/6233 [05:14<27:22,  3.19it/s]

{'epoch': 7, 'iter': 1000, 'avg_loss': 5.224219123562137, 'avg_acc': 50.42145354645354, 'loss': 4.9453125}


EP_train:7:  32%|| 2001/6233 [10:27<22:02,  3.20it/s]

{'epoch': 7, 'iter': 2000, 'avg_loss': 5.219521251813821, 'avg_acc': 49.868815592203894, 'loss': 5.0566325187683105}


EP_train:7:  48%|| 3001/6233 [15:41<16:38,  3.24it/s]

{'epoch': 7, 'iter': 3000, 'avg_loss': 5.211161744233093, 'avg_acc': 49.998958680439856, 'loss': 5.340446472167969}


EP_train:7:  64%|| 4001/6233 [20:55<11:47,  3.15it/s]

{'epoch': 7, 'iter': 4000, 'avg_loss': 5.211153686836164, 'avg_acc': 49.983597850537365, 'loss': 5.262521743774414}


EP_train:7:  80%|| 5001/6233 [26:10<06:26,  3.19it/s]

{'epoch': 7, 'iter': 5000, 'avg_loss': 5.209232713622681, 'avg_acc': 50.03999200159968, 'loss': 4.9712233543396}


EP_train:7:  96%|| 6001/6233 [31:23<01:13,  3.14it/s]

{'epoch': 7, 'iter': 6000, 'avg_loss': 5.211998883812969, 'avg_acc': 49.962506248958505, 'loss': 4.872304439544678}


EP_train:7: 100%|| 6233/6233 [32:36<00:00,  3.19it/s]


EP7, train:             avg_loss=5.211958598539762,             total_acc=49.95136723254485
Model saved at bert_pretrained.pt 


EP_train:8:   0%|| 1/6233 [00:00<34:39,  3.00it/s]

{'epoch': 8, 'iter': 0, 'avg_loss': 5.424551486968994, 'avg_acc': 56.25, 'loss': 5.424551486968994}


EP_train:8:  16%|| 1001/6233 [05:14<27:00,  3.23it/s]

{'epoch': 8, 'iter': 1000, 'avg_loss': 5.203215847244034, 'avg_acc': 50.06868131868132, 'loss': 5.450161933898926}


EP_train:8:  32%|| 2001/6233 [10:27<22:09,  3.18it/s]

{'epoch': 8, 'iter': 2000, 'avg_loss': 5.196544789719856, 'avg_acc': 49.86100699650175, 'loss': 5.124629020690918}


EP_train:8:  48%|| 3001/6233 [15:40<16:36,  3.24it/s]

{'epoch': 8, 'iter': 3000, 'avg_loss': 5.200739224804755, 'avg_acc': 49.88962012662446, 'loss': 5.585949420928955}


EP_train:8:  64%|| 4001/6233 [20:53<11:28,  3.24it/s]

{'epoch': 8, 'iter': 4000, 'avg_loss': 5.198894407057339, 'avg_acc': 49.90939765058735, 'loss': 4.937994956970215}


EP_train:8:  80%|| 5001/6233 [26:07<06:20,  3.24it/s]

{'epoch': 8, 'iter': 5000, 'avg_loss': 5.194027931874715, 'avg_acc': 49.973755248950205, 'loss': 5.078270435333252}


EP_train:8:  96%|| 6001/6233 [31:20<01:12,  3.20it/s]

{'epoch': 8, 'iter': 6000, 'avg_loss': 5.195546485447641, 'avg_acc': 49.962506248958505, 'loss': 4.9364166259765625}


EP_train:8: 100%|| 6233/6233 [32:32<00:00,  3.19it/s]


EP8, train:             avg_loss=5.194762604014365,             total_acc=49.95938913233127
Model saved at bert_pretrained.pt 


EP_train:9:   0%|| 1/6233 [00:00<35:41,  2.91it/s]

{'epoch': 9, 'iter': 0, 'avg_loss': 5.148815155029297, 'avg_acc': 37.5, 'loss': 5.148815155029297}


EP_train:9:  16%|| 1001/6233 [05:12<26:58,  3.23it/s]

{'epoch': 9, 'iter': 1000, 'avg_loss': 5.196970063608724, 'avg_acc': 50.44018481518482, 'loss': 5.578878402709961}


EP_train:9:  32%|| 2001/6233 [10:24<22:05,  3.19it/s]

{'epoch': 9, 'iter': 2000, 'avg_loss': 5.18601803801049, 'avg_acc': 50.07496251874063, 'loss': 4.748006343841553}


EP_train:9:  48%|| 3001/6233 [15:37<16:53,  3.19it/s]

{'epoch': 9, 'iter': 3000, 'avg_loss': 5.188583624756205, 'avg_acc': 50.07080973008997, 'loss': 4.808182716369629}


EP_train:9:  64%|| 4001/6233 [20:51<11:40,  3.19it/s]

{'epoch': 9, 'iter': 4000, 'avg_loss': 5.185565255695925, 'avg_acc': 50.14058985253686, 'loss': 5.6133036613464355}


EP_train:9:  80%|| 5001/6233 [26:04<06:38,  3.09it/s]

{'epoch': 9, 'iter': 5000, 'avg_loss': 5.186321753688966, 'avg_acc': 50.10622875424915, 'loss': 5.160183429718018}


EP_train:9:  96%|| 6001/6233 [31:18<01:15,  3.09it/s]

{'epoch': 9, 'iter': 6000, 'avg_loss': 5.18693307097247, 'avg_acc': 50.16143142809531, 'loss': 4.8067450523376465}


EP_train:9: 100%|| 6233/6233 [32:31<00:00,  3.19it/s]


EP9, train:             avg_loss=5.186073271780321,             total_acc=50.17547905782787
Model saved at bert_pretrained.pt 
