# Tokenize Expresso

In [27]:
from datasets import load_from_disk

ds = load_from_disk("../datasets/encoded_expresso")
ds = ds.with_format('torch')

In [2]:
import pandas as pd

print(ds.column_names)
pd.value_counts(pd.Series(ds['style']))


['text', 'speaker_id', 'style', 'id', 'codes']


  pd.value_counts(pd.Series(ds['style']))


confused      1520
enunciated    1520
happy         1520
laughing      1520
default       1519
sad           1519
whisper       1518
emphasis       800
essentials     160
singing         10
longform         8
Name: count, dtype: int64

Going to drop everything else

In [28]:
supported_styles = ["confused", "enunciated", "happy", "laughing", "default", "sad", "whisper", "emphasis"]
ds = ds.filter(lambda r: r["style"] in supported_styles, num_proc=12)

## Create control tokens

In [5]:
import os
from transformers import AutoTokenizer

init_folder = "../inits/csm-1b-expresso"
os.makedirs(init_folder, exist_ok=True)
tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B")
n_added_tokens = tokenizer.add_special_tokens({"additional_special_tokens": [
    "<|confused|>",
    "<|enunciated|>",
    "<|happy|>",
    "<|laughing|>",
    "<|default|>",
    "<|sad|>",
    "<|whisper|>",
    "<|emphasis|>",
    ]
})

tokenizer.save_pretrained(init_folder)

('../inits/csm-1b-expresso/tokenizer_config.json',
 '../inits/csm-1b-expresso/special_tokens_map.json',
 '../inits/csm-1b-expresso/tokenizer.json')

In [7]:
# Patch the checkpoint
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file, save_file
import torch

model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="model.safetensors")

state_dict = load_file(model_path, device="cpu")

mean_embedding = state_dict["text_embeddings.weight"].mean(dim=0, keepdim=True)
expanded_embedding = mean_embedding.expand(n_added_tokens, -1)
state_dict["text_embeddings.weight"] = torch.cat([state_dict["text_embeddings.weight"], expanded_embedding], dim=0)

save_file(state_dict, f"{init_folder}/model.safetensors")


In [21]:
import json

config = hf_hub_download(repo_id="sesame/csm-1b", filename="config.json")
with open(config, 'r') as f:
    config_json = json.load(f)

config_json["text_vocab_size"] += n_added_tokens
with open(f"{init_folder}/config.json", 'w') as f:
    json.dump(config_json, f, indent=2)


## Tokenize to CSM format

Now let's load our new tokenizer back again:

In [8]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(init_folder)
encoded = tokenizer.encode("[0]<|confused|>test")
print(encoded)
tokenizer.decode(encoded)

[128000, 58, 15, 60, 128256, 1985]


'<|begin_of_text|>[0]<|confused|>test'

In [9]:
from modeling.utils import PromptEncoder

prompt_encoder = PromptEncoder(tokenizer=tokenizer)

Finally, we prepare the inputs:

In [12]:
import torch

def tokenize_row(row: dict):
    text_tokens, text_masks = prompt_encoder._tokenize_text_segment(
        f'<|{row["style"]}|>{row["text"]}', 0
    )
    audio_tokens, audio_masks = prompt_encoder._tokenize_audio(row['codes'])

    return {
        "ground_truth": torch.cat([text_tokens, audio_tokens], dim=0), 
        "ground_truth_masks": torch.cat([text_masks, audio_masks], dim=0),
    }

# TODO speed this up and/or move it to the collate fn: for libritts it doesn't really matter
# ds = ds.map(get_targets, remove_columns=orig_colnames)

In [29]:
from datasets import DatasetDict

orig_colnames = ds.column_names
ds = ds.map(tokenize_row, num_proc=12, remove_columns=orig_colnames)

ds = DatasetDict({
    "train": ds
})
ds.save_to_disk("../datasets/tokenized_expresso")

Saving the dataset (0/1 shards):   0%|          | 0/11436 [00:00<?, ? examples/s]

Saving the dataset (1/1 shards): 100%|██████████| 11436/11436 [00:00<00:00, 217993.86 examples/s]


In [15]:
example_row = ds[0]['ground_truth']
example_row.shape

torch.Size([51, 33])

In [16]:
prompt_encoder._text_tokenizer.decode(example_row[:, -1])

'<|begin_of_text|>[0]<|confused|>Why are you beating up my jukebox?!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'

In [8]:
torch.stack([example_row[:, 0], example_row[:, -1]])

tensor([[     0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,   1049,
           1102,   1686,   1258,   1258,   1689,   1528,   1987,    978,    312,
           2039,    753,    969,    598,   1084,   1268,    621,   1757,    560,
           1734,   1527,   1117,    622,    628,    510,    623,    623,    918,
            689,    997,   1069,   1941,    294,    774,    518,   1987,    769,
              0],
        [128000,     58,     15,   1483,    791,  18266,     60,    358,    342,
          28109,    449,    264,   3169,    315,   5895,     13, 128001,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
          

In [9]:
row = ds['train'][0]
audio_positions = row['ground_truth_masks'][1:, :-1].all(dim=1)
labels = row['ground_truth'][1:, :-1]
labels[~audio_positions] = -100
labels[:, 0]

tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, 1049, 1102, 1686, 1258, 1258, 1689, 1528, 1987,
         978,  312, 2039,  753,  969,  598, 1084, 1268,  621, 1757,  560, 1734,
        1527, 1117,  622,  628,  510,  623,  623,  918,  689,  997, 1069, 1941,
         294,  774,  518, 1987,  769,    0])

In [30]:
my_range = torch.arange(0, 32 * 2051, 2051)
official_range = 2051 * torch.arange(32)
assert my_range.eq(official_range).all()

## Testing collation function

In [38]:
from huggingface_hub import hf_hub_download
from moshi.models import loaders

ds_dev = ds['dev'].map(tokenize_row, num_proc=12, remove_columns=ds['dev'].column_names)

mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
mimi = loaders.get_mimi(mimi_weight, device="cpu")

quantizer = mimi.quantizer.acoustic_quantizer.vq

In [65]:
batch = ds_dev[:32]

B = len(batch["ground_truth"])
CODEBOOK_SIZE=32

height = CODEBOOK_SIZE + 1
max_input_len = max(item.shape[0] - 1 for item in batch["ground_truth"])

B = len(batch["ground_truth"])
tokens = torch.full((B, max_input_len, height), 0, dtype=torch.long)  # 2=some <PAD>
targets = torch.full((B, max_input_len, 256), 0, dtype=torch.float32)

pad_mask = torch.ones(B, max_input_len)

for i in range(B):
    ground_truth = batch["ground_truth"][i]
    ground_truth_masks = batch["ground_truth_masks"][i]

    seq_len = ground_truth.shape[0] - 1
    tokens[i, :seq_len, :] = ground_truth[:-1, :].clone()

    label = ground_truth[1:, :]
    # full block of zeros for audio codes
    codes = label[:, 1:-1].T
    final_residuals = quantizer.decode(codes.unsqueeze(-1)).squeeze(-1)
    # zero text positions with the mask
    mask = ground_truth_masks[1:, :-1].all(dim=1)
    final_residuals[~mask] = 0
    targets[i, :seq_len, :] = final_residuals.unsqueeze(0)
