In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import math

from src.models.blm.config import ModelArgs
from src.models.blm.block import Block

from src.cells.normalization import RMSLayerNorm
from typing import Optional
from src.cells.position import RotaryEmbedding
from src.cells.optim_func import config_optimizer

import lightning.pytorch as pl
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import functools

from datasets import load_from_disk, load_dataset
from tqdm.notebook import tqdm

import torch._dynamo
torch._dynamo.config.suppress_errors = True

torch.manual_seed(123)
torch.cuda.manual_seed(123)

In [2]:
class TinyStoriesDataloader(pl.LightningDataModule):
    def __init__(
        self, data_path_train, data_path_val, tokenizer_path, batch_size, num_workers
    ):
        super().__init__()
        self.data_path_train = data_path_train
        self.data_path_val = data_path_val

        self.batch_size = batch_size
        self.num_workers = num_workers

        self.tokenizer = self._load_tokenizer(tokenizer_path)

    def prepare_data(self):
        pass

    def _load_tokenizer(self, tokenizer_path):
        from src.tokenize.tokenizer import Tokenizer

        return Tokenizer(tokenizer_path)

    def _collate_fn(self, batch: int, padding_id: int):
        batch = pad_sequence(
            (torch.LongTensor(_["idx"]) for _ in batch),
            batch_first=True,
            padding_value=padding_id,
        )  # TODO : ShortTensor suffice our need but nn.Embedding don't support it. Using LOngTensor is a unnecessary waste of GPU memory
        x_batch = torch.stack(
            [en[:-1] for en in batch]
        )  # Extract x (remove last token)
        y_batch = torch.stack(
            [en[1:] for en in batch]
        )  # Extract y (remove first token)
        return x_batch, y_batch

    def setup(self, stage):

        self.train_data = load_from_disk(self.data_path_train)
        self.val_data = load_from_disk(self.data_path_val)

    def train_dataloader(self):
        return DataLoader(
            self.train_data,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=functools.partial(
                self._collate_fn, padding_id=self.tokenizer.eos_id()
            ),
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_data,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=functools.partial(
                self._collate_fn, padding_id=self.tokenizer.eos_id()
            ),
        )

In [3]:


class Transformer(pl.LightningModule):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.save_hyperparameters()
        self.max_seq_len = args.max_seq_len
        self.tok_embd = nn.Embedding(
            args.vocab_size, args.emebdding_dim, padding_idx=args.padding_idx
        )
        self.dropout = nn.Dropout(args.embedding_dropout)
        self.rope_q = RotaryEmbedding(
            args.emebdding_dim // args.num_attention_heads,
            args.max_seq_len,
            device=args.device,
        )
        self.rope_k = RotaryEmbedding(
            args.emebdding_dim // args.num_key_value_heads,
            args.max_seq_len,
            device=args.device,
        )

        # Freeze the parameters rope_q and rope_k
        self.rope_q.requires_grad_(False)
        self.rope_k.requires_grad_(False)

        self.layers = nn.ModuleList([Block(args) for lid in range(args.num_layers)])

        self.norm = RMSLayerNorm(args.emebdding_dim, eps=args.rms_norm_eps)
        self.output = nn.Linear(args.emebdding_dim, args.vocab_size, bias=False)

        # share the unembedding parameters with the embedding parameters
        self.tok_embd.weight = (
            self.output.weight
        )  # https://paperswithcode.com/method/weight-tying

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith("wo.weight"):
                torch.nn.init.normal_(
                    p, mean=0.0, std=0.02 / math.sqrt(2 * args.num_layers)
                )
        self.lr = 1e-4
        
    def __repr__(self):
        return f"{self.get_num_params()} Million Params Model"

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.tok_embd.weight.numel()
        return n_params / 1e6  # In Million

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(
        self, tokens: torch.Tensor
    ) -> torch.Tensor:
        x = self.dropout(self.tok_embd(tokens))
        for layer in self.layers:
            x = layer(
                x, self.rope_q, self.rope_k
            )  ## How about we add residual connection here also ?
        x = self.norm(x)
        return x

    def _common_step(self,batch,batch_index):
        x, targets = batch
        logits = self.output(self.forward(x))
        loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
            )
        return loss
        
    def training_step(self,batch,batch_idx):
        x,y = batch
        
        loss = self._common_step(batch,batch_idx)
        if batch_idx % int(1e4) == 0:
            self.log_dict({'train_loss':loss},prog_bar=True,on_step=False,on_epoch=True)
        return loss

    def validation_step(self,batch,batch_idx):
        x,y = batch
        
        loss = self._common_step(batch,batch_idx)
        if batch_idx % int(1e4) == 0:
            self.log_dict({'val_loss':loss},prog_bar=True,on_step=False,on_epoch=True)
        return loss
        
    def configure_optimizers(self):
        
        # start with all of the candidate parameters
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # filter out those that do not require grad
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {"params": decay_params, "weight_decay": 1e-2},
            {"params": nodecay_params, "weight_decay": 0.0},
        ]
        return torch.optim.AdamW(optim_groups,lr=self.lr,betas=(0.9, 0.95),fused=False)

        
    def predict_step(self,batch,batch_idx,max_new_tokens=30,temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            # trim the token to the max_len 
            if batch.shape[1] > self.max_seq_len:
                batch = batch[:,-self.max_seq_len:]
            # inference-time mini-optimization: only forward the output on the very last position
            logits = self.output(
                    self(batch)[:, [-1], :]
                )  # note: using list [-1] to preserve the time dim
            # pluck the logits at the final step and scale by desired temperature
            logits = logits[:, -1, :] / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float("Inf")
            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)
            # append sampled index to the running sequence and continue
            batch = torch.cat((batch, idx_next), dim=1)
        return batch

In [4]:
BASE_URL = "/home/pranav-pc/projects/OpenTransformer/multiformer"
data_path_train = BASE_URL + "/data/interim/TinyStories_train_65>tk>512.hf"
data_path_val = BASE_URL + "/data/interim/TinyStories_val_65>tk>512.hf"
tokenizer_path = BASE_URL + "/tokenizer_checkpoints"

batch_size = 16
num_workers = 26
ds = TinyStoriesDataloader(
    data_path_train, data_path_val, tokenizer_path, batch_size, num_workers
)

In [5]:
conf = {
    "vocab_size": 32000,
    "emebdding_dim": 768,
    "max_seq_len": 512,
    "embedding_dropout": 0.0,
    "rms_norm_eps": 1e-05,
    "rope_scaling": 1.0,
    "rope_theta": 10000.0,
    "attention_bias": False,
    "attention_dropout": 0.0,
    "num_attention_heads": 12,
    "num_key_value_heads": 12,
    "use_cache": True,
    "use_sliding_window": True,
    "residual_dropout": 0.1,
    "mlp_dropout": 0.0,
    "mlp_hidden_size": int(1.3 * 768),
    "num_layers": 4,
    "device": ("cuda"
        if torch.cuda.is_available()
        else "mps" if torch.backend.mps.is_available() else "cpu"),
    "padding_idx": ds.tokenizer.eos_id(),
}

config = ModelArgs(**conf)
model = Transformer(config)
model = torch.compile(model,dynamic=True)

Process ForkProcess-18:
Process ForkProcess-8:
Process ForkProcess-29:
Process ForkProcess-21:
Process ForkProcess-30:
Process ForkProcess-31:
Process ForkProcess-7:
Process ForkProcess-4:
Process ForkProcess-3:
Process ForkProcess-13:
Process ForkProcess-16:
Process ForkProcess-10:
Process ForkProcess-20:
Process ForkProcess-25:
Process ForkProcess-2:
Process ForkProcess-12:
Process ForkProcess-23:
Process ForkProcess-19:
Process ForkProcess-11:
Process ForkProcess-24:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Process ForkProcess-6:
Traceback (most recent call last):
Traceback (most recent call last):
Process ForkProcess-17:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run

In [6]:
from lightning.pytorch.callbacks import GradientAccumulationScheduler, StochasticWeightAveraging, ModelCheckpoint,EarlyStopping
accumulator = GradientAccumulationScheduler(scheduling={0: 6, 4: 4, 8: 3, 20:1})

logger = pl.loggers.TensorBoardLogger(save_dir='./blm-log/', name='blm', version=0.1)
# profiler = pl.profilers.PyTorchProfiler(
#     on_trace_ready=torch.profiler.tensorboard_trace_handler('./blm-log/'),
#     schedule=torch.profiler.schedule(skip_first=10, wait=10, warmup=1, active=2)
# )
# saves top-K checkpoints based on "train_loss" metric
checkpoint_callback = ModelCheckpoint(
    save_top_k=3,
    monitor="train_loss",
    mode="min",
    dirpath="/home/pranav-pc/projects/OpenTransformer/multiformer/model_checkpoints/blm/",
    filename="baby-llm-{epoch:02d}-{train_loss:.3f}",
    save_last= True,
    every_n_train_steps= int(1e4),
    save_on_train_epoch_end= True
)
early_stop = EarlyStopping('train_loss',patience=10,verbose=True)
stochastic_weight_avg = StochasticWeightAveraging(swa_lrs=1e-2)

trainer = pl.Trainer(
    logger=logger,
    min_epochs=1,
    max_epochs=100,
    precision='bf16-mixed',
    enable_model_summary=True,
    # profiler=profiler,
    callbacks=[early_stop,checkpoint_callback,accumulator,stochastic_weight_avg],
    default_root_dir="/home/pranav-pc/projects/OpenTransformer/multiformer/model_checkpoints/blm/",
    enable_checkpointing  = True,
    # fast_dev_run=True,
    log_every_n_steps=int(1e2),
    enable_progress_bar = True,
    gradient_clip_val =1.0)
torch.set_float32_matmul_precision('medium')
model.train()
trainer.fit(model, ds,ckpt_path="/home/pranav-pc/projects/OpenTransformer/multiformer/model_checkpoints/blm/last.ckpt")


Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/pranav-pc/.env/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /home/pranav-pc/projects/OpenTransformer/multiformer/model_checkpoints/blm exists and is not empty.
Restoring states from the checkpoint path at /home/pranav-pc/projects/OpenTransformer/multiformer/model_checkpoints/blm/last.ckpt
/home/pranav-pc/.env/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:263: Be aware that when using `ckpt_path`, callbacks used to create the checkpoint need to be provided during `Trainer` instantiation. Please add the following callbacks: ["ModelCheckpoint{'monitor': 'train_loss', 'mode': 'min', 'every_n_train_steps': 8, 'every_n_epochs': 0, 'train_time_interval': None}"].
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    

Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

[2024-04-11 01:54:44,388] [15/0_1] torch._guards: [ERROR] Error while creating guard:
[2024-04-11 01:54:44,388] [15/0_1] torch._guards: [ERROR] Name: "L['self']"
[2024-04-11 01:54:44,388] [15/0_1] torch._guards: [ERROR]     Source: local
[2024-04-11 01:54:44,388] [15/0_1] torch._guards: [ERROR]     Create Function: NN_MODULE
[2024-04-11 01:54:44,388] [15/0_1] torch._guards: [ERROR]     Guard Types: ['ID_MATCH']
[2024-04-11 01:54:44,388] [15/0_1] torch._guards: [ERROR]     Code List: ["___check_obj_id(L['self'], 132017970055184)"]
[2024-04-11 01:54:44,388] [15/0_1] torch._guards: [ERROR]     Object Weakref: <weakref at 0x7811d43e20c0; to '_ResultMetric' at 0x7811d4358810>
[2024-04-11 01:54:44,388] [15/0_1] torch._guards: [ERROR]     Guarded Class Weakref: <weakref at 0x7812c0fd3ab0; to 'ABCMeta' at 0x8edbf40 (_ResultMetric)>
[2024-04-11 01:54:44,390] [15/0_1] torch._guards: [ERROR] Created at:
[2024-04-11 01:54:44,390] [15/0_1] torch._guards: [ERROR]   File "/home/pranav-pc/.env/lib/pyt

Training: |                                               | 0/? [00:00<?, ?it/s]

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7813f9701d10>>
Traceback (most recent call last):
  File "/home/pranav-pc/.env/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 770, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


RuntimeError: DataLoader worker (pid(s) 330314, 330315, 330316, 330317, 330318, 330319, 330320, 330321, 330322, 330323, 330324, 330325, 330326, 330327, 330328, 330329, 330330, 330331, 330332, 330333, 330334, 330335, 330336, 330337, 330338, 330339) exited unexpectedly

In [7]:
checkpoint_callback.best_model_path

''

In [8]:
from lightning.pytorch.tuner import Tuner
# Create a Tuner
tuner = Tuner(trainer)

In [13]:
tuner.lr_find(model,ds)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/pranav-pc/.env/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py:161: You're resuming from a checkpoint that ended before the epoch ended and your dataloader is not resumable. This can cause unreliable results if further training is done. Consider using an end-of-epoch checkpoint or make your dataloader resumable by implementing the `state_dict` / `load_state_dict` interface.


Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]

LR finder stopped early after 19 steps due to diverging loss.
Failed to compute suggestion for learning rate because there are not enough points. Increase the loop iteration limits or the size of your dataset/dataloader.
Restoring states from the checkpoint path at /home/pranav-pc/projects/OpenTransformer/multiformer/model_checkpoints/blm/.lr_find_f6a18406-a3d3-4a00-889b-842d1cd2f774.ckpt
Restored all states from the checkpoint at /home/pranav-pc/projects/OpenTransformer/multiformer/model_checkpoints/blm/.lr_find_f6a18406-a3d3-4a00-889b-842d1cd2f774.ckpt


<lightning.pytorch.tuner.lr_finder._LRFinder at 0x781120d27fd0>

In [12]:
model.eval()
trainer.test(ckpt_path="best",dataloaders=ds)

ValueError: `.test(ckpt_path="best")` is set but `ModelCheckpoint` is not configured.

In [15]:
model.eval()
model.to('cuda')
trainer.validate(model, ds)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


[{'val_loss': 1.539431095123291}]

In [136]:
text = "who is the president of india"
tokens = torch.LongTensor(ds.tokenizer.encode(text)).to('cuda:0').view(1,-1)
tokens

tensor([[   1, 1058,  338,  278, 6673,  310, 1399,  423,    2]],
       device='cuda:0')