In [1]:
import torch
import torchaudio
import json
import os
from model import GRUSeq2Seq
from IPython.display import Audio


In [2]:
# === Paramètres ===
MODEL_PATH = "best_model_GRUseq2seq.pt"
VOCAB_PATH = "vocab_GRU.json"  # fichier json avec le vocab
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAMPLE_RATE = 16000
N_MELS = 80
HIDDEN_DIM = 256
ENC_LAYERS = 2
DEC_LAYERS = 2

In [3]:
# === Chargement du vocabulaire pour le modèle GRUSeq2Seq ===
vocab_path = "vocab4.json"
with open(vocab_path, "r", encoding="utf-8") as f:
    vocab = json.load(f)

# Convertir les clés et valeurs
vocab = {k: int(v) for k, v in vocab.items()}
vocab_inv = {v: k for k, v in vocab.items()}

# Indice du token de padding
PAD_IDX = vocab.get("<pad>", 0)


In [4]:
# === Chargement du modèle ===
model = GRUSeq2Seq(input_dim=N_MELS, hidden_dim=HIDDEN_DIM,
                   vocab_size=len(vocab),
                   encoder_layers=ENC_LAYERS,
                   decoder_layers=DEC_LAYERS).to(DEVICE)

model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

GRUSeq2Seq(
  (bridge): Linear(in_features=512, out_features=256, bias=True)
  (encoder): GRUEncoder(
    (gru): GRU(80, 256, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
  )
  (attention): Attention(
    (attn): Linear(in_features=768, out_features=1, bias=True)
  )
  (decoder): GRUDecoder(
    (gru): GRU(512, 256, num_layers=2, batch_first=True, dropout=0.3)
    (out): Linear(in_features=256, out_features=315, bias=True)
  )
)

In [5]:
def preprocess_audio(file_path, sample_rate=SAMPLE_RATE, n_mels=N_MELS):
    waveform, sr = torchaudio.load(file_path)

    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    if sr != sample_rate:
        waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)(waveform)

    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=1024,
        hop_length=256,
        n_mels=n_mels
    )
    mel_spec = mel_transform(waveform)
    mel_spec = torchaudio.transforms.AmplitudeToDB()(mel_spec)
    mel_spec = mel_spec.squeeze(0).transpose(0, 1)

    if mel_spec.shape[1] != n_mels:
        raise ValueError(f"Mel spectrogram inattendu : {mel_spec.shape}")

    return mel_spec


In [6]:
def predict(file_path, model):
    model.eval()
    mel_spec = preprocess_audio(file_path)
    input_tensor = mel_spec.unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        logits = model(input_tensor)
        probs = logits.softmax(2)
        preds = probs.argmax(2)[0]

    output_tokens = []
    previous = -1
    for p in preds:
        if p.item() != previous and p.item() != 0:
            output_tokens.append(p.item())
        previous = p.item()

    decoded = [vocab_inv.get(idx, "") for idx in output_tokens]
    return " ".join(decoded)


In [7]:
def predict_cleaned(file_path, model):
    raw_pred = predict(file_path, model)
    tokens = raw_pred.split()

    cleaned = []
    for token in tokens:
        if token.strip() == "" or token.startswith("∅"):
            continue
        # Enlever la tonalité (ex: "fū|moyen" → "fū")
        base = token.split("|")[0]
        # Supprimer les apostrophes éventuels
        base = base.replace("'", "")
        cleaned.append(base)

    return "".join(cleaned)


In [8]:
def play_audio(file_path):
    waveform, sr = torchaudio.load(file_path)
    return Audio(waveform.numpy(), rate=sr)

In [9]:
if __name__ == "__main__":
    audio_file = r"C:\Users\Christian\Desktop\YembaTones\YembaTones An Annotated Dataset for Tonal and Syllabic Analysis of the Yemba Language\Yemba_Dataset\audios\speaker_5\group_79\spkr_5_group_79_statement_3.wav"

    if not os.path.exists(audio_file):
        print("Fichier audio introuvable.")
    else:
        print(f"Fichier trouvé : {audio_file}")
        waveform, sr = torchaudio.load(audio_file)
        print(f"Durée : {waveform.shape[1] / sr:.2f} s | Fréquence : {sr} Hz")

        try:
            display(Audio(waveform.numpy(), rate=sr))
        except:
            print("Lecture non disponible (hors notebook)")

        print("\nTranscription brute :", predict(audio_file, model))
        print("\nTranscription nettoyée :", predict_cleaned(audio_file, model))

Fichier trouvé : C:\Users\Christian\Desktop\YembaTones\YembaTones An Annotated Dataset for Tonal and Syllabic Analysis of the Yemba Language\Yemba_Dataset\audios\speaker_5\group_79\spkr_5_group_79_statement_3.wav
Durée : 0.58 s | Fréquence : 44100 Hz



Transcription brute : N|haut tí |haut ∅|∅

Transcription nettoyée : Ntí
