# MARS6 Turbo Inference Demo

This notebook demonstrates how to run text-to-speech (TTS) inference using **MARS6** in two modes:
- **Shallow clone**: only clones the speaker's voice timbre.
- **Deep clone**: also uses the reference transcript and reference tokens to better match source reference, particularly in prosody.

The notebook is broken into steps:
1. **Import libraries**
2. **Set device**
3. **(Opt.1) Load model from local directory**
4. **(Opt.2) Load model from Torch Hub**
5. **Load embeddings / SNAC / config**
6. **Generate reference embeddings** (one-time, or re-run with a different audio path if wanting to use a different reference)
7. **Define the final inference logic** (assumes globally accessible reference embeddings)
8. **Shallow clone**
9. **Deep clone**

## 1) Import libraries

In [2]:
import tempfile
import time
from dataclasses import dataclass
from typing import Optional, List

import torch
import torchaudio
import librosa

from mars6_turbo.ar_model import Mars6_Turbo, SNACTokenizerInfo, RASConfig
from mars6_turbo.minbpe.regex import RegexTokenizer
from mars6_turbo.utils import RegexSplitter
from snac import SNAC
from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector
from msclap import CLAP

## 2) Set device

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.half if device == 'cuda' else torch.float
print(f"Using device={device}, dtype={dtype}")

## 3) (Opt.1) Load model from local directory
If you have a local checkpoint (e.g. `model/model-2000100.pt`), load it here.

In [None]:
# Example local paths:
ckpt_path = "model/model-2000100.pt"
tokenizer_path = "model/eng-tok-512.model"

# Load text tokenizer
texttok = RegexTokenizer()
texttok.load(tokenizer_path)
print("Local tokenizer loaded.")

# Load checkpoint from disk
ckpt = torch.load(ckpt_path, map_location='cpu')
model_cfg = ckpt['cfg']
old_sd = ckpt['model']
# Remove any 'module.' if present
new_sd = {}
for k,v in old_sd.items():
    new_sd[k.replace('module.', '')] = v

# Build MARS6 model
n_text_vocab = len(texttok.vocab)
n_speech_vocab = SNACTokenizerInfo.codebook_size*3 + SNACTokenizerInfo.n_snac_special

model = Mars6_Turbo(
    n_input_vocab=n_text_vocab,
    n_output_vocab=n_speech_vocab,
    emb_dim=model_cfg.get('dim', 512),
    n_layers=model_cfg.get('n_layers', 8),
    fast_n_layers=model_cfg.get('fast_n_layers', 4),
    n_langs=len(model_cfg.get('languages', ['en-us']))
)
model.load_state_dict(new_sd)
model = model.to(device=device, dtype=dtype)
model.eval()
print("Local MARS6 model loaded.")

## 4) (Opt.2) Load model from Torch Hub

In [None]:
model, texttok = torch.hub.load(
    repo_or_dir="Camb-ai/mars6-turbo",
    model="mars6_turbo",
    ckpt_format='pt',
    device=device,
    dtype=dtype,
    force_reload=False,
)

## 5) Load Embeddings, SNAC Codec, and Default Config
Here we load the speaker embedding models (CLAP + WavLM-based) and the SNAC codec. We also define a default config dictionary.

In [None]:
# Load embedding models
clap_model = CLAP(use_cuda=(device=='cuda'))
wavlm_feat_ext = Wav2Vec2FeatureExtractor.from_pretrained('microsoft/wavlm-base-plus-sv')
spk_emb_model = WavLMForXVector.from_pretrained('microsoft/wavlm-base-plus-sv').to(device).eval()
print("CLAP + WavLM speaker models loaded.")

# SNAC codec
snac_codec = SNAC.from_pretrained('hubertsiuzdak/snac_24khz').eval().to(device, dtype)
print("SNAC codec loaded.")

# Basic config
config = {
    'sr': 24000,
    'ras_K': 10,
    'ras_t_r': 0.09,
    'top_p': 0.2,
    'sil_trim_db': 33,
    'backoff_top_p_increment': 0.2,
    'chars_per_second_upper_bound': 32,
    'min_valid_audio_volume': -52,
    'prefix': '48000',
    # can be 'none', 'per-chunk', or 'fixed-ref'
    'deep_clone_mode': 'none'
}
print("Default config set.")

## 6) Generate Reference Embeddings (One-time)
Here we load and embed the reference audio just once. This yields:
- A **CLAP** embedding: `clap_emb_global`
- A **WavLM** speaker embedding: `spk_emb_global`
- Optionally, **SNAC tokens** if we want to do *deep clone* by referencing the exact audio tokens: `ref_tokens_global`

We can then pass these directly to `make_predictions(...)` in the shallow or deep clone steps.

In [None]:
def tokenize_speech(
    speechtok: SNACTokenizerInfo,
    codes: List[torch.Tensor],
    add_special: bool = True
) -> torch.Tensor:
    """
    Convert each codebook array into offsets, flatten them, then shape to (sl,7).
    (Used if you want to incorporate reference code tokens for deep clone.)
    """
    tokens = []
    # offset each codebook by i*4096 for i in {0,1,2}:
    codes = [(c + (i * speechtok.codebook_size)).tolist() for i, c in enumerate(codes)]
    quant_levels = [0, 1, 2]

    while any(len(c) > 0 for c in codes):
        for i in quant_levels:
            if i == 0:
                tokens.append(codes[0][0])
                codes[0] = codes[0][1:]
            elif i == 1:
                tokens.extend(codes[1][:2])
                codes[1] = codes[1][2:]
            elif i == 2:
                tokens.extend(codes[2][:4])
                codes[2] = codes[2][4:]
    if add_special:
        # Append 7 eos tokens
        tokens += [speechtok.eos_tok] * 7

    tokens = torch.tensor(tokens, dtype=torch.long)
    tokens = tokens.view(-1, 7)  # shape (sl,7)
    return tokens


def snac_encode_reference(
    snac_codec: SNAC,
    wav: torch.Tensor,
    device='cuda',
    dtype=torch.float16
) -> torch.Tensor:
    """
    Encode a waveform using SNAC, then flatten codebooks into (sl,7).
    """
    with torch.no_grad():
        # shape: (1,T) => snac.encode expects B=1
        wav = wav.to(device=device, dtype=dtype)
        codes_list = snac_codec.encode(wav[None])
        codes_list = [t.squeeze(0) for t in codes_list]
        ref_tokens = tokenize_speech(SNACTokenizerInfo(), codes_list, add_special=True)
    return ref_tokens

def compute_ref_data(
    audio_path: str,
    sr_out=24000,
    do_tokens: bool = True,
    device: str = 'cuda',
    dtype: torch.dtype = torch.half
):
    """Load and process reference audio => CLAP embed, WavLM embed, optional SNAC tokens."""
    raw_wav, sr_in = torchaudio.load(audio_path)
    raw_wav = raw_wav.mean(dim=0, keepdim=True)
    if sr_in != sr_out:
        raw_wav = torchaudio.functional.resample(raw_wav, sr_in, sr_out)

    with tempfile.NamedTemporaryFile(suffix='.wav', delete=True) as tmp:
        torchaudio.save(tmp.name, raw_wav, sr_out)
        clap_emb = clap_model.get_audio_embeddings([tmp.name], resample=True)[0].unsqueeze(0).to(device=device, dtype=dtype)

    # WavLM spk embed at 16kHz
    wav_16 = torchaudio.functional.resample(raw_wav, sr_out, 16000)
    inp = wavlm_feat_ext(
        wav_16.squeeze(),
        padding=True,
        return_tensors='pt',
        sampling_rate=16000
    )
    for k in inp:
        inp[k] = inp[k].to(device)
    spk_emb = spk_emb_model(**inp).embeddings
    spk_emb = torch.nn.functional.normalize(spk_emb, dim=-1).to(device=device, dtype=dtype)

    ref_tokens = None
    if do_tokens:
        ref_tokens = snac_encode_reference(snac_codec, raw_wav, device=device, dtype=dtype)
        ref_tokens = ref_tokens[:-2]

    return spk_emb, clap_emb, ref_tokens

print("Function to compute reference data (embedding + tokens) loaded.")

In [None]:
ref_audio_path = 'example.wav'
spk_emb_global, clap_emb_global, ref_tokens_global = compute_ref_data(
    audio_path=ref_audio_path,
    sr_out=24000,
    do_tokens=True,
    device=device,
    dtype=dtype
)
print("Reference embeddings + tokens computed!")

## 7) Define the Inference Logic (Without Reference Embedding)
Below we define an inference function that **assumes** you already have speaker embeddings (`spk_emb_global`), CLAP embeddings (`clap_emb_global`), and optionally reference tokens for deep clone (`ref_tokens_global`).
We do **not** embed the reference audio inside this function; we handled that in Step 6.
This is done to reduce computation when doing multiple TTS calls for the same reference.

In [None]:
@dataclass
class EvalConfig:
    sr: int = 24000
    ras_K: int = 10
    ras_t_r: float = 0.09
    top_p: float = 0.2
    sil_trim_db: float = 33
    backoff_top_p_increment: float = 0.2
    chars_per_second_upper_bound: float = 32
    min_valid_audio_volume: float = -52
    prefix: str = "48000"
    # Options: 'none', 'per-chunk', or 'fixed-ref'
    deep_clone_mode: str = 'per-chunk'

# A simple punctuation mapping for unifying certain characters
punctuation_mapping = {
    "。": ".", "、": ",", "！": "!", "？": "?",
    "…": "...", '–': '—', '―': '—', '−': '—', '-': '—', '।': '.'
}
punctuation_trans_table = str.maketrans(punctuation_mapping)


def normalize_ref_volume(wav: torch.Tensor, sr: int, target_db: Optional[float]) -> tuple[torch.Tensor, Optional[float]]:
    """Normalize waveform loudness to a target dB using torchaudio."""
    if target_db is None:
        return wav, None
    ln = torchaudio.functional.loudness(wav, sr)
    wav = torchaudio.functional.gain(wav, target_db - ln)
    return wav, ln


def detokenize_speech(codes: torch.Tensor) -> List[torch.Tensor]:
    """
    Convert model output tokens (shape (sl,7)) back into separate lists
    of L0, L1, L2 token IDs, removing <eos>.
    """
    eos_inds = codes.max(dim=-1).values == SNACTokenizerInfo.eos_tok
    codes = codes[~eos_inds]
    # revert L1 offsets
    codes[:, 1:] -= SNACTokenizerInfo.codebook_size
    # revert L2 offsets
    codes[:, 3:] -= SNACTokenizerInfo.codebook_size
    l0 = codes[:, 0]
    l1 = codes[:, 1:3].flatten()
    l2 = codes[:, 3:].flatten()
    return [l0, l1, l2]


def codes2duration(codes: List[torch.Tensor]) -> float:
    """
    Approximate audio duration from hierarchical SNAC tokens.
    L2 is at 48 Hz, so each token is 1/48 seconds.
    """
    l2 = codes[2]
    return len(l2) / 48.0


@torch.inference_mode()
def make_predictions(
    model: Mars6_Turbo,
    texttok: RegexTokenizer,
    cfg: dict,
    snac_codec: SNAC,
    text: str,
    ref_text: str,
    device: str = 'cuda',
):
    """
    End-to-end TTS inference that supports:
      - Shallow clone (deep_clone_mode='none')
      - Fixed-ref deep clone (deep_clone_mode='fixed-ref')
      - Per-chunk deep clone (deep_clone_mode='per-chunk')

    Where "deep clone" means we incorporate code tokens from the reference
    (or from the previously generated chunk) as a prefix, to better match prosody.
    """
    # Create the config as an EvalConfig dataclass
    config = EvalConfig(**cfg)
    splitter = RegexSplitter()
    loudness_tfm = torchaudio.transforms.Loudness(config.sr)

    # 1) Prepare for potential deep clone
    deep_mode = config.deep_clone_mode
    user_ref_text = ref_text.strip().translate(punctuation_trans_table)

    # If user gives a short, but non-zero transcript or one that is too long => skip the first chunk of per-chunk cloning
    if (0 < len(user_ref_text) < 4 or len(user_ref_text) > 300):
        print(f"Invalid transcript length={len(user_ref_text)}, disabling deep clone.")
        deep_mode = 'per-chunk'
        user_ref_text = ""

    # We'll store for per-chunk usage
    prior_chunk_text = user_ref_text
    if prior_chunk_text is not "":
        prior_chunk_tokens = ref_tokens_global
    else:
        prior_chunk_tokens = None

    # 2) Prepare text to generate, then chunk it
    text_to_speak = text.translate(punctuation_trans_table)
    chunks = splitter.split(text_to_speak)
    if not chunks:
        chunks = [text_to_speak]  # fallback if no splits
    all_chunk_wavs = []
    chunk_counter = 0

    # 3) Generate chunk by chunk
    for chunk_text in chunks:
        chunk_text = chunk_text.strip()
        if not chunk_text:
            continue

        # Decide prefix tokens based on deep_clone_mode
        if deep_mode == 'fixed-ref':
            prefix_text = user_ref_text
            prefix_tokens = ref_tokens_global
        elif deep_mode == 'per-chunk':
            if chunk_counter == 0 and len(user_ref_text.strip())>0:
                prefix_text = user_ref_text
                prefix_tokens = ref_tokens_global
            else:
                prefix_text = prior_chunk_text
                prefix_tokens = prior_chunk_tokens
        else:
            prefix_text = None
            prefix_tokens = None

        # Build text for the encoder
        if prefix_text:
            # E.g. "[48000] reference_text + chunk_text"
            full_enc_str = f"<|startoftext|>[{config.prefix}]{prefix_text} {chunk_text}<|endoftext|>"
        else:
            full_enc_str = f"<|startoftext|>[{config.prefix}]{chunk_text}<|endoftext|>"

        text_ids = texttok.encode(full_enc_str, allowed_special="all")
        x = torch.tensor(text_ids, dtype=torch.long, device=device).unsqueeze(0)
        xlengths = torch.tensor([x.shape[1]], dtype=torch.long, device=device)

        language = torch.tensor([0], dtype=torch.long, device=device)

        # We'll ensure minimal chunk duration
        chunk_dur_min = len(chunk_text) / config.chars_per_second_upper_bound
        chunk_wav_final = None

        tries = 0
        while True:
            tries += 1
            ras_cfg = RASConfig(
                K=config.ras_K,
                t_r=config.ras_t_r,
                top_p=config.top_p,
                enabled=True,
                cfg_guidance=1.0
            )

            max_len = 30 + int(xlengths.item() * 2.6)

            result_tokens = model.inference(
                x,
                xlengths,
                clap_embs=clap_emb_global,
                spk_embs=spk_emb_global,
                language=language,
                max_len=max_len,
                fp16=(device == 'cuda'),
                ras_cfg=ras_cfg,
                cache=None,
                decoder_prefix=prefix_tokens,  # deep clone prefix
                lower_bound_dur=chunk_dur_min
            )

            # Detokenize => approximate duration
            rcodes = detokenize_speech(result_tokens)
            chunk_dur = codes2duration(rcodes)

            if chunk_dur < chunk_dur_min:
                # If too short, we can do top_p fallback
                new_top_p = round(min(config.top_p + config.backoff_top_p_increment, 1.0), 2)
                print(f"Chunk too short ({chunk_dur:.2f}s < {chunk_dur_min:.2f}s)."
                            f" Increase top_p from {config.top_p} -> {new_top_p}. Retrying.")
                config.top_p = new_top_p
                if tries > 10:
                    print("Max tries reached for chunk. Using best so far.")
                    break
                continue

            # decode audio
            chunk_audio = snac_codec.decode([r[None].to(device) for r in rcodes])
            # shape => (1, batch?), we index the code dimension
            chunk_audio = chunk_audio[:, 0]
            loudness_val = loudness_tfm(chunk_audio.cpu().float().contiguous())
            if loudness_val < config.min_valid_audio_volume:
                # fallback again
                new_top_p = round(min(config.top_p + config.backoff_top_p_increment, 1.0), 2)
                print(f"Chunk silent or quiet (loud={loudness_val:.2f}). "
                            f"Increasing top_p to {new_top_p}.")
                config.top_p = new_top_p
                if tries > 10:
                    print("Max tries reached for chunk. Using best so far.")
                    break
                continue

            # We are done with this chunk
            chunk_wav_final = chunk_audio.cpu().squeeze()
            break

        if chunk_wav_final is None:
            continue  # no valid chunk generated
        all_chunk_wavs.append(chunk_wav_final)

        # For per-chunk mode, store newly generated tokens as prefix
        if deep_mode == 'per-chunk':
            prior_chunk_text = chunk_text
            prior_chunk_tokens = result_tokens

        chunk_counter += 1

    # 6) Concatenate all chunk waveforms
    if not all_chunk_wavs:
        print("No audio was generated.")
        final_wav = torch.zeros(16000)  # fallback
    else:
        final_wav = torch.cat([wav for wav in all_chunk_wavs], dim=-1)

    # trim silence
    final_np, _ = librosa.effects.trim(final_wav.numpy(), top_db=config.sil_trim_db)
    final_wav = torch.from_numpy(final_np)

    return final_wav

print("Inference logic defined (excluding reference embedding).")

## 8) Shallow Clone
We already have `spk_emb_global` and `clap_emb_global` from step 6. For shallow clone, we pass `ref_tokens=None` to the inference function.

In [None]:
# config['deep_clone_mode'] = 'none'  # pure shallow clone

config['deep_clone_mode'] = 'per-chunk'  # also capable of operating without a reference transcript...
# ...typically has better flow over chunk boundaries (where chunk boundaries are at sentence splits)

target_text = "Hello from MARS six in shallow clone mode!"
reference_text = ""  # not used

t0 = time.time()

out_wav = make_predictions(
    model=model,
    texttok=texttok,
    cfg=config,
    snac_codec=snac_codec,
    text=target_text,
    ref_text=reference_text,
    device=device,
)
t_elapsed = time.time() - t0

print(f"Shallow clone done in {t_elapsed}s")
torchaudio.save("shallow_clone_output.wav", out_wav.unsqueeze(0).float(), 24000)
print("Saved to shallow_clone_output.wav")

## 9) Deep Clone
For deep clone, we provide the reference tokens from step 6 to the inference function, and optionally the reference text to combine in the encoder prefix. This can be done in `'fixed-ref'` or `'per-chunk'` mode.

In [None]:
config['deep_clone_mode'] = 'per-chunk'
target_text = "Now you can hear the same voice, with matched prosody and better similarity!"
reference_text = ""  # from reference audio

t0 = time.time()

out_wav = make_predictions(
    model=model,
    texttok=texttok,
    cfg=config,
    snac_codec=snac_codec,
    text=target_text,
    ref_text=reference_text,
    device=device,
)

t_elapsed = time.time() - t0
print(f"Deep clone done in {t_elapsed:.2f}s")
torchaudio.save("deep_clone_output.wav", out_wav.unsqueeze(0).float(), 24000)
print("Saved to deep_clone_output.wav")