# Seseme CSM Fine-Tuning

This notebook demonstrates how to fine-tune the Seseme CSM model using the Hugging Face `transformers` library.

## 1. Setup

Connect to the Hugging Face Hub and install the dataset. The dataset used is in `parquet` format, which is efficient for large datasets.

In [6]:
from huggingface_hub import login
import os

# initialize login token
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [7]:
!git clone https://huggingface.co/datasets/MrDragonFox/Elise

Cloning into 'Elise'...
remote: Enumerating objects: 17, done.[K
remote: Total 17 (delta 0), reused 0 (delta 0), pack-reused 17 (from 1)[K
Unpacking objects: 100% (17/17), 5.95 KiB | 1.98 MiB/s, done.


## 2. Preprocessing

### 2.1 Load The Model

In [8]:
from transformers import CsmForConditionalGeneration, AutoProcessor, Trainer, TrainingArguments
from datasets import load_dataset, Audio
import torch
import numpy as np
from tqdm import tqdm
import os

model_id = "sesame/csm-1b"
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the model and processor
processor = AutoProcessor.from_pretrained(model_id)
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
model.train()
model.codec_model.eval()  # Keep codec model in eval mode during training

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/271 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/449 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/2.00k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/3.28k [00:00<?, ?B/s]

transformers.safetensors.index.json:   0%|          | 0.00/59.7k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

transformers-00001-of-00002.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

transformers-00002-of-00002.safetensors:   0%|          | 0.00/2.19G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/264 [00:00<?, ?B/s]

MimiModel(
  (encoder): MimiEncoder(
    (layers): ModuleList(
      (0): MimiConv1d(
        (conv): Conv1d(1, 64, kernel_size=(7,), stride=(1,))
      )
      (1): MimiResnetBlock(
        (block): ModuleList(
          (0): ELU(alpha=1.0)
          (1): MimiConv1d(
            (conv): Conv1d(64, 32, kernel_size=(3,), stride=(1,))
          )
          (2): ELU(alpha=1.0)
          (3): MimiConv1d(
            (conv): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
          )
        )
        (shortcut): Identity()
      )
      (2): ELU(alpha=1.0)
      (3): MimiConv1d(
        (conv): Conv1d(64, 128, kernel_size=(8,), stride=(4,))
      )
      (4): MimiResnetBlock(
        (block): ModuleList(
          (0): ELU(alpha=1.0)
          (1): MimiConv1d(
            (conv): Conv1d(128, 64, kernel_size=(3,), stride=(1,))
          )
          (2): ELU(alpha=1.0)
          (3): MimiConv1d(
            (conv): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
          )
        )
        (s

### 2.2 Load The Data

In [10]:
import pandas as pd
from datasets import Dataset

# Path to the dataset after git clone
dataset_path = os.path.join("Elise", "data", "train-00000-of-00001.parquet")

df = pd.read_parquet(dataset_path)
ds = {"train": Dataset.from_pandas(df)}
print(f"Dataset structure: {ds}")

# Ensure the audio is 24kHz (CSM requirement)
# Check if 'audio' column exists, otherwise look for the correct audio column
audio_column = "audio"
if audio_column in ds["train"].column_names:
    ds["train"] = ds["train"].cast_column(audio_column, Audio(sampling_rate=24000))
else:
    print(f"Warning: Column '{audio_column}' not found. Available columns: {ds['train'].column_names}")
    # Try to identify the audio column if it has a different name
    # You might need to adjust this based on your dataset structure

Dataset structure: {'train': Dataset({
    features: ['audio', 'text'],
    num_rows: 1195
})}


In [11]:
def prepare_conversation_batch(batch_size=4, offset=0):
    """Prepare a batch of conversations from the dataset"""
    batch_indices = list(range(offset, min(offset + batch_size, len(ds["train"]))))
    conversation = []

    # Use a stringified integer for the speaker ID (e.g., "0")
    speaker_id = "0"  # Changed from "Elise" to "0"

    for idx in batch_indices:
        example = ds["train"][idx]

        # Extract text and audio from the dataset
        text = example.get("text", "")

        # Handle audio data correctly based on dataset structure
        if "audio" in example and isinstance(example["audio"], dict) and "array" in example["audio"]:
            audio_data = example["audio"]["array"]
        elif "audio" in example:
            # If audio is directly accessible
            audio_data = example["audio"]
        else:
            # Try to find audio under a different key
            print(f"Warning: Audio not found in example. Available keys: {example.keys()}")
            continue

        conversation.append({
            "role": speaker_id,  # Now using a stringified integer
            "content": [
                {"type": "text", "text": text},
                {"type": "audio", "path": audio_data}
            ],
        })

    return conversation

## 3. Modeling

In [None]:
# Add these imports at the top
from peft import LoraConfig, get_peft_model, TaskType
from peft import PeftModel, PeftConfig

# Replace the model loading section
model_id = "sesame/csm-1b"
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the model and processor
processor = AutoProcessor.from_pretrained(model_id)
base_model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)

# Configure LoRA
lora_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,  # or TaskType.CAUSAL_LM depending on the model
    inference_mode=False,
    r=16,  # rank - lower values (8-32) for more conservative fine-tuning
    lora_alpha=32,  # scaling parameter
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],  # target attention and MLP layers
    bias="none",
)

# Apply LoRA to the model
model = get_peft_model(base_model, lora_config)
model.train()
model.base_model.codec_model.eval()  # Keep codec model in eval mode

# Print trainable parameters
model.print_trainable_parameters()

In [None]:
# Training configurations
output_dir = "./csm_elise_lora_model"
num_train_epochs = 5
per_device_train_batch_size = 2  # CSM models can be memory intensive
gradient_accumulation_steps = 2
learning_rate = 1e-4
warmup_steps = 100
logging_steps = 10
save_steps = 500
max_steps = 1000  # Adjust based on dataset size
weight_decay = 0.01  # Add weight decay
max_grad_norm = 1.0  # Add gradient clipping

# Initialize training arguments
training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    learning_rate=learning_rate,
    warmup_steps=warmup_steps,
    logging_steps=logging_steps,
    save_steps=save_steps,
    max_steps=max_steps,
    fp16=True,  # Use mixed precision training
    remove_unused_columns=False,
    report_to="tensorboard",
)

# Custom training loop (as alternative to using Trainer)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer,
                                             start_factor=1.0,
                                             end_factor=0.0,
                                             total_iters=max_steps)

# Training loop
model.train()
total_batches = min(max_steps, len(ds["train"]) // per_device_train_batch_size)

for step in tqdm(range(total_batches)):
    batch_offset = step * per_device_train_batch_size
    conversation = prepare_conversation_batch(per_device_train_batch_size, batch_offset)

    # Skip empty conversations
    if not conversation:
        continue

    # Process the conversation batch
    try:
        inputs = processor.apply_chat_template(
            conversation,
            tokenize=True,
            return_dict=True,
            output_labels=True,
        ).to(device)

        # Forward pass
        outputs = model(**inputs)
        loss = outputs.loss

        # Backward pass
        loss.backward()

        # Add gradient clipping here
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        # Optimizer step with gradient accumulation
        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        # Log progress
        if step % logging_steps == 0:
            print(f"Step {step}, Loss: {loss.item()}")

        # Save checkpoint
        if step % save_steps == 0 and step > 0:
            model.save_pretrained(f"{output_dir}/checkpoint-{step}")
            processor.save_pretrained(f"{output_dir}/checkpoint-{step}")

    except Exception as e:
        print(f"Error in batch at offset {batch_offset}: {e}")
        continue

  0%|          | 1/1195 [00:16<5:24:44, 16.32s/it]

Step 0, Loss: 5.28347635269165


  1%|          | 11/1195 [00:19<09:07,  2.16it/s]

Step 10, Loss: 6.476818561553955


  2%|▏         | 21/1195 [00:22<06:44,  2.90it/s]

Step 20, Loss: 5.606250762939453


  3%|▎         | 31/1195 [00:26<06:37,  2.93it/s]

Step 30, Loss: 4.317570686340332


  3%|▎         | 41/1195 [00:29<06:54,  2.78it/s]

Step 40, Loss: 5.271435260772705


  4%|▍         | 51/1195 [00:33<06:41,  2.85it/s]

Step 50, Loss: 5.881283760070801


  5%|▌         | 61/1195 [00:36<06:15,  3.02it/s]

Step 60, Loss: 5.251502990722656


  6%|▌         | 71/1195 [00:39<06:12,  3.02it/s]

Step 70, Loss: 4.708804130554199


  7%|▋         | 81/1195 [00:43<06:23,  2.91it/s]

Step 80, Loss: 5.438336372375488


  8%|▊         | 91/1195 [00:46<06:34,  2.80it/s]

Step 90, Loss: 4.88773250579834


  8%|▊         | 101/1195 [00:49<06:24,  2.84it/s]

Step 100, Loss: 5.323968887329102


  9%|▉         | 111/1195 [00:53<05:54,  3.06it/s]

Step 110, Loss: 5.825194358825684


 10%|█         | 121/1195 [00:56<05:54,  3.03it/s]

Step 120, Loss: 6.067062854766846


 11%|█         | 131/1195 [00:59<05:30,  3.22it/s]

Step 130, Loss: 5.846033096313477


 12%|█▏        | 141/1195 [01:02<06:07,  2.87it/s]

Step 140, Loss: 4.559338092803955


 13%|█▎        | 151/1195 [01:05<05:25,  3.21it/s]

Step 150, Loss: 6.940877914428711


 13%|█▎        | 161/1195 [01:09<05:31,  3.12it/s]

Step 160, Loss: 5.581697463989258


 14%|█▍        | 171/1195 [01:12<05:49,  2.93it/s]

Step 170, Loss: 5.092883586883545


 15%|█▌        | 181/1195 [01:15<05:55,  2.85it/s]

Step 180, Loss: 6.318438529968262


 16%|█▌        | 191/1195 [01:19<05:39,  2.96it/s]

Step 190, Loss: 5.866942405700684


 17%|█▋        | 201/1195 [01:22<05:33,  2.98it/s]

Step 200, Loss: 5.414992332458496


 18%|█▊        | 211/1195 [01:25<04:56,  3.32it/s]

Step 210, Loss: 5.862943172454834


 18%|█▊        | 221/1195 [01:28<05:06,  3.18it/s]

Step 220, Loss: 5.737164497375488


 19%|█▉        | 231/1195 [01:31<04:56,  3.25it/s]

Step 230, Loss: 4.967309474945068


 20%|██        | 241/1195 [01:34<05:27,  2.91it/s]

Step 240, Loss: 6.606116771697998


 21%|██        | 251/1195 [01:38<05:03,  3.11it/s]

Step 250, Loss: 5.059942245483398


 22%|██▏       | 261/1195 [01:41<04:58,  3.13it/s]

Step 260, Loss: 5.17879581451416


 23%|██▎       | 271/1195 [01:44<05:24,  2.85it/s]

Step 270, Loss: 5.672780990600586


 24%|██▎       | 281/1195 [01:48<05:32,  2.75it/s]

Step 280, Loss: 6.706904888153076


 24%|██▍       | 291/1195 [01:51<05:10,  2.91it/s]

Step 290, Loss: 5.15388298034668


 25%|██▌       | 301/1195 [01:54<04:49,  3.09it/s]

Step 300, Loss: 5.625014305114746


 26%|██▌       | 311/1195 [01:58<05:02,  2.92it/s]

Step 310, Loss: 6.568816184997559


 27%|██▋       | 321/1195 [02:01<04:38,  3.14it/s]

Step 320, Loss: 4.090836524963379


 28%|██▊       | 331/1195 [02:04<04:26,  3.25it/s]

Step 330, Loss: 6.187995910644531


 29%|██▊       | 341/1195 [02:07<05:11,  2.74it/s]

Step 340, Loss: 4.892672061920166


 29%|██▉       | 351/1195 [02:10<04:11,  3.35it/s]

Step 350, Loss: 4.67225456237793


 30%|███       | 361/1195 [02:14<04:54,  2.83it/s]

Step 360, Loss: 4.341012001037598


 31%|███       | 371/1195 [02:17<04:11,  3.27it/s]

Step 370, Loss: 4.848402500152588


 32%|███▏      | 381/1195 [02:20<04:31,  3.00it/s]

Step 380, Loss: 5.971383094787598


 33%|███▎      | 391/1195 [02:23<04:24,  3.04it/s]

Step 390, Loss: 3.987562417984009


 34%|███▎      | 401/1195 [02:27<04:29,  2.94it/s]

Step 400, Loss: 5.369024753570557


 34%|███▍      | 411/1195 [02:30<04:07,  3.17it/s]

Step 410, Loss: 5.155890464782715


 35%|███▌      | 421/1195 [02:33<04:32,  2.84it/s]

Step 420, Loss: 5.086121082305908


 36%|███▌      | 431/1195 [02:37<04:10,  3.05it/s]

Step 430, Loss: 4.73309326171875


 37%|███▋      | 441/1195 [02:40<04:17,  2.93it/s]

Step 440, Loss: 4.128088474273682


 38%|███▊      | 451/1195 [02:43<04:01,  3.08it/s]

Step 450, Loss: 5.206631660461426


 39%|███▊      | 461/1195 [02:46<04:07,  2.97it/s]

Step 460, Loss: 5.5451507568359375


 39%|███▉      | 471/1195 [02:50<04:11,  2.88it/s]

Step 470, Loss: 5.087254047393799


 40%|████      | 481/1195 [02:53<03:36,  3.30it/s]

Step 480, Loss: 5.402521133422852


 41%|████      | 491/1195 [02:56<03:24,  3.44it/s]

Step 490, Loss: 3.9647278785705566


 42%|████▏     | 500/1195 [02:59<03:53,  2.98it/s]

Step 500, Loss: 4.371260643005371


 43%|████▎     | 511/1195 [03:25<05:58,  1.91it/s]

Step 510, Loss: 5.3083577156066895


 44%|████▎     | 521/1195 [03:28<03:48,  2.95it/s]

Step 520, Loss: 5.97493839263916


 44%|████▍     | 531/1195 [03:32<03:25,  3.23it/s]

Step 530, Loss: 5.594260215759277


 45%|████▌     | 541/1195 [03:35<03:33,  3.06it/s]

Step 540, Loss: 5.933420181274414


 46%|████▌     | 551/1195 [03:38<03:45,  2.85it/s]

Step 550, Loss: 5.378222465515137


 47%|████▋     | 561/1195 [03:42<03:32,  2.99it/s]

Step 560, Loss: 6.183438301086426


 48%|████▊     | 571/1195 [03:45<03:43,  2.80it/s]

Step 570, Loss: 5.67246150970459


 49%|████▊     | 581/1195 [03:48<03:32,  2.88it/s]

Step 580, Loss: 4.091733455657959


 49%|████▉     | 591/1195 [03:52<03:06,  3.23it/s]

Step 590, Loss: 6.071618556976318


 50%|█████     | 601/1195 [03:55<03:37,  2.74it/s]

Step 600, Loss: 5.501613616943359


 51%|█████     | 611/1195 [03:58<03:08,  3.09it/s]

Step 610, Loss: 4.877788543701172


 52%|█████▏    | 621/1195 [04:02<03:20,  2.86it/s]

Step 620, Loss: 5.339558124542236


 53%|█████▎    | 631/1195 [04:05<02:57,  3.17it/s]

Step 630, Loss: 5.349918365478516


 54%|█████▎    | 641/1195 [04:08<03:07,  2.96it/s]

Step 640, Loss: 5.087094306945801


 54%|█████▍    | 651/1195 [04:11<02:37,  3.46it/s]

Step 650, Loss: 6.101537704467773


 55%|█████▌    | 661/1195 [04:14<03:15,  2.73it/s]

Step 660, Loss: 5.44284200668335


 56%|█████▌    | 671/1195 [04:17<02:40,  3.26it/s]

Step 670, Loss: 5.197319984436035


 57%|█████▋    | 681/1195 [04:21<02:59,  2.86it/s]

Step 680, Loss: 3.326293468475342


 58%|█████▊    | 691/1195 [04:24<02:52,  2.91it/s]

Step 690, Loss: 6.184392929077148


 59%|█████▊    | 701/1195 [04:27<02:53,  2.85it/s]

Step 700, Loss: 5.148952960968018


 59%|█████▉    | 711/1195 [04:31<02:47,  2.88it/s]

Step 710, Loss: 6.14536190032959


 60%|██████    | 721/1195 [04:34<02:42,  2.92it/s]

Step 720, Loss: 5.030327320098877


 61%|██████    | 731/1195 [04:37<02:28,  3.13it/s]

Step 730, Loss: 4.348564624786377


 62%|██████▏   | 741/1195 [04:41<02:24,  3.14it/s]

Step 740, Loss: 4.789517402648926


 63%|██████▎   | 751/1195 [04:44<02:28,  2.99it/s]

Step 750, Loss: 5.26044225692749


 64%|██████▎   | 761/1195 [04:47<02:29,  2.91it/s]

Step 760, Loss: 4.580572605133057


 65%|██████▍   | 771/1195 [04:50<02:13,  3.17it/s]

Step 770, Loss: 5.888899326324463


 65%|██████▌   | 781/1195 [04:54<02:23,  2.89it/s]

Step 780, Loss: 3.9898791313171387


 66%|██████▌   | 791/1195 [04:57<02:15,  2.98it/s]

Step 790, Loss: 5.3995866775512695


 67%|██████▋   | 801/1195 [05:00<02:07,  3.08it/s]

Step 800, Loss: 4.753012657165527


 68%|██████▊   | 811/1195 [05:03<02:04,  3.09it/s]

Step 810, Loss: 3.6708102226257324


 69%|██████▊   | 821/1195 [05:06<01:55,  3.23it/s]

Step 820, Loss: 0.9964476823806763


 70%|██████▉   | 831/1195 [05:10<01:57,  3.09it/s]

Step 830, Loss: 4.613960266113281


 70%|███████   | 841/1195 [05:13<02:05,  2.82it/s]

Step 840, Loss: 5.138670444488525


 71%|███████   | 851/1195 [05:16<01:40,  3.41it/s]

Step 850, Loss: 4.167598724365234


 72%|███████▏  | 861/1195 [05:19<01:45,  3.15it/s]

Step 860, Loss: 3.6851160526275635


 73%|███████▎  | 871/1195 [05:22<01:38,  3.30it/s]

Step 870, Loss: 5.804029941558838


 74%|███████▎  | 881/1195 [05:26<01:54,  2.75it/s]

Step 880, Loss: 6.351037979125977


 75%|███████▍  | 891/1195 [05:29<01:46,  2.87it/s]

Step 890, Loss: 5.559860706329346


 75%|███████▌  | 901/1195 [05:32<01:43,  2.85it/s]

Step 900, Loss: 4.431290149688721


 76%|███████▌  | 911/1195 [05:36<01:32,  3.08it/s]

Step 910, Loss: 5.410870552062988


 77%|███████▋  | 921/1195 [05:39<01:37,  2.81it/s]

Step 920, Loss: 5.590543746948242


 78%|███████▊  | 931/1195 [05:42<01:25,  3.11it/s]

Step 930, Loss: 4.68243408203125


 79%|███████▊  | 941/1195 [05:45<01:28,  2.87it/s]

Step 940, Loss: 4.965478420257568


 80%|███████▉  | 951/1195 [05:49<01:16,  3.18it/s]

Step 950, Loss: 5.968648910522461


 80%|████████  | 961/1195 [05:52<01:23,  2.82it/s]

Step 960, Loss: 5.542519569396973


 81%|████████▏ | 971/1195 [05:55<01:15,  2.97it/s]

Step 970, Loss: 5.217014312744141


 82%|████████▏ | 981/1195 [05:59<01:11,  2.98it/s]

Step 980, Loss: 5.022400856018066


 83%|████████▎ | 991/1195 [06:02<01:02,  3.27it/s]

Step 990, Loss: 5.57076358795166


 84%|████████▎ | 1000/1195 [06:05<00:59,  3.28it/s]

Step 1000, Loss: 4.524172782897949


 85%|████████▍ | 1011/1195 [06:36<01:41,  1.82it/s]

Step 1010, Loss: 4.051425933837891


 85%|████████▌ | 1021/1195 [06:39<01:00,  2.87it/s]

Step 1020, Loss: 4.962061882019043


 86%|████████▋ | 1031/1195 [06:43<00:53,  3.08it/s]

Step 1030, Loss: 5.035196304321289


 87%|████████▋ | 1041/1195 [06:46<00:52,  2.93it/s]

Step 1040, Loss: 5.16644811630249


 88%|████████▊ | 1051/1195 [06:49<00:44,  3.25it/s]

Step 1050, Loss: 4.918061256408691


 89%|████████▉ | 1061/1195 [06:53<00:49,  2.70it/s]

Step 1060, Loss: 4.690139293670654


 90%|████████▉ | 1071/1195 [06:56<00:42,  2.94it/s]

Step 1070, Loss: 5.528160095214844


 90%|█████████ | 1081/1195 [06:59<00:35,  3.25it/s]

Step 1080, Loss: 4.858752727508545


 91%|█████████▏| 1091/1195 [07:02<00:34,  3.00it/s]

Step 1090, Loss: 5.367155075073242


 92%|█████████▏| 1101/1195 [07:06<00:32,  2.86it/s]

Step 1100, Loss: 5.820680141448975


 93%|█████████▎| 1111/1195 [07:09<00:27,  3.07it/s]

Step 1110, Loss: 5.355015754699707


 94%|█████████▍| 1121/1195 [07:12<00:25,  2.95it/s]

Step 1120, Loss: 5.401312351226807


 95%|█████████▍| 1131/1195 [07:15<00:20,  3.20it/s]

Step 1130, Loss: 5.992394924163818


 95%|█████████▌| 1141/1195 [07:18<00:18,  2.96it/s]

Step 1140, Loss: 3.966352701187134


 96%|█████████▋| 1151/1195 [07:21<00:12,  3.46it/s]

Step 1150, Loss: 6.1534271240234375


 97%|█████████▋| 1161/1195 [07:24<00:11,  3.05it/s]

Step 1160, Loss: 6.0506672859191895


 98%|█████████▊| 1171/1195 [07:27<00:07,  3.29it/s]

Step 1170, Loss: 2.818480968475342


 99%|█████████▉| 1181/1195 [07:31<00:04,  3.05it/s]

Step 1180, Loss: 4.829275131225586


100%|█████████▉| 1191/1195 [07:34<00:01,  3.05it/s]

Step 1190, Loss: 5.5672287940979


100%|██████████| 1195/1195 [07:35<00:00,  2.62it/s]


Now save the trained model in local directory.

In [None]:
# Update the model saving section
# Save the LoRA adapter (much smaller than full model)
model.save_pretrained(f"{output_dir}/final")
processor.save_pretrained(f"{output_dir}/final")

# Also save the base model reference for easier loading later
with open(f"{output_dir}/final/base_model_name.txt", "w") as f:
    f.write(model_id)

[]

## 4. Inference

Try the model with some sample inputs to see how it performs.

In [None]:
# Modified load_peft_model_for_inference function
def load_peft_model_for_inference(adapter_path, base_model_id=None):
    if base_model_id is None:
        try:
            with open(f"{adapter_path}/base_model_name.txt", "r") as f:
                base_model_id = f.read().strip()
        except:
            base_model_id = "sesame/csm-1b"  # fallback
    
    # Load base model
    base_model = CsmForConditionalGeneration.from_pretrained(base_model_id, device_map=device)
    
    # Load and merge PEFT model
    model = PeftModel.from_pretrained(base_model, adapter_path)
    # This is the key step - merge weights to remove adapter overhead
    merged_model = model.merge_and_unload()
    return merged_model

In [None]:
from transformers import BitsAndBytesConfig

def load_optimized_peft_model(adapter_path, base_model_id=None):
    # Similar setup as before
    base_model = CsmForConditionalGeneration.from_pretrained(
        base_model_id, 
        device_map=device,
        torch_dtype=torch.float16,  # Use FP16 for faster inference
    )
    
    model = PeftModel.from_pretrained(base_model, adapter_path)
    merged_model = model.merge_and_unload()
    
    # Now convert to half precision after merging
    merged_model = merged_model.half()
    return merged_model

In [None]:
# Set these optimization parameters for generation
def optimize_generation_config(model):
    model.generation_config.max_length = 256  # Or appropriate value for your needs
    model.generation_config.use_cache = True
    model.generation_config.cache_implementation = "static"
    
    # If model has a depth decoder
    if hasattr(model, "depth_decoder"):
        model.depth_decoder.generation_config.cache_implementation = "static"
    
    return model

# Use after loading the model
model = load_peft_model_for_inference(f"{output_dir}/final")
model = optimize_generation_config(model)

### 4.1 Generate a Sentence

In [15]:
from IPython.display import Audio, display
import soundfile as sf

os.makedirs("inference", exist_ok=True)

# another equivalent way to prepare the inputs
conversation = [
    {"role": "0", "content": [{"type": "text", "text": "The past is just a story we tell ourselves. You're making me blush <giggles>. I love you! <giggles>"}]},
]

inputs = processor.apply_chat_template(
    conversation,
    tokenize=True,
    return_dict=True,
).to(device)

# infer the model
audio = model.generate(**inputs, output_audio=True)

# Move tensor from GPU to CPU before converting to numpy
audio_cpu = audio[0].to(torch.float32).cpu().numpy()  # This fixes the TypeError

# Write to file and play
output_file = "inference/output_nc.wav"
sf.write(output_file, audio_cpu, 24000)
display(Audio(output_file))

In [19]:
conversation = [
    {"role": "0", "content": [{"type": "text", "text": "You're making me blush! I'm climaxing! Coming! Ohh! That's so good! More. Ohh. More. Go harder!"}]},
]

inputs = processor.apply_chat_template(
    conversation,
    tokenize=True,
    return_dict=True,
).to(device)

# infer the model
audio = model.generate(**inputs, output_audio=True)

# Move tensor from GPU to CPU before converting to numpy
audio_cpu = audio[0].to(torch.float32).cpu().numpy()  # This fixes the TypeError

# Write to file and play
output_file = "inference/output_2.wav"
sf.write(output_file, audio_cpu, 24000)
display(Audio(output_file))

### 4.2 Sound With Context

Audio influence by giving a context to the model.

In [23]:
conversation = []

# 1. context
# This code takes the first 4 examples from the dataset to provide context
for text, audio in zip(ds["train"][:4]["text"], ds["train"][:4]["audio"]):
    conversation.append(
        {
            "role": f"0",
            "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
        }
    )

# 2. text prompt
conversation.append({"role": "0", "content": [{"type": "text", "text": ds["train"][4]["text"]}]})

inputs = processor.apply_chat_template(
    conversation,
    tokenize=True,
    return_dict=True,
).to(device)

# infer the model
audio = model.generate(**inputs, output_audio=True)

# Move tensor from GPU to CPU before converting to numpy
audio_cpu = audio[0].to(torch.float32).cpu().numpy()  # This fixes the TypeError

# Write to file and play
output_file = "inference/output_wc.wav"
sf.write(output_file, audio_cpu, 24000)
display(Audio(output_file))

### 4.3 Batch Inference

Different audio characteristics:
- Conversation 1: Audio influenced by the voice context from ds[0]["audio"]
- Conversation 2: Audio generated without voice context (using default model voice)

In [25]:
conversation = [
    [
        {
            "role": f"0",
            "content": [
                {"type": "text", "text": ds["train"][0]["text"]},
                {"type": "audio", "path": ds["train"][0]["audio"]["array"]},
            ],
        },
        {
            "role": f"0",
            "content": [
                {"type": "text", "text": ds["train"][1]["text"]},
            ],
        },
    ],
    [
        {
            "role": f"0",
            "content": [
                {"type": "text", "text": ds["train"][0]["text"]},
            ],
        }
    ],
]

inputs = processor.apply_chat_template(
    conversation,
    tokenize=True,
    return_dict=True,
).to(device)

# infer the model
audio = model.generate(**inputs, output_audio=True)

# Process all batch outputs
for i, audio_output in enumerate(audio):
    audio_cpu = audio_output.to(torch.float32).cpu().numpy()
    output_file = f"inference/output_batch_{i}.wav"
    sf.write(output_file, audio_cpu, 24000)
    print(f"Saved: {output_file}")
    display(Audio(output_file))

Saved: inference/output_batch_0.wav


Saved: inference/output_batch_1.wav


In [28]:
# use static cache, enabling automatically torch compile with fullgraph and reduce-overhead
model.generation_config.max_length = 512 # big enough to avoid recompilation
model.generation_config.max_new_tokens = None # would take precedence over max_length
model.generation_config.cache_implementation = "static"
model.depth_decoder.generation_config.cache_implementation = "static"

# generation kwargs
gen_kwargs = {
    "do_sample": False,
    "depth_decoder_do_sample": False,
    "temperature": 1.0,
    "depth_decoder_temperature": 1.0,
}

# Define a timing decorator
class TimerContext:
    def __init__(self, name="Execution"):
        self.name = name
        self.start_event = None
        self.end_event = None

    def __enter__(self):
        # Use CUDA events for more accurate GPU timing
        self.start_event = torch.cuda.Event(enable_timing=True)
        self.end_event = torch.cuda.Event(enable_timing=True)
        self.start_event.record()
        return self

    def __exit__(self, *args):
        self.end_event.record()
        torch.cuda.synchronize()
        elapsed_time = self.start_event.elapsed_time(self.end_event) / 1000.0
        print(f"{self.name} time: {elapsed_time:.4f} seconds")

conversation = [
    {
        "role": f"0",
        "content": [
            {"type": "text", "text": ds["train"][0]["text"]},
            {"type": "audio", "path": ds["train"][0]["audio"]["array"]},
        ],
    },
    {
        "role": f"0",
        "content": [
            {"type": "text", "text": ds["train"][1]["text"]},
            {"type": "audio", "path": ds["train"][1]["audio"]["array"]},
        ],
    },
    {
        "role": f"0",
        "content": [
            {"type": "text", "text": ds["train"][2]["text"]},
        ],
    },
]

padded_inputs_1 = processor.apply_chat_template(
    conversation,
    tokenize=True,
    return_dict=True,
).to(device)

print("\n" + "="*50)
print("First generation - compiling and recording CUDA graphs...")
with TimerContext("First generation"):
    _ = model.generate(**padded_inputs_1, **gen_kwargs)
print("="*50)

print("\n" + "="*50)
print("Second generation - fast !!!")
with TimerContext("Second generation"):
    _ = model.generate(**padded_inputs_1, **gen_kwargs)
print("="*50)

# now with different inputs
conversation = [
    {
        "role": f"{0}",
        "content": [
            {"type": "text", "text": ds["train"][2]["text"]},
            {"type": "audio", "path": ds["train"][2]["audio"]["array"]},
        ],
    },
    {
        "role": f"{0}",
        "content": [
            {"type": "text", "text": ds["train"][3]["text"]},
            {"type": "audio", "path": ds["train"][3]["audio"]["array"]},
        ],
    },
    {
        "role": f"{0}",
        "content": [
            {"type": "text", "text": ds["train"][4]["text"]},
        ],
    },
]
padded_inputs_2 = processor.apply_chat_template(
    conversation,
    tokenize=True,
    return_dict=True,
).to(device)

print("\n" + "="*50)
print("Generation with other inputs!")
with TimerContext("Generation with different inputs"):
    _ = model.generate(**padded_inputs_2, **gen_kwargs)
print("="*50)


First generation - compiling and recording CUDA graphs...
First generation time: 12.0582 seconds

Second generation - fast !!!
Second generation time: 11.9195 seconds

Generation with other inputs!
Generation with different inputs time: 4.5732 seconds


## Save To Hugging Face Hub

Finally, save the model to the Hugging Face Hub for future use or sharing.

In [29]:
from huggingface_hub import HfApi

# Define your model repository name (create this on the HF website first)
model_name = "keanteng/sesame-csm-elise"  # Replace with your desired repo name

# Alternatively, push an already saved model directory
api = HfApi()
api.upload_folder(
    folder_path=f"{output_dir}/final",
    repo_id=model_name,
    repo_type="model"
)

Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.64G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/keanteng/sesame-csm-elise/commit/4b2e536dffe18bd570116040723c40134a1fc66d', commit_message='Upload folder using huggingface_hub', commit_description='', oid='4b2e536dffe18bd570116040723c40134a1fc66d', pr_url=None, repo_url=RepoUrl('https://huggingface.co/keanteng/sesame-csm-elise', endpoint='https://huggingface.co', repo_type='model', repo_id='keanteng/sesame-csm-elise'), pr_revision=None, pr_num=None)

In [30]:
# Add a model card with description
with open(f"{output_dir}/README.md", "w") as f:
    f.write("""
---
license: agpl-3.0
datasets:
- MrDragonFox/Elise
language:
- en
base_model:
- sesame/csm-1b
pipeline_tag: text-to-speech
library_name: transformers
tags:
- generative-ai
---
# CSM Elise Voice Model

This model is a fine-tuned version of [sesame/csm-1b](https://huggingface.co/sesame/csm-1b) using the [Elise dataset](https://huggingface.co/datasets/MrDragonFox/Elise). There are sample outputs files in the repository.

## Model Details
- **Base Model**: sesame/csm-1b
- **Training Data**: MrDragonFox/Elise dataset
- **Fine-tuning Approach**: Voice cloning through conditional speech generation
- **Voice Characteristics**: [Describe voice qualities]
- **Training Parameters**:
  - Learning Rate: 5e-5
  - Epochs: 3
  - Batch Size: 1 with gradient accumulation steps of 4
""")

In [31]:
api.upload_file(
    path_or_fileobj=f"{output_dir}/README.md",
    path_in_repo="README.md",
    repo_id=model_name,
    repo_type="model"
)

CommitInfo(commit_url='https://huggingface.co/keanteng/sesame-csm-elise/commit/62bf83ae63c89cab0b16b27a7ef3ff944d7454f4', commit_message='Upload README.md with huggingface_hub', commit_description='', oid='62bf83ae63c89cab0b16b27a7ef3ff944d7454f4', pr_url=None, repo_url=RepoUrl('https://huggingface.co/keanteng/sesame-csm-elise', endpoint='https://huggingface.co', repo_type='model', repo_id='keanteng/sesame-csm-elise'), pr_revision=None, pr_num=None)

In [36]:
api.upload_folder(
    folder_path=f"inference",
    repo_id=f"{model_name}",
    repo_type="model",
    path_in_repo="inference"
)

output_nc.wav:   0%|          | 0.00/388k [00:00<?, ?B/s]

output_batch_0.wav:   0%|          | 0.00/480k [00:00<?, ?B/s]

output_batch_1.wav:   0%|          | 0.00/480k [00:00<?, ?B/s]

output_wc.wav:   0%|          | 0.00/219k [00:00<?, ?B/s]

output_2.wav:   0%|          | 0.00/480k [00:00<?, ?B/s]

Upload 5 LFS files:   0%|          | 0/5 [00:00<?, ?it/s]

CommitInfo(commit_url='https://huggingface.co/keanteng/sesame-csm-elise/commit/52909e34ac4c8070186015d104dec45935628ed5', commit_message='Upload folder using huggingface_hub', commit_description='', oid='52909e34ac4c8070186015d104dec45935628ed5', pr_url=None, repo_url=RepoUrl('https://huggingface.co/keanteng/sesame-csm-elise', endpoint='https://huggingface.co', repo_type='model', repo_id='keanteng/sesame-csm-elise'), pr_revision=None, pr_num=None)

## 5. Other

In [37]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA A100-SXM4-40GB. Max memory = 39.557 GB.
31.252 GB of memory reserved.
