# Tokenize LibriTTS-R

In [1]:
from modeling.utils import PromptEncoder

prompt_encoder = PromptEncoder()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from datasets import load_from_disk, concatenate_datasets, DatasetDict

ds = load_from_disk("../datasets/encoded_libritts")
ds = DatasetDict({
    'train': concatenate_datasets([ds['train.clean.100'], ds['train.clean.360']]),
    'test': ds['test.clean'],
    'dev': ds['dev.clean']
})
ds = ds.with_format('torch')

In [3]:
ds['dev'][0]

{'text_normalized': 'The weapon must still have been there.',
 'text_original': 'The weapon must still have been there.',
 'speaker_id': '3081',
 'path': '/root/.cache/huggingface/datasets/downloads/extracted/5551a515e85b9e463062524539c2e1cb52ba32affe128dffd866db0205248bdd/LibriTTS_R/dev-clean/3081/166546/3081_166546_000101_000001.wav',
 'chapter_id': '166546',
 'id': '3081_166546_000101_000001',
 'codes': tensor([[1049, 1114, 1609,  784,  499,  260, 1011,    8, 1407,  540, 1615,  561,
          1945,  201, 1324,  668,  376, 1849,    9, 1921, 1921, 1683,  228,  897,
          1677,  518],
         [ 811, 1149,  739,  410, 1367, 1305, 2046, 1287,  886, 1995, 1727,  678,
          1455,  352, 1914, 1504, 1138, 1154,  669, 1217, 1450, 1003, 1711,  488,
           342,  844],
         [ 373, 1464, 2013, 1306,  102,  561,  852,  267,  442,  718, 1501, 1455,
           233, 1015,  963,   29,  496, 1728,  783, 1870,  879, 1802, 1523,  231,
           333,  199],
         [ 131, 1957,   58,  5

In [14]:
from transformers import MimiModel

mimi_model = MimiModel.from_pretrained('kyutai/mimi')

  self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False)


Let's prove that restoring the residual restores the codes (or close to it):

In [17]:
import torch

# Deliberately break the model for testing purposes
mimi_model.quantizer.acoustic_residual_vector_quantizer.output_proj = None

def get_residual(codes: torch.Tensor, start_depth=16):
    if codes.ndim == 2:
        codes = codes.unsqueeze(0)
    print(codes.shape)

    # Full residual sum before out_proj
    acoustic_sum = mimi_model.quantizer.acoustic_residual_vector_quantizer.decode(codes[:, 1:, :])
    # Sum of first 15 acoustic codes (aka the remaining decoder)
    acoustic_autoreg_sum = mimi_model.quantizer.acoustic_residual_vector_quantizer.decode(codes[:, 1:start_depth, :])


    # Remaining sum (aka what we'll have to predict)
    acoustic_predicted_residual = torch.zeros_like(acoustic_autoreg_sum)
    for i in range(start_depth, 32):
        layer = mimi_model.quantizer.acoustic_residual_vector_quantizer.layers[i - 1]
        quantized = layer.decode(codes[:, i, :])
        acoustic_predicted_residual = acoustic_predicted_residual + quantized

    # it's just a sum
    assert torch.allclose(acoustic_predicted_residual + acoustic_autoreg_sum, acoustic_sum)

    # Demonstrate inference: reverse the process
    all_indices = []
    # supposing we had the full (predicted) sum
    fake_residual = acoustic_sum - acoustic_autoreg_sum
    for i in range(start_depth, 32):
        layer = mimi_model.quantizer.acoustic_residual_vector_quantizer.layers[i - 1]
        indices = layer.encode(fake_residual)
        quantized = layer.decode(indices)
        fake_residual = fake_residual - quantized
        all_indices.append(indices)
    out_indices = torch.stack(all_indices, dim=1)
    print(codes[:, start_depth:, 2])
    print(out_indices[:, :, 2])


get_residual(ds['train'][0]['codes'])

KeyError: 'codes'

## Tokenize to CSM format

Finally, we prepare the inputs:

In [4]:
import torch

def tokenize_row(row: dict):
    text_tokens, text_masks = prompt_encoder._tokenize_text_segment(
        row["text_normalized"], 0
    )
    audio_tokens, audio_masks = prompt_encoder._tokenize_audio(row['codes'])

    return {
        "ground_truth": torch.cat([text_tokens, audio_tokens], dim=0), 
        "ground_truth_masks": torch.cat([text_masks, audio_masks], dim=0),
    }

# TODO speed this up and/or move it to the collate fn: for libritts it doesn't really matter
# ds = ds.map(get_targets, remove_columns=orig_colnames)

In [5]:
orig_colnames = ds['train'].column_names
ds = ds.map(tokenize_row, num_proc=12, remove_columns=orig_colnames)
ds.save_to_disk("../datasets/tokenized_libritts")

Saving the dataset (9/9 shards): 100%|██████████| 149658/149658 [00:05<00:00, 26543.10 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 4837/4837 [00:00<00:00, 27806.20 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 5736/5736 [00:00<00:00, 29566.94 examples/s]


In [8]:
example_row = ds['train'][0]['ground_truth']
example_row.shape

torch.Size([55, 33])

In [10]:
prompt_encoder._text_tokenizer.decode(example_row[:, -1])

'<|begin_of_text|>[0][The moon] I gazed with a kind of wonder.<|end_of_text|>!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'

In [15]:
torch.stack([example_row[:, 0], example_row[:, -1]])

tensor([[     0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,   1049,
           1102,   1686,   1258,   1258,   1689,   1528,   1987,    978,    312,
           2039,    753,    969,    598,   1084,   1268,    621,   1757,    560,
           1734,   1527,   1117,    622,    628,    510,    623,    623,    918,
            689,    997,   1069,   1941,    294,    774,    518,   1987,    769,
              0],
        [128000,     58,     15,   1483,    791,  18266,     60,    358,    342,
          28109,    449,    264,   3169,    315,   5895,     13, 128001,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
          

In [27]:
row = ds['train'][0]
audio_positions = row['ground_truth_masks'][1:, :-1].all(dim=1)
labels = row['ground_truth'][1:, :-1]
labels[~audio_positions] = -100
labels[:, 0]

tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, 1049, 1102, 1686, 1258, 1258, 1689, 1528, 1987,
         978,  312, 2039,  753,  969,  598, 1084, 1268,  621, 1757,  560, 1734,
        1527, 1117,  622,  628,  510,  623,  623,  918,  689,  997, 1069, 1941,
         294,  774,  518, 1987,  769,    0])

In [29]:
my_range = torch.arange(0, 32 * 2051, 2051)
official_range = 2051 * torch.arange(32)
assert torch.allclose(my_range, official_range)

## Testing collation function

In [38]:
from huggingface_hub import hf_hub_download
from moshi.models import loaders

ds_dev = ds['dev'].map(tokenize_row, num_proc=12, remove_columns=ds['dev'].column_names)

mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
mimi = loaders.get_mimi(mimi_weight, device="cpu")

quantizer = mimi.quantizer.acoustic_quantizer.vq

In [65]:
batch = ds_dev[:32]

B = len(batch["ground_truth"])
CODEBOOK_SIZE=32

height = CODEBOOK_SIZE + 1
max_input_len = max(item.shape[0] - 1 for item in batch["ground_truth"])

B = len(batch["ground_truth"])
tokens = torch.full((B, max_input_len, height), 0, dtype=torch.long)  # 2=some <PAD>
targets = torch.full((B, max_input_len, 256), 0, dtype=torch.float32)

pad_mask = torch.ones(B, max_input_len)

for i in range(B):
    ground_truth = batch["ground_truth"][i]
    ground_truth_masks = batch["ground_truth_masks"][i]

    seq_len = ground_truth.shape[0] - 1
    tokens[i, :seq_len, :] = ground_truth[:-1, :].clone()

    label = ground_truth[1:, :]
    # full block of zeros for audio codes
    codes = label[:, 1:-1].T
    final_residuals = quantizer.decode(codes.unsqueeze(-1)).squeeze(-1)
    # zero text positions with the mask
    mask = ground_truth_masks[1:, :-1].all(dim=1)
    final_residuals[~mask] = 0
    targets[i, :seq_len, :] = final_residuals.unsqueeze(0)
