# Pre-training End To End Example

In [None]:
import torch
from transformer_mlm import TransformerForMaskedLM
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
import torch
import numpy as np
import torch
import random
import math
import yaml
from tokenizer import WordPieceTokenizerWrapper
from pathlib import Path
from tqdm.auto import tqdm
import torch
from torch.nn.utils import clip_grad_norm_
from transformers import get_linear_schedule_with_warmup
import wandb
import numpy as np

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Train tokenizer and encode training data

In [5]:
input_train = 'text1.txt' 
input_encode = 'text2.txt' #each line is treated as one training example

tokenizer = WordPieceTokenizerWrapper(tokenizer_type="berat")
tokenizer.train(tokenizer_dir='my_tokenizer', input=input_train)

ds = tokenizer.encode(
    tokenizer_dir='my_tokenizer',
    input=input_encode, 
    max_length=24
)
train_loader = DataLoader(ds, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(ds, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)






## Load hyperparameters from `config.yaml` and initialize model

In [None]:
with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

architecture = config['architecture']
training_params = config['training']
LEARNING_RATE = training_params['learning_rate']
WEIGHT_DECAY = training_params['weight_decay']
WARMUP_RATIO = training_params['warmup_ratio']
MAX_GRAD_NORM = training_params['max_grad_norm']
GRAD_ACCUM_STEPS = training_params['grad_accum_steps']
USE_AMP = training_params['use_amp']

NUM_EPOCHS = 3

model = TransformerForMaskedLM(**architecture)
model.to(device)

TransformerForMaskedLM(
  (embeddings): TransformerTextEmbeddings(
    (word_embeddings): Embedding(30000, 384, padding_idx=0)
    (position): LearnedPositionalEmbedding(
      (position_embeddings): Embedding(256, 384)
    )
    (layer_norm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (layers): ModuleList(
    (0-7): 8 x TransformerEncoderBlock(
      (msa_block): MultiheadSelfAttentionBlock(
        (multihead_attn): MultiheadSelfAttention(
          (attn_drop): Dropout(p=0.0, inplace=False)
          (out_drop): Dropout(p=0.1, inplace=False)
          (Uqkv): Linear(in_features=384, out_features=1152, bias=True)
          (Uout): Linear(in_features=384, out_features=384, bias=True)
        )
        (layer_norm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
      )
      (mlp_block): MLPBlock(
        (mlp): Sequential(
          (0): Linear(in_features=384, out_features=1024, bias=True)
          (1): GELU(appr

## Initialize optimizer loss_fn, scheduler and scaler

In [None]:
no_decay = ["bias", "LayerNorm.weight"]
param_groups = [
    {
        "params": [p for n, p in model.named_parameters()
                   if not any(nd in n for nd in no_decay)],
        "weight_decay": WEIGHT_DECAY,
    },
    {
        "params": [p for n, p in model.named_parameters()
                   if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = torch.optim.AdamW(param_groups, lr=LEARNING_RATE)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)

num_update_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS)
num_training_steps = NUM_EPOCHS * num_update_steps_per_epoch
num_warmup_steps = int(WARMUP_RATIO * num_training_steps)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

scaler = torch.amp.GradScaler(device=device, enabled=USE_AMP)


## Training loop and wand logging

In [None]:
best_val_loss = float("inf")
wandb.init(
    project="bert-training",
    config={
        "epochs": NUM_EPOCHS,
        "lr": LEARNING_RATE,
        "weight_decay": WEIGHT_DECAY,
        "warmup_ratio": WARMUP_RATIO,
        "max_grad_norm": MAX_GRAD_NORM,
        "grad_accum_steps": GRAD_ACCUM_STEPS,
        "amp": USE_AMP,
        "optimizer": "AdamW",
        "scheduler": "linear_with_warmup",
        "model": type(model).__name__
    }
)
wandb.watch(model, log="all", log_freq=50)

def train_step():
    model.train()
    train_loss = 0.0
    pbar = tqdm(enumerate(train_loader, start=1), total=len(
        train_loader), desc=f"Epoch {epoch}/{NUM_EPOCHS} [train]")

    for step, (input_ids, attention_mask) in pbar:
        input_ids = input_ids.to(device, non_blocking=True)
        attention_mask = attention_mask.to(device, non_blocking=True)

        input_ids_masked, labels = tokenizer.mask_input_for_mlm(input_ids=input_ids)

        with torch.amp.autocast(device_type=device, enabled=USE_AMP):
            out = model(input_ids=input_ids_masked, attention_mask=attention_mask)
            logits = out['logits']
            loss = loss_fn(logits, labels)

        train_loss += loss.item()

        optimizer.zero_grad(set_to_none=True)
        scaler.scale(loss / GRAD_ACCUM_STEPS).backward() # multiplies a given loss by scaler's current scale factor <-> scale gradient

        grad_norm = None
        if step % GRAD_ACCUM_STEPS == 0:
            scaler.unscale_(optimizer) # uncscale gradient
            grad_norm = clip_grad_norm_( #clip gradient
                model.parameters(), MAX_GRAD_NORM).item()
            scaler.step(optimizer) #calls optimizer.step()
            scaler.update()  # updates scaler's scale factor

            scheduler.step() #update lr

        if grad_norm is not None:
            wandb.log({
                "train/step_loss": loss.item(),
                "train/lr": scheduler.get_last_lr()[0],
                "train/grad_norm": grad_norm,
                "train/epoch": epoch
            })

        pbar.set_postfix({
            "loss": f"{train_loss/step:.4f}",
            "lr":   f"{scheduler.get_last_lr()[0]:.2e}"
        })
    epoch_train_loss = train_loss / len(train_loader)
    return epoch_train_loss

def val_step():
    model.eval()
    val_loss = 0.0
    with torch.inference_mode():
        for input_ids, attention_mask in tqdm(val_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS} [val]"):
            input_ids = input_ids.to(device, non_blocking=True)
            attention_mask = attention_mask.to(device, non_blocking=True)

            input_ids_masked, labels = tokenizer.mask_input_for_mlm(input_ids=input_ids)

            out = model(input_ids=input_ids_masked, attention_mask=attention_mask)
            logits = out['logits']
            loss = loss_fn(logits, labels)

            val_loss += loss.item()

    epoch_val_loss = val_loss / len(val_loader)
    return epoch_val_loss


for epoch in range(1, NUM_EPOCHS + 1):
    
    epoch_train_loss = train_step()
    epoch_val_loss = val_step()
     
    wandb.log({
        "epoch": epoch,
        "train/loss": epoch_train_loss,
        "val/loss": epoch_val_loss,
    })

    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        torch.save({
            "model_state": model.state_dict(),
            "epoch": epoch,
            "val_loss": epoch_val_loss
        }, Path("best_model.pt"))
        wandb.run.summary["best_val_loss"] = best_val_loss
        wandb.alert(
            title="New best model",
            text=f"Epoch {epoch} | val_loss={epoch_val_loss:.4f}"
        )
    # torch.save({
    #         "model_state": model.state_dict(),
    #         "epoch": epoch,
    #         "val_loss": epoch_val_loss
    #     }, Path(f"epoch-{epoch}-model.pt"))

wandb.finish()