# Tokenize LibriTTS-R Mimi for target LM

For our dataset, we currently simply use the Fish Speech TTS format:
- Text-only data formatted using [ChatML](https://gist.github.com/edwardzjl/8df07c1f7140c9a3e2f48d33a8032090) as a separate sequence "above" the audio code stream
- During sections where audio is being modeled, text stream 0 predicts the first semantic token index $n$ of the 8 Mimi residual codes as special token `<|semantic:n|>`
- For audio, "semantic" (neural, there's not a strong distinction between) codes (from Mimi) padded with 0s during text sections

It's possible this tokenization strategy can be improved, e.g. in [Defossez et al. 2024](https://arxiv.org/html/2410.00037v2#S3.SS4.SSS4) with the base transformer predicting the Whisper-timestamped word timings as an "inner monologue" and a delay between codebook timesteps. lol i'll do it later

In [2]:
from dotenv import load_dotenv
from datasets import load_dataset, DatasetDict, concatenate_datasets, load_from_disk
from data_pipeline.utils.prompt import PromptEncoder, TokenizationConfig
import os

load_dotenv()
# If creating the libritts dataset for the first time
# dataset = load_from_disk("../../Kokoro-82M/libritts_r_mimi_kokoro")
dataset = load_dataset("jkeisling/project-gutenberg-kokoro-2K", token=os.getenv("HUGGINGFACE_TOKEN"))
# full_train = concatenate_datasets([dataset["train.clean.100"], dataset["train.clean.360"]])

# dataset = DatasetDict({
#     "train": full_train,
#     "val": dataset["dev.clean"],
#     "test": dataset["test.clean"]
# })
# dataset = DatasetDict({"full": dataset})
dataset = dataset.with_format("torch")
# dataset = dataset.remove_columns(["chapter_id", "text_original"])
# dataset = dataset.rename_column(original_column_name="text_normalized", new_column_name="normalized_text")
dataset = dataset.rename_column(original_column_name="sentences", new_column_name="text_normalized")

config = TokenizationConfig()

Generating train split: 100%|██████████| 576820/576820 [00:01<00:00, 302411.41 examples/s]


In [3]:
FRAMERATE = 12.5
# NOTE: DELETE THIS, HARD-CODED ASSUMPTION
dataset = dataset.filter(lambda row: row["codes"].size(-1) <= 15 * FRAMERATE, num_proc=12)

Filter (num_proc=12): 100%|██████████| 576820/576820 [00:02<00:00, 248711.59 examples/s]


**NOTE! This is PATH DEPENDENT on ADDING THE SEMANTIC TOKENS TO THE TOKENIZER EARLIER using `create_bytelevel_init.ipynb`. DO NOT SKIP THIS STEP OR THE MODEL WILL BE IRRETRIEVABLY BROKEN! YOU HAVE BEEN WARNED.**

==**THIS IS BYTE LEVEL!**==

In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("../../inits/smoltts_byte_kokoro")
tokenizer.use_default_system_prompt = False

Check this carefully: for byte level, it should be 256.

In [5]:
len(tokenizer), tokenizer.vocab_size

(2368, 256)

Please manually verify the text is done correctly. However, DECODE will not work.

In [6]:
# Test the tokenizer by encoding and decoding some example text
example_text = "<|im_start|>system\n<|american|><|male|><|im_end|>"
encoded = tokenizer(example_text, return_tensors="pt")
print(f"Encoded: {encoded['input_ids']}")
decoded = tokenizer.decode(encoded['input_ids'][0])

# Print the results
decoded

Encoded: tensor([[269, 256,  10, 260, 261, 270]])


'<|im_start|>system\n<|american|><|male|><|im_end|>'

In [7]:
sequence = tokenizer.apply_chat_template([{"role": "user", "content": "help me i am trapped in this computer"}], add_generation_prompt=True,  return_tensors="pt")
sequence

tensor([[269, 257,  10, 104, 101, 108, 112,  32, 109, 101,  32, 105,  32,  97,
         109,  32, 116, 114,  97, 112, 112, 101, 100,  32, 105, 110,  32, 116,
         104, 105, 115,  32,  99, 111, 109, 112, 117, 116, 101, 114, 270,  10,
         269, 258,  10]])

In [8]:
import torch

prompt_encoder = PromptEncoder(tokenizer, config)
tts_sysprompt = prompt_encoder.encode_text_turn(role="system", content="<|speaker:40|>", add_generation_prompt=False)
tokenizer.decode(tts_sysprompt[0,:])

'<|im_start|>system\n<|speaker:40|><|im_end|>\n'

Note that this assumes you're using ChatML. if you're NOT, then there's quite a bit more to fix.

In [9]:
out = prompt_encoder.encode_vq(dataset["full"][0]["codes"])
tokenizer.decode(out[0,:])

'<|semantic:1049|><|semantic:127|><|semantic:1880|><|semantic:1031|><|semantic:1031|><|semantic:1031|><|semantic:1492|><|semantic:1492|><|semantic:1926|><|semantic:1926|><|semantic:1826|><|semantic:1268|><|semantic:1001|><|semantic:382|><|semantic:587|><|semantic:178|><|semantic:1380|><|semantic:1380|><|semantic:1790|><|semantic:117|><|semantic:130|><|semantic:531|><|semantic:2036|><|semantic:722|><|semantic:1470|><|semantic:1725|><|semantic:371|><|semantic:728|><|semantic:774|><|semantic:1677|><|semantic:518|><|semantic:769|><|semantic:666|><|semantic:84|><|semantic:84|><|semantic:752|><|semantic:752|><|semantic:752|><|semantic:752|><|semantic:752|><|semantic:1926|><|semantic:1926|><|semantic:1926|><|semantic:1926|><|im_end|>\n'

In [10]:
out_corrupt = prompt_encoder.encode_vq_corrupt(dataset["full"][0]["codes"])
tokenizer.decode(out_corrupt[0,:])

'<|semantic:1049|><|semantic:127|><|semantic:1880|><|semantic:1031|><|semantic:1031|><|semantic:1031|><|semantic:1492|><|semantic:1492|><|semantic:1926|><|semantic:1926|><|semantic:1826|><|semantic:1268|><|semantic:1001|><|semantic:382|><|semantic:587|><|semantic:178|><|semantic:1380|><|semantic:1380|><|semantic:1790|><|semantic:117|><|semantic:130|><|semantic:531|><|semantic:2036|><|semantic:722|><|semantic:1470|><|semantic:1725|><|semantic:371|><|semantic:728|><|semantic:774|><|semantic:1677|><|semantic:518|><|semantic:769|><|semantic:666|><|semantic:84|><|semantic:84|><|semantic:752|><|semantic:752|><|semantic:752|><|semantic:752|><|semantic:752|><|semantic:1926|><|semantic:1926|><|semantic:1926|><|semantic:1926|><|im_end|>\n'

In [11]:
speaker_names = ["default", "sarah", "sky", "adam", "emma", "isabella", "george", "lewis"]
speaker_ids = {value: index for index, value in enumerate(speaker_names)}
speaker_ids["adam"]

3

In [12]:
from typing import Dict
# import random

# TODO: Not doing ASR for now
def tts_tokenize_row(row: Dict):
    """
    NOTE: Deliberately ignores sysprompt line for now, can be done in packing
    """
    # TODO: Fix this upstream in the data gen!
    # gender = "<|male|>" if row["speaker_id"] in ["george", "lewis", "adam", "michael"] else "<|female|>"
    # accent = f"<|{row['accent']}|>"
    # speaker = f"<|speaker:{speaker_ids[row['speaker_id']]}|>" if random.random() < 0.7 else ""
    speaker = f"<|speaker:{speaker_ids[row['speaker_id']]}|>"

    # Just keep it all for now, will test generalization later
    system_line = prompt_encoder.encode_text_turn(role="system", content="".join([speaker]))
    user_line = prompt_encoder.encode_text_turn(
        role="user", 
        content=row["text_normalized"].encode("utf-8").decode("latin-1"), 
        add_generation_prompt=True
    )
    assistant_line = prompt_encoder.encode_vq(row["codes"])
    ground_truth = torch.cat([system_line, user_line, assistant_line], dim=1)
    # ground_truth = torch.cat([user_line, assistant_line], dim=1)
    # Causal shift
    tokens = ground_truth[:,:-1].clone()
    labels = ground_truth[:,1:].clone()

    # Assuming user line took care of assistant prefix 
    # Offsetting by 1 since labels were shifted
    text_only_length = system_line.size(1) + user_line.size(1) - 1
    labels[1:, :text_only_length] = -100
    # Mask out im_end and newline
    labels[1:, -2:] = -100

    return({
        "tokens": tokens,
        "labels": labels
    })
    


example_row = tts_tokenize_row(dataset["full"][10])
tokenizer.decode(example_row["tokens"][0,:])

'<|im_start|>system\n<|speaker:1|><|im_end|>\n<|im_start|>assistant\n<|im_start|>user\nUncle Roland was gone.<|im_end|>\n<|im_start|>assistant\n<|semantic:1049|><|semantic:127|><|semantic:1880|><|semantic:1031|><|semantic:1031|><|semantic:1031|><|semantic:1492|><|semantic:1492|><|semantic:1926|><|semantic:1926|><|semantic:1826|><|semantic:824|><|semantic:1754|><|semantic:1177|><|semantic:1823|><|semantic:526|><|semantic:1260|><|semantic:1998|><|semantic:95|><|semantic:622|><|semantic:382|><|semantic:1238|><|semantic:160|><|semantic:1597|><|semantic:1663|><|semantic:1528|><|semantic:1987|><|semantic:666|><|semantic:716|><|semantic:84|><|semantic:84|><|semantic:752|><|semantic:752|><|semantic:752|><|semantic:752|><|semantic:1926|><|semantic:1926|><|semantic:1926|><|semantic:1926|><|im_end|>'

In [13]:

from typing import Dict
# import random

# TODO: Not doing ASR for now
def tts_tokenize_row_dropout(row: Dict):
    """
    NOTE: Deliberately ignores sysprompt line for now, can be done in packing
    """
    # TODO: Fix this upstream in the data gen!
    # gender = "<|male|>" if row["speaker_id"] in ["george", "lewis", "adam", "michael"] else "<|female|>"
    # accent = f"<|{row['accent']}|>"
    # speaker = f"<|speaker:{speaker_ids[row['speaker_id']]}|>" if random.random() < 0.7 else ""
    speaker = f"<|speaker:{speaker_ids[row['speaker_id']]}|>"

    # Just keep it all for now, will test generalization later
    system_line = prompt_encoder.encode_text_turn(role="system", content="".join([speaker]))
    user_line = prompt_encoder.encode_text_turn(
        role="user", 
        content=row["text_normalized"].encode("utf-8").decode("latin-1"), 
        add_generation_prompt=True
    )
    assistant_line_true = prompt_encoder.encode_vq(row["codes"])
    assistant_line_dropout = prompt_encoder.encode_vq_corrupt(row["codes"], dropout=0.3)
    messy_input = torch.cat([system_line, user_line, assistant_line_dropout], dim=1)
    ground_truth = torch.cat([system_line, user_line, assistant_line_true], dim=1)
    # Causal shift
    tokens = messy_input[:,:-1]
    labels = ground_truth[:,1:]

    # Assuming user line took care of assistant prefix 
    # Offsetting by 1 since labels were shifted
    text_only_length = system_line.size(1) + user_line.size(1) - 1
    labels[1:, :text_only_length] = -100
    # Mask out im_end and newline
    labels[1:, -2:] = -100

    return({
        "tokens": tokens,
        "labels": labels
    })
    


example_row = tts_tokenize_row_dropout(dataset["full"][10])
tokenizer.decode(example_row["tokens"][0,:])

'<|im_start|>system\n<|speaker:1|><|im_end|>\n<|im_start|>assistant\n<|im_start|>user\nUncle Roland was gone.<|im_end|>\n<|im_start|>assistant\n<|semantic:1049|><|semantic:127|><|semantic:1880|><|semantic:1031|><|semantic:1031|><|semantic:1031|><|semantic:1492|><|semantic:1492|><|semantic:1926|><|semantic:1926|><|semantic:1826|><|semantic:824|><|semantic:1754|><|semantic:1177|><|semantic:1823|><|semantic:526|><|semantic:1260|><|semantic:1998|><|semantic:95|><|semantic:622|><|semantic:382|><|semantic:1238|><|semantic:160|><|semantic:1597|><|semantic:1663|><|semantic:1528|><|semantic:1987|><|semantic:666|><|semantic:716|><|semantic:84|><|semantic:84|><|semantic:752|><|semantic:752|><|semantic:752|><|semantic:752|><|semantic:1926|><|semantic:1926|><|semantic:1926|><|semantic:1926|><|im_end|>'

In [14]:
example_row["tokens"]

tensor([[ 269,  256,   10,  272,  270,   10,  269,  258,   10,  269,  257,   10,
           85,  110,   99,  108,  101,   32,   82,  111,  108,   97,  110,  100,
           32,  119,   97,  115,   32,  103,  111,  110,  101,   46,  270,   10,
          269,  258,   10, 1369,  447, 2200, 1351, 1351, 1351, 1812, 1812, 2246,
         2246, 2146, 1144, 2074, 1497, 2143,  846, 1580, 2318,  415,  942,  702,
         1558,  480, 1917, 1983, 1848, 2307,  986, 1036,  404,  404, 1072, 1072,
         1072, 1072, 2246, 2246, 2246, 2246,  270],
        [   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0, 1049,  127, 1880, 1031, 1031, 1031, 1492, 1492, 1926,
         1926, 1826,  824, 1754, 1177, 1823,  526, 1260, 1998,   95,  622,  382,
         1238,  160, 1597, 1663, 1528, 1987,  666,  716, 

In [15]:
example_row["labels"]

tensor([[ 256,   10,  272,  270,   10,  269,  258,   10,  269,  257,   10,   85,
          110,   99,  108,  101,   32,   82,  111,  108,   97,  110,  100,   32,
          119,   97,  115,   32,  103,  111,  110,  101,   46,  270,   10,  269,
          258,   10, 1369,  447, 2200, 1351, 1351, 1351, 1812, 1812, 2246, 2246,
         2146, 1144, 2074, 1497, 2143,  846, 1580, 2318,  415,  942,  702, 1558,
          480, 1917, 1983, 1848, 2307,  986, 1036,  404,  404, 1072, 1072, 1072,
         1072, 2246, 2246, 2246, 2246,  270,   10],
        [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, 1049,  127, 1880, 1031, 1031, 1031, 1492, 1492, 1926, 1926,
         1826,  824, 1754, 1177, 1823,  526, 1260, 1998,   95,  622,  382, 1238,
          160, 1597, 1663, 1528, 1987,  666,  716,   84, 

In [16]:
# DO NOT INCREASE batch size
dataset = dataset.map(tts_tokenize_row, remove_columns="codes", num_proc=24)

Map (num_proc=24): 100%|██████████| 539921/539921 [00:46<00:00, 11506.33 examples/s]


In [17]:
dataset.save_to_disk("../../datasets/tokenized_project_gutenberg_bytes_kokoro_tau")

Saving the dataset (36/36 shards): 100%|██████████| 539921/539921 [01:03<00:00, 8533.77 examples/s] 


## Appendix: Testing

In [23]:
import torch

def collate_fn(batch, semantic_pad_id: int):
    """
    batch is a list of dicts: each dict has "tokens" shape [9, T],
    and "labels" shape [9, T].
    We pad them into [B, 9, T_max].
    """
    max_input_len = max(item["tokens"].shape[1] for item in batch)

    B = len(batch)
    # We'll create padded arrays:
    tokens = torch.full((B, 9, max_input_len), 0, dtype=torch.long)  # 2=some <PAD>
    tokens[:, 0, :] = semantic_pad_id
    labels = torch.full(
        (B, 9, max_input_len), -100, dtype=torch.long
    )  # default is ignore_index

    pad_mask = torch.ones(B, max_input_len)

    for i, item in enumerate(batch):
        seq_len = item["tokens"].shape[1]
        tokens[i, :, :seq_len] = item["tokens"]
        labels[i, :, :seq_len] = item["labels"][:, :seq_len]
        pad_mask[i, :seq_len] = False

    return {"tokens": tokens, "labels": labels, "pad_mask": pad_mask}

# Create two test sequences of different lengths
seq1 = torch.randint(1, 100, (9, 5))  # Short sequence
seq2 = torch.randint(1, 100, (9, 8))  # Longer sequence

batch = [
    {"tokens": seq1, "labels": seq1},
    {"tokens": seq2, "labels": seq2}
]

# Test the collation
semantic_pad_id = 999
result = collate_fn(batch, semantic_pad_id)

print("Tokens shape:", result["tokens"].shape)
print("\nFirst sequence tokens:")
print(result["tokens"][0])
print("\nSecond sequence tokens:")
print(result["tokens"][1])
print("\nPadding mask:")
print(result["pad_mask"])

# Let's verify:
# 1. Sequences are left-aligned
# 2. Padding is applied correctly
# 3. Padding mask matches content

# Check alignment of first sequence (should be at start)
print("\nFirst 5 tokens of first sequence row 1:")
print(result["tokens"][0, 1, :5])
print("Next 3 tokens (should be 0s):")
print(result["tokens"][0, 1, 5:8])

# Check padding of first row
print("\nFirst row padding for batch item 0:")
print(result["tokens"][0, 0, :8])  # Should be semantic_pad_id

# Check mask alignment
print("\nFirst sequence mask (False=content, True=padding):")
print(result["pad_mask"][0])

Tokens shape: torch.Size([2, 9, 8])

First sequence tokens:
tensor([[ 21,  17,  10,  78,  48, 999, 999, 999],
        [ 58,   6,  85,  54,  88,   0,   0,   0],
        [ 97,  89,   7,  37,  36,   0,   0,   0],
        [ 91,  37,  32,   9,  76,   0,   0,   0],
        [ 73,  23,  67,  80,  82,   0,   0,   0],
        [ 91,  63,  17,  87,  50,   0,   0,   0],
        [ 28,  65,  62,  31,  80,   0,   0,   0],
        [  5,  30,  70,  23,  44,   0,   0,   0],
        [ 36,  82,  58,  98,  18,   0,   0,   0]])

Second sequence tokens:
tensor([[37, 99, 29, 58, 78, 19, 17, 26],
        [23, 59,  2, 53, 83, 66,  1, 38],
        [74, 24,  9, 70, 49, 61, 70, 54],
        [51, 61, 86, 16, 99, 93, 12, 14],
        [39, 79,  4, 84, 33, 17, 98, 79],
        [19, 68,  8, 97, 55, 93, 65, 65],
        [ 6, 56, 99, 23,  6, 93, 64,  1],
        [39, 57, 27, 59, 85, 61, 49, 27],
        [91, 28, 76, 47, 86, 95, 30, 19]])

Padding mask:
tensor([[0., 0., 0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 0.,

In [26]:
def get_length(example):
    return {'length': example['labels'].shape[1]}

max_len = 0
def update_max(example):
    global max_len
    max_len = max(max_len, example['length'])
    return example

# Apply the transformations
dataset["train"].map(
    get_length,
    desc="Getting sequence lengths"
).map(
    update_max,
    desc="Finding maximum"
)

print(f"Maximum sequence length: {max_len}")

Getting sequence lengths: 100%|██████████| 149658/149658 [01:09<00:00, 2150.41 examples/s]
Finding maximum: 100%|██████████| 149658/149658 [01:20<00:00, 1859.10 examples/s]

Maximum sequence length: 893





In [33]:
import numpy as np

# Get arrays from dataset
text_lengths = np.array([len(x) for x in dataset["train"]['normalized_text']])
seq_lengths = np.array([x.shape[1] for x in dataset["train"]['labels']])

# Calculate ratios
ratios = text_lengths / seq_lengths

# Basic stats
print(f"Mean ratio: {ratios.mean():.3f}")
print(f"Std ratio: {ratios.std():.3f}")
print(f"\nPercentile distribution:")
for p in [1, 5, 25, 50, 75, 95, 99]:
    print(f"{p}th percentile: {np.percentile(ratios, p):.3f}")

# Find extreme outliers (3 std from mean)
mean, std = ratios.mean(), ratios.std()
outliers = np.where(np.abs(ratios - mean) > 3 * std)[0]
if len(outliers) > 0:
    print(f"\nFound {len(outliers)} outliers")
    print("\nSample of 5 outlier examples:")
    for idx in outliers[:5]:
        print(f"\nIndex {int(idx)}")  # Convert numpy int to Python int
        print(f"Text ({text_lengths[idx]} chars): {dataset['val'][int(idx)]['normalized_text'][:100]}...")  # Convert idx
        print(f"Sequence length: {seq_lengths[idx]}")
        print(f"Ratio: {ratios[idx]:.3f}")

Mean ratio: 0.442
Std ratio: 0.139

Percentile distribution:
1th percentile: 0.089
5th percentile: 0.189
25th percentile: 0.366
50th percentile: 0.455
75th percentile: 0.515
95th percentile: 0.691
99th percentile: 0.816

Found 521 outliers

Sample of 5 outlier examples:

Index 245
Text (278 chars): The passers by were immediately struck with wonder....
Sequence length: 318
Ratio: 0.874

Index 1426
Text (285 chars): "What would you suggest!"...
Sequence length: 330
Ratio: 0.864

Index 1446
Text (368 chars): As he did so, Tad felt himself gradually sinking into the sombre depths....
Sequence length: 426
Ratio: 0.864

Index 2947
Text (371 chars): It had its points....
Sequence length: 419
Ratio: 0.885

Index 2971
Text (408 chars): No answer; though I allowed a more than decent interval....
Sequence length: 474
Ratio: 0.861


In [12]:
NEWLINE_SEPARATOR = torch.tensor(tokenizer.encode("\n") + [0] * 8).unsqueeze(1)

def batch_pack_sequences(examples, window_size=768, max_items=5):
   """
   Pack sequences with system prompt and metrics
   """
   packed_tokens = []
   packed_labels = []
   packed_speakers = []
   pack_lengths = []
   items_per_pack = []
   
   tokens = examples['tokens']
   labels = examples['labels']
   speakers = examples['speaker_id']
   
   # Account for system prompt in window size
   effective_window = window_size - tts_sysprompt.shape[1]
   
   for i in range(len(tokens)):
       seq_len = tokens[i].shape[1]
       
       # Start new pack
       if i == 0 or current_length + seq_len > effective_window or \
          current_speaker != speakers[i] or current_items >= max_items:
           
           # Save previous pack if it exists
           if i > 0 and current_tokens:
               packed_tokens.append(torch.cat(current_tokens, dim=1))
               packed_labels.append(torch.cat(current_labels, dim=1))
               packed_speakers.append(current_speaker)
               pack_lengths.append(current_length + tts_sysprompt.shape[1])
               items_per_pack.append(current_items)
           
           # Initialize new pack with system prompt
           current_tokens = [tts_sysprompt, tokens[i]]
           current_labels = [tts_sysprompt, labels[i]]
           current_speaker = speakers[i]
           current_length = seq_len
           current_items = 1
           continue
           
       # Add to current pack with separator
       current_tokens.extend([NEWLINE_SEPARATOR, tokens[i]])
       current_labels.extend([NEWLINE_SEPARATOR, labels[i]])
       current_length += seq_len + 1
       current_items += 1
   
   # Don't forget last pack
   if current_tokens:
       packed_tokens.append(torch.cat(current_tokens, dim=1))
       packed_labels.append(torch.cat(current_labels, dim=1))
       packed_speakers.append(current_speaker)
       pack_lengths.append(current_length + tts_sysprompt.shape[1])
       items_per_pack.append(current_items)
   
   return {
       'tokens': packed_tokens,
       'labels': packed_labels,
       'speaker_id': packed_speakers,
       'pack_length': pack_lengths,
       'items_in_pack': items_per_pack
   }

In [16]:
# Usage:
packed_dataset = dataset.map(
    lambda row: batch_pack_sequences(row, max_items=3),
    batched=True,
    remove_columns=dataset['val'].column_names,
    batch_size=1000  # Adjust based on memory constraints
)

Map: 100%|██████████| 149658/149658 [00:31<00:00, 4684.03 examples/s]
Map: 100%|██████████| 5736/5736 [00:01<00:00, 4746.63 examples/s]
Map: 100%|██████████| 4837/4837 [00:01<00:00, 4760.96 examples/s]


In [14]:
example_row = packed_dataset['val'][0]
tokenizer.decode(example_row["tokens"][0,:])

'<|im_start|>system\nSpeak out the provided text<|im_end|>\n<|im_start|>assistant\n<|im_start|>user\nThe weapon must still have been there.<|im_end|>\n<|im_start|>assistant\n<|semantic:1049|><|semantic:1114|><|semantic:1609|><|semantic:784|><|semantic:499|><|semantic:260|><|semantic:1011|><|semantic:8|><|semantic:1407|><|semantic:540|><|semantic:1615|><|semantic:561|><|semantic:1945|><|semantic:201|><|semantic:1324|><|semantic:668|><|semantic:376|><|semantic:1849|><|semantic:9|><|semantic:1921|><|semantic:1921|><|semantic:1683|><|semantic:228|><|semantic:897|><|semantic:1677|><|semantic:518|><|im_end|>\n<|im_start|>user\nHow quickly he disappeared!"<|im_end|>\n<|im_start|>assistant\n<|semantic:1698|><|semantic:1848|><|semantic:1021|><|semantic:414|><|semantic:972|><|semantic:1252|><|semantic:1545|><|semantic:1363|><|semantic:307|><|semantic:722|><|semantic:1169|><|semantic:170|><|semantic:1701|><|semantic:1967|><|semantic:886|><|semantic:1540|><|semantic:1540|><|semantic:1113|><|semant

In [17]:
packed_dataset.save_to_disk("tokenized_libritts_packed_3")

Saving the dataset (5/5 shards): 100%|██████████| 50735/50735 [00:01<00:00, 33495.02 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 1937/1937 [00:00<00:00, 33569.13 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 1649/1649 [00:00<00:00, 31396.69 examples/s]
