In [96]:
import torch
import torch.nn as nn
import lightning as pl

from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import functools

from datasets import load_from_disk, load_dataset

from tqdm.notebook import tqdm

In [97]:
class TinyStoriesDataloader(pl.LightningDataModule):
    def __init__(
        self, data_path_train, data_path_val, tokenizer_path, batch_size, num_workers
    ):
        super().__init__()
        self.data_path_train = data_path_train
        self.data_path_val = data_path_val

        self.batch_size = batch_size
        self.num_workers = num_workers

        self.tokenizer = self._load_tokenizer(tokenizer_path)

    def prepare_data(self):
        pass

    def _load_tokenizer(self, tokenizer_path):
        from src.tokenize.tokenizer import Tokenizer

        return Tokenizer(tokenizer_path)

    def _collate_fn(self, batch: int, padding_id: int):
        batch = pad_sequence(
            (torch.LongTensor(_["idx"]) for _ in batch),
            batch_first=True,
            padding_value=padding_id,
        )  # TODO : ShortTensor suffice our need but nn.Embedding don't support it. Using LOngTensor is a unnecessary waste of GPU memory
        x_batch = torch.stack(
            [en[:-1] for en in batch]
        )  # Extract x (remove last token)
        y_batch = torch.stack(
            [en[1:] for en in batch]
        )  # Extract y (remove first token)
        return x_batch, y_batch

    def setup(self, stage):

        self.train_data = load_from_disk(self.data_path_train)
        self.val_data = load_from_disk(self.data_path_val)

    def train_dataloader(self):
        return DataLoader(
            self.train_data,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=functools.partial(
                self._collate_fn, padding_id=self.tokenizer.eos_id()
            ),
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_data,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=functools.partial(
                self._collate_fn, padding_id=self.tokenizer.eos_id()
            ),
        )

In [98]:
BASE_URL = "/home/pranav-pc/projects/OpenTransformer/multiformer"
data_path_train = BASE_URL + "/data/interim/TinyStories_train_65>tk>512.hf"
data_path_val = BASE_URL + "/data/interim/TinyStories_val_65>tk>512.hf"
tokenizer_path = BASE_URL + "/tokenizer_checkpoints"

batch_size = 16
num_workers = 26
ds = TinyStoriesDataloader(
    data_path_train, data_path_val, tokenizer_path, batch_size, num_workers
)

In [99]:
data = load_from_disk(BASE_URL + "/data/interim/TinyStories_val_65>tk>512.hf")
ds.tokenizer.decode_ids(list(set(i[-1] for i in data["idx"])))

''

In [102]:
ds.setup("val")
val_dataloader = ds.val_dataloader()


(tensor([[    1,  4335,  5148,  ...,  1009, 17487, 29889],
         [    1,  2259,   471,  ...,  1250,  4720, 29889],
         [    1,  3118,  2462,  ...,   310,  3974, 29889],
         ...,
         [    1,  4976,   322,  ...,  3698,  4459, 29889],
         [    1,  9038,  2501,  ...,  4056, 17724, 29889],
         [    1,   365,  2354,  ...,  5121, 22296,  1213]]),
 tensor([[ 4335,  5148,   714,  ..., 17487, 29889,     2],
         [ 2259,   471,  6365,  ...,  4720, 29889,     2],
         [ 3118,  2462,   297,  ...,  3974, 29889,     2],
         ...,
         [ 4976,   322,   670,  ...,  4459, 29889,     2],
         [ 9038,  2501,   263,  ..., 17724, 29889,     2],
         [  365,  2354,   471,  ..., 22296,  1213,     2]]))

In [105]:
data,label = next(iter(val_dataloader))
data,label

(tensor([[    1,  4335,  5148,  ...,  1009, 17487, 29889],
         [    1,  2259,   471,  ...,  1250,  4720, 29889],
         [    1,  3118,  2462,  ...,   310,  3974, 29889],
         ...,
         [    1,  4976,   322,  ...,  3698,  4459, 29889],
         [    1,  9038,  2501,  ...,  4056, 17724, 29889],
         [    1,   365,  2354,  ...,  5121, 22296,  1213]]),
 tensor([[ 4335,  5148,   714,  ..., 17487, 29889,     2],
         [ 2259,   471,  6365,  ...,  4720, 29889,     2],
         [ 3118,  2462,   297,  ...,  3974, 29889,     2],
         ...,
         [ 4976,   322,   670,  ...,  4459, 29889,     2],
         [ 9038,  2501,   263,  ..., 17724, 29889,     2],
         [  365,  2354,   471,  ..., 22296,  1213,     2]]))

In [None]:
## Training
conf = {
    "vocab_size": 32000,
    "emebdding_dim": 768,
    "max_seq_len": block_size,
    "embedding_dropout": 0.0,
    "rms_norm_eps": 1e-05,
    "rope_scaling": 1.0,
    "rope_theta": 10000.0,
    "attention_bias": False,
    "attention_dropout": 0.0,
    "num_attention_heads": 12,
    "num_key_value_heads": 12,
    "use_cache": True,
    "use_sliding_window": True,
    "residual_dropout": 0.1,
    "mlp_dropout": 0.0,
    "mlp_hidden_size": int(1.3 * 768),
    "num_layers": 4,
    "device": device,
    "padding_idx": tokenizer.eos_id(),
}


from src.models.blm.config import ModelArgs
from src.models.blm.model import Transformer


config = ModelArgs(**conf)
model = Transformer(config)