# Tokenize Expresso

In [41]:
from datasets import load_from_disk

ds = load_from_disk("../datasets/encoded_expresso")
ds = ds.with_format('torch')

Going to drop everything else

In [33]:
supported_styles = ["confused", "enunciated", "happy", "laughing", "default", "sad", "whisper", "emphasis"]
ds = ds.filter(lambda r: r["style"] in supported_styles, num_proc=12)

## Add control tokens to model and tokenizer

In [None]:
import os
from transformers import AutoTokenizer

init_folder = "../inits/csm-1b-expresso"
os.makedirs(init_folder, exist_ok=True)
tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B")
n_added_tokens = tokenizer.add_special_tokens({"additional_special_tokens": [
    "<|confused|>",
    "<|enunciated|>",
    "<|happy|>",
    "<|laughing|>",
    "<|default|>",
    "<|sad|>",
    "<|whisper|>",
    "<|emphasis|>",
    ]
})

tokenizer.save_pretrained(init_folder)

In [35]:
# Patch the checkpoint
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file, save_file
import torch

model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="model.safetensors")

state_dict = load_file(model_path, device="cpu")

mean_embedding = state_dict["text_embeddings.weight"].mean(dim=0, keepdim=True)
expanded_embedding = mean_embedding.expand(n_added_tokens, -1)
state_dict["text_embeddings.weight"] = torch.cat([state_dict["text_embeddings.weight"], expanded_embedding], dim=0)

save_file(state_dict, f"{init_folder}/model.safetensors")


In [36]:
import json

config = hf_hub_download(repo_id="sesame/csm-1b", filename="config.json")
with open(config, 'r') as f:
    config_json = json.load(f)

config_json["text_vocab_size"] += n_added_tokens
with open(f"{init_folder}/config.json", 'w') as f:
    json.dump(config_json, f, indent=2)


## Tokenize to CSM format

Now let's load our new tokenizer back again:

In [38]:
from modeling.utils import PromptEncoder

tokenizer = AutoTokenizer.from_pretrained(init_folder)
prompt_encoder = PromptEncoder(tokenizer=tokenizer)

Finally, we prepare the inputs:

In [39]:
import torch

def tokenize_row(row: dict):
    # Abuse turn ID as voice prompt, this is just for testing
    text_tokens, text_masks = prompt_encoder._tokenize_text_segment(
        f'<|{row["style"]}|>{row["text"]}', int(row["speaker_id"][-1]) - 1
    )
    audio_tokens, audio_masks = prompt_encoder._tokenize_audio(row['codes'])

    return {
        "ground_truth": torch.cat([text_tokens, audio_tokens], dim=0), 
        "ground_truth_masks": torch.cat([text_masks, audio_masks], dim=0),
    }

In [None]:
from datasets import DatasetDict

orig_colnames = ds.column_names
ds = ds.map(tokenize_row, num_proc=12, remove_columns=orig_colnames)

ds = DatasetDict({
    "train": ds
})
ds.save_to_disk("../datasets/tokenized_expresso")