In [1]:
!git clone https://github.com/H-N-Chavda/BERT.git

Cloning into 'BERT'...
remote: Enumerating objects: 44, done.[K
remote: Counting objects: 100% (44/44), done.[K
remote: Compressing objects: 100% (34/34), done.[K
remote: Total 44 (delta 7), reused 43 (delta 6), pack-reused 0 (from 0)[K
Receiving objects: 100% (44/44), 4.64 MiB | 10.54 MiB/s, done.
Resolving deltas: 100% (7/7), done.


In [3]:
import os
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from datasets import load_from_disk

from BERT.config.bert_config import BertConfig
from BERT.src.bert import BertForPreTraining
from BERT.src.tokenizer import Tokenizer
from dataset import WikiTextBERTDataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

ModuleNotFoundError: No module named 'src'

In [None]:
dataset = load_from_disk("Wikitext2")
print(dataset)

In [None]:
texts = [x['text'] for x in dataset['train'] if x['text'].strip() != ""]
corpus = " ".join(texts)
tokenizer.train(corpus)

In [None]:
train_ds = WikiTextBERTDataset("train", tokenizer, max_len=64, mlm_prob=0.15)
val_ds = WikiTextBERTDataset("validation", tokenizer, max_len=64, mlm_prob=0.15)

train_dl = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2)
val_dl = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=2)

print(f"Train size: {len(train_ds)} | Val size: {len(val_ds)}")

In [None]:
config = BertConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=128,
    num_hidden_layers=4,
    num_attention_heads=4,
    intermediate_size=512,
    max_position_embeddings=64,
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
)

model = BertForPreTraining(config).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.2, total_iters=5000)

print("Model parameters:", sum(p.numel() for p in model.parameters()) / 1e6, "M")


In [None]:
def train_epoch(model, dataloader, optimizer, scheduler):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc="Training"):
        batch = {k: v.to(device) for k, v in batch.items()}
        out = model(**batch)
        loss = out["loss"]

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)


def evaluate(model, dataloader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            batch = {k: v.to(device) for k, v in batch.items()}
            out = model(**batch)
            total_loss += out["loss"].item()
    return total_loss / len(dataloader)


In [None]:
n_epochs = 5
train_losses, val_losses = [], []

for epoch in range(n_epochs):
    print(f"\nEpoch {epoch+1}/{n_epochs}")
    train_loss = train_epoch(model, train_dl, optimizer, scheduler)
    val_loss = evaluate(model, val_dl)

    train_losses.append(train_loss)
    val_losses.append(val_loss)

    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

    torch.save(model.state_dict(), f"checkpoints/bert_epoch{epoch+1}.pt")

print("Training complete ✅")

In [None]:
plt.figure(figsize=(6,4))
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("BERT Pretraining Losses")
plt.show()

In [None]:
model.eval()
sample = next(iter(val_dl))
for k in sample:
    sample[k] = sample[k].to(device)

with torch.no_grad():
    out = model(**sample)

masked_positions = (sample["mlm_labels"] != -100).nonzero(as_tuple=True)
for i, j in zip(*masked_positions):
    pred_id = out["mlm_logits"][i, j].argmax(-1).item()
    true_id = sample["mlm_labels"][i, j].item()
    print(f"True: {tokenizer.decode([true_id])} | Pred: {tokenizer.decode([pred_id])}")

In [None]:
os.makedirs("checkpoints", exist_ok=True)
torch.save(model.state_dict(), "checkpoints/bert_final.pt")
print("✅ Model saved to checkpoints/bert_final.pt")