# Training of character level language model for names

In [None]:
import os, time, math, pickle
import numpy as np, pandas as pd
from contextlib import nullcontext
import torch

In [None]:
DATA_DIR = "data/"

In [None]:
with open(DATA_DIR + "meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos, vocab_size = meta["stoi"], meta["itos"], meta["vocab_size"]
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])
print("vocab size:", vocab_size)

In [None]:
TRAINING_CONFIG_DIR = "config/"
training_config_file = "charlm_pre-training.py"
#training_config_file = "charlm_endswith_a.py"
#training_config_file = "charlm_gender_classification_head.py"
#training_config_file = "charlm_indian_classification_head.py"
#training_config_file = "charlm_indian_classification_adapter.py"
#training_config_file = "charlm_indian_classification_lora.py"
#training_config_file = "charlm_prompt_tuning.py"
#training_config_file = "charlm_instruction_tuning.py"
exec(open(TRAINING_CONFIG_DIR + training_config_file).read())

In [None]:
if pre_training:
    print("loading pre-training dataset")
    train_names = pd.read_csv(DATA_DIR + "train.bin")[["name"]].values.tolist()
    val_names = pd.read_csv(DATA_DIR + "val.bin")[["name"]].values.tolist()
else:
    print("loading dataset for task:", task)
    train_names = pd.read_csv(DATA_DIR + task + "_train.bin").values.tolist()
    val_names = pd.read_csv(DATA_DIR + task + "_val.bin").values.tolist()

In [None]:
device = "cpu"
if torch.cuda.is_available():
    device="cuda"
elif torch.backends.mps.is_available():
    device="mps"
print("device =", device)

In [None]:
compile = False
if device == "cuda":
    compile = True
    torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
    torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
    if torch.cuda.is_bf16_supported():
        # automatically chooses appropriate precision for operations
        ctx = torch.amp.autocast(device_type=device, dtype=torch.bfloat16)
        # gradient scaler disabled in case of bfloat16
        scaler = torch.cuda.amp.GradScaler(enabled=False)
    else:
        ctx = torch.amp.autocast(device_type=device, dtype=torch.float16)
        # gradient scaler enabled in case of float16
        scaler = torch.cuda.amp.GradScaler(enabled=True)
else:
    ctx = nullcontext()
    # gradient scaling disabled in case of devices other than cuda. No-op in that case
    scaler = torch.cuda.amp.GradScaler(enabled=False)

<img src="assets/floating_point_numbers.png">

source: https://cloud.google.com/tpu/docs/bfloat16

<img src="assets/gradient_scaling.png">

source: https://pytorch.org/docs/stable/amp.html

## Mixed Precision Training

<img src="assets/mixed_precision.png">

source: https://hackernoon.com/rtx-2080ti-vs-gtx-1080ti-fastai-mixed-precision-training-comparisons-on-cifar-100-761d8f615d7f

In [None]:
if not pre_training and from_scratch:
    print("loading pre-trained model")
    checkpoint = torch.load(MODEL_DIR + IN_CHECKPOINT, map_location=device)
    config = checkpoint["config"]
    print("best val loss of pre-trained model:", checkpoint["best_val_loss"])
    print(config)

In [None]:
if not from_scratch:
    print("loading model from checkpoint")
    checkpoint = torch.load(MODEL_DIR + IN_CHECKPOINT, map_location=device)
    config = checkpoint["config"]
    model = GPT(config)
    state_dict = checkpoint["model"]
    unwanted_prefix = "_orig_mod."
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
    iter_num = checkpoint["iter_num"]
    best_val_loss = checkpoint["best_val_loss"]
else:
    print("building model from scratch")
    # default configuration
    config = config if not pre_training else model_config
    # update config if required
    config.update(model_extended_config)
    model = GPT(config)
    if not pre_training:
        state_dict = checkpoint["model"]
        # unwanted prefix gets added especially running on vms. Getting rid of that
        unwanted_prefix = "_orig_mod."
        for k, v in list(state_dict.items()):
            if k.startswith(unwanted_prefix):
                state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
        # there might be differences between the saved and created model. Hence disabled strict mode
        missing, unexpected = model.load_state_dict(state_dict, strict=False)
        print("--- missing ---")
        for _ in missing:
            print(_)
        print("--- unexpected ---")
        for _ in unexpected:
            print(_)
    # updating parameters to resume training from where it was left
    iter_num = 0
    best_val_loss = 1e9
model = model.to(device)
print(config)
if compile:
    print("compiling the model... (takes a ~minute)")
    model = torch.compile(model)

In [None]:
model

In [None]:
print("--- learnable parameters ---")
for pn, p in model.named_parameters():
    if p.requires_grad:
        print(pn)

In [None]:
optimizer = model.configure_optimizers(weight_decay, learning_rate)
if not from_scratch:
    print("loading optimizer from checkpoint")
    optimizer.load_state_dict(checkpoint["optimizer"])
checkpoint = None  # free-up memory

In [None]:
tokens_per_iter = batch_size * gradient_accumulation_steps * config["block_size"]
print(f"tokens per iteration will be: {tokens_per_iter:,}")

## Pre-training Input and target examples

<img src="assets/input_example.png">

In [None]:
train_names[:10]

In [None]:
def get_batch(split, batch_index=None):
    data = train_names if split == "train" else val_names
    if batch_index is not None:
        # creating batch specified by batch_index
        max_batches = len(data) // batch_size
        batch_index = batch_index % max_batches
        ix = torch.arange(batch_index * batch_size, (batch_index + 1) * batch_size)
    else:
        # creating random batch
        ix = torch.randint(len(data), (batch_size,))
    pad_token = stoi["*"]
    if classification_task:
        x = torch.ones(batch_size, config["block_size"], dtype=torch.long) * pad_token
        y = torch.ones(batch_size, 1, dtype=torch.long) * pad_token
        for i, index in enumerate(ix):
            encoded = encode("{" + data[index][0])
            # left padded
            x[i][-len(encoded) :] = torch.Tensor(encoded)
            if "gender" in task:
                y[i][0] = 1 if data[index][1] == "male" else 0
            if "indian" in task:
                y[i][0] = 1 if data[index][1] == "india" else 0
    else:
        if pre_training:
            x = torch.ones(batch_size, config["block_size"], dtype=torch.long) * pad_token
            y = torch.ones(batch_size, config["block_size"], dtype=torch.long) * pad_token
            for i, index in enumerate(ix):
                encoded = encode("{" + data[index][0] + "}")
                # randomly selecting starting points in the block to ensure all position embeddings are learnt
                start = torch.randint(config["block_size"] - len(encoded) + 1, (1,)).item()
                x[i][start : start + len(encoded)] = torch.Tensor(encoded)
                y[i][start : start + len(encoded) - 1] = torch.Tensor(encoded[1:])
        # instruction-tuning
        elif "0" in data[0][0]:
            x = torch.ones(batch_size, config["block_size"] - prompt_vocab_size, dtype=torch.long) * pad_token
            y = torch.ones(batch_size, config["block_size"], dtype=torch.long) * pad_token
            for i, index in enumerate(ix):
                sep = data[index][0].index("0")
                first = encode(data[index][0][:sep])
                second = encode(data[index][0][sep + 1 :])
                encoded = first + second
                x[i][: len(encoded)] = torch.Tensor(encoded)
                y[i][prompt_vocab_size + len(first) :
                     prompt_vocab_size + len(first) + len(second) - 1] = torch.Tensor(second[1:])
        else:
            # fine-tuning
            x = torch.ones(batch_size, config["block_size"] - prompt_vocab_size, dtype=torch.long) * pad_token
            y = torch.ones(batch_size, config["block_size"], dtype=torch.long) * pad_token
            for i, index in enumerate(ix):
                encoded = encode("{" + data[index][0] + "}")
                x[i][: len(encoded)] = torch.Tensor(encoded)
                y[i][prompt_vocab_size : prompt_vocab_size + len(encoded) - 1] = torch.Tensor(encoded[1:])
    pad_mask = torch.ones_like(x)
    pad_mask.masked_fill_(x == config.get("pad_token", -100), 0)
    if prompt_vocab_size>0:
        prompt_pad_mask=torch.ones(batch_size, prompt_vocab_size, dtype=torch.bool)
        pad_mask=torch.cat((prompt_pad_mask, pad_mask), dim=1)
    pad_mask = torch.stack([x.view(config["block_size"],1) @ x.view(1,config["block_size"]) for x in pad_mask])
    pad_mask = pad_mask.reshape(batch_size,1,config["block_size"],config["block_size"])
    pad_mask.type(torch.int8)
    pad_mask = pad_mask.to(device)
    x, y = x.to(device), y.to(device)
    if prompt_vocab_size > 0:
        prompt = torch.arange(prompt_vocab_size)
        prompts = prompt.repeat(batch_size, 1)
        prompts = prompts.to(device)
    else:
        prompts = None
    return x, y, prompts, pad_mask

In [None]:
x, y, prompts, pad_mask = get_batch("train")
ip = x[0].tolist()
op = y[0].tolist()
if classification_task:
    for i in range(len(ip)):
        print(itos[ip[i]])
    print("class", op[0])
else:
    for i in range(prompt_vocab_size):
        print("prompt", i, "-", itos[op[i]])
    for i in range(len(ip)):
        print(itos[ip[i]], "-", itos[op[i + prompt_vocab_size]])

In [None]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            if split == "val":
                X, Y, Prompts, Pad_mask = get_batch(split, batch_index=k)
            else:
                X, Y, Prompts, Pad_mask = get_batch(split)
            with ctx:
                logits, loss = model(X, Y, prompts=Prompts, pad_mask=Pad_mask)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # if decay_lr is false, return maximum learning rate
    if not decay_lr:
        return learning_rate
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)

In [None]:
X, Y, Prompts, Pad_mask = get_batch("train")  # fetch the very first batch
t0 = time.time()
while True:
    # determine and set the learning rate for this iteration
    lr = get_lr(iter_num)
    
    # evaluate the loss on train/val sets and write checkpoints
    if iter_num % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        if not eval_only and (losses["val"] < best_val_loss or always_save_checkpoint):
            best_val_loss = losses["val"]
            if iter_num > 0:
                checkpoint = {
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "iter_num": iter_num,
                    "best_val_loss": best_val_loss,
                    "config": config,
                }
                print(f"saving checkpoint to {MODEL_DIR+OUT_CHECKPOINT}")
                torch.save(checkpoint, MODEL_DIR + OUT_CHECKPOINT)
    if eval_only:
        break

    for micro_step in range(gradient_accumulation_steps):
        with ctx:
            # forward pass
            logits, loss = model(X, Y, prompts=Prompts, pad_mask=Pad_mask)
            if gradient_accumulation_steps > 1:
                # scaling loss in case of gradient accumulation
                loss = loss / gradient_accumulation_steps
        X, Y, Prompts, Pad_mask = get_batch("train")
        # backward pass. And upscaling the loss if gradient scaling enabled
        scaler.scale(loss).backward()
    # clip the gradient
    if grad_clip != 0.0:
        # downscaling the loss before clipping gradients
        scaler.unscale_(optimizer)
        # clipping gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # step the optimizer and unscale loss if gradient clipping disabled
    scaler.step(optimizer)
    scaler.update()
    # flush gradients and free-up memory
    optimizer.zero_grad(set_to_none=True)

    # timing and logging
    t1 = time.time()
    dt = t1 - t0
    t0 = t1
    if do_log and iter_num % log_interval == 0:
        # multiply loss to account for division incase of gradient accumulation
        lossf = loss.item() * gradient_accumulation_steps
        print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")
    iter_num += 1

    # termination conditions
    if iter_num > max_iters:
        break

In [None]:
print("best val loss:", round(best_val_loss.item(), 2))