In [1]:
import datasets
from tokenizers import TRIETokenizerFast
from matplotlib import pyplot as plt
import json
from tqdm.notebook import tqdm
from dataloader import DatasetWriter, DatasetReader
from typing import *
import numpy as np
from torch.utils.data import DataLoader

In [2]:
tokenizer = TRIETokenizerFast('llama_vocab_pruned_32k.json')

In [3]:
def chunk_texts(texts: Iterable[str], min_tokens: int, max_tokens: int, chunk_size: int, return_attn_mask: bool,
                prefix: str = '', postfix: str = '',
                drop_unaligned: bool = False):
    chunk, mask, mask_index = [], [], 1
    for text in tqdm(texts):
        encoded = tokenizer.encode(prefix + text + postfix)
        if len(encoded) < min_tokens or len(encoded) > max_tokens:
            continue
        encoded_cursor = 0
        while encoded_cursor < len(encoded):
            chunk_append_size = min(chunk_size - len(chunk), len(encoded) - encoded_cursor)
            if chunk_append_size == 0:
                yield (chunk, mask) if return_attn_mask else chunk
                chunk, mask = [], []
                mask_index = 1
                chunk_append_size = min(chunk_size, len(encoded) - encoded_cursor)
            chunk += encoded[encoded_cursor:encoded_cursor + chunk_append_size]
            mask += [mask_index] * chunk_append_size
            mask_index += 1
            encoded_cursor += chunk_append_size
    if len(chunk) > 0 and (not drop_unaligned or len(chunk) == chunk_size):
        yield (chunk, mask) if return_attn_mask else chunk


def serialize_dataset(file: str, texts: Iterable[str], min_tokens: int, max_tokens: int, chunk_size: int, return_attn_mask: bool,
                      prefix: str = '', postfix: str = ''):
    if return_attn_mask:
        writer = DatasetWriter(file, {'token_ids': np.uint16, 'attn_mask': np.uint16})
    else:
        writer = DatasetWriter(file, {'token_ids': np.uint16})
    for entry in chunk_texts(texts, min_tokens, max_tokens, chunk_size, return_attn_mask, prefix, postfix, drop_unaligned=True):
        if return_attn_mask:
            writer.add_entry(token_ids=np.array(entry[0], dtype=np.uint16), attn_mask=np.array(entry[1], dtype=np.uint16))
        else:
            writer.add_entry(token_ids=np.array(entry, dtype=np.uint16))
    writer.finish()

In [4]:
# mini_pile = datasets.load_dataset('JeanKaddour/minipile', cache_dir='./corpus')

In [5]:
# serialize_dataset('datasets/minipile_valid.bin', mini_pile['validation']['text'], min_tokens=128, max_tokens=2048 * 8, chunk_size=2048,
#                   prefix='<s>', postfix='</s>')

In [6]:
# serialize_dataset('datasets/minipile_train.bin', mini_pile['train']['text'], min_tokens=128, max_tokens=2048 * 8, chunk_size=2048,
#                   prefix='<s>', postfix='</s>')

In [7]:
with open('corpus/TinyStoriesV2-GPT4-train.txt', 'r') as temp_file:
    tinystories_train = [l.strip() for l in temp_file.read().split('<|endoftext|>')][:-1]

In [8]:
serialize_dataset('datasets/tinystories_train_masked.bin', tinystories_train, min_tokens=128, max_tokens=2048 * 8, chunk_size=2049, return_attn_mask=True,
                  prefix='<s>', postfix='</s>')

  0%|          | 0/2717699 [00:00<?, ?it/s]

In [None]:
reader = DatasetReader('datasets/debug_data_masked.bin')

In [None]:
import torch

In [None]:
def attn_mask_to_seq_indices(attn_mask: torch.Tensor):
    seq_len = attn_mask.shape[0] * attn_mask.shape[1]
    disjoint_point = torch.cat([torch.tensor([[True]] * attn_mask.shape[0]), attn_mask[:, 1:] != attn_mask[:, :-1]], dim=1)
    return torch.cat([torch.nonzero(disjoint_point.view((-1,)))[1:], torch.tensor([[seq_len]])])


iterator = iter(reader)
ids = []
masks = []
for i in range(16):
    t = next(iterator)
    masks.append(t['attn_mask'])
    ids.append(t['token_ids'])
ids = torch.tensor(np.stack(ids, dtype=np.int32))

s = attn_mask_to_seq_indices(torch.tensor(np.stack(masks, dtype=np.int32)))