In [13]:
import torch
import torch.nn as nn
import lightning as pl

from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import functools

from datasets import load_from_disk, load_dataset

from tqdm.notebook import tqdm

In [2]:
class TinyStoriesDataloader(pl.LightningDataModule):
    def __init__(
        self, data_path_train, data_path_val, tokenizer_path, batch_size, num_workers
    ):
        super().__init__()
        self.data_path_train = data_path_train
        self.data_path_val = data_path_val

        self.batch_size = batch_size
        self.num_workers = num_workers

        self.tokenizer = self._load_tokenizer(tokenizer_path)

    def prepare_data(self):
        pass

    def _load_tokenizer(self, tokenizer_path):
        from src.tokenize.tokenizer import Tokenizer

        return Tokenizer(tokenizer_path)

    def _collate_fn(self, batch: int, padding_id: int):
        batch = pad_sequence(
            (torch.LongTensor(_["idx"]) for _ in batch),
            batch_first=True,
            padding_value=padding_id,
        )  # TODO : ShortTensor suffice our need but nn.Embedding don't support it. Using LOngTensor is a unnecessary waste of GPU memory
        x_batch = torch.stack(
            [en[:-1] for en in batch]
        )  # Extract x (remove last token)
        y_batch = torch.stack(
            [en[1:] for en in batch]
        )  # Extract y (remove first token)
        return x_batch, y_batch

    def setup(self, stage):

        self.train_data = load_from_disk(self.data_path_train)
        self.val_data = load_from_disk(self.data_path_val)

    def train_dataloader(self):
        return DataLoader(
            self.train_data,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=functools.partial(
                self._collate_fn, padding_id=self.tokenizer.eos_id()
            ),
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_data,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=functools.partial(
                self._collate_fn, padding_id=self.tokenizer.eos_id()
            ),
        )

In [3]:
BASE_URL = "/home/pranav-pc/projects/OpenTransformer/multiformer"
data_path_train = BASE_URL + "/data/interim/TinyStories_train_65>tk>512.hf"
data_path_val = BASE_URL + "/data/interim/TinyStories_val_65>tk>512.hf"
tokenizer_path = BASE_URL + "/tokenizer_checkpoints"

batch_size = 16
num_workers = 26
ds = TinyStoriesDataloader(
    data_path_train, data_path_val, tokenizer_path, batch_size, num_workers
)

In [10]:
data = load_from_disk(BASE_URL + "/data/interim/TinyStories_train_65>tk>512.hf")

In [11]:
ds.tokenizer.decode_ids(list(set(i[-1] for i in data["idx"])))

'displaying realized pocket garden leadingYes forcesoring ball hear patientatience StevenWatch collected finishlaimroom wings guard fell arrow bigger accomplished piece kick earthThere helpful ter aon orange decidedes datisreit thearle fasteral m surpriseedomioningiane talkedelas iniringet l couldn to Ichil ofam S A Although and strangeol vimadut guridig is M yveceieendodetermentation for L you bely performedte it on polarus " wh Comeirck an riverayand that vest fightumateess as withpe k circleiaantub jlaue do woods mom :( un thiske - judge The not thinillineopiz).og or was at priceaingeagene carry completedile by us z skills have can Inind from sleepable ch \' are sh trard he butwed herselfureackans playingple managedormak daughterase ifudite truthfer anywayive so smiled joined smokeies wild my we ment everywhere your seeing all battery er which par",oundishesactgratedier chanceally envportimmingll dependonearyma work earlier his useder will getousinkug noThank one upwe diver walked o

In [22]:
data_validation = load_dataset(
    BASE_URL + "/data/downloads/TinyStories",
    split="validation",
)

In [23]:
data_validation[0]

{'text': 'Spot. Spot saw the shiny car and said, "Wow, Kitty, your car is so bright and clean!" Kitty smiled and replied, "Thank you, Spot. I polish it every day."\n\nAfter playing with the car, Kitty and Spot felt thirsty. They found a small pond with clear water. They drank the water and felt very happy. They played together all day and became best friends.'}

In [81]:
def text2tokens(
    dataset,
    tokenizer,
    batch_size: int,
    batched: int,
    num_proc: int,
    text_col: str = "text",
):
    dataset = dataset.map(
        lambda x: {"text": [en.strip() for en in x[text_col]]},
        batch_size=batch_size,
        batched=batched,
        num_proc=num_proc,
    )

    dataset = dataset.map(
        lambda example: {"idx": print(tokenizer.encode(example["text"]))},
        batch_size=batch_size,
        batched=batched,
        num_proc=num_proc,
        remove_columns=["text"],
    )

    return dataset

In [86]:
tokens = text2tokens(data_validation, ds.tokenizer, int(1), True, 2, "text")

Map (num_proc=2):   0%|          | 0/21990 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/21990 [00:00<?, ? examples/s]

[[1706, 327, 29889, 1706, 327, 4446, 278, 528, 4901, 1559, 322, 1497, 29892, 376, 29956, 340, 29892, 476, 986, 29891, 29892, 596, 1559, 338, 577, 11785, 322, 5941, 3850, 476, 986, 29891, 25156, 322, 10352, 29892, 376, 25271, 366, 29892, 1706, 327, 29889, 306, 1248, 728, 372, 1432, 2462, 1213, 13, 13, 13555, 8743, 411, 278, 1559, 29892, 476, 986, 29891, 322, 1706, 327, 7091, 266, 765, 29891, 29889, 2688, 1476, 263, 2319, 282, 898, 411, 2821, 4094, 29889, 2688, 270, 10003, 278, 4094, 322, 7091, 1407, 9796, 29889, 2688, 5318, 4208, 599, 2462, 322, 3897, 1900, 7875, 29889]]
[[4335, 322, 365, 2354, 892, 1900, 7875, 29889, 2688, 23289, 304, 1708, 8090, 322, 1809, 12516, 4208, 29889, 3118, 2462, 29892, 896, 1476, 263, 4802, 9200, 6710, 297, 278, 14089, 29889, 739, 471, 528, 4901, 322, 4628, 322, 750, 263, 1472, 8014, 29889, 13, 13, 29908, 29956, 340, 29892, 1106, 472, 445, 3850, 4335, 1497, 29889, 376, 12024, 29915, 29879, 1018, 372, 3850, 13, 13, 15597, 3614, 12169, 304, 4808, 278, 9200, 671

TypeError: Provided `function` which is applied to all elements of table returns a `dict` of types [<class 'NoneType'>]. When using `batched=True`, make sure provided `function` returns a `dict` of types like `(<class 'list'>, <class 'numpy.ndarray'>, <class 'pandas.core.series.Series'>, <class 'torch.Tensor'>)`.

In [66]:
tokens[6]

{'idx': [9038,
  2501,
  263,
  931,
  29892,
  727,
  471,
  263,
  2217,
  17354,
  11203,
  4257,
  1706,
  327,
  29889,
  940,
  18012,
  304,
  1708,
  411,
  670,
  8287,
  297,
  278,
  14089,
  29889,
  3118,
  6575,
  1460,
  2462,
  29892,
  1706,
  327,
  4446,
  263,
  4802,
  7306,
  373,
  278,
  916,
  2625,
  310,
  278,
  14089,
  29889,
  940,
  5131,
  304,
  679,
  670,
  8287,
  964,
  278,
  7306,
  29889,
  13,
  13,
  5592,
  327,
  6350,
  5172,
  411,
  278,
  8287,
  297,
  670,
  13394,
  29889,
  940,
  1898,
  304,
  24817,
  278,
  8287,
  964,
  278,
  7306,
  29892,
  541,
  540,
  471,
  2086,
  2319,
  29889,
  1706,
  327,
  4687,
  304,
  21117,
  29889,
  940,
  1898,
  1449,
  322,
  1449,
  29892,
  541,
  278,
  8287,
  723,
  451,
  748,
  297,
  29889,
  13,
  13,
  11760,
  29892,
  1706,
  327,
  750,
  385,
  2969,
  29889,
  940,
  4433,
  670,
  5121,
  29892,
  263,
  4802,
  17354,
  10435,
  4257,
  7038,
  4518,
  29892,
  363,
  137

In [78]:
ds.setup("val")
val_dataloader = ds.val_dataloader()
for data in next(iter(val_dataloader)):
    print(data.shape)
    break

torch.Size([16, 262])


In [81]:
data, label = next(iter(val_dataloader))

In [82]:
data

tensor([[    1,  2296, 18691,  ...,  1568,   901,  2090],
        [    1,    13,    13,  ..., 29889,   450,  1095],
        [ 9038,  2501,   263,  ...,   347, 29915, 29879],
        ...,
        [ 9038,  2501,   263,  ...,   278,  6501,  4646],
        [ 9038,   727,   471,  ..., 29899, 23057,   287],
        [ 9038,  2501,   263,  ...,  1708,   777,   901]])

In [83]:
label

tensor([[ 2296, 18691,   701,  ...,   901,  2090,  1213],
        [   13,    13, 29931,  ...,   450,  1095, 29889],
        [ 2501,   263,   931,  ..., 29915, 29879, 29889],
        ...,
        [ 2501,   263,   931,  ...,  6501,  4646, 29889],
        [  727,   471,   263,  ..., 23057,   287, 29889],
        [ 2501,   263,   931,  ...,   777,   901, 29889]])

In [67]:
# Sanity Check

for idx, data in enumerate(ds.val_dataloader().dataset):
    print(data)
    if idx == 2:
        break

{'idx': [1, 2296, 18691, 701, 263, 12070, 322, 281, 10511, 372, 472, 278, 11460, 29889, 2296, 21272, 287, 322, 18318, 25702, 472, 372, 29889, 2296, 24936, 372, 723, 748, 3448, 29889, 13, 13, 1576, 11460, 471, 12327, 28133, 491, 365, 2354, 29915, 29879, 12070, 322, 25702, 29889, 739, 8459, 393, 4111, 322, 670, 1559, 892, 451, 7088, 278, 7458, 29889, 739, 5807, 18054, 322, 17096, 3448, 29892, 3063, 363, 385, 6775, 5807, 547, 29889, 13, 13, 20841, 7450, 278, 5970, 310, 278, 6872, 322, 298, 688, 3192, 365, 2354, 29889, 940, 471, 528, 5086, 322, 10901, 292, 29889, 940, 471, 7423, 363, 1641, 19281, 4939, 322, 17928, 728, 29889, 940, 6452, 287, 365, 2354, 363, 14238, 1075, 29889, 13, 13, 29908, 29902, 29915, 29885, 7423, 29892, 365, 2354, 29892, 366, 892, 1492, 29889, 450, 6872, 471, 451, 263, 14378, 29889, 739, 471, 451, 4023, 828, 404, 29889, 739, 471, 885, 653, 322, 4319, 29889, 306, 881, 505, 29616, 304, 366, 29889, 887, 526, 263, 1781, 5121, 29889, 3374, 366, 363, 19912, 592, 1213, 13, 1

In [33]:
ds.val_dataloader().__dict__

{'dataset': Dataset({
     features: ['idx'],
     num_rows: 12501
 }),
 'num_workers': 26,
 'prefetch_factor': 2,
 'pin_memory': True,
 'pin_memory_device': '',
 'timeout': 0,
 'worker_init_fn': None,
 '_DataLoader__multiprocessing_context': None,
 '_dataset_kind': 0,
 'batch_size': 16,
 'drop_last': False,
 'sampler': <torch.utils.data.sampler.SequentialSampler at 0x7fd4f262db90>,
 'batch_sampler': <torch.utils.data.sampler.BatchSampler at 0x7fd4f262fc90>,
 'generator': None,
 'collate_fn': functools.partial(<bound method TinyStoriesDataloader._collate_fn of <__main__.TinyStoriesDataloader object at 0x7fd4f2599850>>, padding_id=2),
 'persistent_workers': False,
 '_DataLoader__initialized': True,
 '_IterableDataset_len_called': None,
 '_iterator': None}