# 1. Load the dataset / TODO: fix this mess

In [None]:
from datasets import load_dataset
from AudioInfo import Ses, SesFromArray
from tqdm import tqdm
import json
import os

def load_dset(size = 10000):
    dset = {}
    pth = '0-10000.json'
    if os.path.exists(pth):
        with open(pth, 'r', encoding='utf-8') as file:
            dset = json.loads(file.read())
    else:
        for d in tqdm(range(0, size, 1000)): # Save RAM space by loading in 1000 element chunks
            dataset = load_dataset("mozilla-foundation/common_voice_17_0", "tr", split=f"train[{d}:{d+1000}]")
            for idx, (audio, sentence) in enumerate(zip(dataset['audio'], dataset['sentence'])):
                dset[idx+d] = {"path": audio['path'], "sentence": sentence}
    return dset

dset = load_dset(10000)
dset = {int(k): v for k, v in dset.items()}

# 2. Load the Zonos model

In [2]:
from AudioInfo import Ses, SesFromArray
import torch
from zonos.model import Zonos
from zonos.conditioning import make_cond_dict
from zonos.codebook_pattern import apply_delay_pattern, revert_delay_pattern
from huggingface_hub import hf_hub_download
from zonos.sampling import sample_from_logits
from tqdm import tqdm
model_path="Zyphra/Zonos-v0.1-transformer"
dataset_name="mozilla-foundation/common_voice_17_0"
device = 'cuda'

model = Zonos.from_pretrained(model_path, device=device)
cfg_scale = 2
sampling_params = dict(min_p=0.1)
cg = model.can_use_cudagraphs()

Initing zonos for ft


# 3. Funcs to grab & preprocess input audios

In [None]:
from silero_vad import load_silero_vad, get_speech_timestamps
silero_model = load_silero_vad()

def audio_to_prefix_code(arr, sr):
    wav_prefix, sr_prefix = torch.tensor(arr, dtype=torch.float32).unsqueeze(0), sr
    wav_prefix = wav_prefix.mean(0, keepdim=True)
    wav_prefix = model.autoencoder.preprocess(wav_prefix, sr_prefix)
    wav_prefix = wav_prefix.to(device, dtype=torch.float32)
    return model.autoencoder.encode(wav_prefix.unsqueeze(0))

def ses_to_prefix_code(ses: Ses):
    if sum(torch.tensor(ses.arr)) != 0:
            return audio_to_prefix_code(ses.arr, ses.sr)
    else:
        print("passed empty prefix")
        return torch.full((1,9,0), 0).to(model.device)

def get_stamps(ses:Ses):
    if ses.sr != 16000:
        ses = ses.resampled(16000)
    speech_timestamps = get_speech_timestamps(
        torch.tensor(ses.arr, dtype=torch.float32),
        silero_model,
        return_seconds=True,  # Return speech timestamps in seconds (default is samples)
    )
    if len(speech_timestamps) < 1: return False
    return speech_timestamps[0]['start'], speech_timestamps[-1]['end']

def idx_to_condition(idx: int, limit=15):
    ses = Ses(dset[idx]['path']).resampled(44100) #
    if ses.duration_ > limit:
        ses = ses.trimmed(0, limit)
    start, end = get_stamps(ses) # stamps do round and sometimes raise errors, fix that.
    ses = ses.trimmed(start, end-start) # note: trimmed takes `t` as second element, not `end`
    sentence = dset[idx]['sentence']

    tens_ = torch.tensor(ses.arr, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
    new_condt = make_cond_dict(
        text=sentence,
        speaker=model.make_speaker_embedding(tens_.squeeze(), ses.sr),
        language='tr',#lang_1,
    )

    # ses = ses.trimmed(0, 1.5)
    prefix_codes = ses_to_prefix_code(ses)
    prefix_conditioning = model.prepare_conditioning(new_condt)

    return prefix_codes, prefix_conditioning

# 4. Prefill with the input audio codes

In [None]:
def prepare_prefix(codes: torch.Tensor, prefix_conditioning=None, trim:None|int=None):
    assert isinstance(prefix_conditioning, torch.Tensor)
    if isinstance(codes, Ses):
        codes = ses_to_prefix_code(codes) # if a `Ses` instance, process
    assert isinstance(codes, torch.Tensor)
    # Encode input audio
    prefix_codes = codes
    p_len = prefix_codes.size(-1)
    # Hyperparams
    max_length = 86 * 30
    seq_len = p_len + max_length + 9
    # Inference params
    with torch.device(model.device):
        batch_size__ = 1
        unknown_token = -1
        inference_params = model.setup_cache(batch_size=batch_size__ * 2, max_seqlen=seq_len)
    # Inference mode (no gradients are needed)
    with torch.no_grad():
        # Prepare prefix codes
        p_padded = torch.nn.functional.pad(prefix_codes, (0, max_length + p_len), value=unknown_token)
        # Re-predict missing token
        p_delayed = apply_delay_pattern(p_padded, mask_token=model.masked_token_id)
        pred_idx = p_len if trim is None else trim
        logits = model._prefill(prefix_conditioning,
                                p_delayed[...,:pred_idx + 1],
                                inference_params, 2)
        next_token = sample_from_logits(logits, **sampling_params)

        frame = p_delayed[..., pred_idx + 1:pred_idx + 2]
        frame.masked_scatter_(frame == unknown_token, next_token)
    # Offset and logit
    offset = p_delayed[...,:pred_idx + 1].size(-1)
    logit_bias = torch.zeros_like(logits)
    logit_bias[:, 1:, model.eos_token_id] = -torch.inf
    # Inference params
    prefices_length = prefix_conditioning.shape[1] + pred_idx + 1
    inference_params.seqlen_offset += prefices_length
    inference_params.lengths_per_sample[:] += prefices_length
    to_compare = p_delayed[...,offset+1:offset+2]
    return p_delayed, inference_params, offset, logit_bias, to_compare

# 5. Training loop

In [None]:
from torch.optim import AdamW
import random

# Params
loss_list = []
total_loss = 0
total_loss_ctr = 0

'''loss_per_second: the amount of input to process every 86 frames'''
loss_per_second = 1 # 
counter = 0
desc = ''
progress = tqdm(dset.keys(), desc=f"desc: {desc}")
STAYED_AT = 0 # in case of restart

# Optimize & scheduler
learning_rate = 1e-4
optimizer = AdamW(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=learning_rate,
        epochs=1,
        steps_per_epoch= (len(dset) - STAYED_AT) * loss_per_second * 5, # Assumed each input is 5 seconds on average, fix this later
        pct_start=0.1
    )

# Training loop
for k in progress:
    if k < STAYED_AT:
        continue
    optimizer.zero_grad()
    # Grab input audio
    codes, cond = idx_to_condition(k, limit=25)
    duration = codes.size(-1)

    batch_loss = 0
    batch_loss_ctr = 0
    batch_size = len(list(range(0, codes.size(-1), int(86/loss_per_second))))

    for duration_idx in range(0, codes.size(-1), int(86/loss_per_second)):
        # Don't take first couple and last couple frames into accounts for now.
        if duration_idx < 10 or duration - duration_idx < 5:
            continue
        try: random_index = random.randint(9, codes.size(-1)-9)
        except: random_index = duration_idx
        
        optimizer.zero_grad()  # Reset gradients for each batch

        # 1. Input codes 
        with torch.no_grad():
            delayed_codes, inference_params, offset, logit_bias, to_compare = prepare_prefix(codes, cond, random_index) # duration_idx)
            offset += 1
        
        # 2. Get logits
        input_ids = delayed_codes[..., offset - 1 : offset]  # Shape: [9,1]
        logits = model._decode_one_token(input_ids, inference_params, cfg_scale, allow_cudagraphs=cg)  # Shape: [1, 9, 1026]

        # 3. Compute loss
        loss = torch.nn.functional.cross_entropy(logits.squeeze(), to_compare.squeeze())

        # 4. Debug
        total_loss += loss.item()
        total_loss_ctr += 1
        batch_loss += loss.item()
        batch_loss_ctr += 1

        loss_list.append(loss.item())
        progress.desc = f"now @ key {k} & {batch_loss_ctr+1}/{batch_size} | idx @ {random_index}/{codes.size(-1)-9} | total processed: {total_loss_ctr} | loss: {loss.item()} | batch avg loss: {batch_loss/(batch_loss_ctr if batch_loss_ctr != 0 else 1)} |total avg loss: {total_loss/(total_loss_ctr if total_loss_ctr != 0 else 1)}"
        progress.update()

        # 5. Backward
        loss.backward(); del loss; torch.cuda.empty_cache() # They accumulate for some reason TODO: dig into this
        optimizer.step()  # Update weights
        scheduler.step()
    
    # Save loss history
    try:
        with open('loss_data.json', 'w', encoding='utf-8') as file:
            file.write(json.dumps(loss_list))
    except:
        pass

### Inference

In [None]:
new_condt = make_cond_dict(
        text= "Merhaba beyefendi, adınız nedir?",
        language='tr',
    )

# To do audio completion
ses_ = Ses('sample.wav')
prefix_codes = ses_to_prefix_code(ses_)
prefix_conditioning = model.prepare_conditioning(new_condt)

# Empty input to generate audios from no prefices
empty_ses = SesFromArray(torch.tensor(()).to(torch.float64).numpy(), 6)
with torch.no_grad():
    delayed_codes, inference_params, offset, logit_bias, to_compare = prepare_prefix(empty_ses, prefix_conditioning)

passed empty prefix


In [None]:
SECONDS_TO_GENERATE = 3

with torch.no_grad():
    for _ in tqdm(range(86*SECONDS_TO_GENERATE)):
        # Increase offset / Offseti artır
        offset += 1

        # Calculate next logit / Sonraki logiti hesapla
        input_ids = delayed_codes[..., offset - 1 : offset] # tensor([ 698, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025], device='cuda:0') [9, 1]; 698 next_tokenin 0. elemanıydı.
        logits = model._decode_one_token(input_ids, inference_params, cfg_scale, allow_cudagraphs=cg) # torch.Size([1, 9, 1026]); 
        logits += logit_bias # decode_one_token'in son elementlerinde [1025ler] olasılık zaten -inf'di. 
        next_token = sample_from_logits(logits, generated_tokens=delayed_codes[..., :offset], **sampling_params)

        # Append the new token / Yeni tokeni ekle
        frame = delayed_codes[..., offset : offset + 1]
        frame.masked_scatter_(frame == -1, next_token)

        # Increase inference_params / Inference ayarla.
        inference_params.seqlen_offset += 1
        inference_params.lengths_per_sample[:] += 1

100%|██████████| 258/258 [00:21<00:00, 12.22it/s]


In [None]:
# TODO: Make this a function
with torch.no_grad():
    out_codes = revert_delay_pattern(delayed_codes)
    out_codes.masked_fill_(out_codes >= 1024, 0)
    out_codes = out_codes[..., : offset - 9]
    print(out_codes.shape)
    decodedarr = model.autoencoder.decode(out_codes).squeeze().to(torch.float64).cpu()
    SesFromArray(decodedarr.numpy(), 44100).write('turkish demo.wav')

torch.Size([1, 9, 250])
