In [None]:
import datasets
from tokenizers import TRIETokenizerFast
from matplotlib import pyplot as plt
import json
from tqdm.notebook import tqdm
from dataloader import DatasetWriter, DatasetReader
from typing import *
import numpy as np
from torch.utils.data import DataLoader

In [None]:
tokenizer = TRIETokenizerFast('llama_vocab_pruned_32k.json')

In [None]:
mini_pile = datasets.load_dataset('JeanKaddour/minipile', cache_dir='./corpus')

In [None]:
def chunk_texts(texts: Iterable[str], min_tokens: int, max_tokens: int, chunk_size: int,
                prefix: str = '', postfix: str = '',
                drop_unaligned: bool = False):
    chunk = []
    for text in tqdm(texts):
        encoded = tokenizer.encode(prefix + text + postfix)
        if len(encoded) < min_tokens or len(encoded) > max_tokens:
            continue
        encoded_cursor = 0
        while encoded_cursor < len(encoded):
            chunk_append_size = min(chunk_size - len(chunk), len(encoded) - encoded_cursor)
            if chunk_append_size == 0:
                yield chunk
                chunk = []
                chunk_append_size = chunk_size
            chunk += encoded[encoded_cursor:encoded_cursor + chunk_append_size]
            encoded_cursor += chunk_append_size
    if len(chunk) > 0 and (not drop_unaligned or len(chunk) == chunk_size):
        yield chunk


def serialize_dataset(file: str, texts: Iterable[str], min_tokens: int, max_tokens: int, chunk_size: int,
                      prefix: str = '', postfix: str = ''):
    writer = DatasetWriter(file, {'token_ids': np.uint16})
    for entry in chunk_texts(texts, min_tokens, max_tokens, chunk_size, prefix, postfix, drop_unaligned=True):
        writer.add_entry(token_ids=np.array(entry, dtype=np.uint16))
    writer.finish()


serialize_dataset('datasets/minipile_validation.bin', mini_pile['validation']['text'], min_tokens=128, max_tokens=2048 * 8, chunk_size=2048,
                  prefix='<s>', postfix='</s>')