In [1]:
import torch
import torchaudio
from model import CNN_BiLSTM_CTC, BiLSTM_CTC, BiGRU_CTC, LSTM_CTC
import os
import json
from IPython.display import Audio

In [18]:
# Paramètres généraux
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "best_model_BiGRU_CTC.pt"
N_MELS = 80
SAMPLE_RATE = 16000
HIDDEN_DIM = 256
NUM_LAYERS = 3
DROPOUT = 0.3

In [19]:
# Choix du modèle et du vocabulaire
if "1" in MODEL_PATH:
    model_class = BiLSTM_CTC
    vocab_path = "vocab1.json"
elif "GRU_CTC" in MODEL_PATH:
    model_class = BiGRU_CTC
    vocab_path = "vocab2.json"
elif "GRU_CTC2" in MODEL_PATH:
    model_class = BiGRU_CTC
    vocab_path = "vocab4.json"
elif "LSTM" in MODEL_PATH:
    model_class = LSTM_CTC
    vocab_path = "vocab3.json"
else:
    model_class = CNN_BiLSTM_CTC
    vocab_path = "vocab.json"


In [20]:
# Détection du modèle à charger
if "best_model1.pt" in MODEL_PATH:
    from model import BiLSTM_CTC
    model_class = BiLSTM_CTC
    vocab_path = "vocab1.json"
elif "best_model_BiGRU_CTC.pt" in MODEL_PATH:
    from model import BiGRU_CTC
    model_class = BiGRU_CTC
    vocab_path = "vocab2.json"
elif "best_model_BiGRU_CTC2.pt" in MODEL_PATH:
    from model import BiGRU_CTC
    model_class = BiGRU_CTC
    vocab_path = "vocab4.json"
elif "best_model_LSTM_CTC.pt" in MODEL_PATH:
    from model import LSTM_CTC
    model_class = LSTM_CTC
    vocab_path = "vocab3.json"
else:
    from model import CNN_BiLSTM_CTC
    model_class = CNN_BiLSTM_CTC
    vocab_path = "vocab.json"


    """best_model1.pt": (BiLSTM_CTC, "vocab1.json"),
    "best_model_BiGRU_CTC.pt": (BiGRU_CTC, "vocab2.json"),
    "best_model_BiGRU_CTC2.pt": (BiGRU_CTC, "vocab4.json"),
    "best_model_LSTM_CTC.pt": (LSTM_CTC, "vocab3.json"),
    "best_model_CNN_BiLSTM_CTC.pt": (CNN_BiLSTM_CTC, "vocab.json"),"""
    

# Chargement du vocabulaire
with open(vocab_path, "r", encoding="utf-8") as f:
    vocab = json.load(f)
vocab = {k: int(v) for k, v in vocab.items()}
vocab_inv = {v: k for k, v in vocab.items()}

# Appel et chargement du modèle
model = model_class(input_dim=N_MELS, hidden_dim=HIDDEN_DIM, vocab_size=len(vocab),
                    num_layers=NUM_LAYERS, dropout=DROPOUT)

# Chargement des bons poids dans le bon modèle
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model = model.to(DEVICE)
model.eval()


BiGRU_CTC(
  (gru): GRU(80, 256, num_layers=3, batch_first=True, dropout=0.3, bidirectional=True)
  (classifier): Linear(in_features=512, out_features=315, bias=True)
)

In [21]:
#Prétraitement de l'audio 
def preprocess_audio(file_path):
    waveform, sr = torchaudio.load(file_path)

    # Forcer mono
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    # Resample si nécessaire
    if sr != SAMPLE_RATE:
        waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform)

    # Extraire les caractéristiques Mel spectrogramme
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=SAMPLE_RATE,
        n_fft=1024,
        hop_length=256,
        n_mels=N_MELS
    )
    mel_spec = mel_transform(waveform)
    mel_spec = torchaudio.transforms.AmplitudeToDB()(mel_spec)

    return mel_spec.squeeze(0).transpose(0, 1)

def predict_cleaned(audio_path, model):
    raw_pred = predict(audio_path, model)  # chaîne brute
    tokens = raw_pred.split()

    cleaned = []
    for token in tokens:
        if token.strip() == "":
            continue
        if "|" in token:
            base = token.split("|")[0]
        else:
            base = token
        if base == "∅":
            continue
        cleaned.append(base)
    return "".join(cleaned)


In [22]:
def preprocess_audio(file_path, sample_rate=SAMPLE_RATE, n_mels=N_MELS):
    try:
        waveform, sr = torchaudio.load(file_path)

        # Force mono
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # Resample si nécessaire
        if sr != sample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
            waveform = resampler(waveform)

        # Extraction Mel spectrogramme
        mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=1024,
            hop_length=256,
            n_mels=n_mels
        )
        mel_spec = mel_transform(waveform)
        mel_spec = torchaudio.transforms.AmplitudeToDB()(mel_spec)
        mel_spec = mel_spec.squeeze(0).transpose(0, 1)
        if mel_spec.shape[1] != n_mels:
            raise ValueError(f"Mel spectrogram avec shape inattendu : {mel_spec.shape}")

        return mel_spec

    except Exception as e:
        print(f"[ERREUR] Impossible de traiter {file_path} : {e}")
        return None


In [23]:
# === Fonction de prédiction ===
def predict(file_path, model):
    model.eval()
    mel_spec = preprocess_audio(file_path)
    input_tensor = mel_spec.unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        logits = model(input_tensor)
        probs = logits.softmax(2)
        preds = probs.argmax(2)[0]

    # Supprimer les doublons et les <BLANK>
    output_tokens = []
    previous = -1
    for p in preds:
        if p.item() != previous and p.item() != 0:
            output_tokens.append(p.item())
        previous = p.item()

    decoded = [vocab_inv.get(idx, "") for idx in output_tokens]
    return " ".join(decoded)


In [24]:
def play_audio(file_path):
    waveform, sr = torchaudio.load(file_path)
    return Audio(waveform.numpy(), rate=sr)

# === Tester sur un fichier audio ===
if __name__ == "__main__":
    audio_file = r"C:\Users\Christian\Desktop\YembaTones\YembaTones An Annotated Dataset for Tonal and Syllabic Analysis of the Yemba Language\Yemba_Dataset\audios\speaker_6\group_8\spkr_6_group_8_statement_1.wav"
    
    if not os.path.exists(audio_file):
        print("Fichier introuvable :", audio_file)
    else:
        print("Fichier trouvé :", audio_file)

        # Charger et écouter (uniquement dans notebook)
        waveform, sr = torchaudio.load(audio_file)
        print(f"Durée : {waveform.shape[1] / sr:.2f} s | Fréquence d'échantillonnage : {sr} Hz")

        try:
            from IPython.display import Audio
            display(Audio(waveform.numpy(), rate=sr))
        except:
            print("Lecture audio indisponible en dehors de Jupyter.")

        # Lancer la prédiction
        transcription = predict_cleaned(audio_file, model)
        raw = predict(audio_file, model)
        clean = predict_cleaned(audio_file, model)
        print("\nTranscription prédite :", transcription)
        print("\nTranscription (avec tonalité) :", raw)


Fichier trouvé : C:\Users\Christian\Desktop\YembaTones\YembaTones An Annotated Dataset for Tonal and Syllabic Analysis of the Yemba Language\Yemba_Dataset\audios\speaker_6\group_8\spkr_6_group_8_statement_1.wav
Durée : 0.83 s | Fréquence d'échantillonnage : 44100 Hz



Transcription prédite : Lekyɛ̄t

Transcription (avec tonalité) : Le|bas kyɛ̄t|moyen ∅|∅
