# Seseme CSM Fine-Tuning

This notebook demonstrates how to fine-tune the Seseme CSM model using the Hugging Face `transformers` library.

## 1. Setup

Connect to the Hugging Face Hub and install the dataset. The dataset used is in `parquet` format, which is efficient for large datasets.

In [1]:
from huggingface_hub import login
import os

# initialize login token
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [2]:
!git clone https://huggingface.co/datasets/MrDragonFox/Elise

Cloning into 'Elise'...
remote: Enumerating objects: 17, done.[K
remote: Total 17 (delta 0), reused 0 (delta 0), pack-reused 17 (from 1)[K
Unpacking objects: 100% (17/17), 5.95 KiB | 1.98 MiB/s, done.


## 2. Preprocessing

### 2.1 Load The Model

In [3]:
from transformers import CsmForConditionalGeneration, AutoProcessor, Trainer, TrainingArguments
from datasets import load_dataset, Audio
import torch
import numpy as np
from tqdm import tqdm
import os

model_id = "sesame/csm-1b"
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the model and processor
processor = AutoProcessor.from_pretrained(model_id)
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
model.train()
model.codec_model.eval()  # Keep codec model in eval mode during training

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/271 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/449 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/2.00k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/3.28k [00:00<?, ?B/s]

transformers.safetensors.index.json:   0%|          | 0.00/59.7k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

transformers-00001-of-00002.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

transformers-00002-of-00002.safetensors:   0%|          | 0.00/2.19G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/264 [00:00<?, ?B/s]

MimiModel(
  (encoder): MimiEncoder(
    (layers): ModuleList(
      (0): MimiConv1d(
        (conv): Conv1d(1, 64, kernel_size=(7,), stride=(1,))
      )
      (1): MimiResnetBlock(
        (block): ModuleList(
          (0): ELU(alpha=1.0)
          (1): MimiConv1d(
            (conv): Conv1d(64, 32, kernel_size=(3,), stride=(1,))
          )
          (2): ELU(alpha=1.0)
          (3): MimiConv1d(
            (conv): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
          )
        )
        (shortcut): Identity()
      )
      (2): ELU(alpha=1.0)
      (3): MimiConv1d(
        (conv): Conv1d(64, 128, kernel_size=(8,), stride=(4,))
      )
      (4): MimiResnetBlock(
        (block): ModuleList(
          (0): ELU(alpha=1.0)
          (1): MimiConv1d(
            (conv): Conv1d(128, 64, kernel_size=(3,), stride=(1,))
          )
          (2): ELU(alpha=1.0)
          (3): MimiConv1d(
            (conv): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
          )
        )
        (s

### 2.2 Load The Data

In [4]:
import pandas as pd
from datasets import Dataset

# Path to the dataset after git clone
dataset_path = os.path.join("Elise", "data", "train-00000-of-00001.parquet")

df = pd.read_parquet(dataset_path)
ds = {"train": Dataset.from_pandas(df)}
print(f"Dataset structure: {ds}")

# Ensure the audio is 24kHz (CSM requirement)
# Check if 'audio' column exists, otherwise look for the correct audio column
audio_column = "audio"
if audio_column in ds["train"].column_names:
    ds["train"] = ds["train"].cast_column(audio_column, Audio(sampling_rate=24000))
else:
    print(f"Warning: Column '{audio_column}' not found. Available columns: {ds['train'].column_names}")
    # Try to identify the audio column if it has a different name
    # You might need to adjust this based on your dataset structure

Dataset structure: {'train': Dataset({
    features: ['audio', 'text'],
    num_rows: 1195
})}


In [5]:
def prepare_conversation_batch(batch_size=4, offset=0):
    """Prepare a batch of conversations from the dataset"""
    batch_indices = list(range(offset, min(offset + batch_size, len(ds["train"]))))
    conversation = []

    # Use a stringified integer for the speaker ID (e.g., "0")
    speaker_id = "0"  # Changed from "Elise" to "0"

    for idx in batch_indices:
        example = ds["train"][idx]

        # Extract text and audio from the dataset
        text = example.get("text", "")

        # Handle audio data correctly based on dataset structure
        if "audio" in example and isinstance(example["audio"], dict) and "array" in example["audio"]:
            audio_data = example["audio"]["array"]
        elif "audio" in example:
            # If audio is directly accessible
            audio_data = example["audio"]
        else:
            # Try to find audio under a different key
            print(f"Warning: Audio not found in example. Available keys: {example.keys()}")
            continue

        conversation.append({
            "role": speaker_id,  # Now using a stringified integer
            "content": [
                {"type": "text", "text": text},
                {"type": "audio", "path": audio_data}
            ],
        })

    return conversation

## 3. Modeling

In [6]:
# Add these imports at the top
from peft import LoraConfig, get_peft_model, TaskType
from peft import PeftModel, PeftConfig

# Replace the model loading section
model_id = "sesame/csm-1b"
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the model and processor
processor = AutoProcessor.from_pretrained(model_id)
base_model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)

# Configure LoRA
lora_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,  # or TaskType.CAUSAL_LM depending on the model
    inference_mode=False,
    r=16,  # rank - lower values (8-32) for more conservative fine-tuning
    lora_alpha=32,  # scaling parameter
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],  # target attention and MLP layers
    bias="none",
)

# Apply LoRA to the model
model = get_peft_model(base_model, lora_config)
model.train()
model.base_model.codec_model.eval()  # Keep codec model in eval mode

# Print trainable parameters
model.print_trainable_parameters()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 14,516,224 || all params: 1,646,616,385 || trainable%: 0.8816


In [7]:
# Training configurations
output_dir = "./csm_elise_lora_model"
num_train_epochs = 4
per_device_train_batch_size = 2  # CSM models can be memory intensive
gradient_accumulation_steps = 2
learning_rate = 1e-4
warmup_steps = 100
logging_steps = 10
save_steps = 500
max_steps = 1000  # Adjust based on dataset size
weight_decay = 0.01  # Add weight decay
max_grad_norm = 1.0  # Add gradient clipping

# Initialize training arguments
training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    learning_rate=learning_rate,
    warmup_steps=warmup_steps,
    logging_steps=logging_steps,
    save_steps=save_steps,
    max_steps=max_steps,
    fp16=True,  # Use mixed precision training
    remove_unused_columns=False,
    report_to="tensorboard",
)

# Custom training loop (as alternative to using Trainer)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer,
                                             start_factor=1.0,
                                             end_factor=0.0,
                                             total_iters=max_steps)

# Training loop
model.train()
total_batches = min(max_steps, len(ds["train"]) // per_device_train_batch_size)

for step in tqdm(range(total_batches)):
    batch_offset = step * per_device_train_batch_size
    conversation = prepare_conversation_batch(per_device_train_batch_size, batch_offset)

    # Skip empty conversations
    if not conversation:
        continue

    # Process the conversation batch
    try:
        inputs = processor.apply_chat_template(
            conversation,
            tokenize=True,
            return_dict=True,
            output_labels=True,
        ).to(device)

        # Forward pass
        outputs = model(**inputs)
        loss = outputs.loss

        # Backward pass
        loss.backward()

        # Add gradient clipping here
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        # Optimizer step with gradient accumulation
        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        # Log progress
        if step % logging_steps == 0:
            print(f"Step {step}, Loss: {loss.item()}")

        # Save checkpoint
        if step % save_steps == 0 and step > 0:
            model.save_pretrained(f"{output_dir}/checkpoint-{step}")
            processor.save_pretrained(f"{output_dir}/checkpoint-{step}")

    except Exception as e:
        print(f"Error in batch at offset {batch_offset}: {e}")
        continue

  0%|          | 1/597 [00:15<2:37:36, 15.87s/it]

Step 0, Loss: 4.28127384185791


  2%|▏         | 11/597 [00:20<05:25,  1.80it/s]

Step 10, Loss: 5.3134284019470215


  4%|▎         | 21/597 [00:24<04:24,  2.18it/s]

Step 20, Loss: 4.729450225830078


  5%|▌         | 31/597 [00:29<04:09,  2.27it/s]

Step 30, Loss: 5.607455730438232


  7%|▋         | 41/597 [00:33<04:04,  2.28it/s]

Step 40, Loss: 5.40264892578125


  9%|▊         | 51/597 [00:38<04:01,  2.27it/s]

Step 50, Loss: 5.180002212524414


 10%|█         | 61/597 [00:42<03:38,  2.45it/s]

Step 60, Loss: 5.4841461181640625


 12%|█▏        | 71/597 [00:46<03:55,  2.24it/s]

Step 70, Loss: 4.891551494598389


 14%|█▎        | 81/597 [00:51<03:37,  2.38it/s]

Step 80, Loss: 5.411448001861572


 15%|█▌        | 91/597 [00:55<03:41,  2.29it/s]

Step 90, Loss: 5.713310241699219


 17%|█▋        | 101/597 [00:59<03:37,  2.28it/s]

Step 100, Loss: 4.976541519165039


 19%|█▊        | 111/597 [01:03<03:20,  2.42it/s]

Step 110, Loss: 4.600924491882324


 20%|██        | 121/597 [01:08<03:22,  2.35it/s]

Step 120, Loss: 4.553371429443359


 22%|██▏       | 131/597 [01:12<03:17,  2.36it/s]

Step 130, Loss: 2.891777753829956


 24%|██▎       | 141/597 [01:17<03:33,  2.14it/s]

Step 140, Loss: 5.7295331954956055


 25%|██▌       | 151/597 [01:21<03:03,  2.43it/s]

Step 150, Loss: 5.16084623336792


 27%|██▋       | 161/597 [01:25<02:54,  2.50it/s]

Step 160, Loss: 4.664568901062012


 29%|██▊       | 171/597 [01:29<03:09,  2.25it/s]

Step 170, Loss: 5.068541049957275


 30%|███       | 181/597 [01:34<02:57,  2.35it/s]

Step 180, Loss: 3.6774778366088867


 32%|███▏      | 191/597 [01:38<02:48,  2.40it/s]

Step 190, Loss: 5.533851623535156


 34%|███▎      | 201/597 [01:42<02:53,  2.28it/s]

Step 200, Loss: 4.424760341644287


 35%|███▌      | 211/597 [01:47<02:54,  2.22it/s]

Step 210, Loss: 5.0877532958984375


 37%|███▋      | 221/597 [01:51<02:39,  2.35it/s]

Step 220, Loss: 4.218507766723633


 39%|███▊      | 231/597 [01:55<02:35,  2.36it/s]

Step 230, Loss: 5.3765950202941895


 40%|████      | 241/597 [02:00<02:24,  2.47it/s]

Step 240, Loss: 5.467538833618164


 42%|████▏     | 251/597 [02:04<02:40,  2.15it/s]

Step 250, Loss: 4.970861434936523


 44%|████▎     | 261/597 [02:09<02:25,  2.30it/s]

Step 260, Loss: 4.6631083488464355


 45%|████▌     | 271/597 [02:13<02:15,  2.41it/s]

Step 270, Loss: 5.257638931274414


 47%|████▋     | 281/597 [02:17<02:13,  2.36it/s]

Step 280, Loss: 4.930272102355957


 49%|████▊     | 291/597 [02:22<02:18,  2.20it/s]

Step 290, Loss: 4.167793273925781


 50%|█████     | 301/597 [02:26<02:09,  2.29it/s]

Step 300, Loss: 5.005581378936768


 52%|█████▏    | 311/597 [02:30<02:01,  2.35it/s]

Step 310, Loss: 5.097028732299805


 54%|█████▍    | 321/597 [02:35<01:55,  2.39it/s]

Step 320, Loss: 4.555657863616943


 55%|█████▌    | 331/597 [02:39<02:00,  2.22it/s]

Step 330, Loss: 4.74074125289917


 57%|█████▋    | 341/597 [02:43<01:50,  2.32it/s]

Step 340, Loss: 4.319786071777344


 59%|█████▉    | 351/597 [02:48<01:50,  2.22it/s]

Step 350, Loss: 5.0367021560668945


 60%|██████    | 361/597 [02:52<01:41,  2.33it/s]

Step 360, Loss: 4.4825286865234375


 62%|██████▏   | 371/597 [02:56<01:31,  2.46it/s]

Step 370, Loss: 4.970941543579102


 64%|██████▍   | 381/597 [03:01<01:34,  2.28it/s]

Step 380, Loss: 4.586134910583496


 65%|██████▌   | 391/597 [03:05<01:26,  2.39it/s]

Step 390, Loss: 4.4637885093688965


 67%|██████▋   | 401/597 [03:09<01:20,  2.42it/s]

Step 400, Loss: 4.561644077301025


 69%|██████▉   | 411/597 [03:14<01:17,  2.39it/s]

Step 410, Loss: 3.348134994506836


 71%|███████   | 421/597 [03:18<01:13,  2.39it/s]

Step 420, Loss: 4.40899658203125


 72%|███████▏  | 431/597 [03:22<01:08,  2.42it/s]

Step 430, Loss: 4.389047622680664


 74%|███████▍  | 441/597 [03:26<01:09,  2.25it/s]

Step 440, Loss: 5.778710842132568


 76%|███████▌  | 451/597 [03:31<01:06,  2.20it/s]

Step 450, Loss: 4.798888683319092


 77%|███████▋  | 461/597 [03:35<00:59,  2.27it/s]

Step 460, Loss: 5.07237434387207


 79%|███████▉  | 471/597 [03:40<00:54,  2.31it/s]

Step 470, Loss: 4.579644203186035


 81%|████████  | 481/597 [03:44<00:50,  2.28it/s]

Step 480, Loss: 4.798603057861328


 82%|████████▏ | 491/597 [03:48<00:45,  2.33it/s]

Step 490, Loss: 4.950549125671387


 84%|████████▍ | 500/597 [03:52<00:40,  2.41it/s]

Step 500, Loss: 4.042849540710449


 86%|████████▌ | 511/597 [03:58<00:38,  2.22it/s]

Step 510, Loss: 4.0877299308776855


 87%|████████▋ | 521/597 [04:02<00:31,  2.38it/s]

Step 520, Loss: 5.257163047790527


 89%|████████▉ | 531/597 [04:07<00:30,  2.14it/s]

Step 530, Loss: 4.808886528015137


 91%|█████████ | 541/597 [04:11<00:23,  2.39it/s]

Step 540, Loss: 5.373443603515625


 92%|█████████▏| 551/597 [04:16<00:19,  2.32it/s]

Step 550, Loss: 5.503174781799316


 94%|█████████▍| 561/597 [04:20<00:15,  2.30it/s]

Step 560, Loss: 5.062353134155273


 96%|█████████▌| 571/597 [04:24<00:11,  2.32it/s]

Step 570, Loss: 4.094350337982178


 97%|█████████▋| 581/597 [04:28<00:06,  2.37it/s]

Step 580, Loss: 5.190174579620361


 99%|█████████▉| 591/597 [04:32<00:02,  2.38it/s]

Step 590, Loss: 4.886165618896484


100%|██████████| 597/597 [04:35<00:00,  2.17it/s]


Now save the trained model in local directory.

In [8]:
# Update the model saving section
# Save the LoRA adapter (much smaller than full model)
model.save_pretrained(f"{output_dir}/final")
processor.save_pretrained(f"{output_dir}/final")

# Also save the base model reference for easier loading later
with open(f"{output_dir}/final/base_model_name.txt", "w") as f:
    f.write(model_id)

## 4. Inference

Try the model with some sample inputs to see how it performs.

In [9]:
# Modified load_peft_model_for_inference function
def load_peft_model_for_inference(adapter_path, base_model_id=None):
    if base_model_id is None:
        try:
            with open(f"{adapter_path}/base_model_name.txt", "r") as f:
                base_model_id = f.read().strip()
        except:
            base_model_id = "sesame/csm-1b"  # fallback

    # Load base model
    base_model = CsmForConditionalGeneration.from_pretrained(base_model_id, device_map=device)

    # Load and merge PEFT model
    model = PeftModel.from_pretrained(base_model, adapter_path)
    # This is the key step - merge weights to remove adapter overhead
    merged_model = model.merge_and_unload()
    return merged_model

In [10]:
from transformers import BitsAndBytesConfig

def load_optimized_peft_model(adapter_path, base_model_id=None):
    # Similar setup as before
    base_model = CsmForConditionalGeneration.from_pretrained(
        base_model_id,
        device_map=device,
        torch_dtype=torch.float16,  # Use FP16 for faster inference
    )

    model = PeftModel.from_pretrained(base_model, adapter_path)
    merged_model = model.merge_and_unload()

    # Now convert to half precision after merging
    merged_model = merged_model.half()
    return merged_model

In [11]:
# Set these optimization parameters for generation
def optimize_generation_config(model):
    model.generation_config.max_length = 256  # Or appropriate value for your needs
    model.generation_config.use_cache = True
    model.generation_config.cache_implementation = "static"

    # If model has a depth decoder
    if hasattr(model, "depth_decoder"):
        model.depth_decoder.generation_config.cache_implementation = "static"

    return model

# Use after loading the model
model = load_peft_model_for_inference(f"{output_dir}/final")
model = optimize_generation_config(model)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### 4.1 Generate a Sentence

In [12]:
from IPython.display import Audio, display
import soundfile as sf

os.makedirs("inference", exist_ok=True)

# another equivalent way to prepare the inputs
conversation = [
    {"role": "0", "content": [{"type": "text", "text": "The past is just a story we tell ourselves. You're making me blush <giggles>. I love you! <giggles>"}]},
]

inputs = processor.apply_chat_template(
    conversation,
    tokenize=True,
    return_dict=True,
).to(device)

# infer the model
audio = model.generate(**inputs, output_audio=True)

# Move tensor from GPU to CPU before converting to numpy
audio_cpu = audio[0].to(torch.float32).cpu().numpy()  # This fixes the TypeError

# Write to file and play
output_file = "inference/output_nc.wav"
sf.write(output_file, audio_cpu, 24000)
display(Audio(output_file))



In [13]:
conversation = [
    {"role": "0", "content": [{"type": "text", "text": "You're making me blush! I'm climaxing! Coming! Ohh! That's so good! More. Ohh. More. Go harder!"}]},
]

inputs = processor.apply_chat_template(
    conversation,
    tokenize=True,
    return_dict=True,
).to(device)

# infer the model
audio = model.generate(**inputs, output_audio=True)

# Move tensor from GPU to CPU before converting to numpy
audio_cpu = audio[0].to(torch.float32).cpu().numpy()  # This fixes the TypeError

# Write to file and play
output_file = "inference/output_2.wav"
sf.write(output_file, audio_cpu, 24000)
display(Audio(output_file))

### 4.2 Sound With Context

Audio influence by giving a context to the model.

In [14]:
conversation = []

# 1. context
# This code takes the first 4 examples from the dataset to provide context
for text, audio in zip(ds["train"][:4]["text"], ds["train"][:4]["audio"]):
    conversation.append(
        {
            "role": f"0",
            "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
        }
    )

# 2. text prompt
conversation.append({"role": "0", "content": [{"type": "text", "text": ds["train"][4]["text"]}]})

inputs = processor.apply_chat_template(
    conversation,
    tokenize=True,
    return_dict=True,
).to(device)

# infer the model
audio = model.generate(**inputs, output_audio=True)

# Move tensor from GPU to CPU before converting to numpy
audio_cpu = audio[0].to(torch.float32).cpu().numpy()  # This fixes the TypeError

# Write to file and play
output_file = "inference/output_wc.wav"
sf.write(output_file, audio_cpu, 24000)
display(Audio(output_file))

### 4.3 Batch Inference

Different audio characteristics:
- Conversation 1: Audio influenced by the voice context from ds[0]["audio"]
- Conversation 2: Audio generated without voice context (using default model voice)

In [15]:
conversation = [
    [
        {
            "role": f"0",
            "content": [
                {"type": "text", "text": ds["train"][0]["text"]},
                {"type": "audio", "path": ds["train"][0]["audio"]["array"]},
            ],
        },
        {
            "role": f"0",
            "content": [
                {"type": "text", "text": ds["train"][1]["text"]},
            ],
        },
    ],
    [
        {
            "role": f"0",
            "content": [
                {"type": "text", "text": ds["train"][0]["text"]},
            ],
        }
    ],
]

inputs = processor.apply_chat_template(
    conversation,
    tokenize=True,
    return_dict=True,
).to(device)

# infer the model
audio = model.generate(**inputs, output_audio=True)

# Process all batch outputs
for i, audio_output in enumerate(audio):
    audio_cpu = audio_output.to(torch.float32).cpu().numpy()
    output_file = f"inference/output_batch_{i}.wav"
    sf.write(output_file, audio_cpu, 24000)
    print(f"Saved: {output_file}")
    display(Audio(output_file))

Saved: inference/output_batch_0.wav


Saved: inference/output_batch_1.wav


In [16]:
# use static cache, enabling automatically torch compile with fullgraph and reduce-overhead
model.generation_config.max_length = 512 # big enough to avoid recompilation
model.generation_config.max_new_tokens = None # would take precedence over max_length
model.generation_config.cache_implementation = "static"
model.depth_decoder.generation_config.cache_implementation = "static"

# generation kwargs
gen_kwargs = {
    "do_sample": False,
    "depth_decoder_do_sample": False,
    "temperature": 1.0,
    "depth_decoder_temperature": 1.0,
}

# Define a timing decorator
class TimerContext:
    def __init__(self, name="Execution"):
        self.name = name
        self.start_event = None
        self.end_event = None

    def __enter__(self):
        # Use CUDA events for more accurate GPU timing
        self.start_event = torch.cuda.Event(enable_timing=True)
        self.end_event = torch.cuda.Event(enable_timing=True)
        self.start_event.record()
        return self

    def __exit__(self, *args):
        self.end_event.record()
        torch.cuda.synchronize()
        elapsed_time = self.start_event.elapsed_time(self.end_event) / 1000.0
        print(f"{self.name} time: {elapsed_time:.4f} seconds")

conversation = [
    {
        "role": f"0",
        "content": [
            {"type": "text", "text": ds["train"][0]["text"]},
            {"type": "audio", "path": ds["train"][0]["audio"]["array"]},
        ],
    },
    {
        "role": f"0",
        "content": [
            {"type": "text", "text": ds["train"][1]["text"]},
            {"type": "audio", "path": ds["train"][1]["audio"]["array"]},
        ],
    },
    {
        "role": f"0",
        "content": [
            {"type": "text", "text": ds["train"][2]["text"]},
        ],
    },
]

padded_inputs_1 = processor.apply_chat_template(
    conversation,
    tokenize=True,
    return_dict=True,
).to(device)

print("\n" + "="*50)
print("First generation - compiling and recording CUDA graphs...")
with TimerContext("First generation"):
    _ = model.generate(**padded_inputs_1, **gen_kwargs)
print("="*50)

print("\n" + "="*50)
print("Second generation - fast !!!")
with TimerContext("Second generation"):
    _ = model.generate(**padded_inputs_1, **gen_kwargs)
print("="*50)

# now with different inputs
conversation = [
    {
        "role": f"{0}",
        "content": [
            {"type": "text", "text": ds["train"][2]["text"]},
            {"type": "audio", "path": ds["train"][2]["audio"]["array"]},
        ],
    },
    {
        "role": f"{0}",
        "content": [
            {"type": "text", "text": ds["train"][3]["text"]},
            {"type": "audio", "path": ds["train"][3]["audio"]["array"]},
        ],
    },
    {
        "role": f"{0}",
        "content": [
            {"type": "text", "text": ds["train"][4]["text"]},
        ],
    },
]
padded_inputs_2 = processor.apply_chat_template(
    conversation,
    tokenize=True,
    return_dict=True,
).to(device)

print("\n" + "="*50)
print("Generation with other inputs!")
with TimerContext("Generation with different inputs"):
    _ = model.generate(**padded_inputs_2, **gen_kwargs)
print("="*50)


First generation - compiling and recording CUDA graphs...
First generation time: 36.5081 seconds

Second generation - fast !!!
Second generation time: 12.6307 seconds

Generation with other inputs!
Generation with different inputs time: 18.2411 seconds


## Save To Hugging Face Hub

Finally, save the model to the Hugging Face Hub for future use or sharing.

In [18]:
from huggingface_hub import HfApi

# Define your model repository name (create this on the HF website first)
model_name = "keanteng/sesame-csm-elise-lora"  # Replace with your desired repo name

# Alternatively, push an already saved model directory
api = HfApi()
api.upload_folder(
    folder_path=f"{output_dir}/final",
    repo_id=model_name,
    repo_type="model"
)

adapter_model.safetensors:   0%|          | 0.00/58.1M [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/keanteng/sesame-csm-elise-lora/commit/b4282241e60e4b850b50097f936c69b09cd27123', commit_message='Upload folder using huggingface_hub', commit_description='', oid='b4282241e60e4b850b50097f936c69b09cd27123', pr_url=None, repo_url=RepoUrl('https://huggingface.co/keanteng/sesame-csm-elise-lora', endpoint='https://huggingface.co', repo_type='model', repo_id='keanteng/sesame-csm-elise-lora'), pr_revision=None, pr_num=None)

In [19]:
# Add a model card with description
with open(f"{output_dir}/README.md", "w") as f:
    f.write("""
---
license: agpl-3.0
datasets:
- MrDragonFox/Elise
language:
- en
base_model:
- sesame/csm-1b
pipeline_tag: text-to-speech
library_name: transformers
tags:
- generative-ai
---
# CSM Elise Voice Model LoRA

This model is a fine-tuned version of [sesame/csm-1b](https://huggingface.co/sesame/csm-1b) using the [Elise dataset](https://huggingface.co/datasets/MrDragonFox/Elise) with LoRA. There are sample outputs files in the repository.

The sound quality seems to be better than tuning on full-parameters. However, more tweaking would be needed to ensure consistent performance.

## Model Details
- **Base Model**: sesame/csm-1b
- **Training Data**: MrDragonFox/Elise dataset
- **Fine-tuning Approach**: Voice cloning through conditional speech generation using LoRA
- **Voice Characteristics**: [Describe voice qualities]
- **Training Parameters**:
  - Learning Rate: 1e-5
  - Epochs: 4
  - Batch Size: 1 with gradient accumulation steps of 4
""")

In [20]:
api.upload_file(
    path_or_fileobj=f"{output_dir}/README.md",
    path_in_repo="README.md",
    repo_id=model_name,
    repo_type="model"
)

CommitInfo(commit_url='https://huggingface.co/keanteng/sesame-csm-elise-lora/commit/9f8ada4841d94b1e7fc832ecbc54b4b2ab38f3e6', commit_message='Upload README.md with huggingface_hub', commit_description='', oid='9f8ada4841d94b1e7fc832ecbc54b4b2ab38f3e6', pr_url=None, repo_url=RepoUrl('https://huggingface.co/keanteng/sesame-csm-elise-lora', endpoint='https://huggingface.co', repo_type='model', repo_id='keanteng/sesame-csm-elise-lora'), pr_revision=None, pr_num=None)

In [21]:
api.upload_folder(
    folder_path=f"inference",
    repo_id=f"{model_name}",
    repo_type="model",
    path_in_repo="inference"
)

output_2.wav:   0%|          | 0.00/480k [00:00<?, ?B/s]

output_batch_1.wav:   0%|          | 0.00/480k [00:00<?, ?B/s]

output_wc.wav:   0%|          | 0.00/265k [00:00<?, ?B/s]

output_batch_0.wav:   0%|          | 0.00/469k [00:00<?, ?B/s]

output_nc.wav:   0%|          | 0.00/480k [00:00<?, ?B/s]

Upload 5 LFS files:   0%|          | 0/5 [00:00<?, ?it/s]

CommitInfo(commit_url='https://huggingface.co/keanteng/sesame-csm-elise-lora/commit/039fef9c408e708c12e9163b7f68e0c74b88e7d7', commit_message='Upload folder using huggingface_hub', commit_description='', oid='039fef9c408e708c12e9163b7f68e0c74b88e7d7', pr_url=None, repo_url=RepoUrl('https://huggingface.co/keanteng/sesame-csm-elise-lora', endpoint='https://huggingface.co', repo_type='model', repo_id='keanteng/sesame-csm-elise-lora'), pr_revision=None, pr_num=None)

## 5. Other

In [22]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA A100-SXM4-40GB. Max memory = 39.557 GB.
15.564 GB of memory reserved.
