# 1. Load the Zonos model

In [None]:
from AudioInfo import Ses, SesFromArray
import torch
from zonos.model import Zonos
from zonos.conditioning import make_cond_dict
from zonos.codebook_pattern import apply_delay_pattern, revert_delay_pattern
from huggingface_hub import hf_hub_download
from zonos.sampling import sample_from_logits
from tqdm import tqdm

model_path="Zyphra/Zonos-v0.1-transformer"
device = 'cuda'

model = Zonos.from_pretrained(model_path, device=device)
cfg_scale = 2
sampling_params = dict(min_p=0.1)
cg = model.can_use_cudagraphs()

Initing zonos for ft


# 2. Funcs to grab & preprocess input audios

In [None]:
from silero_vad import load_silero_vad, get_speech_timestamps
silero_model = load_silero_vad()

def audio_to_prefix_code(arr, sr):
    wav_prefix, sr_prefix = torch.tensor(arr, dtype=torch.float32).unsqueeze(0), sr
    wav_prefix = wav_prefix.mean(0, keepdim=True)
    wav_prefix = model.autoencoder.preprocess(wav_prefix, sr_prefix)
    wav_prefix = wav_prefix.to(device, dtype=torch.float32)
    return model.autoencoder.encode(wav_prefix.unsqueeze(0))

def ses_to_prefix_code(ses: Ses):
    '''Generate codes from ses (audio).'''
    if sum(torch.tensor(ses.arr)) != 0:
            return audio_to_prefix_code(ses.arr, ses.sr)
    else:
        print("passed empty prefix")
        return torch.full((1,9,0), 0).to(model.device)

def get_stamps(ses:Ses):
    '''Get the segments with human voice'''
    if ses.sr != 16000:
        ses = ses.resampled(16000)
    speech_timestamps = get_speech_timestamps(
        torch.tensor(ses.arr, dtype=torch.float32),
        silero_model,
        return_seconds=True,  # Return speech timestamps in seconds (default is samples); change this with samples and convert to seconds, it is currently buggy.
    )
    if len(speech_timestamps) < 1: return False
    return speech_timestamps[0]['start'], speech_timestamps[-1]['end']

# 3. Prefill with the input audio codes

In [None]:
def prepare_prefix(codes: torch.Tensor, prefix_conditioning=None, trim:None|int=None):
    '''Generate model-accurate prefices to train on from codes and conditioning'''
    assert isinstance(prefix_conditioning, torch.Tensor)
    if isinstance(codes, Ses):
        codes = ses_to_prefix_code(codes) # if a `Ses` instance, process
    assert isinstance(codes, torch.Tensor)
    # Encode input audio
    prefix_codes = codes
    p_len = prefix_codes.size(-1)
    # Hyperparams
    max_length = 86 * 30
    seq_len = p_len + max_length + 9
    # Inference params
    with torch.device(model.device):
        batch_size__ = 1
        unknown_token = -1
        inference_params = model.setup_cache(batch_size=batch_size__ * 2, max_seqlen=seq_len)
    # Inference mode (no gradients are needed)
    with torch.no_grad():
        # Prepare prefix codes
        p_padded = torch.nn.functional.pad(prefix_codes, (0, max_length + p_len), value=unknown_token)
        # Re-predict missing token
        p_delayed = apply_delay_pattern(p_padded, mask_token=model.masked_token_id)
        pred_idx = p_len if trim is None else trim
        logits = model._prefill(prefix_conditioning,
                                p_delayed[...,:pred_idx + 1],
                                inference_params, 2)
        next_token = sample_from_logits(logits, **sampling_params)

        frame = p_delayed[..., pred_idx + 1:pred_idx + 2]
        frame.masked_scatter_(frame == unknown_token, next_token)
    # Offset and logit
    offset = p_delayed[...,:pred_idx + 1].size(-1)
    logit_bias = torch.zeros_like(logits)
    logit_bias[:, 1:, model.eos_token_id] = -torch.inf
    # Inference params
    prefices_length = prefix_conditioning.shape[1] + pred_idx + 1
    inference_params.seqlen_offset += prefices_length
    inference_params.lengths_per_sample[:] += prefices_length
    to_compare = p_delayed[...,offset+1:offset+2]
    return p_delayed, inference_params, offset, logit_bias, to_compare

# 4. Inference

- Prepare prefix for inference

In [None]:
new_condt = make_cond_dict(
        text= "Greetings, lad. What would your name be?",
        language='en-us',
    )

# To do audio completion
ses_ = Ses('sample.wav')
prefix_codes = ses_to_prefix_code(ses_)
prefix_conditioning = model.prepare_conditioning(new_condt)

# Empty input to generate audios from no prefices
# If you don't have an input audio to start with and you want to generate from scratch, just use this empty tensor.
empty_ses = SesFromArray(torch.tensor(()).to(torch.float64).numpy(), 6)
with torch.no_grad():
    delayed_codes, inference_params, offset, logit_bias, to_compare = prepare_prefix(empty_ses, prefix_conditioning)

passed empty prefix


- Generation loop / A second = 86 tokens

In [None]:
SECONDS_TO_GENERATE = 3

with torch.no_grad():
    for _ in tqdm(range(86*SECONDS_TO_GENERATE)):
        # Increase offset / Offseti artır
        offset += 1

        # Calculate next logit / Sonraki logiti hesapla
        input_ids = delayed_codes[..., offset - 1 : offset] # tensor([ 698, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025], device='cuda:0') [9, 1]; 698 next_tokenin 0. elemanıydı.
        logits = model._decode_one_token(input_ids, inference_params, cfg_scale, allow_cudagraphs=cg) # torch.Size([1, 9, 1026]); 
        logits += logit_bias # decode_one_token'in son elementlerinde [1025ler] olasılık zaten -inf'di. 
        next_token = sample_from_logits(logits, generated_tokens=delayed_codes[..., :offset], **sampling_params)

        # Append the new token / Yeni tokeni ekle
        frame = delayed_codes[..., offset : offset + 1]
        frame.masked_scatter_(frame == -1, next_token)

        # Increase inference_params / Inference ayarla.
        inference_params.seqlen_offset += 1
        inference_params.lengths_per_sample[:] += 1

100%|██████████| 258/258 [00:21<00:00, 12.22it/s]


In [None]:
# TODO: Make this a function
with torch.no_grad():
    out_codes = revert_delay_pattern(delayed_codes)
    out_codes.masked_fill_(out_codes >= 1024, 0)
    out_codes = out_codes[..., : offset - 9]
    print(out_codes.shape)
    decodedarr = model.autoencoder.decode(out_codes).squeeze().to(torch.float64).cpu()
    SesFromArray(decodedarr.numpy(), 44100).write('demo.wav')

torch.Size([1, 9, 250])
