<a href="https://colab.research.google.com/github/TitanSage02/AI-Agents-browser/blob/main/IWSLT_2025_ASR_Fon_Fr.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Projet ASR & MT pour le Fon

Ce notebook documente le parcours de développement d'un système de reconnaissance vocale (ASR) pour le Fon, intégré à un module de traduction (MT) Fon→FR.

Ce projet s'appuie sur des modèles préentraînés et des techniques d'augmentation de données pour optimiser la performance en contexte low-resource.

---

## Objectifs :
- **ASR en Fon** : Développer un modèle performant en fine-tuning un préentraîné (Wav2Vec)
- **MT Fon-FR** : Intégrer et adapter un modèle existant pour traduire les transcriptions du Fon vers le français.
- **Pipeline Intégré** : Mettre en place un système end-to-end permettant de passer de l'audio à la traduction, tout en documentant méthodiquement chaque étape pour un article scientifique.

---

## Roadmap :

1. **Préparation des données**
   - Vérifier et nettoyer le dataset (audio et transcriptions pour ASR, paires Fon-FR pour MT)
   - Normaliser et prétraiter les données
   - Appliquer des techniques d'augmentation (SpecAugment, variations de vitesse, ajout de bruit)

2. **Conception du Modèle ASR pour le Fon**
   - Sélection d’un modèle préentraîné (Wav2Vec )
   - Ajout d’un modèle de langage externe pour améliorer la cohérence des transcriptions ( si nécessaire )

3. **Entraînement et Fine-Tuning**
   - Adapter le modèle préentraîné au dataset Fon via le fine-tuning
   - Utiliser des techniques de régularisation (dropout, label smoothing)
   - Itérer avec des data augmentations pour renforcer la robustesse du modèle

4. **Évaluation et Optimisation**
   - Mesurer les performances avec des métriques (WER, CER)
   - Réaliser une validation croisée et une analyse fine des erreurs
   - Ajuster hyperparamètres et architecture en fonction des résultats

5. **Intégration du système MT Fon-FR**
   - Exploiter un modèle préentraîné de traduction (disponible sur Huggingface ou celui de Google)
   - Construire un pipeline en cascade : ASR → MT
   - Mettre en place des stratégies de correction pour limiter l'impact des erreurs de transcription

6. **Documentation et publication**
   - Rédiger un rapport détaillé sur la méthodologie, les choix techniques et les résultats obtenus
   - Intégrer les références scientifiques clés (ex. travaux sur Wav2Vec 2.0, HuBERT, Transformers)
   - Discuter des limitations et perspectives d'amélioration

7. **Déploiement et tests en conditions réelles**
   - Implémenter un pipeline end-to-end pour des tests sur le terrain
   - Recueillir des feedbacks et itérer pour optimiser le système final





*Ce notebook est conçu pour guider étape par étape le développement et l'intégration de notre système ASR-MT en contexte low-resource, en documentant soigneusement chaque phase pour une analyse approfondie et une publication scientifique ultérieure.*


# Chargment des données

In [None]:
import os

def detect_environment():
    """Détecte l'environnement d'exécution avec style 🌟"""
    if os.path.exists("/kaggle/input/"):
      return "Kaggle"
    else :
      return "Google Colab"

    # try:
    #     import google.colab
    #     return "Google Colab"
    # except ImportError:
    #     pass

    # return "Local"

env = "Google Colab"
print(f"L'environnement d'exécution est : {env}")

if env ==  "Google Colab":
    # Connecting of kaglle account
    import kagglehub
    # kagglehub.login()

L'environnement d'exécution est : Google Colab


In [None]:
if env == "Google Colab":
  # Download the dataset
  source_path = kagglehub.dataset_download('esperance01/ffstc-2025')
  print('Dataset downloaded...')

else :
    source_path = "/kaggle/input/ffstc-2025"

Downloading from https://www.kaggle.com/api/v1/datasets/download/esperance01/ffstc-2025?dataset_version_number=2...


100%|██████████| 4.52G/4.52G [00:41<00:00, 116MB/s]

Extracting files...





Dataset downloaded...


In [None]:
import os
import shutil
from tqdm import tqdm

def move_all_contents(source_path, dataset_name="dataset"):
    """Déplace récursivement avec des barres de progression élégantes"""

    # Configuration de l'environnement
    if env == "Kaggle":
        print("Les fichiers sont déjà directement accessibles.")
        return


    base_path = {
        "Google Colab": "/content/",
        "Kaggle" : "/kaggle/working/"
    }

    dest_root = os.path.join(base_path[env], dataset_name)
    os.makedirs(dest_root, exist_ok=True)

    # Comptage initial pour la barre de progression
    total_files = sum(len(files) for _, _, files in os.walk(source_path))

    # Barre de progression principale
    with tqdm(
        total=total_files,
        desc="Déplacement des fichiers",
        unit=" fichier(s)",
        bar_format="{l_bar}{bar:30}{r_bar}",
        colour="#00ff00"
    ) as pbar:

        for root, dirs, files in os.walk(source_path):
            relative_path = os.path.relpath(root, source_path)
            dest_dir = os.path.join(dest_root, relative_path)

            # Création de dossier avec feedback visuel
            if not os.path.exists(dest_dir):
                # tqdm.write(f" Création du dossier: {dest_dir}")
                os.makedirs(dest_dir, exist_ok=True)

            for file in files:
                src_file = os.path.join(root, file)
                dest_file = os.path.join(dest_dir, file)

                # Gestion des conflits
                base, ext = os.path.splitext(file)
                counter = 1
                while os.path.exists(dest_file):
                    new_name = f"{base}_{counter}{ext}"
                    dest_file = os.path.join(dest_dir, new_name)
                    counter += 1
                    tqdm.write(f"Renommage: {file} => {new_name}")

                shutil.move(src_file, dest_file)
                pbar.update(1)
                pbar.set_postfix(file=file[:20] + "..." if len(file) > 20 else file)

    print(f"\nTous les fichiers ont été déplacés vers: {dest_root}")
    return dest_root

dataset_path = move_all_contents(source_path, "dataset")

Déplacement des fichiers: 100%|[38;2;0;255;0m██████████████████████████████[0m| 34376/34376 [00:24<00:00, 1381.76 fichier(s)/s, file=26392.wav]


Tous les fichiers ont été déplacés vers: /content/dataset





# 1. Préparation des données

**Tâches :**
- Vérifier et nettoyer le dataset (audio et transcriptions pour ASR, paires Fon-FR pour MT)
- Normaliser et prétraiter les données
- Appliquer des techniques d'augmentation (SpecAugment, variations de vitesse, ajout de bruit)

In [None]:
MODE_DEBUG = False # Pour désactiver les vérifications quand on veut lancer tout le notebook

In [None]:
# Vérifier et nettoyer le dataset (audio et transcriptions pour ASR, paires Fon-FR pour MT)

import os
import numpy as np
import pandas as pd
import librosa

In [None]:
# Configuration du path
if env == "Google Colab":
  BASE_PATH = "/content/dataset"
else :
  BASE_PATH = "/kaggle/input/ffstc-2025"

TRAIN_PATH = os.path.join(BASE_PATH, "train_ok.csv")
TRAIN_AUDIO_DIR = os.path.join(BASE_PATH, "train")

VALID_PATH = os.path.join(BASE_PATH, "valid_ok.csv")
VALID_AUDIO_DIR = os.path.join(BASE_PATH, "valid")

In [None]:
# CHARGEMENT DES DF
train_df = pd.read_csv(TRAIN_PATH)
valid_df = pd.read_csv(VALID_PATH)

train_df.info()

print("\n\n")

valid_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 450 entries, 0 to 449
Data columns (total 5 columns):
 #   Column     Non-Null Count  Dtype  
---  ------     --------------  -----  
 0   ID         450 non-null    object 
 1   utterance  450 non-null    object 
 2   filename   450 non-null    object 
 3   duration   450 non-null    float64
 4   translate  450 non-null    object 
dtypes: float64(1), object(4)
memory usage: 17.7+ KB



<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1129 entries, 0 to 1128
Data columns (total 5 columns):
 #   Column     Non-Null Count  Dtype  
---  ------     --------------  -----  
 0   ID         1129 non-null   object 
 1   utterance  1129 non-null   object 
 2   filename   1129 non-null   object 
 3   duration   1129 non-null   float64
 4   translate  1129 non-null   object 
dtypes: float64(1), object(4)
memory usage: 44.2+ KB


In [None]:
if MODE_DEBUG:
  train_df.sample(3)

In [None]:
if MODE_DEBUG:
  valid_df.sample(5)

## Nettoyage du dataset

In [None]:
# On ajoute le chemin complet vers les fichiers audio

def add_audio_path(df, DIR, col_name="filename"):
    """
    Ajoute une colonne 'audio_filepath' au DataFrame en combinant le chemin
    du répertoire (DIR) avec le nom du fichier audio dans la colonne spécifiée (col_name).
    """
    df['filename'] = df[col_name].apply(lambda filename: os.path.join(DIR, filename))
    return df

train_df = add_audio_path(train_df, TRAIN_AUDIO_DIR)
valid_df = add_audio_path(valid_df, VALID_AUDIO_DIR)

In [None]:
if MODE_DEBUG:
  train_df.sample(2)

In [None]:
if MODE_DEBUG:
  valid_df.sample(2)

In [None]:
# Suppression d'éventuels espaces superflus
train_df['utterance'] = train_df['utterance'].str.strip()
train_df['translate'] = train_df['translate'].str.strip()

valid_df['utterance'] = valid_df['utterance'].str.strip()
valid_df['translate'] = valid_df['translate'].str.strip()


if MODE_DEBUG:
  train_df.head()

In [None]:
if MODE_DEBUG:
  train_df.info()

  print("\n\n")

  valid_df.info()

In [None]:
if MODE_DEBUG:

  # Vérification de l'existence des fichiers audio
  def check_file_exists(row):
      """Vérifie si le fichier audio existe, en tenant compte du chemin de base."""
      file_path = row['filename']
      return os.path.exists(file_path)

  print("Vérification de l'existence des fichiers audio...")

  for df in [train_df, valid_df]:
    df['files_exists'] = df.apply(check_file_exists, axis=1)
    missing_files = df[~df['files_exists']]

    if not missing_files.empty:
        print(f"Avertissement : {len(missing_files)} fichiers audio n'ont pas été trouvés.")
        # df = df[df['file_exists']] # Suppression des lignes des audios manquants
        # print(f"Nombre de lignes après suppression des fichiers manquants : {len(df)}")
    else:
        print("Tous les fichiers audio sont accessibles.")

In [None]:
if MODE_DEBUG:
  # Vérification de la cohérence des durées audio
  def verify_duration(row, tolerance=0.1):
      """
      Vérifie que la durée réelle du fichier audio correspond à celle indiquée dans le CSV.
      La tolérance (en secondes) permet de compenser de légères variations.
      """
      file_path = row['audio_filepath']

      try:
          # Chargement de l'audio sans rééchantillonnage
          audio, sr = librosa.load(file_path, sr=None)
          real_duration = librosa.get_duration(y=audio, sr=sr)
          expected_duration = float(row['duration'])
          return abs(real_duration - expected_duration) <= tolerance
      except Exception as e:
          print(f"Erreur lors du chargement du fichier {file_path} : {e}")
          return False

  print("Vérification de la cohérence des durées audio...")

  for df in [train_df, valid_df]:
    df['duration_match'] = df.apply(verify_duration, axis=1)

    mismatched = df[~df['duration_match']]

    if not mismatched.empty:
        print(f"Avertissement : {len(mismatched)} fichiers présentent une durée incohérente.")
    else:
        print("Les durées audio correspondent aux valeurs attendues.")

In [None]:
if MODE_DEBUG:
  train_df = train_df.drop(columns=['files_exists', 'duration_match'])
  valid_df = valid_df.drop(columns=['files_exists', 'duration_match'])
  # Le dataset est propre. On peut commencer les manipulations 😁😁😁

## Préparation du dataset

**Augmentation de données audios :**
L'objectif ici est de rendre le modèle robuste

 Nous utilisons TorchAudio pour appliquer :
- Perturbations de vitesse (±10%)
- Ajout de bruit blanc/ambiant
- SpecAugment (masquage de fréquences et de temps)
- Simulation de réverbération
- Application de filtres (lowpass/highpass)


In [None]:
# # Fonctions utiles pour manipuler les audios

# import torch
# import torchaudio
# import random

# def apply_speed_perturbation(waveform, sample_rate, min_speed=0.9, max_speed=1.1):
#     """
#     Applique une perturbation de vitesse à l'audio en utilisant les effets sox.
#     """
#     speed_factor = random.uniform(min_speed, max_speed)
#     effects = [
#         ["speed", f"{speed_factor:.2f}"],
#         ["rate", f"{sample_rate}"]  # Re-échantillonne pour conserver le taux d'échantillonnage
#     ]
#     augmented_waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, sample_rate, effects)
#     return augmented_waveform

# def add_white_noise(waveform, noise_factor=0.005):
#     """
#     Ajoute du bruit blanc à l'audio.
#     """
#     noise = torch.randn_like(waveform) * noise_factor
#     return waveform + noise

# def apply_reverb(waveform, sample_rate, reverberance=50, hf_damping=50, room_scale=100, stereo_depth=100, pre_delay=0.5):
#     """
#     Applique une réverbération simulée à l'audio.
#     Les paramètres peuvent être ajustés pour simuler différentes conditions acoustiques.
#     """
#     effects = [
#         ["reverb", f"{reverberance}", f"{hf_damping}", f"{room_scale}", f"{stereo_depth}", f"{pre_delay}"]
#     ]
#     augmented_waveform, _ = torchaudio.sox_effects.apply_effects_tensor(waveform, sample_rate, effects)
#     return augmented_waveform

# def apply_lowpass_filter(waveform, sample_rate, cutoff_freq=4000):
#     """
#     Applique un filtre passe-bas.
#     """
#     return torchaudio.functional.lowpass_biquad(waveform, sample_rate, cutoff_freq)

# def apply_highpass_filter(waveform, sample_rate, cutoff_freq=200):
#     """
#     Applique un filtre passe-haut.
#     """
#     return torchaudio.functional.highpass_biquad(waveform, sample_rate, cutoff_freq)

In [None]:
# # Classe pour la chaîne d'augmentation audio
# class AudioAugmentor:
#     def __init__(self,
#                  noise_factor=0.005,
#                  prob_speed=0.65,
#                  prob_noise=0.55,
#                  prob_reverb=0.65,
#                  prob_filter=0.5):
#         """
#         Initialise l'augmentateur avec des probabilités d'application pour chaque transformation.
#         """
#         self.noise_factor = noise_factor
#         self.prob_speed = prob_speed
#         self.prob_noise = prob_noise
#         self.prob_reverb = prob_reverb
#         self.prob_filter = prob_filter

#     def __call__(self, filename):
#         waveform, sr = torchaudio.load(filename)

#         augmented = waveform.clone()

#         # Perturbation de vitesse
#         if random.random() < self.prob_speed:
#             augmented = apply_speed_perturbation(augmented, sr)

#         # Ajout de bruit blanc
#         if random.random() < self.prob_noise:
#             augmented = add_white_noise(augmented, self.noise_factor)

#         # Réverbération
#         if random.random() < self.prob_reverb:
#             augmented = apply_reverb(augmented, sr)

#         # Application d'un filtre (choix aléatoire entre passe-bas et passe-haut)
#         if random.random() < self.prob_filter:
#             if random.random() < 0.35:
#                 augmented = apply_lowpass_filter(augmented, sr)
#             else:
#                 augmented = apply_highpass_filter(augmented, sr)

#         return augmented

# # SpecAugment (appliqué sur le spectrogramme)
# class SpecAugment:
#     def __init__(self, freq_mask_param=15, time_mask_param=35):
#         """
#         Initialise SpecAugment avec des paramètres pour le masquage fréquentiel et temporel.
#         """
#         self.freq_masking = torchaudio.transforms.FrequencyMasking(freq_mask_param)
#         self.time_masking = torchaudio.transforms.TimeMasking(time_mask_param)

#     def __call__(self, spectrogram):
#         spec = self.freq_masking(spectrogram)
#         spec = self.time_masking(spec)
#         return spec


In [None]:
train_df.head()

Unnamed: 0,ID,utterance,filename,duration,translate
0,5_fongbe_3c382b5a455a42febe63400a2cd25998_vali...,elles jonchaient le sol.,/content/dataset/train/5_fongbe_3c382b5a455a42...,4.78,Yě lí kɔ́.
1,13_fongbe_04fa004f62da4374b53ba928b462c2ae_val...,"cela suffit, dit le père.",/content/dataset/train/13_fongbe_04fa004f62da4...,23.68,"Enɛ ko kpé, wɛ tɔ́ ɔ ɖɔ."
2,5_fongbe_0bed6e14b6ec43b387e6fce3763f39b9_vali...,tu ne t’en es pas laissé imposer par toutes le...,/content/dataset/train/5_fongbe_0bed6e14b6ec43...,21.18,Nǔ e azětɔ́ lɛ́ɛ ɖɔ ɖ'ayǐ lɛ́ɛ bǐ wɛ sɔ́ we dó...
3,7_fongbe_0385bf0bfa824fdba6661b055703a19b_vali...,c’est bien.,/content/dataset/train/7_fongbe_0385bf0bfa824f...,1.81,É nyɔ́.
4,13_fongbe_02ff49ef6ab7453b87f883c6d12bcb5d_val...,"combien de jours, de lunes, d’années était-il ...",/content/dataset/train/13_fongbe_02ff49ef6ab74...,6.13,Azǎn nabi wɛ é ka nɔ ɖò finɛ?


In [None]:
# Installation de libsox.so
# !apt-get update && apt-get install sox libsox-dev libsox-fmt-all

In [None]:
# # Application sur notre dataset

# augmentor = AudioAugmentor()

# def augment_audio_for_dataframe(df, aug_folder, augmentor, prob_threshold=0.4):
#     """
#     Pour chaque ligne du dataframe, si random.random() > prob_threshold,
#     on charge l'audio, on applique les transformations via l'augmentor,
#     on sauvegarde le fichier généré dans aug_folder et on ajoute une nouvelle ligne
#     au dataframe avec le nouveau chemin, la durée recalculée, et un ID modifié.
#     """
#     new_rows = []

#     for idx, row in df.iterrows():
#         if random.random() > prob_threshold:
#             # Chargement de l'audio
#             original_file = row['filename']

#             try:
#                 waveform, sr = torchaudio.load(original_file)
#             except Exception as e:
#                 print(f"Erreur lors du chargement du fichier {original_file}: {e}")
#                 continue

#             # Application des transformations via l'augmentor
#             augmented_waveform = augmentor(original_file)

#             # Génération d'un nouveau nom de fichier
#             base_name = os.path.basename(original_file)
#             name, ext = os.path.splitext(base_name)
#             new_file_name = f"aug_{name}{ext}"
#             new_file_path = os.path.join(aug_folder, new_file_name)

#             # Sauvegarde du fichier audio généré
#             try:
#                 torchaudio.save(new_file_path, augmented_waveform, sr)
#             except Exception as e:
#                 print(f"Erreur lors de la sauvegarde du fichier {new_file_path}: {e}")
#                 continue

#             # Calcul de la nouvelle durée
#             new_duration = augmented_waveform.shape[1] / sr

#             # Création d'une nouvelle ligne avec les mêmes transcriptions
#             new_row = {
#                 'ID':  'aug_' + str(row['ID']),
#                 'utterance': row['utterance'],
#                 'filename': new_file_path,
#                 'duration': new_duration,
#                 'translate': row['translate']
#             }
#             new_rows.append(new_row)

#     # Si des augmentations ont été effectuées, on les ajoute au DataFrame original
#     if new_rows:
#         df_augmented = pd.DataFrame(new_rows)
#         df_updated = pd.concat([df, df_augmented], ignore_index=True)
#     else:
#         df_updated = df.copy()

#     return df_updated

# # Dossiers pour les fichiers augmentés
# if env == "Google Colab":
#     train_aug_folder = TRAIN_AUDIO_DIR
#     valid_aug_folder = VALID_AUDIO_DIR
#     # else :
#     #     train_aug_folder = "/kaggle/working/train_augment"
#     #     valid_aug_folder = "/kaggle/working/valid_augment"

#     # Application de l'augmentation sur train_df et valid_df
#     train_df = augment_audio_for_dataframe(train_df, train_aug_folder, augmentor, prob_threshold=0.6)
#     # valid_df = augment_audio_for_dataframe(valid_df, valid_aug_folder, augmentor, prob_threshold=0.4) # On va garder l'ensemble de validation intact

#     print("Augmentation terminée. Nombre total de lignes dans train_df :", len(train_df))
#     # print("Augmentation terminée. Nombre total de lignes dans valid_df :", len(valid_df))

In [None]:
train_df.info()

valid_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 450 entries, 0 to 449
Data columns (total 5 columns):
 #   Column     Non-Null Count  Dtype  
---  ------     --------------  -----  
 0   ID         450 non-null    object 
 1   utterance  450 non-null    object 
 2   filename   450 non-null    object 
 3   duration   450 non-null    float64
 4   translate  450 non-null    object 
dtypes: float64(1), object(4)
memory usage: 17.7+ KB
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1129 entries, 0 to 1128
Data columns (total 5 columns):
 #   Column     Non-Null Count  Dtype  
---  ------     --------------  -----  
 0   ID         1129 non-null   object 
 1   utterance  1129 non-null   object 
 2   filename   1129 non-null   object 
 3   duration   1129 non-null   float64
 4   translate  1129 non-null   object 
dtypes: float64(1), object(4)
memory usage: 44.2+ KB


# 2. Conception du Modèle ASR pour le Fon

Tâches :

- Sélection d’un modèle préentraîné (Wav2Vec )
- Ajout d’un modèle de langage externe pour améliorer la cohérence des transcriptions ( si nécessaire )

In [None]:
# Initialisation de wandb et configuration
!pip -q install wandb

import wandb

wandb.login(key="3d89041cf7fd0ddfff36976cf74b121d47b88c00")

wandb.init(project="ASR_Fon", name="Fon-ASR", reinit=True)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mtitansage02[0m ([33mtitansage02-university-of-abomey-calavi[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
# Install requirements
# !pip install -q transformers datasets soundfile librosa wandb torch-audiomentations audiomentations evaluate accelerate
!pip install -q -U datasets evaluate

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/485.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━[0m [32m337.9/485.4 kB[0m [31m9.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/84.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/143.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# !pip install -U -q transformers

In [None]:
import json
import re
import os
import torch
import soundfile as sf
import numpy as np
import pandas as pd
from pathlib import Path
from dataclasses import dataclass, field
import wandb

from transformers import (
    Wav2Vec2CTCTokenizer,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2BertProcessor,
    Wav2Vec2ForCTC,
    TrainingArguments,
    Trainer
)
from datasets import Dataset, load_dataset

import evaluate

from torch import nn
# from torch_audiomentations import Compose, Gain
# from audiomentations import (
#     Compose,
#     AddGaussianNoise,
#     AddGaussianSNR,
#     ClippingDistortion,
#     Gain,
#     LoudnessNormalization,
#     Normalize,
#     PitchShift,
#     PolarityInversion,
#     Shift,
#     TimeMask,
#     TimeStretch,
# )

In [None]:
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataArgs:
    chars_to_ignore: list = field(default_factory=lambda: [",", "?", ".", "!", "-", ";", ":", "'", "\""])  # Use default_factory for mutable defaults
    min_duration: float = 1.0
    max_duration: float = 35.0

data_args = DataArgs()

In [None]:
# df = train_df.sample(frac=2/3, random_state=42).copy()
# Conversion en Dataset Hugging Face
dataset = Dataset.from_pandas(train_df.copy())

# Split train/validation
dataset = dataset.train_test_split(test_size=0.2)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
valid_dataset = Dataset.from_pandas(valid_df.copy())

In [None]:
def create_vocab():
    vocab = [" ", "a", "b", "c", "d", "e", "f", "g", "h", "i",
             "j", "k", "l", "m", "n", "o", "p", "r", "s", "t",
             "u", "v", "w", "x", "y", "z", "à", "á", "è", "é",
             "ì", "í", "î", "ï", "ó", "ù", "ú", "ā", "ă", "ē",
             "ĕ", "ŏ", "ū", "ŭ", "ɔ", "ɖ", "ò", "ε", "έ", "ɔ̀",
             "ɔ̆", "ĭ", "ɛ̆", "ɛ̃"]

    vocab = [v for v in vocab if v not in data_args.chars_to_ignore]

    vocab = sorted(vocab)
    vocab_dict = {v: k for k, v in enumerate(vocab)}

    # The blank token should be 0
    vocab_dict["[PAD]"] = 0
    vocab_dict["|"] = len(vocab_dict)  # Délimiteur de mots
    vocab_dict["[UNK]"] = len(vocab_dict)

    with open("vocab.json", "w", encoding="utf-8") as f:
        json.dump(vocab_dict, f, ensure_ascii=False, indent=4)

    print("Le vocabulaire est prêt")

    return True

if create_vocab():
  tokenizer = Wav2Vec2CTCTokenizer("vocab.json",
                                unk_token="[UNK]",
                                pad_token="[PAD]",
                                word_delimiter_token="|")

Le vocabulaire est prêt


In [None]:
# feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1,
#                                               sampling_rate=16000,
#                                               padding_value=0.0,
#                                               do_normalize=True,
#                                               return_attention_mask=True)

from transformers import SeamlessM4TFeatureExtractor

feature_extractor = SeamlessM4TFeatureExtractor(feature_size=80, num_mel_bins=80, sampling_rate=16000, padding_value=0.0)

processor = Wav2Vec2BertProcessor(feature_extractor=feature_extractor,
                              tokenizer=tokenizer)

In [None]:
# def prepare_dataset(batch):
#     audio = batch["audio"]
#     batch["input_features"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
#     batch["input_length"] = len(batch["input_features"])

#     batch["labels"] = processor(text=batch["sentence"]).input_ids
#     return batch

In [None]:
def prepare_dataset(batch):
    # Chargement audio
    speech, sr = sf.read(batch["filename"])
    if sr!=16000:
        speech = librosa.resample(speech, sr, 16000)

    input_features = processor(speech, sampling_rate=16000).input_features[0]

    batch["input_features"] = input_features
    batch["input_length"] = len(input_features)

    # Nettoyage du texte
    text = batch["translate"].lower()
    text = re.sub(f"[{re.escape(''.join(data_args.chars_to_ignore))}]", "", text)

    # Tokenisation du texte via le tokenizer contenu dans le processor
    batch["labels"] = processor.tokenizer(text=text).input_ids

    return batch

In [None]:
# Application du prétraitement
train_dataset = train_dataset.map(
    prepare_dataset,
    remove_columns=train_dataset.column_names
)

eval_dataset = eval_dataset.map(
    prepare_dataset,
    remove_columns=eval_dataset.column_names
)

valid_dataset = valid_dataset.map(
    prepare_dataset,
    remove_columns=valid_dataset.column_names
)

Map:   0%|          | 0/360 [00:00<?, ? examples/s]

Map:   0%|          | 0/90 [00:00<?, ? examples/s]

Map:   0%|          | 0/1129 [00:00<?, ? examples/s]

In [None]:
train_dataset

Dataset({
    features: ['input_features', 'input_length', 'labels'],
    num_rows: 360
})

In [None]:
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollator:
    processor: Wav2Vec2BertProcessor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )

        labels_batch = self.processor.pad(
            labels=label_features,
            padding=self.padding,
            return_tensors="pt",
        )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [None]:
data_collator = DataCollator(
    processor=processor,
    padding=True
)

In [None]:
!pip -q install jiwer

# Fonction pour calculer les métriques
import logging
import numpy as np
import evaluate

def compute_metrics(pred, compute_result=False):
    """Calculates WER and CER without any logging output."""
    # Sauvegarde du niveau de log actuel
    logging.disable(logging.CRITICAL)

    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)
    pred_str = processor.batch_decode(pred_ids)

    label_ids = pred.label_ids
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids)

    wer = evaluate.load("wer")
    cer = evaluate.load("cer")

    wer_score = wer.compute(predictions=pred_str, references=label_str)
    cer_score = cer.compute(predictions=pred_str, references=label_str)

    return {"wer": wer_score, "cer": cer_score}

In [None]:
# Configuration du modèle
from transformers import Wav2Vec2BertForCTC

model = Wav2Vec2BertForCTC.from_pretrained(
    "facebook/w2v-bert-2.0",
    cache_dir="cache",
    attention_dropout=0.2,
    hidden_dropout=0.0,
    feat_proj_dropout=0.0,
    mask_time_prob=0.0,
    layerdrop=0.0,
    ctc_loss_reduction="mean",
    add_adapter=True,
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
)

# model.config.pad_token_id = processor.tokenizer.pad_token_id
# model.config.vocab_size = len(processor.tokenizer)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)

# if model_args.freeze_feature_extractor:
#     model.freeze_feature_extractor()

Some weights of Wav2Vec2BertForCTC were not initialized from the model checkpoint at facebook/w2v-bert-2.0 and are newly initialized: ['adapter.layers.0.ffn.intermediate_dense.bias', 'adapter.layers.0.ffn.intermediate_dense.weight', 'adapter.layers.0.ffn.output_dense.bias', 'adapter.layers.0.ffn.output_dense.weight', 'adapter.layers.0.ffn_layer_norm.bias', 'adapter.layers.0.ffn_layer_norm.weight', 'adapter.layers.0.residual_conv.bias', 'adapter.layers.0.residual_conv.weight', 'adapter.layers.0.residual_layer_norm.bias', 'adapter.layers.0.residual_layer_norm.weight', 'adapter.layers.0.self_attn.linear_k.bias', 'adapter.layers.0.self_attn.linear_k.weight', 'adapter.layers.0.self_attn.linear_out.bias', 'adapter.layers.0.self_attn.linear_out.weight', 'adapter.layers.0.self_attn.linear_q.bias', 'adapter.layers.0.self_attn.linear_q.weight', 'adapter.layers.0.self_attn.linear_v.bias', 'adapter.layers.0.self_attn.linear_v.weight', 'adapter.layers.0.self_attn_conv.bias', 'adapter.layers.0.self_

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="model_ASR_Fon",
  group_by_length=True,
  per_device_train_batch_size=16,
  gradient_accumulation_steps=2,
  eval_strategy="steps",
  num_train_epochs=25,
  gradient_checkpointing=True,
  fp16=True,
  save_steps=600,
  eval_steps=50,
  logging_steps=50,
  learning_rate=5e-5,
  warmup_steps=500,
  save_total_limit=2,
  push_to_hub=False,
)

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    processing_class=processor.feature_extractor,
)

In [None]:
# import gc
# gc.collect()
# torch.cuda.empty_cache()

In [None]:
# Démarrage l'entraînement
wandb.init(project="wav2vec2-Good")
trainer.train()

Step,Training Loss,Validation Loss,Wer,Cer
50,10.5572,inf,1.0,1.0
100,3.52,inf,1.0,1.0
150,3.2763,inf,1.0,1.0
200,3.234,inf,1.0,1.0
250,3.3306,inf,1.0,1.0
300,3.2955,inf,0.997796,0.98926


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Tr

Step,Training Loss,Validation Loss,Wer,Cer
50,10.5572,inf,1.0,1.0
100,3.52,inf,1.0,1.0
150,3.2763,inf,1.0,1.0
200,3.234,inf,1.0,1.0
250,3.3306,inf,1.0,1.0
300,3.2955,inf,0.997796,0.98926
350,3.2221,inf,0.974284,0.923329
400,3.1813,inf,0.980162,0.80355
450,2.957,inf,0.995591,0.707488
500,2.8533,inf,0.980162,0.703461


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Tr

TrainOutput(global_step=1125, training_loss=2.528677678426107, metrics={'train_runtime': 1534.5601, 'train_samples_per_second': 5.865, 'train_steps_per_second': 0.733, 'total_flos': 2.4108430948804567e+18, 'train_loss': 2.528677678426107, 'epoch': 25.0})

In [None]:
!pip install huggingface_hub

from huggingface_hub import notebook_login

notebook_login()

In [None]:
trainer.push_to_hub("TitanSage02/fongbeASRv0")

events.out.tfevents.1740923064.05fff79fea0b.5453.2:   0%|          | 0.00/19.3k [00:00<?, ?B/s]

events.out.tfevents.1740921329.05fff79fea0b.5453.0:   0%|          | 0.00/8.96k [00:00<?, ?B/s]

Upload 5 LFS files:   0%|          | 0/5 [00:00<?, ?it/s]

events.out.tfevents.1740922263.05fff79fea0b.5453.1:   0%|          | 0.00/453 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.42G [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.30k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/TitanSage02/model_ASR_Fon/commit/b0aef4b85589333dcf862e6f93ddb6295863322d', commit_message='TitanSage02/fongbeASRv0', commit_description='', oid='b0aef4b85589333dcf862e6f93ddb6295863322d', pr_url=None, repo_url=RepoUrl('https://huggingface.co/TitanSage02/model_ASR_Fon', endpoint='https://huggingface.co', repo_type='model', repo_id='TitanSage02/model_ASR_Fon'), pr_revision=None, pr_num=None)



Step,Training Loss,Validation Loss,Wer,Cer
300,4.4395,3.30893,0.997914,0.997521
600,3.1886,3.405535,1.051095,0.695104
900,2.2837,3.772542,1.080292,0.693245
1200,0.8893,4.006538,1.062565,0.705433


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.


Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/5.60k [00:00<?, ?B/s]

Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.


TrainOutput(global_step=1200, training_loss=2.7002648671468097, metrics={'train_runtime': 925.3266, 'train_samples_per_second': 2.594, 'train_steps_per_second': 1.297, 'total_flos': 5.596214240797632e+17, 'train_loss': 2.7002648671468097, 'epoch': 10.0})

In [None]:
# Évaluation et sauvegarde

# Évaluation
metrics = trainer.evaluate()
print(f"WER: {metrics['eval_wer']} | CER: {metrics['eval_cer']}")

# Sauvegarde
trainer.save_model(training_args.output_dir)
processor.save_pretrained(training_args.output_dir)

Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.


WER: 1.0668626010286555 | CER: 0.7480608591885441


[]

In [None]:
# Utilisation du modèle
def transcribe_audio(file_path):
    speech, _ = sf.read(file_path)
    inputs = processor(speech, sampling_rate=16000, return_tensors="pt")

    # Access the input_values tensor from the inputs dictionary
    input_values = inputs.input_features

    # Pass input_values to the model and move it to the correct device
    logits = model(input_values.to("cuda")).logits


    pred_ids = torch.argmax(logits, dim=-1)[0]
    return processor.batch_decode(pred_ids)

# Test sur un échantillon
sample = valid_df.sample(1).iloc[0]
audio_path = sample["filename"]
reference_text = sample["translate"]

predicted_text = transcribe_audio(audio_path)

# Afficher le résultat
print(f"Texte Référence : {reference_text}")
print(f"Texte Prédit : {predicted_text}")

Texte Référence : Jigbezán ce jɛ août mɛ ɖò Avivɔsun ko mɛ.
Texte Prédit : ['e', '[PAD]', '[PAD]', '[PAD]', '[PAD]', 'é', '[PAD]', 'j', '[PAD]', '[PAD]', '[PAD]', 'i', '', '', 'j', '[PAD]', '[PAD]', 'e', '[PAD]', 'z', '[PAD]', '[PAD]', 'ɔ', '[PAD]', '', 'n', 'z', '[PAD]', '[PAD]', '[PAD]', '[PAD]', 'o', '[PAD]', 'w', 'o', '', '', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '', '[PAD]', '[PAD]', '[PAD]', '[PAD]', 'ɔ', '', '', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', 'o', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '', 'é', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', 'i', 'n', '', 'd', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '', '', 'k', 'l', 'ɔ', '', '', 'ɖ', '[PAD]', '[UNK]', '[PAD]', '', 'n', '[PAD]', '[PAD]', '[PAD]', 'n', '', 'n', 'd', '[PAD

In [None]:
# # Vocabulaire Fon

# import json

# # Définition manuelle du vocabulaire
# vocab_dict = {
#     "[PAD]": 0,
#     "[UNK]": 1,
#     " ": 2,
#     "a": 3, "b": 4, "c": 5, "d": 6, "e": 7, "f": 8, "g": 9, "h": 10,
#     "i": 11, "j": 12, "k": 13, "l": 14, "m": 15, "n": 16, "o": 17, "p": 18,
#     "r": 19, "s": 20, "t": 21, "u": 22, "v": 23, "w": 24, "x": 25, "y": 26, "z": 27,
#     "à": 28, "á": 29, "è": 30, "é": 31, "ì": 32, "í": 33, "î": 34, "ï": 35,
#     "ó": 36, "ù": 37, "ú": 38, "ā": 39, "ă": 40, "ē": 41, "ĕ": 42, "ŏ": 43, "ū": 44, "ŭ": 45,
#     "ɔ": 46, "ɖ": 47, "ò": 48, "ε": 49, "έ": 50, "ɔ̀": 51, "ɔ̆": 52, "ὲ": 53, "ɔ́": 54,
#     "ĭ": 55, "ɛ̆": 56, "ɛ̃": 57, ".": 58, ",": 59
# }

# # Sauvegarde du vocabulaire
# with open("vocab.json", "w", encoding="utf-8") as f:
#     json.dump(vocab_dict, f, ensure_ascii=False, indent=2)

# print("Vocabulaire sauvegardé dans vocab.json")

In [None]:
# import os
# import pandas as pd
# import torch
# from torch.utils.data import Dataset, DataLoader
# import torchaudio

# from transformers import (
#     Wav2Vec2FeatureExtractor,
#     Wav2Vec2CTCTokenizer,
#     Wav2Vec2Processor,
#     Wav2Vec2ForCTC,
#     Trainer,
#     TrainingArguments
# )


# # Initialisation du Tokenizer, Feature Extractor et Processor

# tokenizer = Wav2Vec2CTCTokenizer(
#     "vocab.json",
#     unk_token="[UNK]",
#     pad_token="[PAD]",
#     word_delimiter_token=" "
#   )

# # Configuration du feature extractor
# feature_extractor = Wav2Vec2FeatureExtractor(
#     feature_size=1, # audio en mono
#     sampling_rate=16000,
#     padding_value=0.0,
#     do_normalize=True,
#     return_attention_mask=True
# )

# # Fusion du tokenizer et du feature extractor dans un processor
# processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [None]:
# # Création du Dataset personnalisé
# class FonASRDataset(Dataset):
#     """
#     Dataset pour l’ASR en Fon.
#     Chaque échantillon est constitué d’un fichier audio et de sa transcription en Fon (colonne 'translate').
#     """
#     def __init__(self, dataframe, processor, sampling_rate=16000):
#         self.data = dataframe
#         self.processor = processor
#         self.sampling_rate = sampling_rate

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         row = self.data.iloc[idx]
#         audio_path = row["filename"]

#         waveform, sr = torchaudio.load(audio_path)
#         if sr != self.sampling_rate:
#             waveform = torchaudio.transforms.Resample(sr, self.sampling_rate)(waveform)

#         # Si le signal est multicanal, on moyenne pour obtenir un signal mono.
#         if waveform.ndim == 2:
#             if waveform.size(0) > 1:
#                 waveform = waveform.mean(dim=0)
#             else:
#                 waveform = waveform.squeeze(0)

#         waveform = waveform.contiguous().float()

#         labels = row["translate"].lower().strip()

#         return {"input_values": waveform, "labels": labels}


# # Définition du data_collator
# def data_collator(batch):
#     # Récupère les tenseurs audio (1D) et effectue un padding manuel
#     input_values = [example["input_values"] for example in batch]

#     input_values = torch.nn.utils.rnn.pad_sequence(input_values, batch_first=True, padding_value=0.0)

#     # Tokenisation des labels
#     texts = [example["labels"] for example in batch]
#     tokenized_labels = [torch.tensor(processor.tokenizer(text).input_ids, dtype=torch.long) for text in texts]

#     labels = torch.nn.utils.rnn.pad_sequence(
#         tokenized_labels, batch_first=True, padding_value=tokenizer.pad_token_id
#     )
#     labels[labels == tokenizer.pad_token_id] = -100

#     # print(f"Labels dans le batch : {labels}")  # Debug

#     return {"input_values": input_values, "labels": labels}

In [None]:
# # Chargement et adaptation du modèle préentraîné
# model_name = "facebook/wav2vec2-base-960h"
# model = Wav2Vec2ForCTC.from_pretrained(model_name,
#                                        gradient_checkpointing=True,
#                                        ctc_loss_reduction="mean")

# # Adaptation de la taille du vocabulaire au nouveau tokenizer
# model.config.vocab_size = len(tokenizer)
# model.lm_head = torch.nn.Linear(
#     in_features=model.lm_head.in_features,
#     out_features=len(tokenizer),
#     bias=True
# )

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)

In [None]:
# # Instanciation du Dataset et DataLoader
# train_dataset = FonASRDataset(train_df, processor)
# eval_dataset = FonASRDataset(valid_df, processor)

# # train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=data_collator)
# # eval_loader = DataLoader(eval_dataset, batch_size=2, shuffle=False, collate_fn=data_collator)

In [None]:
# # add metrics
# !pip -q install jiwer

# from jiwer import wer
# import numpy as np

# def compute_metrics(pred):
#     pred_logits = pred.predictions
#     pred_ids = np.argmax(pred_logits, axis=-1)
#     pred_str = processor.batch_decode(pred_ids)

#     label_ids = pred.label_ids
#     label_ids[label_ids == -100] = tokenizer.pad_token_id
#     label_str = processor.batch_decode(label_ids, group_tokens=False)

#     wer_score = wer(label_str, pred_str)

#     # On traite les mots comme des unitésau lieu de caractères individuels
#     def word_list(texts):
#         # Si texts est une liste, on applique split à chaque élément
#         if isinstance(texts, list):
#             return [text.split() for text in texts]
#         else:
#             return texts.split()  # Divise la chaîne en mots si ce n'est pas une liste

#     cer_score = wer(label_str, pred_str, truth_transform=word_list, hypothesis_transform=word_list)

#     return {"wer": wer_score, "cer": cer_score}

In [None]:
# # Check visuel si le tokenizer fonctionne bien
# sample_text = "É nyɔ́".lower()
# tokens = processor.tokenizer(sample_text).input_ids
# print(f"Texte : {sample_text}")
# print(f"Tokens : {tokens}")
# print(f"Vocabulaire du tokenizer : {processor.tokenizer.get_vocab()}")

In [None]:
# EPOCHS = 20

# # Configuration des arguments d'entraînement
# training_args = TrainingArguments(
#     output_dir="asr_fon_model",
#     eval_strategy="epoch",
#     save_strategy="epoch",
#     learning_rate=3e-4,
#     per_device_train_batch_size=4,
#     per_device_eval_batch_size=4,
#     num_train_epochs=EPOCHS,
#     weight_decay=0.005,
#     save_total_limit=2,
#     fp16=True,
#     run_name="Fon-Wav2Vec2-Training",
#     gradient_checkpointing=True,
#     report_to="wandb",
#     push_to_hub=False
# )

# # Instanciation du Trainer
# trainer = Trainer(
#     model=model,
#     args=training_args,
#     data_collator=data_collator,
#     train_dataset=train_dataset,
#     eval_dataset=eval_dataset,
#     compute_metrics=compute_metrics
# )

# # Lancement de l'entraînement
# trainer.train()

In [None]:
# import torch
# import torchaudio
# from transformers import Wav2Vec2Processor

# # Test sur un échantillon
# sample = valid_df.sample(1).iloc[0]
# audio_path = sample["filename"]
# reference_text = sample["translate"]

# # Chargement de l'audio
# waveform, sample_rate = torchaudio.load(audio_path)
# if sample_rate != 16000:
#     waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)

# # Conversion en mono si nécessaire
# if waveform.ndim == 2:
#     waveform = waveform.mean(dim=0)

# waveform = waveform.float()

# # Préparer l'audio pour le modèle
# input_values = processor(waveform, sampling_rate=16000, return_tensors="pt").input_values.to(device)

# model.eval()
# with torch.no_grad():
#     logits = model(input_values).logits

# # Décoder la sortie du modèle en texte
# predicted_ids = torch.argmax(logits, dim=-1)
# predicted_text = processor.batch_decode(predicted_ids)[0]

# # Afficher le résultat
# print(f"Texte Référence : {reference_text}")
# print(f"Texte Prédit : {predicted_text}")