# Pre-training End To End Example

In [1]:
import torch
from transformer_mlm import TransformerForMaskedLM
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
import torch
import numpy as np
import torch
import random
import math
import yaml
from tokenizer import WordPieceTokenizerWrapper
from pathlib import Path
from tqdm.auto import tqdm
import torch
from torch.nn.utils import clip_grad_norm_
from transformers import get_linear_schedule_with_warmup
import wandb
import numpy as np

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"  

## Train tokenizer and encode training data

In [2]:
input_train = 'text1.txt' 
input_encode = 'text1.txt' #each line is treated as one training example

tokenizer = WordPieceTokenizerWrapper()
tokenizer.train(tokenizer_dir='my_tokenizer', input=input_train)

ds = tokenizer.encode(
    tokenizer_dir='my_tokenizer',
    input=input_encode, 
    max_length=24 #max_sequence_length
)
train_loader = DataLoader(ds, batch_size=12, shuffle=True,  pin_memory=True)
val_loader = DataLoader(ds, batch_size=12, shuffle=False, pin_memory=True)






In [3]:
print(f'vocab size: {len(tokenizer.tokenizer)}')

vocab size: 229


## Load hyperparameters from `config.yaml` and initialize model

In [4]:
with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

architecture = config['architecture']
training_params = config['training']
LEARNING_RATE = float(training_params['learning_rate'])
WEIGHT_DECAY = training_params['weight_decay']
WARMUP_RATIO = training_params['warmup_ratio']
MAX_GRAD_NORM = training_params['max_grad_norm']
GRAD_ACCUM_STEPS = training_params['grad_accum_steps']
USE_AMP = training_params['use_amp']

NUM_EPOCHS = 50

model = TransformerForMaskedLM(**architecture)
model.to(device)

TransformerForMaskedLM(
  (embeddings): TransformerTextEmbeddings(
    (word_embeddings): Embedding(229, 32, padding_idx=0)
    (position): LearnedPositionalEmbedding(
      (position_embeddings): Embedding(24, 32)
    )
    (layer_norm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (layers): ModuleList(
    (0-3): 4 x TransformerEncoderBlock(
      (msa_block): MultiheadSelfAttentionBlock(
        (multihead_attn): MultiheadSelfAttention(
          (attn_drop): Dropout(p=0.0, inplace=False)
          (out_drop): Dropout(p=0.1, inplace=False)
          (Uqkv): Linear(in_features=32, out_features=96, bias=True)
          (Uout): Linear(in_features=32, out_features=32, bias=True)
        )
        (layer_norm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
      )
      (mlp_block): MLPBlock(
        (mlp): Sequential(
          (0): Linear(in_features=32, out_features=128, bias=True)
          (1): GELU(approximate='none'

## Summary check

In [5]:
from torchinfo import summary
import torch

model.eval()

B, N = 1, 24
vocab_size, pad_id = 229, 0
device = "cpu"  

input_ids = torch.randint(1, vocab_size, (B, N), dtype=torch.long)
input_ids[:, -4:] = pad_id  
inputs = {
    "input_ids": input_ids,
    "attention_mask": (input_ids == pad_id),            
}

info = summary(
    model,
    input_data=inputs,     
    device=device,
    depth=3,
    col_names=("input_size", "output_size", "num_params", "mult_adds"),
    verbose=1,
    return_sequence=False,  
)


Layer (type:depth-idx)                             Input Shape               Output Shape              Param #                   Mult-Adds
TransformerForMaskedLM                             --                        [1, 24, 229]              --                        --
├─TransformerTextEmbeddings: 1-1                   [1, 24]                   [1, 24, 32]               --                        --
│    └─Embedding: 2-1                              [1, 24]                   [1, 24, 32]               7,328                     7,328
│    └─LearnedPositionalEmbedding: 2-2             [1, 24]                   [1, 24, 32]               --                        --
│    │    └─Embedding: 3-1                         [1, 24]                   [1, 24, 32]               768                       768
│    └─LayerNorm: 2-3                              [1, 24, 32]               [1, 24, 32]               64                        64
│    └─Dropout: 2-4                                [1, 24, 32]   

## Initialize optimizer loss_fn, scheduler and scaler

In [6]:
no_decay = ["bias", "LayerNorm.weight"]
param_groups = [
    {
        "params": [p for n, p in model.named_parameters()
                   if not any(nd in n for nd in no_decay)],
        "weight_decay": WEIGHT_DECAY,
    },
    {
        "params": [p for n, p in model.named_parameters()
                   if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = torch.optim.AdamW(param_groups, lr=LEARNING_RATE)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)

num_update_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS)
num_training_steps = NUM_EPOCHS * num_update_steps_per_epoch
num_warmup_steps = int(WARMUP_RATIO * num_training_steps)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

scaler = torch.amp.GradScaler(device=device, enabled=USE_AMP)


## Training loop and wand logging

Login into wandb:
```bash
wandb login 
#paste api key
```

In [7]:
wandb.init(
    project="bert-training-2",
    config={
        "epochs": NUM_EPOCHS,
        "lr": LEARNING_RATE,
        "weight_decay": WEIGHT_DECAY,
        "warmup_ratio": WARMUP_RATIO,
        "max_grad_norm": MAX_GRAD_NORM,
        "grad_accum_steps": GRAD_ACCUM_STEPS,
        "amp": USE_AMP,
        "optimizer": "AdamW",
        "scheduler": "linear_with_warmup",
        "model": type(model).__name__
    }
)
wandb.watch(model, log="all", log_freq=50)

[34m[1mwandb[0m: Currently logged in as: [33miwaniuk-michal03[0m ([33miwaniuk-michal03-politechnika-warszawska[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
best_val_loss = float("inf")


def train_step():
    model.train()
    train_loss = 0.0
    pbar = tqdm(enumerate(train_loader, start=1), total=len(
        train_loader), desc=f"Epoch {epoch}/{NUM_EPOCHS} [train]")

    for step, (input_ids, attention_mask) in pbar:
        input_ids = input_ids.to(device, non_blocking=True)
        attention_mask = attention_mask.to(device, non_blocking=True)

        input_ids_masked, labels = tokenizer.mask_input_for_mlm(input_ids=input_ids)

        with torch.amp.autocast(device_type=device, enabled=USE_AMP):
            out = model(input_ids=input_ids_masked, attention_mask=attention_mask)
            logits = out['logits'].transpose(1, 2)
            loss = loss_fn(logits, labels)

        train_loss += loss.item()

        optimizer.zero_grad(set_to_none=True)
        scaler.scale(loss / GRAD_ACCUM_STEPS).backward() # multiplies a given loss by scaler's current scale factor <-> scale gradient

        grad_norm = None
        if step % GRAD_ACCUM_STEPS == 0:
            scaler.unscale_(optimizer) # uncscale gradient
            grad_norm = clip_grad_norm_( #clip gradient
                model.parameters(), MAX_GRAD_NORM).item()
            scaler.step(optimizer) #calls optimizer.step()
            scaler.update()  # updates scaler's scale factor

            scheduler.step() #update lr

        if grad_norm is not None:
            wandb.log({
                "train/step_loss": loss.item(),
                "train/lr": scheduler.get_last_lr()[0],
                "train/grad_norm": grad_norm,
                "train/epoch": epoch
            })

        pbar.set_postfix({
            "loss": f"{train_loss/step:.4f}",
            "lr":   f"{scheduler.get_last_lr()[0]:.2e}"
        })
    epoch_train_loss = train_loss / len(train_loader)
    return epoch_train_loss

def val_step():
    model.eval()
    val_loss = 0.0
    with torch.inference_mode():
        for input_ids, attention_mask in tqdm(val_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS} [val]"):
            input_ids = input_ids.to(device, non_blocking=True)
            attention_mask = attention_mask.to(device, non_blocking=True)

            input_ids_masked, labels = tokenizer.mask_input_for_mlm(input_ids=input_ids)

            out = model(input_ids=input_ids_masked, attention_mask=attention_mask)
            logits = out['logits'].transpose(1, 2)
            loss = loss_fn(logits, labels)

            val_loss += loss.item()

    epoch_val_loss = val_loss / len(val_loader)
    return epoch_val_loss


for epoch in range(1, NUM_EPOCHS + 1):
    
    epoch_train_loss = train_step()
    epoch_val_loss = val_step()
     
    wandb.log({
        "epoch": epoch,
        "train/epoch_loss": epoch_train_loss,
        "val/epoch_loss": epoch_val_loss,
    })

    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        torch.save({
            "model_state": model.state_dict(),
            "epoch": epoch,
            "val_loss": epoch_val_loss
        }, Path("best_model.pt"))
        wandb.run.summary["best_val_loss"] = best_val_loss
        wandb.alert(
            title="New best model",
            text=f"Epoch {epoch} | val_loss={epoch_val_loss:.4f}"
        )
    # torch.save({
    #         "model_state": model.state_dict(),
    #         "epoch": epoch,
    #         "val_loss": epoch_val_loss
    #     }, Path(f"epoch-{epoch}-model.pt"))

wandb.finish()



Epoch 1/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 1/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 2/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 2/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 3/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 3/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 4/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 4/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 5/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 5/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 6/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 6/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 7/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 7/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 8/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 8/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 9/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 9/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 10/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 10/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 11/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 11/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 12/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 12/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 13/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 13/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 14/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 14/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 15/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 15/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 16/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 16/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 17/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 17/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 18/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 18/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 19/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 19/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 20/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 20/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 21/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 21/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 22/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 22/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 23/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 23/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 24/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 24/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 25/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 25/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 26/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 26/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 27/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 27/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 28/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 28/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 29/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 29/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 30/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 30/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 31/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 31/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 32/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 32/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 33/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 33/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 34/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 34/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 35/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 35/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 36/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 36/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 37/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 37/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 38/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 38/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 39/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 39/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 40/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 40/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 41/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 41/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 42/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 42/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 43/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 43/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 44/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 44/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 45/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 45/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 46/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 46/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 47/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 47/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 48/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 48/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 49/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 49/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 50/50 [train]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 50/50 [val]:   0%|          | 0/3 [00:00<?, ?it/s]

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
train/epoch,▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇███
train/epoch_loss,▆▇▅▆▄▂▃▂▄▅▅▅▅▅▁▆▄▅▅▅▆▅▂▄▃▅▄▃▅▄▃█▄▅▆▃▃▄▄▄
train/grad_norm,▃▃▅▄▃▄▅▄▂▇▂▄▄▇▄▅▅▁▄▄▂▄▄▅▆▆▅▄▁█▃▁▅▂▆█▃▄▃▃
train/lr,▁▃▄▅▅███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▄▄▄▄▃▃▃▂▂▂▂▂▂▂▁▁▁
train/step_loss,▇▂▆▇▄▅▂▁▃▁▇▇▇▄▅▆▂▄▆▄▄▄▇▆▆▂▅▅▂▆▆▄█▆▂▄▄▅▃▁
val/epoch_loss,▇▆▆▅▅█▅▅▆▅▇▂▄▅▄▅▅▅▄▅▄▄▄▄▃▁▄▄▅▃▃▄▅▆▅▃▂▄▄▅

0,1
best_val_loss,5.43912
epoch,50.0
train/epoch,50.0
train/epoch_loss,5.53488
train/grad_norm,1.68815
train/lr,0.0
train/step_loss,5.54601
val/epoch_loss,5.57803
