# Tokenize LibriTTS-R Mimi for target LM

For our dataset, we currently simply use the Fish Speech TTS format:
- Text-only data formatted using [ChatML](https://gist.github.com/edwardzjl/8df07c1f7140c9a3e2f48d33a8032090) as a separate sequence "above" the audio code stream
- During sections where audio is being modeled, text stream 0 predicts the first semantic token index $n$ of the 8 Mimi residual codes as special token `<|semantic:n|>`
- For audio, "semantic" (neural, there's not a strong distinction between) codes (from Mimi) padded with 0s during text sections

It's possible this tokenization strategy can be improved, e.g. in [Defossez et al. 2024](https://arxiv.org/html/2410.00037v2#S3.SS4.SSS4) with the base transformer predicting the Whisper-timestamped word timings as an "inner monologue" and a delay between codebook timesteps. lol i'll do it later

In [None]:
from datasets import load_dataset, DatasetDict, concatenate_datasets

# If creating the libritts dataset for the first time
# from datasets import load_from_disk 
# dataset = load_from_disk("encoded_dataset")
# train_clean_100 = load_from_disk("encoded_libritts/train.clean.100/")
# train_clean_360 = load_from_disk("encoded_libritts/train.clean.360/")
# dev_clean = load_from_disk("encoded_libritts/dev.clean")
# test_clean = load_from_disk("encoded_libritts/test.clean")
# full_train = concatenate_datasets([train_clean_100, train_clean_360])
dataset = load_dataset("jkeisling/libritts-r-mimi")
full_train = concatenate_datasets([dataset["train.clean.100"], dataset["train.clean.360"]])

dataset = DatasetDict({
    "train": full_train,
    "val": dataset["dev.clean"],
    "test": dataset["test.clean"]
})
dataset = dataset.with_format("torch")
dataset = dataset.remove_columns(["path", "chapter_id", "text_original"])
dataset = dataset.rename_column(original_column_name="text_normalized", new_column_name="normalized_text")

**NOTE! This is PATH DEPENDENT on ADDING THE SEMANTIC TOKENS TO THE TOKENIZER EARLIER using `create_smoltts_init.ipynb`. DO NOT SKIP THIS STEP OR THE MODEL WILL BE IRRETRIEVABLY BROKEN! YOU HAVE BEEN WARNED.**

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("../inits/smoltts_init")
tokenizer.use_default_system_prompt = False

Check this carefully: for SmolTTS, it should be 51200.

In [None]:
len(tokenizer), tokenizer.vocab_size

Please manually verify the text is done correctly.

In [None]:
# Test the tokenizer by encoding and decoding some example text
example_text = "This is a test sentence."
encoded = tokenizer(example_text, return_tensors="pt")
decoded = tokenizer.decode(encoded['input_ids'][0])

# Print the results
dataset["test"][0]

In [None]:
sequence = tokenizer.apply_chat_template([{"role": "user", "content": "help me i am trapped in this computer"}], add_generation_prompt=True,  return_tensors="pt")
sequence

In [None]:
tokenizer.decode(sequence[0, :])

In [None]:
import torch

def encode_text(role: str, content: str, add_generation_prompt: bool = True) -> torch.Tensor:
    # baseline = tokenizer.apply_chat_template(f"{chr(10) if ''}<|im_start|>{role}\n{content}<|im_end|>\n",)
    baseline = tokenizer.apply_chat_template(
        [{"role": role, "content": content}],
        add_generation_prompt=add_generation_prompt,
        return_tensors="pt"
    )
    zeros_mask = torch.zeros(8, baseline.size(1), dtype=baseline.dtype)
    return torch.cat([baseline, zeros_mask])

tts_sysprompt = encode_text("system", "Speak out the provided text")
asr_sysprompt = encode_text("system", "Transcribe the provided speech", False)
tokenizer.decode(asr_sysprompt[0,:])

Note that this assumes you're using ChatML. if you're NOT, then there's quite a bit more to fix.

In [None]:
SEMANTIC_OFFSET = tokenizer.encode("<|semantic:0|>")[0]
# B * C+1 * 2
VQ_USER_PREFIX = encode_text(role="user", content="")[:,:-2]
TRAILING_IM_END = torch.tensor([
    tokenizer.encode("<|im_end|>") + [0] * 8,
    tokenizer.encode("\n") + [0] * 8,
]).T

def encode_vq(codes: torch.Tensor, is_assistant=True) -> torch.Tensor:
    """
    Expects C * T
    """
    if codes.ndim != 2:
        raise ValueError("Must be single batch")
    speaker_line = codes[0,:] + SEMANTIC_OFFSET
    vq_block = torch.cat([speaker_line.unsqueeze(0), codes])

    block = torch.cat([vq_block, TRAILING_IM_END], dim=1)
    return block if is_assistant else torch.cat([VQ_USER_PREFIX, block], dim=1)


out = encode_vq(dataset["test"][0]["codes"], is_assistant=True)
tokenizer.decode(out[0,:])

In [None]:
from typing import Dict

# ASSISTANT_PREFIX_LEN = len(tokenizer.tokenize("<|im_start|>assistant\n"))
# USER_PREFIX_LEN  = len(tokenizer.tokenize("<|im_start|>user\n"))

# def tokenize_row(row: Dict, is_batch=True):
#     """
#     row["normalized_text"] is a string
#     row["codes"] is a torch.Tensor shaped [9, T_vq]
#     """
#     row = {
#         "normalized_text": row["normalized_text"][0],
#         "codes": row["codes"][0],
#         "speaker_id": row["speaker_id"],
#         "id": row["id"]
#     } if is_batch else row
#     tts_user_line = encode_text(role="user", content=row["normalized_text"])
#     asr_assistant_line = encode_text(role="assistant", content=row["normalized_text"], needs_initial_newline=True)
#     tts_assistant_codes = encode_vq(row["codes"])  # shape [9, T_vq]
#     asr_user_codes = encode_vq(row["codes"], is_assistant=False)  # shape [9, T_vq]
    
#     # Concatenate system prompt (row=1?), user line (row=1?), codebooks (row=9),
#     # but along the *time* dimension => final shape [9, T_total] 
#     #   (since sysprompt and user_line are [1, T_something], 
#     #    codes_9rows is [9, T_vq], so we pad them to 9 rows if needed)
#     # For demonstration, I'm just stacking them. You probably do:
#     tts_ground_truth = torch.cat([tts_sysprompt, tts_user_line, tts_assistant_codes], dim=1)
#     asr_ground_truth = torch.cat([asr_sysprompt, asr_user_codes, asr_assistant_line], dim=1)
#     tts_tokens = tts_ground_truth[:,:-1].clone()
#     asr_tokens = asr_ground_truth[:,:-1].clone()
#     # Clone for labels
#     tts_labels = tts_ground_truth[:, 1:].clone()
#     asr_labels = asr_ground_truth[:, 1:].clone()

#     # TTS MASKING (easy)
#     # labels = asr_ground_truth[:, 1:].clone()
#     # Let's define the "text portion" as sysprompt + user_line only
#     text_len = tts_sysprompt.size(1) + tts_user_line.size(1) + ASSISTANT_PREFIX_LEN - 1  # no VQ_WRAPPER or codes
#     # ONLY mask codebook rows for that text region
#     # row=0 is your "text" row, row=1..8 might be codebooks, or vice versa
#     # (Here I'm assuming row=0 is your actual text tokens. 
#     #  If it's reversed, tweak accordingly!)
#     tts_labels[1:, :text_len] = -100

#     asr_start_len = asr_sysprompt.size(1) + USER_PREFIX_LEN - 1
#     asr_labels[1:, :asr_start_len] = -100
#     asr_labels[1:, -asr_assistant_line.size(1):] = -100

#     out = {
#         "tokens": [tts_tokens, asr_tokens],
#         "labels": [tts_labels, asr_labels],
#         "task": ["tts", "asr"],
#         "normalized_text": [row["normalized_text"]] * 2,
#         "speaker_id": row["speaker_id"] * 2,
#         "id": row["id"] * 2,
#     }
#     return out

# TODO: Not doing ASR for now
def tts_tokenize_row(row: Dict):
    """
    NOTE: Deliberately ignores sysprompt line for now, can be done in packing
    """
    user_line = encode_text(role="user", content=row["normalized_text"], add_generation_prompt=True)
    assistant_line = encode_vq(row["codes"])
    ground_truth = torch.cat([user_line, assistant_line], dim=1)
    # Causal shift
    tokens = ground_truth[:,:-1].clone()
    labels = ground_truth[:,1:].clone()

    # Assuming user line took care of assistant prefix
    labels[1:, :user_line.size(1) - 1] = -100
    # Mask out newline
    labels[1:, -1] = -100

    return({
        "tokens": tokens,
        "labels": labels
    })
    


example_row = tts_tokenize_row(dataset["test"][0])
tokenizer.decode(example_row["labels"][0,:])

In [None]:
example_row["labels"]

In [None]:
# DO NOT INCREASE batch size
dataset = dataset.map(tts_tokenize_row, remove_columns="codes")

In [None]:
NEWLINE_SEPARATOR = torch.tensor(tokenizer.encode("\n") + [0] * 8).unsqueeze(1)

def batch_pack_sequences(examples, window_size=768, max_items=5):
   """
   Pack sequences with system prompt and metrics
   """
   packed_tokens = []
   packed_labels = []
   packed_speakers = []
   pack_lengths = []
   items_per_pack = []
   
   tokens = examples['tokens']
   labels = examples['labels']
   speakers = examples['speaker_id']
   
   # Account for system prompt in window size
   effective_window = window_size - tts_sysprompt.shape[1]
   
   for i in range(len(tokens)):
       seq_len = tokens[i].shape[1]
       
       # Start new pack
       if i == 0 or current_length + seq_len > effective_window or \
          current_speaker != speakers[i] or current_items >= max_items:
           
           # Save previous pack if it exists
           if i > 0 and current_tokens:
               packed_tokens.append(torch.cat(current_tokens, dim=1))
               packed_labels.append(torch.cat(current_labels, dim=1))
               packed_speakers.append(current_speaker)
               pack_lengths.append(current_length + tts_sysprompt.shape[1])
               items_per_pack.append(current_items)
           
           # Initialize new pack with system prompt
           current_tokens = [tts_sysprompt, tokens[i]]
           current_labels = [tts_sysprompt, labels[i]]
           current_speaker = speakers[i]
           current_length = seq_len
           current_items = 1
           continue
           
       # Add to current pack with separator
       current_tokens.extend([NEWLINE_SEPARATOR, tokens[i]])
       current_labels.extend([NEWLINE_SEPARATOR, labels[i]])
       current_length += seq_len + 1
       current_items += 1
   
   # Don't forget last pack
   if current_tokens:
       packed_tokens.append(torch.cat(current_tokens, dim=1))
       packed_labels.append(torch.cat(current_labels, dim=1))
       packed_speakers.append(current_speaker)
       pack_lengths.append(current_length + tts_sysprompt.shape[1])
       items_per_pack.append(current_items)
   
   return {
       'tokens': packed_tokens,
       'labels': packed_labels,
       'speaker_id': packed_speakers,
       'pack_length': pack_lengths,
       'items_in_pack': items_per_pack
   }

In [None]:
# Usage:
packed_dataset = dataset.map(
    lambda row: batch_pack_sequences(row, max_items=3),
    batched=True,
    remove_columns=dataset['val'].column_names,
    batch_size=1000  # Adjust based on memory constraints
)

In [None]:
example_row = packed_dataset['val'][0]
tokenizer.decode(example_row["tokens"][0,:])

In [None]:
packed_dataset.save_to_disk("tokenized_libritts_packed_3")

## Appendix: Markdown

In [None]:
tokenizer.decode(example_row[0]["labels"][0,:])

In [None]:
tokenizer.decode(example_row[1]["labels"][0,:])

In [None]:
example_row["tokens"][0][1,:]

In [None]:
example_row[1]["labels"][1,:]

In [None]:
dataset["test"][0]