In [None]:
!pip install torch torchaudio onnx onnxruntime nemo_toolkit[all]



In [None]:
import torch
import nemo.collections.asr as nemo_asr

# üîπ Carregar modelo pr√©-treinado
asr_model = nemo_asr.models.EncDecCTCModel.from_pretrained(
    model_name="neongeckocom/stt_pt_citrinet_512_gamma_0_25"
)

# üîπ Combinar encoder + decoder manualmente
class CitrinetCore(torch.nn.Module):
    def __init__(self, nemo_model):
        super().__init__()
        self.encoder = nemo_model.encoder
        self.decoder = nemo_model.decoder

    def forward(self, features, features_len):
        encoded, encoded_len = self.encoder(audio_signal=features, length=features_len)
        logits = self.decoder(encoder_output=encoded)
        return logits, encoded_len

core_model = CitrinetCore(asr_model)
core_model.eval()

# üîπ Dummy input (como o featurizer produziria)
dummy_feat = torch.randn(1, 80, 200, dtype=torch.float32)   # [B, mel_bins, frames]
dummy_len = torch.tensor([200], dtype=torch.int64)

# üîπ Exportar para ONNX
torch.onnx.export(
    core_model,
    (dummy_feat, dummy_len),
    "citrinet_encoder_decoder.onnx",
    input_names=["features", "features_len"],
    output_names=["logits", "encoded_len"],
    dynamic_axes={
        "features": {2: "n_frames"},
        "logits": {1: "n_frames"},
    },
    opset_version=14,
)

print("‚úÖ Exportado: citrinet_encoder_decoder.onnx")


W1025 18:45:04.139000 22790 torch/utils/cpp_extension.py:118] No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'


[NeMo I 2025-10-25 18:45:24 nemo_logging:393] Tokenizer SentencePieceTokenizer initialized with 256 tokens


[NeMo W 2025-10-25 18:45:24 nemo_logging:405] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: datasets/cv_neon/manifests/commonvoice_train_manifest_processed.json
    sample_rate: 16000
    batch_size: 32
    trim_silence: false
    max_duration: 9.0
    min_duration: 1.0
    shuffle: true
    use_start_end_token: false
    num_workers: 8
    pin_memory: true
    is_tarred: false
    tarred_audio_filepaths: null
    shuffle_n: 2048
    bucketing_strategy: synced_randomized
    bucketing_batch_size: null
    
[NeMo W 2025-10-25 18:45:24 nemo_logging:405] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). 
    Validation config : 
    manifest_filepath: datasets/cv_neon/

[NeMo I 2025-10-25 18:45:24 nemo_logging:393] PADDING: 16
[NeMo I 2025-10-25 18:45:26 nemo_logging:393] Model EncDecCTCModelBPE was successfully restored from /root/.cache/huggingface/hub/models--neongeckocom--stt_pt_citrinet_512_gamma_0_25/snapshots/ea95a18b0eaa1ccaf86faa209dc5c72a4325df51/stt_pt_citrinet_512_gamma_0_25.nemo.
‚úÖ Exportado: citrinet_encoder_decoder.onnx


In [None]:
# Extrair tokens do modelo NeMo
tokens = asr_model.decoder.vocabulary  # lista de strings, cada token

# Salvar em tokens.txt
with open("tokens.txt", "w", encoding="utf-8") as f:
    for token in tokens:
        f.write(token + "\n")

print("‚úÖ tokens.txt criado com sucesso!")

‚úÖ tokens.txt criado com sucesso!


In [4]:
from omegaconf import OmegaConf
import yaml

# Extrair a estrutura como dicion√°rio nativo do Python
config_dict = OmegaConf.to_container(asr_model.cfg, resolve=True)

# Salvar como YAML
with open("model_config.yaml", "w", encoding="utf-8") as f:
    yaml.dump(config_dict, f, allow_unicode=True)

print("‚úÖ Arquivo 'model_config.yaml' exportado com sucesso!")


‚úÖ Arquivo 'model_config.yaml' exportado com sucesso!


In [None]:
import onnx
import numpy as np
import torchaudio
import torch

# Caminho para o modelo exportado
onnx_path = "citrinet_encoder_decoder.onnx"

# 1Ô∏è‚É£ Carregar o modelo ONNX
model = onnx.load(onnx_path)

print("=== ENTRADAS DO MODELO ===")
for inp in model.graph.input:
    dims = [d.dim_value for d in inp.type.tensor_type.shape.dim]
    dtype = inp.type.tensor_type.elem_type
    print(f"- {inp.name} | shape: {dims} | dtype: {dtype}")

print("\n=== SA√çDAS DO MODELO ===")
for out in model.graph.output:
    dims = [d.dim_value for d in out.type.tensor_type.shape.dim]
    dtype = out.type.tensor_type.elem_type
    print(f"- {out.name} | shape: {dims} | dtype: {dtype}")

# 2Ô∏è‚É£ Verificar se o modelo cont√©m o decoder
# Dica: modelos com decoder (CTC) costumam ter "logits" ou "probs" na sa√≠da
output_names = [out.name.lower() for out in model.graph.output]
if any("logits" in n or "probs" in n or "ctc" in n for n in output_names):
    print("\n‚úÖ Este modelo cont√©m o DECODER (CTC head).")
else:
    print("\n‚ö†Ô∏è Este modelo provavelmente cont√©m apenas o ENCODER.")

# 3Ô∏è‚É£ Pr√©-processamento esperado

def preprocess_audio(filepath, sample_rate=16000):
    # Carregar √°udio
    waveform, sr = torchaudio.load(filepath)
    if sr != sample_rate:
        waveform = torchaudio.functional.resample(waveform, sr, sample_rate)

    # Converter para mono
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)

    # Normalizar (amplitude entre -1 e 1)
    waveform = waveform / torch.abs(waveform).max()

    # Gerar espectrograma mel
    mel_spec = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=512,
        win_length=400,
        hop_length=160,
        n_mels=80
    )(waveform)

    # Converter para log-mel
    log_mel_spec = torch.log(mel_spec + 1e-6)

    print("\n=== Pr√©-processamento conclu√≠do ===")
    print(f"Waveform shape: {waveform.shape}")
    print(f"Log-mel shape: {log_mel_spec.shape}")
    print(f"Exemplo de entrada esperado: (batch=1, n_mels=80, time={log_mel_spec.shape[-1]})")

    return log_mel_spec

# Exemplo de uso:
# log_mel = preprocess_audio("teste.wav")


=== ENTRADAS DO MODELO ===
- features | shape: [1, 80, 0] | dtype: 1
- features_len | shape: [1] | dtype: 7

=== SA√çDAS DO MODELO ===
- logits | shape: [1, 0, 257] | dtype: 1
- encoded_len | shape: [1] | dtype: 7

‚úÖ Este modelo cont√©m o DECODER (CTC head).


In [None]:
import onnxruntime as ort

session = ort.InferenceSession("citrinet_encoder_decoder.onnx", providers=["CPUExecutionProvider"])

input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

print("Input name:", input_name)
print("Output name:", output_name)


Input name: features
Output name: logits
