In [1]:
import logging
from argparse import ArgumentParser
from functools import partial
from pathlib import Path
from transformers  import AutoTokenizer
import datasets as ds
import numpy as np
from rich.progress import track

log = logging.getLogger(__name__)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def filter(example):
    return example["n_labels"] > 0

In [3]:
def preprocess(example, tokenizer, max_seq_len: int):
    input_ids = [tokenizer.eos_token_id]
    label_mask = [False]

    for msg in example["messages"]:
        role_tokens = tokenizer.encode(f"<|{msg['role']}|>\n", add_special_tokens=False)
        label_mask += [False] * len(role_tokens)
        input_ids += role_tokens

        if msg["role"] == "assistant":
            content_tokens = tokenizer.encode(
                msg["content"].strip() + tokenizer.eos_token + "\n", add_special_tokens=False
            )
            label_mask += [True] * len(content_tokens)
            # mask out the last '\n'
            assert content_tokens[-2] == tokenizer.eos_token_id
            label_mask[-1] = False
        else:
            content_tokens = tokenizer.encode(msg["content"].strip() + "\n", add_special_tokens=False)
            label_mask += [False] * len(content_tokens)
        input_ids += content_tokens

    input_ids = input_ids[:max_seq_len]
    label_mask = label_mask[:max_seq_len]

    if len(input_ids) < max_seq_len:
        pad_len = max_seq_len - len(input_ids)
        input_ids += [tokenizer.pad_token_id] * pad_len
        label_mask += [False] * pad_len

    assert len(input_ids) == len(label_mask)
    n_labels = sum(label_mask)

    return {"input_ids": input_ids, "label_mask": label_mask, "n_labels": n_labels}

In [4]:
s = 2048
eos = 50279
pad = 1
num_proc=8
seq_len=2048
output_dir = "./out"

In [5]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
dataset = ds.load_dataset("allenai/tulu-v2-sft-mixture", split="train")

log.info("Tokenizing dataset...")
dataset = dataset.map(
    partial(preprocess, tokenizer=tokenizer, max_seq_len=seq_len),
    batched=False,
    remove_columns=["dataset", "id", "messages"],
    num_proc=num_proc,  # type: ignore
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Map (num_proc=8):   0%|          | 0/326154 [00:00<?, ? examples/s]Token indices sequence length is longer than the specified maximum sequence length for this model (2449 > 2048). Running this sequence through the model will result in indexing errors
Map (num_proc=8):   0%|          | 348/326154 [00:00<03:29, 1554.48 examples/s]Token indices sequence length is longer than the specified maximum sequence length for this model (2328 > 2048). Running this sequence through the model will result in indexing errors
Map (num_proc=8):   0%|          | 707/326154 [00:00<02:19, 2324.89 examples/s]Token indices sequence length is longer than the specified maximum sequence length for this model (2246 > 2048). Running this sequence through the model will result in indexing errors
Map (num_proc=8):   0%|          | 1113/326154 [00:00<01:53, 2870.91 examples/s]Token indices sequence le

In [6]:
log.info("Filtering dataset...")
n = len(dataset)  # type: ignore
dataset = dataset.filter(filter, batched=False, num_proc=num_proc)  # type: ignore
log.info(f"Filtered out {n - len(dataset):,d} examples")

log.info("Counting tokens...")
total_tokens = 0
for ex in track(dataset):
    assert len(ex["input_ids"]) == seq_len  # type: ignore
    total_tokens += len(ex["input_ids"])  # type: ignore
log.info(f"Total tokens: {total_tokens:,d}")


Filter (num_proc=8): 100%|██████████| 326154/326154 [00:46<00:00, 7030.42 examples/s] 


In [7]:
log.info(f"Saving results to '{output_dir}'...")
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True, parents=True)

input_ids_file = np.memmap(
    str(output_dir / "input_ids.npy"), dtype=np.uint16, mode="w+", shape=(total_tokens,)
)
label_mask_file = np.memmap(
    str(output_dir / "label_mask.npy"), dtype=np.bool_, mode="w+", shape=(total_tokens,)
)

In [26]:
offset = 10
for ex in track(dataset):
    ex_len = len(ex["input_ids"])  # type: ignore
    print(ex["input_ids"])
    input_ids_file[offset : offset + 200] = ex["input_ids"][0:200]  # type: ignore
    label_mask_file[offset : offset + ex_len] = ex["label_mask"]  # type: ignore
    offset += ex_len
input_ids_file.flush()
label_mask_file.flush()

log.info("Done!")

TypeError: int() argument must be a string, a bytes-like object or a real number, not 'NoneType'

In [20]:
list = [1,2,3,4]
x=2
y=1
print(list[y:x])
list2=[1,4,4]
list[0:1] = list2
print(list)

[2]
[1, 4, 4, 2, 3, 4]
