In [75]:
import torch
print(torch.__version__)
print(torch.cuda.is_available()) 


2.2.2
False


In [76]:
from transformers import AutoProcessor, AutoModelForCTC, Wav2Vec2Processor
import librosa
import torch
from itertools import groupby
from datasets import load_dataset

def decode_phonemes(
    ids: torch.Tensor, processor: Wav2Vec2Processor, ignore_stress: bool = False
) -> str:
    """CTC-like decoding. First removes consecutive duplicates, then removes special tokens."""
    # removes consecutive duplicates
    ids = [id_ for id_, _ in groupby(ids)]

    special_token_ids = processor.tokenizer.all_special_ids + [
        processor.tokenizer.word_delimiter_token_id
    ]
    # converts id to token, skipping special tokens
    phonemes = [processor.decode(id_) for id_ in ids if id_ not in special_token_ids]

    # joins phonemes
    prediction = " ".join(phonemes)

    # whether to ignore IPA stress marks
    if ignore_stress == True:
        prediction = prediction.replace("ˈ", "").replace("ˌ", "")

    return prediction

checkpoint = "bookbot/wav2vec2-ljspeech-gruut"

model = AutoModelForCTC.from_pretrained(checkpoint)
processor = AutoProcessor.from_pretrained(checkpoint)
sr = processor.feature_extractor.sampling_rate

# # load dummy dataset and read soundfiles
# ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
# audio_array = ds[0]["audio"]["array"]

# or, read a single audio file
audio_array, _ = librosa.load("data/test/wight.wav", sr=sr)

inputs = processor(audio_array, return_tensors="pt", padding=True)

with torch.no_grad():
    logits = model(inputs["input_values"]).logits

predicted_ids = torch.argmax(logits, dim=-1)
prediction = decode_phonemes(predicted_ids[0], processor, ignore_stress=True)
# => should give 'b ɪ k ʌ z j u ɚ z s l i p ɪ ŋ ɪ n s t ɛ d ə v k ɔ ŋ k ɚ ɪ ŋ ð ə l ʌ v l i ɹ z p ɹ ɪ n s ə s h æ z b ɪ k ʌ m ə v f ɪ t ə l w ɪ θ n b oʊ p ɹ ə ʃ æ ɡ i s ɪ t s ð ɛ ɹ ə k u ɪ ŋ d ʌ v'

print(prediction)

It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


w aɪ t


In [77]:
# word bank
phoneme_bank = {
    "run": "ɹ ʌ n",
    "rope": "ɹ ə ʊ p",
    "rabbit": "ˈɹ æ b ɪ t",
    "rocket ": "ˈr ɒ k ɪ t",
    "carrot": "ˈk æ ɹ ə t",
    "berry": "ˈb ɛ ɹ i",
    "pirate": "ˈp a ɪ ɹ ə t",
    "airplane": "ˈe ə p l e ɪ n",
    "car": "k ɑː",
    "door": "d ɔː",
    "tiger": "ˈt aɪ ɡ ə",
    "hammer": "ˈh æ m ə", 
    "right": "ɹ aɪ t"
}

# splitting entire strings into indivudal phonemes for comparison later
phoneme_bank_split = {word: phonemes.split() for word, phonemes in phoneme_bank.items()}

# Print the new phoneme bank to check the format
# print(phoneme_bank_split)

In [78]:
import Levenshtein 

def find_most_similar_word(prediction, phoneme_bank_split):
    prediction_phonemes = prediction.split()
    most_similar_word = None
    min_distance = float('inf')  

    for word, phonemes in phoneme_bank_split.items():
        distance = Levenshtein.distance(" ".join(prediction_phonemes), " ".join(phonemes))
        if distance < min_distance:
            min_distance = distance
            most_similar_word = word

    return most_similar_word, min_distance, prediction_phonemes

word, distance, prediction_phonemes = find_most_similar_word(prediction, phoneme_bank_split)
print(f"The most similar word is '{word}' with a distance of {distance}")

The most similar word is 'right' with a distance of 1


In [79]:
from difflib import SequenceMatcher

def align_and_compare(prediction, correct):
    matcher = SequenceMatcher(None, prediction, correct)
    extra_phonemes = []
    missing_phonemes = []
    aligned_prediction = []
    aligned_correct = []

    for tag, i1, i2, j1, j2 in matcher.get_opcodes():
        if tag == "equal":
            aligned_prediction.extend(prediction[i1:i2])
            aligned_correct.extend(correct[j1:j2])
        elif tag == "replace":
            # substitution mismatch does not relate to length
            aligned_prediction.extend(prediction[i1:i2])
            aligned_correct.extend(correct[j1:j2])
            extra_phonemes.extend(prediction[i1:i2])
            missing_phonemes.extend(correct[j1:j2])
        elif tag == "delete":
            # extra phonemes
            aligned_prediction.extend(prediction[i1:i2])
            aligned_correct.extend(["-"] * (i2 - i1))  
            extra_phonemes.extend(prediction[i1:i2])
        elif tag == "insert":
            # missing phonemes (ineherently occurs for any mismatch)
            aligned_prediction.extend(["-"] * (j2 - j1))  
            aligned_correct.extend(correct[j1:j2])
            missing_phonemes.extend(correct[j1:j2])

    print("Aligned Prediction: ", " ".join(aligned_prediction))
    print("Aligned Correct:    ", " ".join(aligned_correct))
    print("Extra Phonemes:     ", " ".join(extra_phonemes))
    print("Missing Phonemes:   ", " ".join(missing_phonemes))

    return extra_phonemes, missing_phonemes

correct_phonemes = phoneme_bank_split[word]
extra_phonemes, target_phonemes = align_and_compare(prediction_phonemes, correct_phonemes)

Aligned Prediction:  w aɪ t
Aligned Correct:     ɹ aɪ t
Extra Phonemes:      w
Missing Phonemes:    ɹ


In [80]:
# feedback bank
feedback_bank = {
    'w': "Try starting with closed lips. Then lift your lips just slightly apart without making an 'o' shape. Focus on keeping your teeth close together with your tongue tip slightly hovering.",
    'l': "Try lowering the tip of your tongue and keeping it stationary. Make the sound only moving your jaw",
    'ʌ': "Raise the back of your tongue and produce more tension in the throat for /r/.",
    # Add more phoneme-specific feedback as needed
}

def generate_feedback(extra_phonemes):
    feedback = []

    # Feedback for extra phonemes
    for phoneme in extra_phonemes:
        if phoneme in feedback_bank:
            feedback.append(f"Extra phoneme '{phoneme}': {feedback_bank[phoneme]}")
        else:
            feedback.append(f"Extra phoneme '{phoneme}': No specific feedback available.")

    return feedback

generate_feedback(extra_phonemes)


["Extra phoneme 'w': Try starting with closed lips. Then lift your lips just slightly apart without making an 'o' shape. Focus on keeping your teeth close together with your tongue tip slightly hovering."]

In [81]:
# # ACCESS A SPECIFIC WORD IN BANK AND MATCH ENTIRE STRING OF BOTH (NO NUANCES)
# word = "right"
# phonemes = phoneme_bank[word]

# # Print the phoneme sequence
# print(f"The phonemes for '{word}' are: {phonemes}")

# if prediction == phonemes:
#     print("The predictions matches the word bank")
# else:
#     print("The prediction does NOT match the word")




# # HANDLE CASES WHERE LENGTHS DIFFER BUT ONLY BASED ON LENGTH (ONLY ENDING PHONEMES CONSIDERED, NO NUANCES)
# if len(prediction_phonemes) != len(correct_phonemes):
#     print("Prediction and target phonemes have different lengths.")
#     extra_phonemes = prediction_phonemes[len(correct_phonemes):]
#     if extra_phonemes:
#         print(f"Extra phonemes in prediction: {' '.join(extra_phonemes)}")
#     missing_phonemes = correct_phonemes[len(prediction_phonemes):]
#     if missing_phonemes:
#         print(f"Missing phonemes in prediction: {' '.join(missing_phonemes)}")




# # CHECKING FOR EXTRA / MISSING PHONEMES WHEN LENGTHS DIFFER (ANY POSITION MISMATCH) BUT DOESN'T CONSIDER ALIGNMENT OF CORRECT PHONEMES AFTER MISMATCH
# # Check for phonemes that are extra or missing within the sequence
# extra_phonemes = []
# missing_phonemes = []

# # Iterate over both lists up to the length of the shorter one to find mismatches
# for i in range(min(len(prediction_phonemes), len(correct_phonemes))):
#     if prediction_phonemes[i] != correct_phonemes[i]:
#         extra_phonemes.append(prediction_phonemes[i])
#         missing_phonemes.append(correct_phonemes[i])

# # If one list is longer than the other, add remaining phonemes
# if len(prediction_phonemes) > len(correct_phonemes):
#     extra_phonemes.extend(prediction_phonemes[len(correct_phonemes):])
# elif len(correct_phonemes) > len(prediction_phonemes):
#     missing_phonemes.extend(correct_phonemes[len(prediction_phonemes):])

# # Print results
# if extra_phonemes:
#     print(f"Extra phonemes in prediction: {' '.join(extra_phonemes)}")
# if missing_phonemes:
#     print(f"Missing phonemes in prediction: {' '.join(missing_phonemes)}")




# # FINDING MISMATCHES BETWEEN PREDICTION AND CORRECT PHONEMES (WORKS PERFECT FOR STRINGS SAME LENGTH)
# # Phoneme string for the word in the word bank
# correct_phonemes = phoneme_bank_split[word]

# # Compare phoneme by phoneme
# mismatches = []
# for i, (predicted, correct) in enumerate(zip(prediction_phonemes, correct_phonemes)):
#     if predicted != correct:
#         mismatches.append((i, predicted, correct))

# # Report mismatches
# if not mismatches:
#     print("The pronounciation matches exactly!")
# else:
#     print("Mismatch found!")
#     for i, predicted, correct in mismatches:
#         print(f"Position {i}: Predicted '{predicted}', Expected '{correct}'")