In [1]:
""" 
Required module 
uv pip install transformers soundfile
uv pip install neucodec
uv pip install nemo_toolkit[all]
"""
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    Trainer,
    HfArgumentParser, 
    TrainingArguments, 
    AutoModelForSpeechSeq2Seq,
    GenerationConfig, 
    AutoModelForSpeechSeq2Seq, 
    AutoProcessor,
)
import argparse
from dataclasses import dataclass, field, fields
import torch 
import torch.nn as nn
import soundfile as sf
import torchaudio.transforms as T
from neucodec import NeuCodec

Failed to load /venv/main/lib/python3.12/site-packages/torchao/_C_cutlass_90a.abi3.so: Could not load this library: /venv/main/lib/python3.12/site-packages/torchao/_C_cutlass_90a.abi3.so
Failed to load /venv/main/lib/python3.12/site-packages/torchao/_C_mxfp8.cpython-310-x86_64-linux-gnu.so: Could not load this library: /venv/main/lib/python3.12/site-packages/torchao/_C_mxfp8.cpython-310-x86_64-linux-gnu.so


In [7]:
# Download the audio file 
!wget -O sample.wav http://thepodcastexchange.ca/s/Porsche-Macan-July-5-2018-1.mp3

# Resample to 16000
ref_audio, sr = sf.read("sample.wav")
ref_audio = torch.tensor(ref_audio, dtype=torch.float32) # shape (time, channel)
ref_audio = ref_audio.permute(1, 0) # (channel, time)
sampler = T.Resample(orig_freq=sr, new_freq=16000)
ref_audio, sr  = sampler(ref_audio), 16000
ref_audio = ref_audio[0:1, :16000*10]  # 10 seconds with mono
print(ref_audio.numpy())
sf.write("sample_16k.wav", ref_audio.permute(1,0).numpy(), sr)

--2026-02-26 10:00:04--  http://thepodcastexchange.ca/s/Porsche-Macan-July-5-2018-1.mp3
Resolving thepodcastexchange.ca (thepodcastexchange.ca)... 198.185.159.145, 198.49.23.144, 198.49.23.145, ...
Connecting to thepodcastexchange.ca (thepodcastexchange.ca)|198.185.159.145|:80... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://www.thepodcastexchange.ca/s/Porsche-Macan-July-5-2018-1.mp3 [following]
--2026-02-26 10:00:04--  https://www.thepodcastexchange.ca/s/Porsche-Macan-July-5-2018-1.mp3
Resolving www.thepodcastexchange.ca (www.thepodcastexchange.ca)... 198.185.159.144, 198.49.23.145, 198.185.159.145, ...
Connecting to www.thepodcastexchange.ca (www.thepodcastexchange.ca)|198.185.159.144|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://static1.squarespace.com/static/5b1724a7f2e6b17851239b45/t/5e90f949e0d581404b11fbc7/1758035906776/Porsche+Macan+July+5+2018+%281%29.mp3 [following]
--2026-02-26 10:00:05--  

In [2]:
# load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(
    "Scicom-intl/Multilingual-TTS-1.7B-Base"
)
model = AutoModelForCausalLM.from_pretrained(
    "Scicom-intl/Multilingual-TTS-1.7B-Base"
).to(device)
codec = NeuCodec.from_pretrained("neuphonic/neucodec").to(device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  WeightNorm.apply(module, name, dim)


In [9]:
print(ref_audio.shape)

torch.Size([1, 160000])


In [3]:
# prepare input tokens

ref_audio, sr= sf.read("sample_16k.wav")
ref_audio = torch.tensor(ref_audio, dtype=torch.float32)
ref_audio = ref_audio[None, :].to(device) # (channels, samples)
print(ref_audio[None, :1, :].shape)

with torch.no_grad():
    ref_audio = codec.encode_code(ref_audio[None, :1, :])
    ref_audio = [ f"<|s_{int(codec)}|>" for codec in ref_audio[0,0].cpu().numpy()]
    ref_audio = "".join(ref_audio)
    ref_text = ("If the Porsche Macan has proven anything, it's that the days "
               "of sacrificing performance for practicality are gone, long gone. " )
    target_text = "Hi, what can I do for you?"

input_tokens = tokenizer(
    f"<|im_start|>{ref_text}<|speach_start|>{ref_audio}<|im_end|><|im_start|>{target_text}<|speech_start|>", 
    return_tensors="pt"
).to(device)

torch.Size([1, 1, 160000])


In [4]:
def generate_sequences(model, input_tokens, num_of_sequences=2, temperature=0.9): 
    generation_config = {
        "max_new_tokens": 512,
        "do_sample": True,
        "num_beams": 1,
        "num_return_sequences": num_of_sequences,
        "temperature": temperature,
        "return_dict_in_generate": True,
        "output_scores": True,
    }
    generation_config = GenerationConfig(**generation_config)
    # input_tokens [B, seq_len], with left padding (if B is present, but currently we proceed with single prompt)
    prompt_length = input_tokens["input_ids"].shape[1]
    prompt_completion_ids = model.generate(
        **input_tokens, 
        generation_config=generation_config,
    ) # (B * num_of_sequences, max_length)
    # completion_ids = prompt_completion_ids[:, prompt_length:] 
    return prompt_completion_ids

# generate sampling 
output = generate_sequences(model, input_tokens, num_of_sequences = 4)
completion_ids = output["sequences"][:, input_tokens["input_ids"].shape[1]:]

`generation_config` default values have been modified to match model-specific defaults: {'pad_token_id': 151643, 'eos_token_id': 151645}. If this is not desired, please set these values explicitly.


In [5]:
# generate score is only generated new tokens only
assert output["sequences"].shape[1] - input_tokens["input_ids"].shape[1] == len(output["scores"])

# get the π of the new policy 
print(completion_ids.shape)
new_per_token_log_probs = []
for idx, token_logits in enumerate(output["scores"]):
    token_probs = nn.functional.log_softmax(token_logits, dim=-1).gather(1, completion_ids[:, idx].unsqueeze(-1))
    new_per_token_log_probs.append(token_probs)

new_per_token_log_probs = torch.concat(new_per_token_log_probs, dim=-1)
print(new_per_token_log_probs.shape)

torch.Size([4, 131])
torch.Size([4, 131])


In [6]:
# get the π of the old policy (no grad flow)
with torch.no_grad():
    logits = model(input_ids=output["sequences"]).logits
    print(logits.shape)

    log_probs = nn.functional.log_softmax(logits[: , :-1, : ][: , input_tokens["input_ids"].shape[1]-1: , : ], dim=-1)
old_per_token_log_probs = torch.gather(log_probs , dim=2, index=completion_ids.unsqueeze(-1)).squeeze(-1)


torch.Size([4, 676, 217208])


In [7]:
def get_per_token_logps(
    model, 
    prompt_completion_ids, 
    completion_ids, 
    prompt_length
): 
    logits = model(input_ids=prompt_completion_ids).logits
    log_probs = nn.functional.log_softmax(logits[: , :-1, : ][: , prompt_length-1: , : ], dim=-1)
    per_token_log_probs = torch.gather(log_probs , dim=2, index=completion_ids).squeeze(-1)
    return per_token_log_probs

old_per_token_log_probs = get_per_token_logps(model, 
                                              output["sequences"],
                                              completion_ids.unsqueeze(-1),
                                              input_tokens["input_ids"].shape[1])

In [8]:
torch.exp(new_per_token_log_probs - old_per_token_log_probs)

tensor([[ 7.2119,  2.3171,  2.5029,  5.6480,  2.9633,  4.7772,  3.2172,  2.8953,
          3.0390,  1.4439,  4.9794,  2.1718,  2.6582,  2.2212,  1.6630,  1.6697,
          2.6936,  4.7104,  3.8436,  4.1031,  4.4592,  4.9507,  4.8182,  2.9187,
          3.7008,  2.0347,  3.9604,  3.0819,  5.4510,  4.5021,  2.5907,  2.8581,
          1.7759,  3.2053,  1.7853,  2.6363,  1.8540,  2.4071,  2.0251,  2.9509,
          1.7125,  5.6748,  7.2240,  7.5955,  4.8234,  2.7216,  4.2985,  3.6992,
          5.0761,  2.4303,  3.3964,  2.4699,  3.5842,  6.0286,  3.9399,  3.4583,
          2.7272,  1.7582,  1.9071,  1.9698,  3.0365,  2.6660,  1.4121,  1.9401,
          6.8881,  1.3196,  2.8019,  1.5946,  4.4088,  2.6816,  2.0188,  1.5493,
          1.3882,  1.1498,  1.4245,  2.0464,  3.3222,  2.4209,  2.3710,  3.5925,
          2.8309,  2.7452,  3.4416,  2.1920,  2.3811,  3.3902,  2.9311,  2.4210,
          3.9252,  2.5794,  4.8121,  1.9701,  4.6733,  3.2990,  2.6262,  2.4023,
          2.5189,  2.0174,  

In [8]:
# importance ratio (GSPO)
important_ratio_grpo = torch.exp(new_per_token_log_probs - old_per_token_log_probs)
important_ratio_gspo = torch.exp(new_per_token_log_probs - old_per_token_log_probs).sum(dim=-1)

In [9]:
from IPython.display import Audio, display

Audio("sample_16k.wav")

In [10]:
import re
generated_codecs = [ torch.tensor([int(token) for token in re.findall(r'<\|s_(\d+)\|>',seq)]) for seq in  tokenizer.batch_decode(completion_ids)]

In [11]:
# can't decode in batch for now, unless u know how the silent codecs work for Nuecodec. 
audio_waveforms = []
sampler = T.Resample(24000, 16000)

with torch.no_grad():
    for token in generated_codecs: 
        audio_waveform = codec.decode_code(torch.tensor(token, device=device)[None, None, :]).cpu()
        print(audio_waveform[0].shape)
        audio_waveforms.append(audio_waveform[0][0])
        # display(Audio(audio_waveform[0], rate=24000))

from torch.nn.utils.rnn import pad_sequence
audio_waveforms = pad_sequence(audio_waveforms, batch_first=True, padding_value = 0)[:, None, :]
audio_waveforms = sampler(audio_waveforms)
for batch in audio_waveforms:
   display(Audio(batch, rate=16000))

  audio_waveform = codec.decode_code(torch.tensor(token, device=device)[None, None, :]).cpu()


torch.Size([1, 62400])
torch.Size([1, 43200])
torch.Size([1, 45120])
torch.Size([1, 43680])


In [12]:
# Load the titanet
import nemo.collections.asr as nemo_asr
speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained("nvidia/speakerverification_en_titanet_large")

# sample_16k.wav
# output.wav
with torch.no_grad():
    speaker_model.eval()
    source_emb = speaker_model.get_embedding("sample_16k.wav")
    target_emb = speaker_model(input_signal=audio_waveforms[:,0,:].to(speaker_model.device), 
                        input_signal_length=torch.tensor([audio_waveforms.shape[-1]] * audio_waveforms.shape[0]).to(speaker_model.device))[1]

source_emb_repeat = source_emb.repeat_interleave(4, dim=0)
reward_embedding =  nn.functional.cosine_similarity(source_emb, target_emb, dim=-1)


      text = re.sub("\s+", " ", text)
    
      text = re.sub("\s+\.\s+", ".", text)
    
fused_indices_to_multihot has reached end of life. Please migrate to a non-experimental function.
OneLogger: Setting error_handling_strategy to DISABLE_QUIETLY_AND_REPORT_METRIC_ERROR for rank (rank=0) with OneLogger disabled. To override: explicitly set error_handling_strategy parameter.
No exporters were provided. This means that no telemetry data will be collected.
      m = re.match('([su]([0-9]{1,2})p?) \(([0-9]{1,2}) bit\)$', token)
    
      m2 = re.match('([su]([0-9]{1,2})p?)( \(default\))?$', token)
    
      elif re.match('(flt)p?( \(default\))?$', token):
    
      elif re.match('(dbl)p?( \(default\))?$', token):
    


speakerverification_en_titanet_large.nem(…):   0%|          | 0.00/102M [00:00<?, ?B/s]

[NeMo W 2026-02-26 10:04:33 modelPT:188] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: /manifests/combined_fisher_swbd_voxceleb12_librispeech/train.json
    sample_rate: 16000
    labels: null
    batch_size: 64
    shuffle: true
    is_tarred: false
    tarred_audio_filepaths: null
    tarred_shard_strategy: scatter
    augmentor:
      noise:
        manifest_path: /manifests/noise/rir_noise_manifest.json
        prob: 0.5
        min_snr_db: 0
        max_snr_db: 15
      speed:
        prob: 0.5
        sr: 16000
        resample_type: kaiser_fast
        min_speed_rate: 0.95
        max_speed_rate: 1.05
    num_workers: 15
    pin_memory: true
    
[NeMo W 2026-02-26 10:04:33 modelPT:195] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method 

[NeMo I 2026-02-26 10:04:33 save_restore_connector:285] Model EncDecSpeakerLabelModel was successfully restored from /root/.cache/huggingface/hub/models--nvidia--speakerverification_en_titanet_large/snapshots/0dc382f40121a5fbd34db10a2bb04d826c2be6a8/speakerverification_en_titanet_large.nemo.


In [13]:
whisper = AutoModelForSpeechSeq2Seq.from_pretrained(
    "openai/whisper-large-v3", 
    dtype=torch.bfloat16
).to(device)
processor = AutoProcessor.from_pretrained("openai/whisper-large-v3")

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

generation_config.json: 0.00B [00:00, ?B/s]

preprocessor_config.json:   0%|          | 0.00/340 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

normalizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

In [14]:
from jiwer import cer
target_output = []
with torch.no_grad(), torch.autocast(device_type="cuda"):
    for batch in audio_waveforms:
        input_tokens = processor(batch[0, :], 
                             return_tensors="pt",
                             sampling_rate=16000,).to(device)
        predicted_ids = whisper.generate(input_tokens["input_features"])
        target_output.append(processor.tokenizer.decode(predicted_ids[0]))

# normalize the error rate and convert to accurcy for advantage calculation
reward_transcription = torch.tensor([
    1-min(1, cer(target_text, target)) for target in target_output
])
print(reward_transcription)

Using custom `forced_decoder_ids` from the (generation) config. This is deprecated in favor of the `task` and `language` flags/config options.
Transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English. This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`. See https://github.com/huggingface/transformers/pull/28687 for more details.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


tensor([1.0000, 1.0000, 0.9231, 1.0000])


In [15]:
reward_transcription

tensor([1.0000, 1.0000, 0.9231, 1.0000])

In [16]:
reward_embedding = reward_embedding.to('cpu')
reward_embedding

tensor([0.5000, 0.3315, 0.3223, 0.2469])

In [17]:
# GRPO loss 
# advantage 
alpha = 0.5
beta = 0.5
reward_total = alpha *  reward_transcription + beta * reward_embedding
reward_mean = reward_total.mean()
reward_std = reward_total.std()
advantages = (reward_total - reward_mean) / reward_std # shape ( G )
advantages = advantages.unsqueeze(dim=-1).to('cuda')

In [19]:
importance_ratio = torch.exp(new_per_token_log_probs - old_per_token_log_probs)
importance_ratio.shape # B * G , completion tokens
eps = 0.2 
pg_loss1 = -advantages * importance_ratio
pg_loss2 = -advantages * torch.clamp(
    importance_ratio, 1.0 - eps, 1.0 + eps
) 
pg_loss_max = torch.max(pg_loss1, pg_loss2)
pg_loss_max

tensor([[-1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00,
         -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00,
         -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00,
         -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00,
         -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00,
         -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00,
         -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00,
         -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00,
         -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00,
         -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00,
         -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00,
         -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00,
         -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00, -1.6951e+00,
         -1.6951e+00, -1.6951e+00, -1.

In [20]:
pg_loss_max.mean().backward()

In [22]:
for p in model.parameters():
    print(p.grad)
    break

tensor([[ 4.0244e-10, -1.0103e-09, -1.2042e-09,  ..., -1.1146e-09,
         -1.4462e-09,  1.2131e-09],
        [-6.4211e-11, -1.1283e-10, -5.1102e-10,  ..., -2.1192e-10,
         -2.6248e-10,  1.7344e-10],
        [ 2.1385e-11, -1.4353e-10, -5.2901e-11,  ..., -2.2265e-10,
         -6.5091e-11,  1.5455e-10],
        ...,
        [-7.5713e-11, -2.1944e-10, -1.1003e-09,  ...,  2.3488e-10,
         -4.0670e-10,  4.0832e-10],
        [-1.3343e-11, -1.0683e-11, -4.2819e-11,  ..., -2.6044e-11,
         -2.8077e-11,  1.6469e-11],
        [-1.2141e-11, -1.1699e-11, -4.2174e-11,  ..., -2.1526e-11,
         -2.8543e-11,  1.6002e-11]], device='cuda:0')
