<a href="https://colab.research.google.com/github/Matthew-diehl/phoneguytts/blob/main/nb/Sesame_CSM_(1B)-TTS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
<div class="align-center">
<a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
<a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
<a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
</div>

To install Unsloth your local device, follow [our guide](https://docs.unsloth.ai/get-started/install-and-update). This notebook is licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).

You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)


### News


Introducing FP8 precision training for faster RL inference. [Read Blog](https://docs.unsloth.ai/new/fp8-reinforcement-learning).

Unsloth's [Docker image](https://hub.docker.com/r/unsloth/unsloth) is here! Start training with no setup & environment issues. [Read our Guide](https://docs.unsloth.ai/new/how-to-train-llms-with-unsloth-and-docker).

[gpt-oss RL](https://docs.unsloth.ai/new/gpt-oss-reinforcement-learning) is now supported with the fastest inference & lowest VRAM. Try our [new notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/gpt-oss-(20B)-GRPO.ipynb) which creates kernels!

Introducing [Vision](https://docs.unsloth.ai/new/vision-reinforcement-learning-vlm-rl) and [Standby](https://docs.unsloth.ai/basics/memory-efficient-rl) for RL! Train Qwen, Gemma etc. VLMs with GSPO - even faster with less VRAM.

Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).


### Installation

In [1]:
%%capture
import os, re
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    import torch; v = re.match(r"[0-9]{1,}\.[0-9]{1,}", str(torch.__version__)).group(0)
    xformers = "xformers==" + ("0.0.33.post1" if v=="2.9" else "0.0.32.post2" if v=="2.8" else "0.0.29.post3")
    !pip install --no-deps bitsandbytes accelerate {xformers} peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth
!pip install transformers==4.52.3
!pip install --no-deps trl==0.22.2
!pip install torchcodec "datasets>=3.4.1,<4.0.0"

### Unsloth

`FastModel` supports loading nearly any model now! This includes Vision and Text models!

In [2]:
from unsloth import FastModel
from transformers import CsmForConditionalGeneration
import torch

model, processor = FastModel.from_pretrained(
    model_name = "unsloth/csm-1b",
    max_seq_length= 2048, # Choose any for long context!
    dtype = None, # Leave as None for auto-detection
    auto_model = CsmForConditionalGeneration,
    load_in_4bit = False, # Select True for 4bit - reduces memory usage
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2026.1.2: Fast Csm patching. Transformers: 4.52.3.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.33.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.


model.safetensors:   0%|          | 0.00/4.15G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/264 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/449 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

unsloth/csm-1b does not have a padding token! Will use pad_token = <|PAD_TOKEN|>.


We now add LoRA adapters so we only need to update 1 to 10% of all parameters!

In [3]:
model = FastModel.get_peft_model(
    model,
    r = 32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 32,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

Unsloth: Making `model.base_model.model.backbone_model` require gradients


<a name="Data"></a>
### Data Prep  

We will use the `MrDragonFox/Elise`, which is designed for training TTS models. Ensure that your dataset follows the required format: **text, audio** for single-speaker models or **source, text, audio** for multi-speaker models. You can modify this section to accommodate your own dataset, but maintaining the correct structure is essential for optimal training.

In [4]:
#@title Dataset Prep functions
from datasets import load_dataset, Audio, Dataset
import os
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("unsloth/csm-1b")

raw_ds = load_dataset("tubcent/PhoneGuy", split="train")

# Getting the speaker id is important for multi-speaker models and speaker consistency
speaker_key = "source"
if "source" not in raw_ds.column_names and "speaker_id" not in raw_ds.column_names:
    print("Unsloth: No speaker found, adding default \"source\" of 0 for all examples")
    new_column = ["0"] * len(raw_ds)
    raw_ds = raw_ds.add_column("source", new_column)
elif "source" not in raw_ds.column_names and "speaker_id" in raw_ds.column_names:
    speaker_key = "speaker_id"

target_sampling_rate = 24000
raw_ds = raw_ds.cast_column("audio", Audio(sampling_rate=target_sampling_rate))

def preprocess_example(example):
    conversation = [
        {
            "role": str(example[speaker_key]),
            "content": [
                {"type": "text", "text": example["text"]},
                {"type": "audio", "path": example["audio"]["array"]},
            ],
        }
    ]

    try:
        model_inputs = processor.apply_chat_template(
            conversation,
            tokenize=True,
            return_dict=True,
            output_labels=True,
            text_kwargs = {
                "padding": "max_length", # pad to the max_length
                "max_length": 256, # this should be the max length of audio
                "pad_to_multiple_of": 8,
                "padding_side": "right",
            },
            audio_kwargs = {
                "sampling_rate": 24_000,
                "max_length": 240001, # max input_values length of the whole dataset
                "padding": "max_length",
            },
            common_kwargs = {"return_tensors": "pt"},
        )
    except Exception as e:
        print(f"Error processing example with text '{example['text'][:50]}...': {e}")
        return None

    required_keys = ["input_ids", "attention_mask", "labels", "input_values", "input_values_cutoffs"]
    processed_example = {}
    # print(model_inputs.keys())
    for key in required_keys:
        if key not in model_inputs:
            print(f"Warning: Required key '{key}' not found in processor output for example.")
            return None

        value = model_inputs[key][0]
        processed_example[key] = value


    # Final check (optional but good)
    if not all(isinstance(processed_example[key], torch.Tensor) for key in processed_example):
         print(f"Error: Not all required keys are tensors in final processed example. Keys: {list(processed_example.keys())}")
         return None

    return processed_example

processed_ds = raw_ds.map(
    preprocess_example,
    remove_columns=raw_ds.column_names,
    desc="Preprocessing dataset",
)

preprocessor_config.json:   0%|          | 0.00/271 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/246 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/246 [00:00<?, ?files/s]

data/phone_guy_000.wav:   0%|          | 0.00/409k [00:00<?, ?B/s]

data/phone_guy_006.wav:   0%|          | 0.00/586k [00:00<?, ?B/s]

data/phone_guy_007.wav:   0%|          | 0.00/247k [00:00<?, ?B/s]

data/phone_guy_003.wav:   0%|          | 0.00/522k [00:00<?, ?B/s]

data/phone_guy_001.wav:   0%|          | 0.00/790k [00:00<?, ?B/s]

data/phone_guy_002.wav:   0%|          | 0.00/416k [00:00<?, ?B/s]

data/phone_guy_008.wav:   0%|          | 0.00/734k [00:00<?, ?B/s]

data/phone_guy_005.wav:   0%|          | 0.00/607k [00:00<?, ?B/s]

data/phone_guy_012.wav:   0%|          | 0.00/572k [00:00<?, ?B/s]

data/phone_guy_014.wav:   0%|          | 0.00/473k [00:00<?, ?B/s]

data/phone_guy_011.wav:   0%|          | 0.00/353k [00:00<?, ?B/s]

data/phone_guy_009.wav:   0%|          | 0.00/219k [00:00<?, ?B/s]

data/phone_guy_015.wav:   0%|          | 0.00/515k [00:00<?, ?B/s]

data/phone_guy_010.wav:   0%|          | 0.00/91.8k [00:00<?, ?B/s]

data/phone_guy_013.wav:   0%|          | 0.00/487k [00:00<?, ?B/s]

data/phone_guy_004.wav:   0%|          | 0.00/557k [00:00<?, ?B/s]

data/phone_guy_016.wav:   0%|          | 0.00/473k [00:00<?, ?B/s]

data/phone_guy_017.wav:   0%|          | 0.00/501k [00:00<?, ?B/s]

data/phone_guy_021.wav:   0%|          | 0.00/487k [00:00<?, ?B/s]

data/phone_guy_019.wav:   0%|          | 0.00/487k [00:00<?, ?B/s]

data/phone_guy_020.wav:   0%|          | 0.00/325k [00:00<?, ?B/s]

data/phone_guy_023.wav:   0%|          | 0.00/219k [00:00<?, ?B/s]

data/phone_guy_027.wav:   0%|          | 0.00/219k [00:00<?, ?B/s]

data/phone_guy_031.wav:   0%|          | 0.00/445k [00:00<?, ?B/s]

data/phone_guy_026.wav:   0%|          | 0.00/734k [00:00<?, ?B/s]

data/phone_guy_018.wav:   0%|          | 0.00/466k [00:00<?, ?B/s]

data/phone_guy_024.wav:   0%|          | 0.00/409k [00:00<?, ?B/s]

data/phone_guy_022.wav:   0%|          | 0.00/473k [00:00<?, ?B/s]

data/phone_guy_025.wav:   0%|          | 0.00/543k [00:00<?, ?B/s]

data/phone_guy_029.wav:   0%|          | 0.00/593k [00:00<?, ?B/s]

data/phone_guy_028.wav:   0%|          | 0.00/134k [00:00<?, ?B/s]

data/phone_guy_030.wav:   0%|          | 0.00/402k [00:00<?, ?B/s]

data/phone_guy_032.wav:   0%|          | 0.00/529k [00:00<?, ?B/s]

data/phone_guy_033.wav:   0%|          | 0.00/374k [00:00<?, ?B/s]

data/phone_guy_034.wav:   0%|          | 0.00/318k [00:00<?, ?B/s]

data/phone_guy_035.wav:   0%|          | 0.00/212k [00:00<?, ?B/s]

data/phone_guy_038.wav:   0%|          | 0.00/706k [00:00<?, ?B/s]

data/phone_guy_037.wav:   0%|          | 0.00/445k [00:00<?, ?B/s]

data/phone_guy_043.wav:   0%|          | 0.00/240k [00:00<?, ?B/s]

data/phone_guy_039.wav:   0%|          | 0.00/367k [00:00<?, ?B/s]

data/phone_guy_042.wav:   0%|          | 0.00/515k [00:00<?, ?B/s]

data/phone_guy_036.wav:   0%|          | 0.00/487k [00:00<?, ?B/s]

data/phone_guy_044.wav:   0%|          | 0.00/360k [00:00<?, ?B/s]

data/phone_guy_045.wav:   0%|          | 0.00/487k [00:00<?, ?B/s]

data/phone_guy_046.wav:   0%|          | 0.00/699k [00:00<?, ?B/s]

data/phone_guy_041.wav:   0%|          | 0.00/445k [00:00<?, ?B/s]

data/phone_guy_040.wav:   0%|          | 0.00/430k [00:00<?, ?B/s]

data/phone_guy_047.wav:   0%|          | 0.00/388k [00:00<?, ?B/s]

data/phone_guy_048.wav:   0%|          | 0.00/430k [00:00<?, ?B/s]

data/phone_guy_049.wav:   0%|          | 0.00/402k [00:00<?, ?B/s]

data/phone_guy_050.wav:   0%|          | 0.00/444k [00:00<?, ?B/s]

data/phone_guy_051.wav:   0%|          | 0.00/544k [00:00<?, ?B/s]

data/phone_guy_052.wav:   0%|          | 0.00/303k [00:00<?, ?B/s]

data/phone_guy_053.wav:   0%|          | 0.00/444k [00:00<?, ?B/s]

data/phone_guy_054.wav:   0%|          | 0.00/205k [00:00<?, ?B/s]

data/phone_guy_057.wav:   0%|          | 0.00/508k [00:00<?, ?B/s]

data/phone_guy_058.wav:   0%|          | 0.00/508k [00:00<?, ?B/s]

data/phone_guy_055.wav:   0%|          | 0.00/684k [00:00<?, ?B/s]

data/phone_guy_060.wav:   0%|          | 0.00/311k [00:00<?, ?B/s]

data/phone_guy_059.wav:   0%|          | 0.00/727k [00:00<?, ?B/s]

data/phone_guy_061.wav:   0%|          | 0.00/353k [00:00<?, ?B/s]

data/phone_guy_056.wav:   0%|          | 0.00/282k [00:00<?, ?B/s]

data/phone_guy_063.wav:   0%|          | 0.00/353k [00:00<?, ?B/s]

data/phone_guy_062.wav:   0%|          | 0.00/600k [00:00<?, ?B/s]

data/phone_guy_064.wav:   0%|          | 0.00/296k [00:00<?, ?B/s]

data/phone_guy_065.wav:   0%|          | 0.00/339k [00:00<?, ?B/s]

data/phone_guy_066.wav:   0%|          | 0.00/529k [00:00<?, ?B/s]

data/phone_guy_068.wav:   0%|          | 0.00/388k [00:00<?, ?B/s]

data/phone_guy_067.wav:   0%|          | 0.00/670k [00:00<?, ?B/s]

data/phone_guy_069.wav:   0%|          | 0.00/445k [00:00<?, ?B/s]

data/phone_guy_070.wav:   0%|          | 0.00/303k [00:00<?, ?B/s]

data/phone_guy_071.wav:   0%|          | 0.00/318k [00:00<?, ?B/s]

data/phone_guy_074.wav:   0%|          | 0.00/268k [00:00<?, ?B/s]

data/phone_guy_075.wav:   0%|          | 0.00/381k [00:00<?, ?B/s]

data/phone_guy_073.wav:   0%|          | 0.00/466k [00:00<?, ?B/s]

data/phone_guy_072.wav:   0%|          | 0.00/282k [00:00<?, ?B/s]

data/phone_guy_076.wav:   0%|          | 0.00/508k [00:00<?, ?B/s]

data/phone_guy_077.wav:   0%|          | 0.00/268k [00:00<?, ?B/s]

data/phone_guy_078.wav:   0%|          | 0.00/550k [00:00<?, ?B/s]

data/phone_guy_079.wav:   0%|          | 0.00/656k [00:00<?, ?B/s]

data/phone_guy_081.wav:   0%|          | 0.00/148k [00:00<?, ?B/s]

data/phone_guy_080.wav:   0%|          | 0.00/529k [00:00<?, ?B/s]

data/phone_guy_084.wav:   0%|          | 0.00/261k [00:00<?, ?B/s]

data/phone_guy_082.wav:   0%|          | 0.00/423k [00:00<?, ?B/s]

data/phone_guy_083.wav:   0%|          | 0.00/501k [00:00<?, ?B/s]

data/phone_guy_085.wav:   0%|          | 0.00/381k [00:00<?, ?B/s]

data/phone_guy_086.wav:   0%|          | 0.00/367k [00:00<?, ?B/s]

data/phone_guy_088.wav:   0%|          | 0.00/663k [00:00<?, ?B/s]

data/phone_guy_090.wav:   0%|          | 0.00/374k [00:00<?, ?B/s]

data/phone_guy_089.wav:   0%|          | 0.00/882k [00:00<?, ?B/s]

data/phone_guy_091.wav:   0%|          | 0.00/423k [00:00<?, ?B/s]

data/phone_guy_093.wav:   0%|          | 0.00/219k [00:00<?, ?B/s]

data/phone_guy_094.wav:   0%|          | 0.00/536k [00:00<?, ?B/s]

data/phone_guy_092.wav:   0%|          | 0.00/635k [00:00<?, ?B/s]

data/phone_guy_087.wav:   0%|          | 0.00/579k [00:00<?, ?B/s]

data/phone_guy_095.wav:   0%|          | 0.00/275k [00:00<?, ?B/s]

data/phone_guy_097.wav:   0%|          | 0.00/318k [00:00<?, ?B/s]

data/phone_guy_096.wav:   0%|          | 0.00/311k [00:00<?, ?B/s]

data/phone_guy_098.wav:   0%|          | 0.00/635k [00:00<?, ?B/s]

data/phone_guy_099.wav:   0%|          | 0.00/381k [00:00<?, ?B/s]

data/phone_guy_100.wav:   0%|          | 0.00/177k [00:00<?, ?B/s]

data/phone_guy_101.wav:   0%|          | 0.00/593k [00:00<?, ?B/s]

data/phone_guy_102.wav:   0%|          | 0.00/176k [00:00<?, ?B/s]

data/phone_guy_103.wav:   0%|          | 0.00/543k [00:00<?, ?B/s]

data/phone_guy_106.wav:   0%|          | 0.00/381k [00:00<?, ?B/s]

data/phone_guy_104.wav:   0%|          | 0.00/423k [00:00<?, ?B/s]

data/phone_guy_105.wav:   0%|          | 0.00/452k [00:00<?, ?B/s]

data/phone_guy_108.wav:   0%|          | 0.00/402k [00:00<?, ?B/s]

data/phone_guy_107.wav:   0%|          | 0.00/367k [00:00<?, ?B/s]

data/phone_guy_109.wav:   0%|          | 0.00/593k [00:00<?, ?B/s]

data/phone_guy_110.wav:   0%|          | 0.00/296k [00:00<?, ?B/s]

data/phone_guy_111.wav:   0%|          | 0.00/353k [00:00<?, ?B/s]

data/phone_guy_113.wav:   0%|          | 0.00/882k [00:00<?, ?B/s]

data/phone_guy_114.wav:   0%|          | 0.00/388k [00:00<?, ?B/s]

data/phone_guy_112.wav:   0%|          | 0.00/614k [00:00<?, ?B/s]

data/phone_guy_115.wav:   0%|          | 0.00/346k [00:00<?, ?B/s]

data/phone_guy_116.wav:   0%|          | 0.00/332k [00:00<?, ?B/s]

data/phone_guy_117.wav:   0%|          | 0.00/649k [00:00<?, ?B/s]

data/phone_guy_118.wav:   0%|          | 0.00/296k [00:00<?, ?B/s]

data/phone_guy_120.wav:   0%|          | 0.00/176k [00:00<?, ?B/s]

data/phone_guy_119.wav:   0%|          | 0.00/141k [00:00<?, ?B/s]

data/phone_guy_121.wav:   0%|          | 0.00/374k [00:00<?, ?B/s]

data/phone_guy_123.wav:   0%|          | 0.00/233k [00:00<?, ?B/s]

data/phone_guy_125.wav:   0%|          | 0.00/212k [00:00<?, ?B/s]

data/phone_guy_122.wav:   0%|          | 0.00/332k [00:00<?, ?B/s]

data/phone_guy_124.wav:   0%|          | 0.00/600k [00:00<?, ?B/s]

data/phone_guy_126.wav:   0%|          | 0.00/670k [00:00<?, ?B/s]

data/phone_guy_127.wav:   0%|          | 0.00/430k [00:00<?, ?B/s]

data/phone_guy_128.wav:   0%|          | 0.00/430k [00:00<?, ?B/s]

data/phone_guy_130.wav:   0%|          | 0.00/557k [00:00<?, ?B/s]

data/phone_guy_132.wav:   0%|          | 0.00/579k [00:00<?, ?B/s]

data/phone_guy_131.wav:   0%|          | 0.00/614k [00:00<?, ?B/s]

data/phone_guy_129.wav:   0%|          | 0.00/522k [00:00<?, ?B/s]

data/phone_guy_133.wav:   0%|          | 0.00/247k [00:00<?, ?B/s]

data/phone_guy_136.wav:   0%|          | 0.00/565k [00:00<?, ?B/s]

data/phone_guy_134.wav:   0%|          | 0.00/734k [00:00<?, ?B/s]

data/phone_guy_135.wav:   0%|          | 0.00/318k [00:00<?, ?B/s]

data/phone_guy_138.wav:   0%|          | 0.00/565k [00:00<?, ?B/s]

data/phone_guy_137.wav:   0%|          | 0.00/282k [00:00<?, ?B/s]

data/phone_guy_140.wav:   0%|          | 0.00/557k [00:00<?, ?B/s]

data/phone_guy_141.wav:   0%|          | 0.00/713k [00:00<?, ?B/s]

data/phone_guy_139.wav:   0%|          | 0.00/184k [00:00<?, ?B/s]

data/phone_guy_142.wav:   0%|          | 0.00/501k [00:00<?, ?B/s]

data/phone_guy_143.wav:   0%|          | 0.00/557k [00:00<?, ?B/s]

data/phone_guy_144.wav:   0%|          | 0.00/508k [00:00<?, ?B/s]

data/phone_guy_145.wav:   0%|          | 0.00/487k [00:00<?, ?B/s]

data/phone_guy_146.wav:   0%|          | 0.00/367k [00:00<?, ?B/s]

data/phone_guy_147.wav:   0%|          | 0.00/318k [00:00<?, ?B/s]

data/phone_guy_149.wav:   0%|          | 0.00/402k [00:00<?, ?B/s]

data/phone_guy_150.wav:   0%|          | 0.00/381k [00:00<?, ?B/s]

data/phone_guy_148.wav:   0%|          | 0.00/219k [00:00<?, ?B/s]

data/phone_guy_151.wav:   0%|          | 0.00/727k [00:00<?, ?B/s]

data/phone_guy_152.wav:   0%|          | 0.00/226k [00:00<?, ?B/s]

data/phone_guy_154.wav:   0%|          | 0.00/586k [00:00<?, ?B/s]

data/phone_guy_153.wav:   0%|          | 0.00/134k [00:00<?, ?B/s]

data/phone_guy_156.wav:   0%|          | 0.00/445k [00:00<?, ?B/s]

data/phone_guy_157.wav:   0%|          | 0.00/529k [00:00<?, ?B/s]

data/phone_guy_159.wav:   0%|          | 0.00/318k [00:00<?, ?B/s]

data/phone_guy_155.wav:   0%|          | 0.00/402k [00:00<?, ?B/s]

data/phone_guy_158.wav:   0%|          | 0.00/374k [00:00<?, ?B/s]

data/phone_guy_160.wav:   0%|          | 0.00/205k [00:00<?, ?B/s]

data/phone_guy_161.wav:   0%|          | 0.00/494k [00:00<?, ?B/s]

data/phone_guy_163.wav:   0%|          | 0.00/706k [00:00<?, ?B/s]

data/phone_guy_164.wav:   0%|          | 0.00/367k [00:00<?, ?B/s]

data/phone_guy_166.wav:   0%|          | 0.00/445k [00:00<?, ?B/s]

data/phone_guy_162.wav:   0%|          | 0.00/275k [00:00<?, ?B/s]

data/phone_guy_165.wav:   0%|          | 0.00/430k [00:00<?, ?B/s]

data/phone_guy_167.wav:   0%|          | 0.00/515k [00:00<?, ?B/s]

data/phone_guy_168.wav:   0%|          | 0.00/247k [00:00<?, ?B/s]

data/phone_guy_170.wav:   0%|          | 0.00/494k [00:00<?, ?B/s]

data/phone_guy_172.wav:   0%|          | 0.00/388k [00:00<?, ?B/s]

data/phone_guy_171.wav:   0%|          | 0.00/692k [00:00<?, ?B/s]

data/phone_guy_175.wav:   0%|          | 0.00/459k [00:00<?, ?B/s]

data/phone_guy_174.wav:   0%|          | 0.00/395k [00:00<?, ?B/s]

data/phone_guy_173.wav:   0%|          | 0.00/438k [00:00<?, ?B/s]

data/phone_guy_169.wav:   0%|          | 0.00/353k [00:00<?, ?B/s]

data/phone_guy_176.wav:   0%|          | 0.00/543k [00:00<?, ?B/s]

data/phone_guy_177.wav:   0%|          | 0.00/311k [00:00<?, ?B/s]

data/phone_guy_178.wav:   0%|          | 0.00/642k [00:00<?, ?B/s]

data/phone_guy_181.wav:   0%|          | 0.00/508k [00:00<?, ?B/s]

data/phone_guy_180.wav:   0%|          | 0.00/275k [00:00<?, ?B/s]

data/phone_guy_183.wav:   0%|          | 0.00/346k [00:00<?, ?B/s]

data/phone_guy_182.wav:   0%|          | 0.00/501k [00:00<?, ?B/s]

data/phone_guy_184.wav:   0%|          | 0.00/381k [00:00<?, ?B/s]

data/phone_guy_179.wav:   0%|          | 0.00/692k [00:00<?, ?B/s]

data/phone_guy_185.wav:   0%|          | 0.00/318k [00:00<?, ?B/s]

data/phone_guy_186.wav:   0%|          | 0.00/346k [00:00<?, ?B/s]

data/phone_guy_187.wav:   0%|          | 0.00/593k [00:00<?, ?B/s]

data/phone_guy_188.wav:   0%|          | 0.00/409k [00:00<?, ?B/s]

data/phone_guy_189.wav:   0%|          | 0.00/240k [00:00<?, ?B/s]

data/phone_guy_190.wav:   0%|          | 0.00/289k [00:00<?, ?B/s]

data/phone_guy_191.wav:   0%|          | 0.00/515k [00:00<?, ?B/s]

data/phone_guy_193.wav:   0%|          | 0.00/268k [00:00<?, ?B/s]

data/phone_guy_192.wav:   0%|          | 0.00/409k [00:00<?, ?B/s]

data/phone_guy_194.wav:   0%|          | 0.00/388k [00:00<?, ?B/s]

data/phone_guy_196.wav:   0%|          | 0.00/303k [00:00<?, ?B/s]

data/phone_guy_197.wav:   0%|          | 0.00/311k [00:00<?, ?B/s]

data/phone_guy_199.wav:   0%|          | 0.00/466k [00:00<?, ?B/s]

data/phone_guy_200.wav:   0%|          | 0.00/282k [00:00<?, ?B/s]

data/phone_guy_198.wav:   0%|          | 0.00/176k [00:00<?, ?B/s]

data/phone_guy_202.wav:   0%|          | 0.00/508k [00:00<?, ?B/s]

data/phone_guy_203.wav:   0%|          | 0.00/275k [00:00<?, ?B/s]

data/phone_guy_205.wav:   0%|          | 0.00/423k [00:00<?, ?B/s]

data/phone_guy_204.wav:   0%|          | 0.00/550k [00:00<?, ?B/s]

data/phone_guy_207.wav:   0%|          | 0.00/148k [00:00<?, ?B/s]

data/phone_guy_201.wav:   0%|          | 0.00/367k [00:00<?, ?B/s]

data/phone_guy_208.wav:   0%|          | 0.00/423k [00:00<?, ?B/s]

data/phone_guy_195.wav:   0%|          | 0.00/445k [00:00<?, ?B/s]

data/phone_guy_209.wav:   0%|          | 0.00/684k [00:00<?, ?B/s]

data/phone_guy_206.wav:   0%|          | 0.00/586k [00:00<?, ?B/s]

data/phone_guy_210.wav:   0%|          | 0.00/148k [00:00<?, ?B/s]

data/phone_guy_211.wav:   0%|          | 0.00/374k [00:00<?, ?B/s]

data/phone_guy_212.wav:   0%|          | 0.00/367k [00:00<?, ?B/s]

data/phone_guy_214.wav:   0%|          | 0.00/663k [00:00<?, ?B/s]

data/phone_guy_213.wav:   0%|          | 0.00/579k [00:00<?, ?B/s]

data/phone_guy_216.wav:   0%|          | 0.00/635k [00:00<?, ?B/s]

data/phone_guy_215.wav:   0%|          | 0.00/684k [00:00<?, ?B/s]

data/phone_guy_217.wav:   0%|          | 0.00/416k [00:00<?, ?B/s]

data/phone_guy_218.wav:   0%|          | 0.00/635k [00:00<?, ?B/s]

data/phone_guy_221.wav:   0%|          | 0.00/275k [00:00<?, ?B/s]

data/phone_guy_223.wav:   0%|          | 0.00/332k [00:00<?, ?B/s]

data/phone_guy_225.wav:   0%|          | 0.00/367k [00:00<?, ?B/s]

data/phone_guy_224.wav:   0%|          | 0.00/635k [00:00<?, ?B/s]

data/phone_guy_222.wav:   0%|          | 0.00/303k [00:00<?, ?B/s]

data/phone_guy_226.wav:   0%|          | 0.00/176k [00:00<?, ?B/s]

data/phone_guy_219.wav:   0%|          | 0.00/219k [00:00<?, ?B/s]

data/phone_guy_220.wav:   0%|          | 0.00/536k [00:00<?, ?B/s]

data/phone_guy_227.wav:   0%|          | 0.00/586k [00:00<?, ?B/s]

data/phone_guy_228.wav:   0%|          | 0.00/176k [00:00<?, ?B/s]

data/phone_guy_230.wav:   0%|          | 0.00/416k [00:00<?, ?B/s]

data/phone_guy_229.wav:   0%|          | 0.00/543k [00:00<?, ?B/s]

data/phone_guy_232.wav:   0%|          | 0.00/289k [00:00<?, ?B/s]

data/phone_guy_233.wav:   0%|          | 0.00/318k [00:00<?, ?B/s]

data/phone_guy_234.wav:   0%|          | 0.00/367k [00:00<?, ?B/s]

data/phone_guy_231.wav:   0%|          | 0.00/459k [00:00<?, ?B/s]

data/phone_guy_236.wav:   0%|          | 0.00/296k [00:00<?, ?B/s]

data/phone_guy_235.wav:   0%|          | 0.00/593k [00:00<?, ?B/s]

data/phone_guy_237.wav:   0%|          | 0.00/360k [00:00<?, ?B/s]

data/phone_guy_238.wav:   0%|          | 0.00/607k [00:00<?, ?B/s]

data/phone_guy_239.wav:   0%|          | 0.00/346k [00:00<?, ?B/s]

data/phone_guy_241.wav:   0%|          | 0.00/388k [00:00<?, ?B/s]

data/phone_guy_240.wav:   0%|          | 0.00/543k [00:00<?, ?B/s]

metadata.csv: 0.00B [00:00, ?B/s]

data/phone_guy_242.wav:   0%|          | 0.00/353k [00:00<?, ?B/s]

data/phone_guy_243.wav:   0%|          | 0.00/332k [00:00<?, ?B/s]

data/phone_guy_244.wav:   0%|          | 0.00/56.8k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/245 [00:00<?, ? examples/s]

Unsloth: No speaker found, adding default "source" of 0 for all examples


Preprocessing dataset:   0%|          | 0/245 [00:00<?, ? examples/s]

<a name="Train"></a>
### Train the model
Now let's use Huggingface  `Trainer`! More docs here: [Transformers docs](https://huggingface.co/docs/transformers/main_classes/trainer). We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`.

In [5]:
from transformers import TrainingArguments, Trainer
from unsloth import is_bfloat16_supported

trainer = Trainer(
    model = model,
    train_dataset = processed_ds,
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        max_steps = 60,
        learning_rate = 2e-4,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.001, # Turn this on if overfitting
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none", # Use TrackIO/WandB etc
    ),
)

In [None]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = Tesla T4. Max memory = 14.741 GB.
6.719 GB of memory reserved.


In [6]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 245 | Num Epochs = 2 | Total steps = 60
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 29,032,448 of 1,661,132,609 (1.75% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
1,6.9583
2,6.9011
3,6.7593
4,7.0772
5,7.1258
6,6.5823
7,7.1208
8,6.6467
9,6.4742
10,6.3182


In [None]:
# @title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(
    f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

313.8963 seconds used for training.
5.23 minutes used for training.
Peak reserved memory = 6.719 GB.
Peak reserved memory for training = 0.0 GB.
Peak reserved memory % of max memory = 45.58 %.
Peak reserved memory for training % of max memory = 0.0 %.


<a name="Inference"></a>
### Inference
Let's run the model! You can change the prompts

In [7]:
from IPython.display import Audio, display
import soundfile as sf

text = "We just finished fine tuning a text to speech model... and it's pretty good!"
speaker_id = 0
inputs = processor(f"[{speaker_id}]{text}", add_special_tokens=True).to("cuda")
audio_values = model.generate(
    **inputs,
    max_new_tokens=125, # 125 tokens is 10 seconds of audio, for longer speech increase this
    # play with these parameters to tweak results
    # depth_decoder_top_k=0,
    # depth_decoder_top_p=0.9,
    # depth_decoder_do_sample=True,
    # depth_decoder_temperature=0.9,
    # top_k=0,
    # top_p=1.0,
    # temperature=0.9,
    # do_sample=True,
    #########################################################
    output_audio=True
)
audio = audio_values[0].to(torch.float32).cpu().numpy()
sf.write("example_without_context.wav", audio, 24000)
display(Audio(audio, rate=24000))

In [None]:
text = "Sesame is a super cool TTS model which can be fine tuned with Unsloth."

speaker_id = 0
# Another equivalent way to prepare the inputs
conversation = [
    {"role": str(speaker_id), "content": [{"type": "text", "text": text}]},
]
audio_values = model.generate(
    **processor.apply_chat_template(
        conversation,
        tokenize=True,
        return_dict=True,
    ).to("cuda"),
    max_new_tokens=125, # 125 tokens is 10 seconds of audio, for longer speech increase this
    # play with these parameters to tweak results
    # depth_decoder_top_k=0,
    # depth_decoder_top_p=0.9,
    # depth_decoder_do_sample=True,
    # depth_decoder_temperature=0.9,
    # top_k=0,
    # top_p=1.0,
    # temperature=0.9,
    # do_sample=True,
    #########################################################
    output_audio=True
)
audio = audio_values[0].to(torch.float32).cpu().numpy()
sf.write("example_without_context.wav", audio, 24000)
display(Audio(audio, rate=24000))

#### Voice and style consistency

Sesame CSM's power comes from providing audio context for each speaker. Let's pass a sample utterance from our dataset to ground speaker identity and style.

In [8]:
speaker_id = 0

utterance = raw_ds[3]["audio"]["array"]
utterance_text = raw_ds[3]["text"]
text = "Sesame is a super cool TTS model which can be fine tuned with Unsloth."

# CSM will fill in the audio for the last text.
# You can even provide a conversation history back in as you generate new audio

conversation = [
    {"role": str(speaker_id), "content": [{"type": "text", "text": utterance_text},{"type": "audio", "path": utterance}]},
    {"role": str(speaker_id), "content": [{"type": "text", "text": text}]},
]

inputs = processor.apply_chat_template(
        conversation,
        tokenize=True,
        return_dict=True,
    )
audio_values = model.generate(
    **inputs.to("cuda"),
    max_new_tokens=125, # 125 tokens is 10 seconds of audio, for longer text increase this
    # play with these parameters to tweak results
    # depth_decoder_top_k=0,
    # depth_decoder_top_p=0.9,
    # depth_decoder_do_sample=True,
    # depth_decoder_temperature=0.9,
    # top_k=0,
    # top_p=1.0,
    # temperature=0.9,
    # do_sample=True,
    #########################################################
    output_audio=True
)
audio = audio_values[0].to(torch.float32).cpu().numpy()
sf.write("example_with_context.wav", audio, 24000)
display(Audio(audio, rate=24000))

<a name="Save"></a>
### Saving, loading finetuned models
To save the final model as LoRA adapters, either use Huggingface's `push_to_hub` for an online save or `save_pretrained` for a local save.

**[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!

In [9]:
model.save_pretrained("lora_model")  # Local saving
processor.save_pretrained("lora_model")
# model.push_to_hub("your_name/lora_model", token = "...") # Online saving
# processor.push_to_hub("your_name/lora_model", token = "...") # Online saving

[]

### Saving to float16

We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens.

In [10]:
# Merge to 16bit
if False: model.save_pretrained_merged("model", processor, save_method = "merged_16bit",)
if False: model.push_to_hub_merged("hf/model", processor, save_method = "merged_16bit", token = "")

# Merge to 4bit
if False: model.save_pretrained_merged("model", processor, save_method = "merged_4bit",)
if False: model.push_to_hub_merged("hf/model", processor, save_method = "merged_4bit", token = "")

# Just LoRA adapters
if False:
    model.save_pretrained("model")
    processor.save_pretrained("model")
if False:
    model.push_to_hub("hf/model", token = "")
    processor.push_to_hub("hf/model", token = "")


And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!

Some other links:
1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)
2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
6. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://docs.unsloth.ai/get-started/unsloth-notebooks)!

<div class="align-center">
  <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
  <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
  <a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>

  Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️

  This notebook and all Unsloth notebooks are licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme)
</div>
