In [None]:
%%capture
%pip install --upgrade git+https://github.com/senstella/csm-mlx
%pip install --upgrade mlx-lm huggingface_hub torch torchaudio scipy

In [None]:
from mlx import nn
from mlx_lm.sample_utils import make_sampler
from huggingface_hub import hf_hub_download
from csm_mlx import CSM, csm_1b, generate
from scipy.io import wavfile
import numpy as np

In [None]:
# Initialize the model
csm = CSM(csm_1b())  # csm_1b() is a configuration for the CSM model.
weight = hf_hub_download(repo_id="senstella/csm-1b-mlx", filename="ckpt.safetensors")
csm.load_weights(weight)

# Quantize
# nn.quantize(csm, bits=4)

# Generate audio from text
audio = generate(
    csm,
    text="In the heart of a bustling city stood an ancient stone wall that most people hurried past without a second glance. Behind this wall, hidden from view, lay a forgotten garden that had once been the pride of the neighborhood. One rainy Tuesday morning, Emma, a young librarian with a passion for history, noticed a small wooden door in the wall that she'd somehow never seen before. Curiosity getting the better of her, she pushed against it gently, surprised when it creaked open.",
    speaker=0,
    context=[],
    max_audio_length_ms=60_000,
    sampler=make_sampler(temp=0.8, top_k=50), # Put mlx_lm's sampler here! Supports: temp, top_p, min_p, min_tokens_to_keep, top_k.
    # Additionally, you can provide `stream` argument to specify what device to use for generation.
    # https://ml-explore.github.io/mlx/build/html/usage/using_streams.html
)

# Convert to int16 format (standard for WAV files)
# Scale to int16 range and convert
audio_int16 = (np.asarray(audio) * 32767).astype(np.int16)

# Write using scipy.io.wavfile
wavfile.write("./output/audio.wav", 24000, audio_int16)