In [None]:
import tqdm.notebook as tqdm
from unsloth import FastModel
from transformers import CsmForConditionalGeneration
import torch
import datasets
from datasets import load_dataset, Audio, Dataset
from IPython.display import Audio, display
import soundfile as sf

In [None]:
model, processor = FastModel.from_pretrained(
    model_name = "unsloth/csm-1b",
    max_seq_length= 2048, # Choose any for long context!
    dtype = None, # Leave as None for auto-detection
    auto_model = CsmForConditionalGeneration,
    load_in_4bit = False, # Select True for 4bit - reduces memory usage
)

In [None]:
import mlflow
from getpass import getpass
import os
MLFLOW_TRACKING_USERNAME = getpass('Enter the MLFLOW_TRACKING_USERNAME: ')
os.environ['MLFLOW_TRACKING_USERNAME'] = MLFLOW_TRACKING_USERNAME
MLFLOW_TRACKING_PASSWORD = getpass('Enter the MLFLOW_TRACKING_PASSWORD: ')
os.environ['MLFLOW_TRACKING_PASSWORD'] = MLFLOW_TRACKING_PASSWORD
os.environ["MLFLOW_TRACKING_URI"] = "https://mlflow-sunbird-ce0ecfc14244.herokuapp.com" 
os.environ["MLFLOW_EXPERIMENT_NAME"] = "tts-csm-1b"

In [None]:
model = FastModel.get_peft_model(
    model,
    r = 128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 128,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

In [None]:
ds_lug = load_dataset(
    "Sunbird/salt", "studio-lug", split="train").map(lambda example: {"speaker_id": 1})

ds_eng = load_dataset(
    "Sunbird/salt", "studio-eng", split="train").map(lambda example: {"speaker_id": 1})

ds_ach = load_dataset(
    "Sunbird/salt", "studio-ach", split="train").map(lambda example: {"speaker_id": 2})

ds_swa = load_dataset(
    "Sunbird/salt", "studio-swa", split="train").map(lambda example: {"speaker_id": 3})

ds_lgg = load_dataset(
    "Sunbird/salt", "studio-lgg", split="train").map(lambda example: {"speaker_id": 4})

ds_nyn = load_dataset(
    "Sunbird/salt", "studio-nyn", split="train").map(lambda example: {"speaker_id": 5})

ds_teo = load_dataset(
    "Sunbird/salt", "studio-teo", split="train").map(lambda example: {"speaker_id": 6})

raw_ds = datasets.concatenate_datasets(
    [ds_ach,ds_lug, ds_eng, ds_swa, ds_lgg, ds_nyn, ds_teo]).shuffle(seed=42)

sampling_rate = 24000
raw_ds = raw_ds.cast_column("audio", Audio(sampling_rate=sampling_rate))

raw_ds = raw_ds.filter(
    lambda example: (0.5 * sampling_rate) < len(example["audio"]["array"]) < (8 * sampling_rate),
    num_proc=20,
)

In [None]:
len(raw_ds)

In [None]:
if False: 
    # For new datasets, check what the longest audio/text is.
    audio_lengths = []
    text_lengths = []
    for example in tqdm.tqdm(raw_ds):
        audio_lengths.append(len(example['audio']['array']))
        text_lengths.append(len(example['text']))

In [None]:
import os
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("unsloth/csm-1b")

def preprocess_example(example):
    conversation = [
        {
            "role": str(example["speaker_id"]),
            "content": [
                {"type": "text", "text": example["text"]},
                {"type": "audio", "path": example["audio"]["array"]},
            ],
        }
    ]

    try:
        model_inputs = processor.apply_chat_template(
            conversation,
            tokenize=True,
            return_dict=True,
            output_labels=True,
            text_kwargs = {
                "padding": "max_length", # pad to the max_length
                "max_length": 256, # this should be the max length of audio
                "pad_to_multiple_of": 8,
                "padding_side": "right",
            },
            audio_kwargs = {
                "sampling_rate": 24_000,
                "max_length": 8 * 24_000, # max input_values length of the whole dataset
                "padding": "max_length",
            },
            common_kwargs = {"return_tensors": "pt"},
        )
    except Exception as e:
        print(f"Error processing example with text '{example['text'][:50]}...': {e}")
        return None

    required_keys = ["input_ids", "attention_mask", "labels", "input_values", "input_values_cutoffs"]
    processed_example = {}
    for key in required_keys:
        if key not in model_inputs:
            print(f"Warning: Required key '{key}' not found in processor output for example.")
            return None

        value = model_inputs[key][0]
        processed_example[key] = value

    # Final check (optional but good)
    if not all(isinstance(processed_example[key], torch.Tensor) for key in processed_example):
         print(f"Error: Not all required keys are tensors in final processed example. Keys: {list(processed_example.keys())}")
         return None

    return processed_example

processed_ds = raw_ds.take(12_000).map(
    preprocess_example,
    remove_columns=raw_ds.column_names,
    desc="Preprocessing dataset",
    num_proc=20,
)

In [None]:
N_eval_samples = 128
train_dataset = processed_ds.skip(N_eval_samples)
eval_dataset = processed_ds.take(N_eval_samples)

In [None]:
from transformers import TrainingArguments, Trainer
from unsloth import is_bfloat16_supported

trainer = Trainer(
    model = model,
    train_dataset = train_dataset,
    eval_dataset = eval_dataset,
    args = TrainingArguments(
        per_device_train_batch_size = 8,
        per_device_eval_batch_size = 8,
        dataloader_num_workers = 8,
        gradient_accumulation_steps = 1,
        warmup_steps = 5,
        learning_rate = 2e-4,
        num_train_epochs=1,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 100,
        eval_steps = 100,
        eval_strategy="steps",
        optim = "adamw_8bit",
        weight_decay = 0.01, # Turn this on if overfitting
        lr_scheduler_type = "linear",
        seed = 42,
        output_dir = "csm-1b-lora-bs8",
        report_to = "mlflow", # Use this for WandB etc
    ),
)

In [None]:
trainer_stats = trainer.train()

In [None]:
text = "Nsobola okwogera Oluganda n'ennimi endala." 
speaker_id = 1

inputs = processor(f"[{speaker_id}]{text}", add_special_tokens=True).to("cuda")
audio_values = model.generate(
    **inputs,
    max_new_tokens=125, # 125 tokens is 10 seconds of audio, for longer speech increase this
    depth_decoder_temperature=0.6,
    depth_decoder_top_k=0,
    depth_decoder_top_p=0.9,
    temperature=0.8,
    top_k=50,
    top_p=1.0,
    output_audio=True
)

audio = audio_values[0].to(torch.float32).cpu().numpy()
#sf.write("example.wav", audio, 24000)
display(Audio(audio, rate=24000))

In [None]:
model.push_to_hub('csm-1b-salt')

In [None]:
processor.push_to_hub('csm-1b-salt')