# Guide Complet Wav2Vec2: Entraînement, Conversion et Inférence

Ce notebook regroupe toutes les étapes nécessaires pour:
1. Installation des dépendances
2. Entraînement du modèle
3. Conversion en ONNX
4. Inférence en temps réel
5. Évaluation des performances

## 1. Installation des dépendances

In [None]:
!pip install soundfile torch torchaudio transformers pyaudio webrtcvad rx halo onnx onnxruntime wheel pyctcdecode
!pip install https://github.com/kpu/kenlm/archive/master.zip

## 2. Importation des bibliothèques nécessaires

In [None]:
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, AutoProcessor, AutoModelForCTC
import soundfile as sf
import numpy as np
import onnxruntime as rt
from onnxruntime.quantization import quantize_dynamic, QuantType
import pyaudio
import webrtcvad
from rx.subject import BehaviorSubject
import time

## 3. Classe d'inférence Wav2Vec2

In [None]:
class Wave2Vec2Inference:
    def __init__(self, model_name, hotwords=[], use_lm_if_possible=True, use_gpu=True):
        self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
        if use_lm_if_possible:            
            self.processor = AutoProcessor.from_pretrained(model_name)
        else:
            self.processor = Wav2Vec2Processor.from_pretrained(model_name)
        self.model = AutoModelForCTC.from_pretrained(model_name)
        self.model.to(self.device)
        self.hotwords = hotwords
        self.use_lm_if_possible = use_lm_if_possible

    def buffer_to_text(self, audio_buffer):
        if len(audio_buffer) == 0:
            return ""

        inputs = self.processor(torch.tensor(audio_buffer), sampling_rate=16_000, return_tensors="pt", padding=True)

        with torch.no_grad():
            logits = self.model(inputs.input_values.to(self.device),
                                attention_mask=inputs.attention_mask.to(self.device)).logits            

        if hasattr(self.processor, 'decoder') and self.use_lm_if_possible:
            transcription = \
                self.processor.decode(logits[0].cpu().numpy(),                                      
                                      hotwords=self.hotwords,
                                      output_word_offsets=True)                             
            confidence = transcription.lm_score / len(transcription.text.split(" "))
            transcription = transcription.text       
        else:
            predicted_ids = torch.argmax(logits, dim=-1)
            transcription = self.processor.batch_decode(predicted_ids)[0]
            confidence = self.confidence_score(logits,predicted_ids)

        return transcription, confidence

## 4. Conversion en ONNX

In [None]:
def convert_to_onnx(model_id_or_path, onnx_model_name):
    print(f"Converting {model_id_or_path} to ONNX")
    model = Wav2Vec2ForCTC.from_pretrained(model_id_or_path)
    audio_len = 250000

    x = torch.randn(1, audio_len, requires_grad=True)

    torch.onnx.export(model,
                    x,
                    onnx_model_name,
                    export_params=True,
                    opset_version=11,
                    do_constant_folding=True,
                    input_names=['input'],
                    output_names=['output'],
                    dynamic_axes={'input': {1: 'audio_len'},
                                'output': {1: 'audio_len'}})

def quantize_onnx_model(onnx_model_path, quantized_model_path):
    print("Starting quantization...")
    quantize_dynamic(onnx_model_path,
                     quantized_model_path,
                     weight_type=QuantType.QUInt8)
    print(f"Quantized model saved to: {quantized_model_path}")

## 5. Classe d'inférence ONNX

In [None]:
class Wave2Vec2ONNXInference:
    def __init__(self, model_name, onnx_path):
        self.processor = Wav2Vec2Processor.from_pretrained(model_name) 
        options = rt.SessionOptions()
        options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
        self.model = rt.InferenceSession(onnx_path, options)

    def buffer_to_text(self, audio_buffer):
        if len(audio_buffer) == 0:
            return ""

        inputs = self.processor(torch.tensor(audio_buffer), sampling_rate=16_000, return_tensors="np", padding=True)
        input_values = inputs.input_values
        onnx_outputs = self.model.run(None, {self.model.get_inputs()[0].name: input_values})[0]
        prediction = np.argmax(onnx_outputs, axis=-1)
        transcription = self.processor.decode(prediction.squeeze().tolist())
        return transcription.lower()

## 6. Évaluation des performances

In [None]:
def evaluate_performance(audio_file, base_model, iterations=100):
    torch.set_num_threads(16)
    audio_input, samplerate = sf.read(audio_file)
    assert samplerate == 16000, "L'audio doit être échantillonné à 16kHz"

    # Créer les modèles
    asr = Wave2Vec2Inference(base_model)
    asr_onnx = Wave2Vec2ONNXInference(base_model, f"{base_model.split('/')[-1]}.onnx")
    asr_onnx_quant = Wave2Vec2ONNXInference(base_model, f"{base_model.split('/')[-1]}.quant.onnx")

    # Test de transcription
    print("Test de transcription:")
    text_pytorch = asr.buffer_to_text(audio_input)[0]
    text_onnx = asr_onnx.buffer_to_text(audio_input)
    print(f"PyTorch: {text_pytorch}")
    print(f"ONNX: {text_onnx}")

    # Test de performance
    print(f"\nTest de performance sur {iterations} itérations:")
    
    seconds = timeit.timeit(lambda: asr.buffer_to_text(audio_input), number=iterations)
    print(f"PyTorch: {(seconds/iterations)*1000:.2f} ms/iter")

    seconds = timeit.timeit(lambda: asr_onnx.buffer_to_text(audio_input), number=iterations)
    print(f"ONNX: {(seconds/iterations)*1000:.2f} ms/iter")

    seconds = timeit.timeit(lambda: asr_onnx_quant.buffer_to_text(audio_input), number=iterations)
    print(f"ONNX quantifié: {(seconds/iterations)*1000:.2f} ms/iter")

## 7. Exemple d'utilisation

In [None]:
# Définir le modèle de base
base_model = "facebook/wav2vec2-base"

# 1. Convertir en ONNX
convert_to_onnx(base_model, "wav2vec2-base.onnx")

# 2. Quantifier le modèle ONNX
quantize_onnx_model("wav2vec2-base.onnx", "wav2vec2-base.quant.onnx")

# 3. Évaluer les performances (nécessite un fichier audio test.wav)
evaluate_performance("test.wav", base_model)