In [1]:
import datasets
from tokenizers import TRIETokenizerFast
from matplotlib import pyplot as plt
import json
from tqdm.notebook import tqdm
from dataloader import DatasetWriter, SingleDatasetReader
from typing import *
import numpy as np
from torch.utils.data import DataLoader
from functools import partial

In [2]:
tokenizer = TRIETokenizerFast('llama_vocab_pruned_32k.json')

In [4]:
alpaca_gpt = datasets.load_dataset('vicgalle/alpaca-gpt4', cache_dir='./corpus')

In [8]:
airoboros = datasets.load_dataset('jondurbin/airoboros-2.2.1', cache_dir='./corpus')

In [11]:
wizardlm = datasets.load_dataset('WizardLM/WizardLM_evol_instruct_V2_196k', cache_dir='./corpus')

Repo card metadata block was not found. Setting CardData to empty.


In [142]:
def preview_chunk(token_ids, attn_mask, loss_mask):
    unique_mask = [i for i in set(attn_mask) if i != 0]
    print(f'Total {len(unique_mask)} dialogues within chunk.')
    for m in unique_mask:
        print(f'Dialogue index {m}')
        dialogue_token_ids = [token_ids[i] for i in range(len(token_ids)) if attn_mask[i] == m]
        dialogue_token_ids_loss = [token_ids[i] for i in range(len(token_ids)) if attn_mask[i] == m and loss_mask[i] == 1]
        print('Full text:')
        print(tokenizer.decode(dialogue_token_ids))
        print('-' * 80)
        print('Loss text:')
        print(tokenizer.decode(dialogue_token_ids_loss))
        print('=' * 80)

In [111]:
def dialogues_to_chunks(dialogues: List[List[Tuple[str, str]]], chunk_length: int, max_message_length: int, overlap_count: int):
    assert max_message_length * overlap_count < chunk_length, 'max_message_length * overlap_count >= chunk_size can cause infinite loop'

    skip_dialogue_count = 0

    start_tokens = tokenizer.encode('<s>A chat between User and Assistant.\n')

    mask_index = 0
    token_ids, attn_mask, loss_mask = [], [], []

    for dial in tqdm(dialogues):
        dial_encoded = [(m[0], tokenizer.encode(f'{m[0]}:{m[1]}' + {'User': '\n', 'Assistant': '</s>\n'}[m[0]])) for m in dial]
        if any(len(m[1]) > max_message_length for m in dial_encoded):
            skip_dialogue_count += 1
            continue

        mask_index += 1
        if chunk_length - len(token_ids) <= len(start_tokens):
            pad_length = chunk_length - len(token_ids)
            token_ids += [0 for _ in range(pad_length)]
            attn_mask += [0 for _ in range(pad_length)]
            loss_mask += [0 for _ in range(pad_length)]
            assert len(token_ids) == len(attn_mask) == len(loss_mask) == chunk_length
            yield token_ids, attn_mask, loss_mask
            mask_index = 1
            token_ids, attn_mask, loss_mask = start_tokens.copy(), [1 for _ in range(len(start_tokens))], [0 for _ in range(len(start_tokens))]
        else:
            token_ids += start_tokens
            attn_mask += [mask_index for _ in range(len(start_tokens))]
            loss_mask += [0 for _ in range(len(start_tokens))]

        msg_index, max_msg_index = 0, -1
        while msg_index < len(dial_encoded):
            src, msg = dial_encoded[msg_index]
            append_length = min(chunk_length - len(token_ids), len(msg))
            token_ids += msg[:append_length]
            attn_mask += [mask_index for _ in range(append_length)]
            loss_mask += [0 for _ in range(append_length)] if src == 'User' or msg_index <= max_msg_index else [1 for _ in range(append_length)]
            max_msg_index = max(msg_index, max_msg_index)
            if len(token_ids) == chunk_length:
                assert len(token_ids) == len(attn_mask) == len(loss_mask) == chunk_length
                yield token_ids, attn_mask, loss_mask
                mask_index = 1
                token_ids, attn_mask, loss_mask = [], [], []
                msg_index -= min(overlap_count, msg_index)
            else:
                msg_index += 1

    if len(token_ids) > len(start_tokens):
        pad_length = chunk_length - len(token_ids)
        token_ids += [0 for _ in range(pad_length)]
        attn_mask += [0 for _ in range(pad_length)]
        loss_mask += [0 for _ in range(pad_length)]
        assert len(token_ids) == len(attn_mask) == len(loss_mask) == chunk_length
        yield token_ids, attn_mask, loss_mask

    print(f'Skipped {skip_dialogue_count}/{len(dialogues)} dialogues.')

In [146]:
def write_out_dataset(file, entries):
    writer = DatasetWriter(file, {'token_ids': np.uint16, 'attn_mask': np.uint16, 'loss_mask': np.uint16})
    for (token_ids, attn_mask, loss_mask) in tqdm(entries):
        writer.add_entry(token_ids=np.array(token_ids, dtype=np.uint16),
                         attn_mask=np.array(attn_mask, dtype=np.uint16),
                         loss_mask=np.array(loss_mask, dtype=np.uint16))
    writer.finish()

In [112]:
dialogues_to_chunks_1024 = partial(dialogues_to_chunks, chunk_length=1024, max_message_length=450, overlap_count=1)

In [135]:
def alpaca_to_dialogue(alpaca_sample):
    return [('User', alpaca_sample['instruction'] + (f'\n{alpaca_sample["input"]}' if alpaca_sample['input'] != '' else '')),
            ('Assistant', alpaca_sample['output'])]


alpaca_diags = [alpaca_to_dialogue(d) for d in alpaca_gpt['train']]

In [136]:
alpaca_chunks = [x for x in dialogues_to_chunks_1024(alpaca_diags)]

  0%|          | 0/52002 [00:00<?, ?it/s]

Skipped 2888 dialogues.


In [147]:
write_out_dataset('datasets/sft/alpaca_gpt4.bin', alpaca_chunks)

  0%|          | 0/9956 [00:00<?, ?it/s]

In [155]:
airoboros_diags = [[('User', d['instruction']), ('Assistant', d['response'])] for d in airoboros['train'] if not 'contextual' in d['category']]

In [158]:
airoboros_chunks = [x for x in dialogues_to_chunks_1024(airoboros_diags)]

  0%|          | 0/40331 [00:00<?, ?it/s]

Skipped 11984 dialogues.


In [160]:
write_out_dataset('datasets/sft/airoboros_2.2.1.bin', airoboros_chunks)

  0%|          | 0/7447 [00:00<?, ?it/s]

In [113]:
wizardlm_diags = [[('User' if msg['from'] == 'human' else 'Assistant', msg['value']) for msg in diag] for diag in wizardlm['train']['conversations']]

In [117]:
wizardlm_chunks = [x for x in dialogues_to_chunks_1024(wizardlm_diags)]

  0%|          | 0/143000 [00:00<?, ?it/s]

Skipped 61171 dialogues.


In [151]:
write_out_dataset('datasets/sft/wizardlm_evol_2.bin', wizardlm_chunks)

  0%|          | 0/35655 [00:00<?, ?it/s]