In [6]:
def clean_wikitext(line):
  line = line.strip()

  if line.startswith("=") and line.endswith("="):
    return None

  line = line.replace("@-@","-")

  line = " ".join(line.split())

  return line if len(line) > 10 else None

In [7]:
with open("/content/drive/MyDrive/BERT/train.txt") as f:
  raw_lines = f.readlines()

texts = []
for line in raw_lines:
  clean = clean_wikitext(line)
  if clean:
    texts.append(clean)

In [8]:
for i in range(5):
  print(texts[i])

The 2013 – 14 season was the <unk> season of competitive association football and 77th season in the Football League played by York City Football Club , a professional football club based in York , North Yorkshire , England . Their 17th - place finish in 2012 – 13 meant it was their second consecutive season in League Two . The season ran from 1 July 2013 to 30 June 2014 .
Nigel Worthington , starting his first full season as York manager , made eight permanent summer signings . By the turn of the year York were only above the relegation zone on goal difference , before a 17 - match unbeaten run saw the team finish in seventh - place in the 24 - team 2013 – 14 Football League Two . This meant York qualified for the play - offs , and they were eliminated in the semi - final by Fleetwood Town . York were knocked out of the 2013 – 14 FA Cup , Football League Cup and Football League Trophy in their opening round matches .
35 players made at least one appearance in nationally organised firs

#### Imports & Config

In [9]:
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [10]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [11]:
VOCAB_SIZE = 30000
SEQ_LEN = 64
BATCH_SIZE = 16
EPOCHS = 1
LR = 3e-4

### Simple Tokenizer

In [12]:
class SimpleTokenizer:
  def __init__(self, texts, vocab_size):
    words = set()
    for text in texts:
      words.update(text.split())

    words = list(words)[: vocab_size - 5]

    self.word2id = {
        "[PAD]":0,
        "[CLS]":1,
        "[SEP]":2,
        "[MASK]":3,
        "<unk>":4
    }

    for i, w in enumerate(words):
      self.word2id[w] = i + 4

    self.id2word = {v: k for k, v in self.word2id.items()}

  def encode(self, text):
    return [self.word2id.get(w, self.word2id["<unk>"]) for w in text.split()]

### MLM Dataset Generator

In [13]:
class MLMDataset(Dataset):
  def __init__(self, texts, tokenizer):
    self.texts = texts
    self.tokenizer = tokenizer

  def mask_tokens(self, tokens):
    labels = [-100]* len(tokens)

    for i in range(len(tokens)):
      if random.random()< 0.15:
        labels[i] = tokens[i]

        prob = random.random()
        if prob < 0.8:
          tokens[i] = self.tokenizer.word2id["[MASK]"]
        elif prob < 0.9:
          tokens[i] = random.randint(0, VOCAB_SIZE - 1)
    return tokens, labels

  def __getitem__(self, idx):
    tokens = self.tokenizer.encode(self.texts[idx])
    tokens = tokens[:SEQ_LEN - 2]

    tokens = (
        [self.tokenizer.word2id["[CLS]"]]
        + tokens
        + [self.tokenizer.word2id["[SEP]"]]
    )

    tokens, labels = self.mask_tokens(tokens)

    padding = SEQ_LEN - len(tokens)
    tokens += [0] * padding
    labels += [-100] * padding # Pad labels with ignore_index

    return (
        torch.tensor(tokens),
        torch.tensor(labels)
    )

  def __len__(self):
    return len(self.texts)

In [14]:
tokenizer = SimpleTokenizer(texts, VOCAB_SIZE)
mlm_dataset = MLMDataset(texts, tokenizer)

# Let's take the first item from the dataset as an example
index_to_sample = 0
sampled_tokens, sampled_labels = mlm_dataset[index_to_sample]

# Decode tokens and labels for better understanding
original_text = texts[index_to_sample]

# To see the original tokens before masking, we'd need to re-encode the text
# For this example, let's focus on what the dataset outputs

print(f"Original text: {original_text}\n")

print("Masked tokens (numerical IDs):\n", sampled_tokens.tolist())
print("Labels (numerical IDs, -100 for unmasked):\n", sampled_labels.tolist())

print("\n--- Decoded Example ---")
decoded_masked_tokens = [tokenizer.id2word.get(token_id.item(), '<unk>') for token_id in sampled_tokens]
decoded_labels = []
for i, label_id in enumerate(sampled_labels):
    if label_id.item() != -100:
        decoded_labels.append(f"({decoded_masked_tokens[i]} -> {tokenizer.id2word.get(label_id.item(), '<unk>')})")
    else:
        decoded_labels.append(tokenizer.id2word.get(sampled_tokens[i].item(), '<unk>')) # Show the unmasked token

print(f"Decoded Masked Sequence: {' '.join(decoded_masked_tokens)}")
print(f"Decoded Labels (masked words with original, unmasked words shown as is): {' '.join(decoded_labels)}")

Original text: The 2013 – 14 season was the <unk> season of competitive association football and 77th season in the Football League played by York City Football Club , a professional football club based in York , North Yorkshire , England . Their 17th - place finish in 2012 – 13 meant it was their second consecutive season in League Two . The season ran from 1 July 2013 to 30 June 2014 .

Masked tokens (numerical IDs):
 [1, 588, 20195, 6056, 7184, 10457, 20271, 5033, 7184, 10457, 14175, 7184, 8844, 21582, 1150, 22942, 10457, 21056, 5033, 3, 20615, 10515, 7184, 17755, 12001, 2693, 25678, 13613, 361, 6759, 21582, 7805, 3901, 21056, 17755, 13613, 6658, 4885, 3, 24974, 29551, 3, 6212, 15496, 1422, 7131, 3, 29169, 6056, 3, 19999, 3, 20271, 9098, 28082, 21083, 3, 21056, 20615, 19635, 3, 588, 10457, 2]
Labels (numerical IDs, -100 for unmasked):
 [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 14140, -100, -100, -100, -100, -100, -100, 2693, -100, -100, -100, -100, -10

### Attention & Transformer

In [15]:
def attention(Q, K, V):
  scores = torch.matmul(Q, K.transpose(-2,-1))
  scores /= math.sqrt(Q.size(-1))
  attn = F.softmax(scores, dim=-1)
  return torch.matmul(attn, V)

### Multi-Head Self Attention

In [16]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, heads):
    super().__init__()
    self.heads = heads
    self.d_k = d_model // heads

    self.q = nn.Linear(d_model, d_model)
    self.k = nn.Linear(d_model, d_model)
    self.v = nn.Linear(d_model, d_model)
    self.out = nn.Linear(d_model, d_model)

  def forward(self, x):
    B, T, C  = x.shape

    Q = self.q(x).view(B, T, self.heads, self.d_k).transpose(1,2)
    K = self.q(x).view(B, T, self.heads, self.d_k).transpose(1,2)
    V = self.q(x).view(B, T, self.heads, self.d_k).transpose(1,2)

    out = attention(Q, K, V)
    out = out.transpose(1, 2).contiguous().view(B, T, C)
    return self.out(out)


### Transformer Encoder Block

In [17]:
class TransformerBlock(nn.Module):
  def __init__(self, d_model, heads):
    super().__init__()
    self.attn = MultiHeadAttention(d_model, heads)
    self.ff = nn.Sequential(
        nn.Linear(d_model, 4 * d_model),
        nn.GELU(),
        nn.Linear(4 * d_model, d_model)
    )

    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)

  def forward(self, x):
    x = self.norm1(x + self.attn(x))
    x = self.norm2(x + self.ff(x))
    return x

### BERT Model

In [18]:
class MiniBERT(nn.Module):
  def __init__(self, vocab_size, d_model = 256, layers = 4, heads = 4):
    super().__init__()

    self.token_emb = nn.Embedding(vocab_size, d_model)
    self.pos_emb = nn.Embedding(SEQ_LEN, d_model)

    self.blocks = nn.ModuleList([
        TransformerBlock(d_model, heads) for _ in range(layers)
    ])

    self.mlm_head = nn.Linear(d_model, vocab_size)

  def forward(self, x):
    B, T = x.shape
    pos = torch.arange(T, device=x.device)

    x = self.token_emb(x) + self.pos_emb(pos)

    for block in self.blocks:
      x = block(x)

    return self.mlm_head(x)

### Training Loop

In [19]:
from tqdm.auto import tqdm

def train(model, loader):
  model.train()
  opt = torch.optim.AdamW(model.parameters(), lr = LR)

  for epoch in range(EPOCHS):
    total_loss = 0
    # Wrap the loader with tqdm to display a progress bar
    progress_bar = tqdm(loader, desc=f"Epoch {epoch + 1}/{EPOCHS}")

    for tokens, labels in progress_bar:
      tokens = tokens.to(DEVICE)
      labels = labels.to(DEVICE)

      logits = model(tokens)
      loss = F.cross_entropy(
          logits.view(-1, VOCAB_SIZE),
          labels.view(-1),
          ignore_index = -100
      )

      opt.zero_grad()
      loss.backward()
      opt.step()

      total_loss += loss.item()
      progress_bar.set_postfix({'loss': loss.item()})

    print(f"Epoch {epoch + 1} | Loss : {total_loss/len(loader):.4f}")

In [20]:
if __name__ == "__main__":
  tokenizer = SimpleTokenizer(texts, VOCAB_SIZE)
  dataset = MLMDataset(texts, tokenizer)
  loader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle=True)

  model = MiniBERT(VOCAB_SIZE).to(DEVICE)
  train(model, loader)

Epoch 1/1:   0%|          | 0/1085 [00:00<?, ?it/s]

Epoch 1 | Loss : 6.5982


### Save Model


In [21]:
model_path = '/content/drive/MyDrive/BERT/mini_bert_model.pth'
torch.save(model.state_dict(), model_path)
print(f"Model saved to {model_path}")

Model saved to mini_bert_model.pth


In [22]:
test_sentences = [
    "The quick brown fox jumps over the lazy dog .",
    "Artificial intelligence is transforming many industries around the world .",
    "Natural language processing is a subfield of artificial intelligence .",
    "The sun shines brightly in the summer sky .",
    "Learning deep neural networks can be challenging but rewarding ."
]

### Testing Model


In [23]:
def predict_masked_words(text, model, tokenizer, device, seq_len, vocab_size):
    # a. Encode the input text and truncate/pad to SEQ_LEN - 2
    tokens = tokenizer.encode(text)
    tokens = tokens[:seq_len - 2]

    # b. Add [CLS] and [SEP] tokens
    cls_id = tokenizer.word2id["[CLS]"]
    sep_id = tokenizer.word2id["[SEP]"]
    mask_id = tokenizer.word2id["[MASK]"]
    unk_id = tokenizer.word2id["<unk>"]

    full_tokens = [
        cls_id
    ] + tokens + [
        sep_id
    ]

    # c. Create a masked version and keep track of original tokens for masked positions
    masked_input_tokens = list(full_tokens)
    labels = [-100] * len(full_tokens) # -100 for ignore_index

    # Only mask non-special tokens and within the SEQ_LEN limit
    maskable_indices = [i for i, token_id in enumerate(full_tokens) if token_id not in [cls_id, sep_id, unk_id]]
    random.shuffle(maskable_indices)

    num_masks = int(len(maskable_indices) * 0.15)
    masked_indices = maskable_indices[:num_masks]

    for i in masked_indices:
        original_token_id = full_tokens[i]
        labels[i] = original_token_id # Store original ID as label

        prob = random.random()
        if prob < 0.8:
            masked_input_tokens[i] = mask_id  # 80% change to [MASK]
        elif prob < 0.9:
            masked_input_tokens[i] = random.randint(0, vocab_size - 1)  # 10% change to random word
        # 10% change to keep as is (original_token_id remains in masked_input_tokens[i])

    # d. Pad to SEQ_LEN
    padding = seq_len - len(masked_input_tokens)
    if padding > 0:
        masked_input_tokens += [tokenizer.word2id["[PAD]"]] * padding
        labels += [-100] * padding
    else:
        masked_input_tokens = masked_input_tokens[:seq_len]
        labels = labels[:seq_len]

    # e. Convert to torch.tensor and move to device
    input_tensor = torch.tensor([masked_input_tokens], device=device)

    # f. Pass through the model
    with torch.no_grad():
        logits = model(input_tensor)

    # g. Get predicted token IDs
    predicted_token_ids = torch.argmax(logits, dim=-1).squeeze().cpu().tolist()

    # h. Decode and print results
    decoded_masked_input = []
    predictions_str = []
    for i, token_id in enumerate(input_tensor.squeeze().cpu().tolist()):
        word = tokenizer.id2word.get(token_id, '<unk>')
        decoded_masked_input.append(word)
        if labels[i] != -100: # If it was a masked position
            predicted_word = tokenizer.id2word.get(predicted_token_ids[i], '<unk>')
            predictions_str.append(f"({word} -> {predicted_word})")

    print(f"Original Sentence: {text}")
    print(f"Masked Input: {' '.join(decoded_masked_input)}")
    print(f"Predictions for Masked Words: {' '.join(predictions_str)}")
    print("\n" + "-"*50 + "\n")

In [24]:
model = MiniBERT(VOCAB_SIZE).to(DEVICE)
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.eval()

print("Generating predictions for test sentences...")
for sentence in test_sentences:
    predict_masked_words(sentence, model, tokenizer, DEVICE, SEQ_LEN, VOCAB_SIZE)


Generating predictions for test sentences...
Original Sentence: The quick brown fox jumps over the lazy dog .
Masked Input: [CLS] The quick [MASK] fox jumps over the <unk> <unk> . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
Predictions for Masked Words: ([MASK] -> <unk>)

--------------------------------------------------

Original Sentence: Artificial intelligence is transforming many industries around the world .
Masked Input: [CLS] Artificial intelligence [MASK] transforming many industries around the world . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD