# Tokenize LibriTTS-R Mimi for target LM

For our dataset, we currently simply use the Fish Speech TTS format:
- Text-only data formatted using [ChatML](https://gist.github.com/edwardzjl/8df07c1f7140c9a3e2f48d33a8032090) as a separate sequence "above" the audio code stream
- During sections where audio is being modeled, text stream 0 predicts the first semantic token index $n$ of the 8 Mimi residual codes as special token `<|semantic:n|>`
- For audio, "semantic" (neural, there's not a strong distinction between) codes (from Mimi) padded with 0s during text sections

It's possible this tokenization strategy can be improved, e.g. in [Defossez et al. 2024](https://arxiv.org/html/2410.00037v2#S3.SS4.SSS4) with the base transformer predicting the Whisper-timestamped word timings as an "inner monologue" and a delay between codebook timesteps. lol i'll do it later

In [None]:
from dotenv import load_dotenv
from datasets import load_dataset, DatasetDict, concatenate_datasets, load_from_disk
from data_pipeline.utils.prompt import PromptEncoder, TokenizationConfig
import os

load_dotenv()
# If creating the libritts dataset for the first time
# dataset = load_from_disk("../../Kokoro-82M/libritts_r_mimi_kokoro")
dataset = load_dataset("jkeisling/project-gutenberg-kokoro-2K", token=os.getenv("HUGGINGFACE_TOKEN"))
# full_train = concatenate_datasets([dataset["train.clean.100"], dataset["train.clean.360"]])

# dataset = DatasetDict({
#     "train": full_train,
#     "val": dataset["dev.clean"],
#     "test": dataset["test.clean"]
# })
# dataset = DatasetDict({"full": dataset})
dataset = dataset.with_format("torch")
# dataset = dataset.remove_columns(["chapter_id", "text_original"])
# dataset = dataset.rename_column(original_column_name="text_normalized", new_column_name="normalized_text")
dataset = dataset.rename_column(original_column_name="sentences", new_column_name="text_normalized")

config = TokenizationConfig()

In [None]:
FRAMERATE = 12.5
# NOTE: DELETE THIS, HARD-CODED ASSUMPTION
dataset = dataset.filter(lambda row: row["codes"].size(-1) <= 15 * FRAMERATE, num_proc=12)

**NOTE! This is PATH DEPENDENT on ADDING THE SEMANTIC TOKENS TO THE TOKENIZER EARLIER using `create_bytelevel_init.ipynb`. DO NOT SKIP THIS STEP OR THE MODEL WILL BE IRRETRIEVABLY BROKEN! YOU HAVE BEEN WARNED.**

==**THIS IS BYTE LEVEL!**==

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("../../inits/smoltts_byte_kokoro")
tokenizer.use_default_system_prompt = False

Check this carefully: for byte level, it should be 256.

In [None]:
len(tokenizer), tokenizer.vocab_size

Please manually verify the text is done correctly. However, DECODE will not work.

In [None]:
# Test the tokenizer by encoding and decoding some example text
example_text = "<|im_start|>system\n<|american|><|male|><|im_end|>"
encoded = tokenizer(example_text, return_tensors="pt")
print(f"Encoded: {encoded['input_ids']}")
decoded = tokenizer.decode(encoded['input_ids'][0])

# Print the results
decoded

In [None]:
sequence = tokenizer.apply_chat_template([{"role": "user", "content": "help me i am trapped in this computer"}], add_generation_prompt=True,  return_tensors="pt")
sequence

In [None]:
import torch

prompt_encoder = PromptEncoder(tokenizer, config)
tts_sysprompt = prompt_encoder.encode_text_turn(role="system", content="<|speaker:40|>", add_generation_prompt=False)
tokenizer.decode(tts_sysprompt[0,:])

Note that this assumes you're using ChatML. if you're NOT, then there's quite a bit more to fix.

In [None]:
out = prompt_encoder.encode_vq(dataset["full"][0]["codes"])
tokenizer.decode(out[0,:])

In [None]:
out_corrupt = prompt_encoder.encode_vq_corrupt(dataset["full"][0]["codes"])
tokenizer.decode(out_corrupt[0,:])

In [None]:
speaker_names = ["default", "sarah", "sky", "adam", "emma", "isabella", "george", "lewis"]
speaker_ids = {value: index for index, value in enumerate(speaker_names)}
speaker_ids["adam"]

In [None]:
from typing import Dict
# import random

# TODO: Not doing ASR for now
def tts_tokenize_row(row: Dict):
    """
    NOTE: Deliberately ignores sysprompt line for now, can be done in packing
    """
    # TODO: Fix this upstream in the data gen!
    # gender = "<|male|>" if row["speaker_id"] in ["george", "lewis", "adam", "michael"] else "<|female|>"
    # accent = f"<|{row['accent']}|>"
    # speaker = f"<|speaker:{speaker_ids[row['speaker_id']]}|>" if random.random() < 0.7 else ""
    speaker = f"<|speaker:{speaker_ids[row['speaker_id']]}|>"

    # Just keep it all for now, will test generalization later
    system_line = prompt_encoder.encode_text_turn(role="system", content="".join([speaker]))
    user_line = prompt_encoder.encode_text_turn(
        role="user", 
        content=row["text_normalized"].encode("utf-8").decode("latin-1"), 
        add_generation_prompt=True
    )
    assistant_line = prompt_encoder.encode_vq(row["codes"])
    ground_truth = torch.cat([system_line, user_line, assistant_line], dim=1)
    # ground_truth = torch.cat([user_line, assistant_line], dim=1)
    # Causal shift
    tokens = ground_truth[:,:-1].clone()
    labels = ground_truth[:,1:].clone()

    # Assuming user line took care of assistant prefix 
    # Offsetting by 1 since labels were shifted
    text_only_length = system_line.size(1) + user_line.size(1) - 1
    labels[1:, :text_only_length] = -100
    # Mask out im_end and newline
    labels[1:, -2:] = -100

    return({
        "tokens": tokens,
        "labels": labels
    })
    


example_row = tts_tokenize_row(dataset["full"][10])
tokenizer.decode(example_row["tokens"][0,:])

In [None]:

from typing import Dict
# import random

# TODO: Not doing ASR for now
def tts_tokenize_row_dropout(row: Dict):
    """
    NOTE: Deliberately ignores sysprompt line for now, can be done in packing
    """
    # TODO: Fix this upstream in the data gen!
    # gender = "<|male|>" if row["speaker_id"] in ["george", "lewis", "adam", "michael"] else "<|female|>"
    # accent = f"<|{row['accent']}|>"
    # speaker = f"<|speaker:{speaker_ids[row['speaker_id']]}|>" if random.random() < 0.7 else ""
    speaker = f"<|speaker:{speaker_ids[row['speaker_id']]}|>"

    # Just keep it all for now, will test generalization later
    system_line = prompt_encoder.encode_text_turn(role="system", content="".join([speaker]))
    user_line = prompt_encoder.encode_text_turn(
        role="user", 
        content=row["text_normalized"].encode("utf-8").decode("latin-1"), 
        add_generation_prompt=True
    )
    assistant_line_true = prompt_encoder.encode_vq(row["codes"])
    assistant_line_dropout = prompt_encoder.encode_vq_corrupt(row["codes"], dropout=0.3)
    messy_input = torch.cat([system_line, user_line, assistant_line_dropout], dim=1)
    ground_truth = torch.cat([system_line, user_line, assistant_line_true], dim=1)
    # Causal shift
    tokens = messy_input[:,:-1]
    labels = ground_truth[:,1:]

    # Assuming user line took care of assistant prefix 
    # Offsetting by 1 since labels were shifted
    text_only_length = system_line.size(1) + user_line.size(1) - 1
    labels[1:, :text_only_length] = -100
    # Mask out im_end and newline
    labels[1:, -2:] = -100

    return({
        "tokens": tokens,
        "labels": labels
    })
    


example_row = tts_tokenize_row_dropout(dataset["full"][10])
tokenizer.decode(example_row["tokens"][0,:])

In [None]:
example_row["tokens"]

In [None]:
example_row["labels"]

In [None]:
# DO NOT INCREASE batch size
dataset = dataset.map(tts_tokenize_row, remove_columns="codes", num_proc=24)

In [None]:
dataset.save_to_disk("../../datasets/tokenized_project_gutenberg_bytes_kokoro_tau")

## Appendix: Testing

In [None]:
import torch

def collate_fn(batch, semantic_pad_id: int):
    """
    batch is a list of dicts: each dict has "tokens" shape [9, T],
    and "labels" shape [9, T].
    We pad them into [B, 9, T_max].
    """
    max_input_len = max(item["tokens"].shape[1] for item in batch)

    B = len(batch)
    # We'll create padded arrays:
    tokens = torch.full((B, 9, max_input_len), 0, dtype=torch.long)  # 2=some <PAD>
    tokens[:, 0, :] = semantic_pad_id
    labels = torch.full(
        (B, 9, max_input_len), -100, dtype=torch.long
    )  # default is ignore_index

    pad_mask = torch.ones(B, max_input_len)

    for i, item in enumerate(batch):
        seq_len = item["tokens"].shape[1]
        tokens[i, :, :seq_len] = item["tokens"]
        labels[i, :, :seq_len] = item["labels"][:, :seq_len]
        pad_mask[i, :seq_len] = False

    return {"tokens": tokens, "labels": labels, "pad_mask": pad_mask}

# Create two test sequences of different lengths
seq1 = torch.randint(1, 100, (9, 5))  # Short sequence
seq2 = torch.randint(1, 100, (9, 8))  # Longer sequence

batch = [
    {"tokens": seq1, "labels": seq1},
    {"tokens": seq2, "labels": seq2}
]

# Test the collation
semantic_pad_id = 999
result = collate_fn(batch, semantic_pad_id)

print("Tokens shape:", result["tokens"].shape)
print("\nFirst sequence tokens:")
print(result["tokens"][0])
print("\nSecond sequence tokens:")
print(result["tokens"][1])
print("\nPadding mask:")
print(result["pad_mask"])

# Let's verify:
# 1. Sequences are left-aligned
# 2. Padding is applied correctly
# 3. Padding mask matches content

# Check alignment of first sequence (should be at start)
print("\nFirst 5 tokens of first sequence row 1:")
print(result["tokens"][0, 1, :5])
print("Next 3 tokens (should be 0s):")
print(result["tokens"][0, 1, 5:8])

# Check padding of first row
print("\nFirst row padding for batch item 0:")
print(result["tokens"][0, 0, :8])  # Should be semantic_pad_id

# Check mask alignment
print("\nFirst sequence mask (False=content, True=padding):")
print(result["pad_mask"][0])

In [None]:
def get_length(example):
    return {'length': example['labels'].shape[1]}

max_len = 0
def update_max(example):
    global max_len
    max_len = max(max_len, example['length'])
    return example

# Apply the transformations
dataset["train"].map(
    get_length,
    desc="Getting sequence lengths"
).map(
    update_max,
    desc="Finding maximum"
)

print(f"Maximum sequence length: {max_len}")

In [None]:
import numpy as np

# Get arrays from dataset
text_lengths = np.array([len(x) for x in dataset["train"]['normalized_text']])
seq_lengths = np.array([x.shape[1] for x in dataset["train"]['labels']])

# Calculate ratios
ratios = text_lengths / seq_lengths

# Basic stats
print(f"Mean ratio: {ratios.mean():.3f}")
print(f"Std ratio: {ratios.std():.3f}")
print(f"\nPercentile distribution:")
for p in [1, 5, 25, 50, 75, 95, 99]:
    print(f"{p}th percentile: {np.percentile(ratios, p):.3f}")

# Find extreme outliers (3 std from mean)
mean, std = ratios.mean(), ratios.std()
outliers = np.where(np.abs(ratios - mean) > 3 * std)[0]
if len(outliers) > 0:
    print(f"\nFound {len(outliers)} outliers")
    print("\nSample of 5 outlier examples:")
    for idx in outliers[:5]:
        print(f"\nIndex {int(idx)}")  # Convert numpy int to Python int
        print(f"Text ({text_lengths[idx]} chars): {dataset['val'][int(idx)]['normalized_text'][:100]}...")  # Convert idx
        print(f"Sequence length: {seq_lengths[idx]}")
        print(f"Ratio: {ratios[idx]:.3f}")

In [None]:
NEWLINE_SEPARATOR = torch.tensor(tokenizer.encode("\n") + [0] * 8).unsqueeze(1)

def batch_pack_sequences(examples, window_size=768, max_items=5):
   """
   Pack sequences with system prompt and metrics
   """
   packed_tokens = []
   packed_labels = []
   packed_speakers = []
   pack_lengths = []
   items_per_pack = []
   
   tokens = examples['tokens']
   labels = examples['labels']
   speakers = examples['speaker_id']
   
   # Account for system prompt in window size
   effective_window = window_size - tts_sysprompt.shape[1]
   
   for i in range(len(tokens)):
       seq_len = tokens[i].shape[1]
       
       # Start new pack
       if i == 0 or current_length + seq_len > effective_window or \
          current_speaker != speakers[i] or current_items >= max_items:
           
           # Save previous pack if it exists
           if i > 0 and current_tokens:
               packed_tokens.append(torch.cat(current_tokens, dim=1))
               packed_labels.append(torch.cat(current_labels, dim=1))
               packed_speakers.append(current_speaker)
               pack_lengths.append(current_length + tts_sysprompt.shape[1])
               items_per_pack.append(current_items)
           
           # Initialize new pack with system prompt
           current_tokens = [tts_sysprompt, tokens[i]]
           current_labels = [tts_sysprompt, labels[i]]
           current_speaker = speakers[i]
           current_length = seq_len
           current_items = 1
           continue
           
       # Add to current pack with separator
       current_tokens.extend([NEWLINE_SEPARATOR, tokens[i]])
       current_labels.extend([NEWLINE_SEPARATOR, labels[i]])
       current_length += seq_len + 1
       current_items += 1
   
   # Don't forget last pack
   if current_tokens:
       packed_tokens.append(torch.cat(current_tokens, dim=1))
       packed_labels.append(torch.cat(current_labels, dim=1))
       packed_speakers.append(current_speaker)
       pack_lengths.append(current_length + tts_sysprompt.shape[1])
       items_per_pack.append(current_items)
   
   return {
       'tokens': packed_tokens,
       'labels': packed_labels,
       'speaker_id': packed_speakers,
       'pack_length': pack_lengths,
       'items_in_pack': items_per_pack
   }

In [None]:
# Usage:
packed_dataset = dataset.map(
    lambda row: batch_pack_sequences(row, max_items=3),
    batched=True,
    remove_columns=dataset['val'].column_names,
    batch_size=1000  # Adjust based on memory constraints
)

In [None]:
example_row = packed_dataset['val'][0]
tokenizer.decode(example_row["tokens"][0,:])

In [None]:
packed_dataset.save_to_disk("tokenized_libritts_packed_3")