In [17]:
import datasets
from tokenization import TRIETokenizerFast
from matplotlib import pyplot as plt
import json
from tqdm.notebook import tqdm
from dataloader import DatasetWriter, SingleDatasetReader
from typing import *
import numpy as np
from torch.utils.data import DataLoader
from functools import partial

In [2]:
tokenizer = TRIETokenizerFast('llama_vocab_pruned_32k.json')

In [13]:
class DatasetColumnIter:
    def __init__(self, dataset, col_name):
        self.dataset = dataset
        self.iter = iter(dataset)
        self.col_name = col_name

    def __next__(self):
        return next(self.iter)[self.col_name]

    def __iter__(self):
        return self

    def __len__(self):
        return len(self.dataset)

In [3]:
def chunk_texts(texts: Iterable[str], min_tokens: int, max_tokens: int, chunk_size: int, return_attn_mask: bool,
                prefix: str = '', postfix: str = '',
                drop_unaligned: bool = False):
    chunk, mask, mask_index = [], [], 1
    for text in tqdm(texts):
        encoded = tokenizer.encode(prefix + text + postfix)
        if len(encoded) < min_tokens or len(encoded) > max_tokens:
            continue
        encoded_cursor = 0
        while encoded_cursor < len(encoded):
            chunk_append_size = min(chunk_size - len(chunk), len(encoded) - encoded_cursor)
            if chunk_append_size == 0:
                yield (chunk, mask) if return_attn_mask else chunk
                chunk, mask = [], []
                mask_index = 1
                chunk_append_size = min(chunk_size, len(encoded) - encoded_cursor)
            chunk += encoded[encoded_cursor:encoded_cursor + chunk_append_size]
            mask += [mask_index] * chunk_append_size
            mask_index += 1
            encoded_cursor += chunk_append_size
    if len(chunk) > 0 and (not drop_unaligned or len(chunk) == chunk_size):
        yield (chunk, mask) if return_attn_mask else chunk


def serialize_dataset(file: str, texts: Iterable[str], min_tokens: int, max_tokens: int, chunk_size: int, enable_attn_mask: bool,
                      prefix: str = '', postfix: str = ''):
    if enable_attn_mask:
        writer = DatasetWriter(file, {'token_ids': np.uint16, 'attn_mask': np.uint16})
    else:
        writer = DatasetWriter(file, {'token_ids': np.uint16})
    for entry in chunk_texts(texts, min_tokens, max_tokens, chunk_size, enable_attn_mask, prefix, postfix, drop_unaligned=True):
        if enable_attn_mask:
            writer.add_entry(token_ids=np.array(entry[0], dtype=np.uint16), attn_mask=np.array(entry[1], dtype=np.uint16))
        else:
            writer.add_entry(token_ids=np.array(entry, dtype=np.uint16))
    writer.finish()

In [None]:
mini_pile = datasets.load_dataset('JeanKaddour/minipile', cache_dir='./corpus')

In [None]:
serialize_dataset('datasets/minipile_valid_masked_1024.bin', mini_pile['validation']['text'], min_tokens=128, max_tokens=2048 * 8, chunk_size=1024,
                  enable_attn_mask=True, prefix='<s>', postfix='</s>')

In [None]:
serialize_dataset('datasets/minipile_train_masked_1024.bin', mini_pile['train']['text'], min_tokens=128, max_tokens=2048 * 8, chunk_size=1024,
                  enable_attn_mask=True, prefix='<s>', postfix='</s>')

In [None]:
# with open('corpus/TinyStoriesV2-GPT4-train.txt', 'r') as temp_file:
#     tinystories_train = [l.strip() for l in temp_file.read().split('<|endoftext|>')][:-1]
# serialize_dataset('datasets/tinystories_train_masked.bin', tinystories_train, min_tokens=128, max_tokens=2048 * 8, chunk_size=2048, return_attn_mask=True,
#                   prefix='<s>', postfix='</s>')

In [None]:
# with open('corpus/TinyStoriesV2-GPT4-valid.txt', 'r') as temp_file:
#     tinystories_valid = [l.strip() for l in temp_file.read().split('<|endoftext|>')][:-1]
# serialize_dataset('datasets/tinystories_valid_masked.bin', tinystories_valid, min_tokens=128, max_tokens=2048 * 8, chunk_size=2048, return_attn_mask=True,
#                   prefix='<s>', postfix='</s>')

In [None]:
enwiki = datasets.load_dataset('teven/enwiki_100k', cache_dir='./corpus')['train']

In [None]:
def enwiki_filter(row, min_length):
    if len(row['text']) < min_length:
        return False
    if row['text'].endswith(':'):
        return False
    return True


enwiki = enwiki.filter(partial(enwiki_filter, min_length=128))

In [None]:
serialize_dataset('datasets/enwiki_train_masked_1024.bin', enwiki['text'], min_tokens=128, max_tokens=2048 * 8, chunk_size=1024,
                  enable_attn_mask=True, prefix='<s>', postfix='</s>')

In [None]:
tiny_textbooks = datasets.load_dataset('nampdn-ai/tiny-textbooks', cache_dir='./corpus')['train']

In [None]:
serialize_dataset('datasets/tinytextbooks_train_masked_1024.bin', tiny_textbooks['textbook'], min_tokens=128, max_tokens=2048 * 8, chunk_size=1024,
                  enable_attn_mask=True, prefix='<s>', postfix='</s>')

In [5]:
slimpajama = datasets.load_dataset('DKYoon/SlimPajama-6B', cache_dir='./corpus')

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

In [15]:
serialize_dataset('datasets/slimpajama_valid_masked_1024.bin',
                  DatasetColumnIter(slimpajama['validation'], 'text'),
                  min_tokens=128, max_tokens=2048 * 8, chunk_size=1024,
                  enable_attn_mask=True, prefix='<s>', postfix='</s>')

  0%|          | 0/9347 [00:00<?, ?it/s]

In [26]:
serialize_dataset('datasets/slimpajama_train_masked_1024.bin',
                  DatasetColumnIter(slimpajama['train'], 'text'),
                  min_tokens=128, max_tokens=2048 * 8, chunk_size=1024,
                  enable_attn_mask=True, prefix='<s>', postfix='</s>')

  0%|          | 0/5489000 [00:00<?, ?it/s]