In [15]:
import functools
import math
from typing import Optional

import lightning.pytorch as pl
import torch
import torch._dynamo
import torch.nn.functional as F
from datasets import load_dataset, load_from_disk
from src.cells.normalization import RMSLayerNorm
from src.cells.optim_func import config_optimizer
from src.cells.position import RotaryEmbedding
from src.models.blm.block import Block
from src.models.blm.config import ModelArgs
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger, CSVLogger

torch._dynamo.config.suppress_errors = True

pl.seed_everything(123, workers=True)
torch.manual_seed(123)
torch.cuda.manual_seed(123)

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backend.mps.is_available() else "cpu"
)

Seed set to 123


In [16]:
class TinyStoriesDataloader(pl.LightningDataModule):
    def __init__(
        self, data_path_train, data_path_val, tokenizer_path, batch_size, num_workers
    ):
        super().__init__()
        self.data_path_train = data_path_train
        self.data_path_val = data_path_val

        self.batch_size = batch_size
        self.num_workers = num_workers

        self.tokenizer = self._load_tokenizer(tokenizer_path)

    def prepare_data(self):
        pass

    def _load_tokenizer(self, tokenizer_path):
        from src.tokenize.tokenizer import Tokenizer

        return Tokenizer(tokenizer_path)

    def _collate_fn(self, batch: int, padding_id: int):
        batch = pad_sequence(
            (torch.LongTensor(_["idx"]) for _ in batch),
            batch_first=True,
            padding_value=padding_id,
        )  # TODO : ShortTensor suffice our need but nn.Embedding don't support it. Using LOngTensor is a unnecessary waste of GPU memory
        x_batch = torch.stack(
            [en[:-1] for en in batch]
        )  # Extract x (remove last token)
        y_batch = torch.stack(
            [en[1:] for en in batch]
        )  # Extract y (remove first token)
        return x_batch, y_batch

    def setup(self, stage):

        self.train_data = load_from_disk(self.data_path_train)
        self.val_data = load_from_disk(self.data_path_val)

    def train_dataloader(self):
        return DataLoader(
            self.train_data,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=functools.partial(
                self._collate_fn, padding_id=self.tokenizer.eos_id()
            ),
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_data,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=functools.partial(
                self._collate_fn, padding_id=self.tokenizer.eos_id()
            ),
        )

In [17]:
class Transformer(pl.LightningModule):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.save_hyperparameters()
        self.max_seq_len = args.max_seq_len
        self.tok_embd = nn.Embedding(
            args.vocab_size, args.embedding_dim, padding_idx=args.padding_idx
        )
        self.dropout = nn.Dropout(args.embedding_dropout)
        self.rope_q = RotaryEmbedding(
            args.embedding_dim // args.num_attention_heads,
            args.max_seq_len,
            device=args.device,
        )
        self.rope_k = RotaryEmbedding(
            args.embedding_dim // args.num_key_value_heads,
            args.max_seq_len,
            device=args.device,
        )

        # Freeze the parameters rope_q and rope_k
        self.rope_q.requires_grad_(False)
        self.rope_k.requires_grad_(False)

        self.layers = nn.ModuleList([Block(args) for lid in range(args.num_layers)])

        self.norm = RMSLayerNorm(args.embedding_dim, eps=args.rms_norm_eps)
        self.output = nn.Linear(args.embedding_dim, args.vocab_size, bias=False)

        # share the unembedding parameters with the embedding parameters
        self.tok_embd.weight = (
            self.output.weight
        )  # https://paperswithcode.com/method/weight-tying

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith("wo.weight"):
                torch.nn.init.normal_(
                    p, mean=0.0, std=0.02 / math.sqrt(2 * args.num_layers)
                )
        self.lr = 1e-4

    def __repr__(self):
        return f"{self.get_num_params()} Million Params Model"

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.tok_embd.weight.numel()
        return n_params / 1e6  # In Million

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, tokens: torch.Tensor) -> torch.Tensor:
        x = self.dropout(self.tok_embd(tokens))
        for layer in self.layers:
            x = layer(
                x, self.rope_q, self.rope_k
            )  ## How about we add residual connection here also ?
        x = self.norm(x)
        return x

    def _common_step(self, batch, batch_index):
        x, targets = batch
        logits = self.output(self.forward(x))
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
        )
        return loss

    def training_step(self, batch, batch_idx):
        x, y = batch

        loss = self._common_step(batch, batch_idx)
        if trainer.global_step == 0:
            wandb.define_metric("train_loss", summary="mean")
        self.log_dict(
            {"train_loss": loss, "lr": self.lr},
            prog_bar=True,
            on_step=True,
            on_epoch=True,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch

        loss = self._common_step(batch, batch_idx)
        self.log_dict({"val_loss": loss}, prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def configure_optimizers(self):

        # start with all of the candidate parameters
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # filter out those that do not require grad
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {"params": decay_params, "weight_decay": 1e-2},
            {"params": nodecay_params, "weight_decay": 0.0},
        ]
        return torch.optim.AdamW(
            optim_groups, lr=self.lr, betas=(0.9, 0.95), fused=False
        )

    def predict_step(
        self,
        batch,
        batch_idx,
        max_new_tokens=30,
        temperature=1.0,
        top_k=None,
        conditional_break=[13, 13, 1],
    ):

        for _ in range(max_new_tokens):
            # trim the token to the max_len
            if batch.shape[1] > self.max_seq_len:
                batch = batch[:, -self.max_seq_len :]

            # inference-time mini-optimization: only forward the output on the very last position
            logits = self.output(
                self(batch)[:, [-1], :]
            )  # note: using list [-1] to preserve the time dim
            # pluck the logits at the final step and scale by desired temperature
            logits = logits[:, -1, :] / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float("Inf")
            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)
            # append sampled index to the running sequence and continue
            batch = torch.cat((batch, idx_next), dim=1)
            if conditional_break:
                last_three_tokens = batch[-1][-len(conditional_break) :]
                if torch.equal(
                    last_three_tokens,
                    torch.LongTensor(conditional_break).to(batch.device),
                ):

                    break

        return batch

In [18]:
#### Training

In [19]:
BASE_URL = "/home/pranav-pc/projects/OpenTransformer/multiformer"
MODEL_CHECKPOINT_PATH = BASE_URL + "/model_checkpoints/blm/last-v3.ckpt"
data_path_train = BASE_URL + "/data/interim/TinyStories_train_65>tk>512.hf"
data_path_val = BASE_URL + "/data/interim/TinyStories_val_65>tk>512.hf"
tokenizer_path = BASE_URL + "/tokenizer_checkpoints"


load_from_checkpoint = False
train = True


batch_size = 16
num_workers = 26
ds = TinyStoriesDataloader(
    data_path_train, data_path_val, tokenizer_path, batch_size, num_workers
)

In [20]:
conf = {
    "vocab_size": 32000,
    "embedding_dim": 768,
    "max_seq_len": 512,
    "embedding_dropout": 0.0,
    "rms_norm_eps": 1e-05,
    "rope_scaling": 1.0,
    "rope_theta": 10000.0,
    "attention_bias": False,
    "attention_dropout": 0.0,
    "num_attention_heads": 12,
    "num_key_value_heads": 12,
    "use_cache": True,
    "use_sliding_window": True,
    "residual_dropout": 0.1,
    "mlp_dropout": 0.0,
    "mlp_hidden_size": int(1.3 * 768),
    "num_layers": 4,
    "device": device,
    "padding_idx": ds.tokenizer.eos_id(),
}
if load_from_checkpoint:
    model = Transformer.load_from_checkpoint(MODEL_CHECKPOINT_PATH)
    model = torch.compile(model, dynamic=True)
else:
    config = ModelArgs(**conf)
    model = Transformer(config)
    model = torch.compile(model, dynamic=True)

In [21]:
from lightning.pytorch.callbacks import (
    BatchSizeFinder,
    EarlyStopping,
    GradientAccumulationScheduler,
    LearningRateFinder,
    ModelCheckpoint,
    StochasticWeightAveraging,
)

# class DynamicBatchSizeFinder(BatchSizeFinder):
#     def __init__(self, milestones, *args, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.milestones = milestones

#     def on_fit_start(self, *args, **kwargs):
#         return

#     def on_train_batch_start(self, trainer, pl_module,batch,batch_idx):
#         if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
#             if batch_idx % 5 == 0:
#                 self.scale_batch_size(trainer, pl_module)


class FineTuneLearningRateFinder(LearningRateFinder):
    def __init__(self, milestones, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.milestones = milestones

    def on_train_epoch_start(self, trainer, pl_module):
        if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:

            self.lr_find(trainer, pl_module)

In [22]:
from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(
    name="blm",
    save_dir="blm/",
    version="v1",
    offline=False,
    project="tiny-stories",
    log_model="all",
)
import wandb

# wandb.login()

In [None]:
accumulator = GradientAccumulationScheduler(scheduling={0: 4, 4: 3, 6: 1})

# logger = pl.loggers.TensorBoardLogger(save_dir="./blm-log/", name="blm", version=1.0)
# profiler = pl.profilers.PyTorchProfiler(
#     on_trace_ready=torch.profiler.tensorboard_trace_handler('./blm-log/'),
#     schedule=torch.profiler.schedule(skip_first=10, wait=10, warmup=1, active=2)
# )
# saves top-K checkpoints based on "train_loss" metric
checkpoint_callback = ModelCheckpoint(
    save_top_k=2,
    monitor="train_loss",
    mode="max",
    dirpath="model_checkpoints/",
    filename="baby-llm-{epoch:02d}-{train_loss:.3f}",
    save_last=True,
    every_n_train_steps=int(1e4),
    save_on_train_epoch_end=True,
)
logger = TensorBoardLogger(
    save_dir="./lightning-log/", name="TinnyStories", version=0.1
)
early_stop = EarlyStopping("train_loss", patience=10, verbose=True)
stochastic_weight_avg = StochasticWeightAveraging(swa_lrs=1e-6)
# dynamic_batch_size = DynamicBatchSizeFinder(milestones=(6, 20))
lr_finder = FineTuneLearningRateFinder(milestones=(5, 20))

trainer = pl.Trainer(
    logger=logger,
    min_epochs=1,
    max_epochs=100,
    precision="bf16-mixed",
    enable_model_summary=True,
    profiler="simple",
    callbacks=[
        # early_stop,
        checkpoint_callback,
        accumulator,
        # stochastic_weight_avg,
        # lr_finder,
    ],
    default_root_dir="model_checkpoints/",
    enable_checkpointing=True,
    # fast_dev_run=True,
    log_every_n_steps=5,
    enable_progress_bar=True,
    gradient_clip_val=1.0,
    max_steps=5000,
    val_check_interval=400,
    check_val_every_n_epoch=None,
)
torch.set_float32_matmul_precision("medium")

if train:
    model.train()
    trainer.fit(model, ds)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/pranav-pc/.env/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:652: Checkpoint directory /home/pranav-pc/projects/OpenTransformer/multiformer/notebooks/models/baby-language-model/model_checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type            | Params | Mode 
-----------------------------------------------------
0 | tok_embd | Embedding       | 24.6 M | train
1 | dropout  | Dropout         | 0      | train
2 | rope_q   | RotaryEmbedding | 0      | train
3 | rope_k   | RotaryEmbedding | 0      | train
4 | layers   | ModuleList      | 18.6 M | train
5 | norm     | RMSLayerNorm    | 768    | train
6 | output   | Linear          | 24.6 M | train
-----------------------------------------------------
43.2 M    Trainable params
0         Non-trai

Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

[2024-04-18 01:04:34,265] [15/0_1] torch._guards: [ERROR] Error while creating guard:
[2024-04-18 01:04:34,265] [15/0_1] torch._guards: [ERROR] Name: "L['self']"
[2024-04-18 01:04:34,265] [15/0_1] torch._guards: [ERROR]     Source: local
[2024-04-18 01:04:34,265] [15/0_1] torch._guards: [ERROR]     Create Function: NN_MODULE
[2024-04-18 01:04:34,265] [15/0_1] torch._guards: [ERROR]     Guard Types: ['ID_MATCH']
[2024-04-18 01:04:34,265] [15/0_1] torch._guards: [ERROR]     Code List: ["___check_obj_id(L['self'], 140097472225296)"]
[2024-04-18 01:04:34,265] [15/0_1] torch._guards: [ERROR]     Object Weakref: <weakref at 0x7f6b3a32fd80; to '_ResultMetric' at 0x7f6afc10e010>
[2024-04-18 01:04:34,265] [15/0_1] torch._guards: [ERROR]     Guarded Class Weakref: <weakref at 0x7f6b4de2c6d0; to 'ABCMeta' at 0x976acd0 (_ResultMetric)>
[2024-04-18 01:04:34,266] [15/0_1] torch._guards: [ERROR] Created at:
[2024-04-18 01:04:34,266] [15/0_1] torch._guards: [ERROR]   File "/home/pranav-pc/.env/lib/pyt

Training: |                                               | 0/? [00:00<?, ?it/s]

In [37]:
def get_subset_sampler(dataset, subset_ratio=0.02):
    total_data = len(dataset)
    subset_size = int(total_data * subset_ratio)
    print(subset_size)
    indices = list(range(subset_size))
    return torch.utils.data.Subset(dataset, indices)

In [38]:
len(get_subset_sampler(load_from_disk(data_path_train)))

25245


25245

TypeError: object of type 'TinyStoriesDataloader' has no len()

In [None]:
model.eval()
trainer.validate(model, ds)

In [11]:
#### Inference
model.eval()
model = model.cuda()

import os

os.environ["WANDB_DISABLED"] = "true"

In [12]:
# text = "Write a story containing the words: dive, job, sorry. Story summary: Bob the big fish finds a shiny rock while searching for food for his friends, but when he tells them about it, they are excited to play with it instead of being sad about not having food."

In [13]:
# text = "Sita wanted to watch either a movie or a cartoon. Her mother didn’t let her watch a cartoon so instead she"
text = (
    "Tim is a good boy. one day his father called and asked for the school exam result"
)
tokens = torch.LongTensor(ds.tokenizer.encode(text)).to("cuda:0").view(1, -1)
# tokens
print(
    ds.tokenizer.decode_ids(
        model.predict_step(
            tokens, None, max_new_tokens=450, temperature=0.9, top_k=None
        )[0].tolist()
    )
)

Tim is a good boy. one day his father called and asked for the school exam resultend. Tim was excited and said yes. He looked at the exam paper and read the words. It said: "Very well done, Tim, you have got an answer from your father. Come on, let us check the answers together."

When the exam was ready, Tim stepped into the classroom and tried on the paper. It was very hard but he did it! He felt so proud of himself. His father was very patient and hugged him.

At the end of the exam, all of the answers Tim worked together. He was so happy to have the result. He smiled and thanked his father, who was very proud and patient. He was glad that he had finished the exam.




In [None]:
LinearWarmupCosineAnnealingLR