# Forced Alignment of Phonemes
Author: Oscar Friedman



This notebook performs forced alignment of English phonemes. Forced alignment requires an input speech or singing audio file and corresponding text transcript, and outputs start and end timestamps corresponding to each phoneme.

We will compare 3 approaches.

In [1]:
# first time only:
# !pip install -r requirements.txt

import torch
import torchaudio
from torchaudio.datasets import CMUDict
import torchaudio.transforms as T
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from dataclasses import dataclass
import IPython
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
import json
import re
import os
import string
import copy
import IPython
import tgt

from transformers import Wav2Vec2CTCTokenizer
from transformers import Wav2Vec2FeatureExtractor
from transformers import Wav2Vec2Processor
from transformers import Wav2Vec2ForCTC

In [2]:
def load_audio(audio_file):
    root, _ = os.path.splitext(audio_file)
    output_file = f"{root}.wav"
    os.system(f'ffmpeg -y -i {audio_file} -acodec pcm_s16le -ac 1 -ar 16000 {output_file}') # Convert test file to correct format for wav2vec2 (16khz mono wav 16-bit)
    audio_file = output_file
    waveform, sample_rate = torchaudio.load(audio_file)
    return waveform, sample_rate

def load_transcript(transcript_file):
    with open(transcript_file, "r") as file:
        transcript = file.read()
    return transcript

def format_transcript(transcript): # Remove punctuation, convert each letter to uppercase and join words with '|'
    import string
    translator = str.maketrans('', '', string.punctuation)
    cleaned_string = transcript.translate(translator)
    words = cleaned_string.split()   
    result_string = '|'.join(word.upper() for word in words)
    return result_string

def save_phonemes_JSON(phonemes, output_file):
    phoneme_segments = []
    for phoneme in phonemes:
        phoneme_segments.append({'phoneme': phoneme.label, 'start': phoneme.start, 'end': phoneme.end})
    with open(output_file, 'w') as outfile:
        json.dump(phoneme_segments, outfile)

def display_segment(segments,i, waveform, sample_rate=16000):
    phoneme = segments[i]
    x0 = int(phoneme.start * sample_rate)
    x1 = int(phoneme.end * sample_rate)
    print(f"{phoneme.label} ({phoneme.score:.2f}): {phoneme.start:.3f} - {phoneme.end:.3f} sec")
    segment = waveform[0][x0:x1]
    return IPython.display.display(IPython.display.Audio(segment, rate=sample_rate))

In [3]:
audio_file = r'C:\Users\oscar\Documents\470\JS_2023\Phoneme_Recognition\input\assessment_9.mp3'
audio_file = audio_file.replace('\\', '/')
waveform, sample_rate = load_audio(audio_file)
IPython.display.display(IPython.display.Audio(waveform, rate=sample_rate))

transcript_file = r'C:\Users\oscar\Documents\470\JS_2023\Phoneme_Recognition\input\assessment_9.txt'
transcript_file = transcript_file.replace('\\', '/')
transcript = load_transcript(transcript_file)
transcript = format_transcript(transcript)
print(f"Transcript: {transcript}")

Transcript: BUT|AFTER|ALL|THAT|COMMOTION|WAS|IT|ALL|WORTHWHILE|ABSOLUTELY|YES|THE|SET|DESIGN|WAS|BREATHTAKING|THE|ACTORS|WERE|INCREDIBLE|AND|THE|SONGS|WERE|MEMORABLE


# Method 1: CTC Segmentation using WAV2VEC2_ASR_BASE_960H
Adapted from https://pytorch.org/audio/stable/tutorials/forced_alignment_tutorial.html

This method uses a pretrained Wav2Vec2 English ASR model. Since the training dataset (Librespeech) contains standard English transcripts, we must first force-align each letter. 
Next, we map sequences of letters to the CMU Sphinx-40 phoneme set to get each phoneme start and end time.

In [4]:
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device)
letters = bundle.get_labels()
print(f"Letters: {letters}") # | is used as a word separator
letters_dictionary = {c: i for i, c in enumerate(letters)} # {letter: index} corresponding to the pretrained wav2vec2 model
tokens = [letters_dictionary[c] for c in transcript] # [index of first letter in transcript, index of second letter in transcript, ...]
print(f"Transcript tokens: {tokens}")

with torch.inference_mode():
    emissions, _ = model(waveform.to(device))
    emissions = torch.log_softmax(emissions, dim=-1) # According to the tutorial, this avoids numerical instability -- should test whether it reduces performance
emission = emissions[0].cpu().detach() # Shape: (Frames, Letters)

Letters: ('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
Transcript tokens: [21, 13, 3, 1, 4, 17, 3, 2, 10, 1, 4, 12, 12, 1, 3, 8, 4, 3, 1, 16, 5, 14, 14, 5, 3, 7, 5, 6, 1, 15, 4, 9, 1, 7, 3, 1, 4, 12, 12, 1, 15, 5, 10, 3, 8, 15, 8, 7, 12, 2, 1, 4, 21, 9, 5, 12, 13, 3, 2, 12, 19, 1, 19, 2, 9, 1, 3, 8, 2, 1, 9, 2, 3, 1, 11, 2, 9, 7, 18, 6, 1, 15, 4, 9, 1, 21, 10, 2, 4, 3, 8, 3, 4, 23, 7, 6, 18, 1, 3, 8, 2, 1, 4, 16, 3, 5, 10, 9, 1, 15, 2, 10, 2, 1, 7, 6, 16, 10, 2, 11, 7, 21, 12, 2, 1, 4, 6, 11, 1, 3, 8, 2, 1, 9, 5, 6, 18, 9, 1, 15, 2, 10, 2, 1, 14, 2, 14, 5, 10, 4, 21, 12, 2]


In [5]:
# We use the CTC algorithm to find the most likely alignment timestamps between the wav2vec2 emission and the ground truth transcript.
# torch.nn.CTCLoss will give us the probability of the ground truth transcript matching the wav2vec2 output, but to get the most likely timestamps, we have to implement the CTC algorithm.

def get_trellis(emission, tokens, blank_id=0): 
    # Trellis is a matrix of (Frames, Transcript Letters) where each element is the probability of the transcript letter occuring at that frame. 
    # The general idea is that wav2vec2 outputs a probability for each letter in the vocab per frame, but we want to align a ground truth transcript who's letters may not perfectly match the maximum probabiliity of the wav2vec2 output at each frame.
    # Traversing the most likely path on the trellis will give us the most likely alignment timestamps between the wav2vec2 output and the ground truth transcript.
    # The following functions are adapted from https://pytorch.org/audio/stable/tutorials/forced_alignment_tutorial.html which has further details. 
    num_frame = emission.size(0)
    num_tokens = len(tokens)

    # Trellis has extra dimensions for both time axis and tokens.
    # The extra dim for tokens represents <SoS> (start-of-sentence)
    # The extra dim for time axis is for simplification of the code.
    trellis = torch.empty((num_frame + 1, num_tokens + 1))
    trellis[0, 0] = 0
    trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
    trellis[0, -num_tokens:] = -float("inf")
    trellis[-num_tokens:, 0] = float("inf")

    for t in range(num_frame):
        trellis[t + 1, 1:] = torch.maximum(
            # Score for staying at the same token
            trellis[t, 1:] + emission[t, blank_id],
            # Score for changing to the next token
            trellis[t, :-1] + emission[t, tokens],
        )
    return trellis


@dataclass
class Point:
    token_index: int
    time_index: int
    score: float


def backtrack(trellis, emission, tokens, blank_id=0):
    # Note:
    # j and t are indices for trellis, which has extra dimensions
    # for time and tokens at the beginning.
    # When referring to time frame index `T` in trellis,
    # the corresponding index in emission is `T-1`.
    # Similarly, when referring to token index `J` in trellis,
    # the corresponding index in transcript is `J-1`.
    j = trellis.size(1) - 1
    t_start = torch.argmax(trellis[:, j]).item()

    path = []
    for t in range(t_start, 0, -1):
        # 1. Figure out if the current position was stay or change
        # Note (again):
        # `emission[J-1]` is the emission at time frame `J` of trellis dimension.
        # Score for token staying the same from time frame J-1 to T.
        stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
        # Score for token changing from C-1 at T-1 to J at T.
        changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]

        # 2. Store the path with frame-wise probability.
        prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
        # Return token index and time index in non-trellis coordinate.
        path.append(Point(j - 1, t - 1, prob))

        # 3. Update the token
        if changed > stayed:
            j -= 1
            if j == 0:
                break
    else:
        raise ValueError("Failed to align")
    return path[::-1]

trellis = get_trellis(emission, tokens)

path = backtrack(trellis, emission, tokens)
for p in path:
    print(f"Letter: {transcript[p.token_index]}, Frame: {p.time_index}, Probability: {p.score}")

Letter: B, Frame: 30, Probability: 0.999861478805542
Letter: B, Frame: 31, Probability: 0.9998544454574585
Letter: B, Frame: 32, Probability: 0.9939178228378296
Letter: U, Frame: 33, Probability: 0.999883770942688
Letter: T, Frame: 34, Probability: 0.9998774528503418
Letter: T, Frame: 35, Probability: 0.0041043516248464584
Letter: |, Frame: 36, Probability: 0.9996973276138306
Letter: |, Frame: 37, Probability: 0.9999879598617554
Letter: |, Frame: 38, Probability: 0.9999998807907104
Letter: A, Frame: 39, Probability: 0.9999967813491821
Letter: A, Frame: 40, Probability: 0.009566818363964558
Letter: F, Frame: 41, Probability: 0.9999899864196777
Letter: F, Frame: 42, Probability: 0.9999974966049194
Letter: F, Frame: 43, Probability: 0.043931275606155396
Letter: T, Frame: 44, Probability: 0.9999842643737793
Letter: T, Frame: 45, Probability: 0.9999840259552002
Letter: E, Frame: 46, Probability: 0.9999806880950928
Letter: R, Frame: 47, Probability: 0.9999549388885498
Letter: R, Frame: 48, P

In [6]:
# We merge repeated letters and average the probability scores
@dataclass
class Segment:
    label: str
    start: int
    end: int
    score: float

    def __repr__(self):
        return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"

    @property
    def length(self):
        return self.end - self.start


def merge_repeats(path, transcript):
    i1, i2 = 0, 0
    segments = []
    while i1 < len(path):
        while i2 < len(path) and path[i1].token_index == path[i2].token_index:
            i2 += 1
        score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
        segments.append(
            Segment(
                transcript[path[i1].token_index],
                path[i1].time_index,
                path[i2 - 1].time_index + 1,
                score,
            )
        )
        i1 = i2
    return segments


letter_segments = merge_repeats(path, transcript)
for seg in letter_segments:
    print(f"Letter: {seg.label}, Frame(s): {seg.start} - {seg.end}, Probability: {seg.score}")

Letter: B, Frame(s): 30 - 33, Probability: 0.9978779157002767
Letter: U, Frame(s): 33 - 34, Probability: 0.999883770942688
Letter: T, Frame(s): 34 - 36, Probability: 0.5019909022375941
Letter: |, Frame(s): 36 - 39, Probability: 0.9998950560887655
Letter: A, Frame(s): 39 - 41, Probability: 0.5047817998565733
Letter: F, Frame(s): 41 - 44, Probability: 0.6813062528769175
Letter: T, Frame(s): 44 - 46, Probability: 0.9999841451644897
Letter: E, Frame(s): 46 - 47, Probability: 0.9999806880950928
Letter: R, Frame(s): 47 - 50, Probability: 0.34963943806360476
Letter: |, Frame(s): 50 - 60, Probability: 0.8055566123227663
Letter: A, Frame(s): 60 - 61, Probability: 0.9999972581863403
Letter: L, Frame(s): 61 - 65, Probability: 0.7504934325697832
Letter: L, Frame(s): 65 - 67, Probability: 0.49972673904585463
Letter: |, Frame(s): 67 - 69, Probability: 0.49966502919778577
Letter: T, Frame(s): 69 - 71, Probability: 0.5490692183375359
Letter: H, Frame(s): 71 - 74, Probability: 0.999919056892395
Letter:

In [7]:
# Now, we merge words and average the probability scores
def merge_words(segments, separator="|"):
    words = []
    i1, i2 = 0, 0
    while i1 < len(segments):
        if i2 >= len(segments) or segments[i2].label == separator:
            if i1 != i2:
                segs = segments[i1:i2]
                word = "".join([seg.label for seg in segs])
                score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
                words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
            i1 = i2 + 1
            i2 = i1
        else:
            i2 += 1
    return words


word_segments = merge_words(letter_segments)
for word in word_segments:
    print(f"Word: {word.label}, Frame(s): {word.start} - {word.end}, Probability: {word.score}")

Word: BUT, Frame(s): 30 - 36, Probability: 0.8329165537531177
Word: AFTER, Frame(s): 39 - 50, Probability: 0.6456681500871624
Word: ALL, Frame(s): 60 - 67, Probability: 0.7144892095081689
Word: THAT, Frame(s): 69 - 78, Probability: 0.6820675389677086
Word: COMMOTION, Frame(s): 80 - 106, Probability: 0.9012833536055452
Word: WAS, Frame(s): 144 - 154, Probability: 0.7583912280867026
Word: IT, Frame(s): 156 - 158, Probability: 0.9993602633476257
Word: ALL, Frame(s): 162 - 168, Probability: 0.8379867176214854
Word: WORTHWHILE, Frame(s): 171 - 201, Probability: 0.7357348553647801
Word: ABSOLUTELY, Frame(s): 280 - 316, Probability: 0.7817917605252309
Word: YES, Frame(s): 318 - 331, Probability: 0.9199858714262239
Word: THE, Frame(s): 375 - 380, Probability: 0.6205049332231283
Word: SET, Frame(s): 383 - 394, Probability: 0.836859337054193
Word: DESIGN, Frame(s): 395 - 418, Probability: 0.7975587824653105
Word: WAS, Frame(s): 424 - 431, Probability: 0.8453983834811619
Word: BREATHTAKING, Frame

In [8]:
# If we wanted forced alignment of letters or words, we would be done. However, we want forced alignment of phonemes, so we need to convert the letters to phonemes.
# We make use of the CMU grapheme 2 phoneme alignment dictionary (aligned using Phonetisaurus) hosted here: https://github.com/ckw017/aligned-cmudict/blob/master/g2p.json

# Let's clean up the g2p dictionary a bit and make sure it's using the Sphinx 40 Phoneme set.
g2p_path = "./cmudict/g2p.json"
import json
import re
with open(g2p_path, "r") as f:
    g2p = json.load(f)

to_remove = []

unique_phonemes = set()
for key in g2p:
    if not any("foreign" in element for element in g2p[key]['phonemes']) and not any("#" in element for element in g2p[key]['phonemes']) and not any("old" in element for element in g2p[key]['phonemes']):
      for phoneme in g2p[key]['phonemes']:
        for individual_phoneme in phoneme.split('|'):
            unique_phonemes.add(re.sub(r'\d+', '', individual_phoneme))
    else:
      to_remove.append(key)

for key in to_remove:
    del g2p[key]

# Read CMU Phonemes into a list

CMU_phonemes_path = "./cmudict/SphinxPhones_40.txt"
CMU_phonemes = set()
with open(CMU_phonemes_path, "r") as f:
    for line in f:
        CMU_phonemes.add(line.strip())

print(f"The following phonemes are in CMU but not in g2p: {CMU_phonemes.difference(unique_phonemes)}")
print(f"The following phonemes are in g2p but not in CMU: {unique_phonemes.difference(CMU_phonemes)}")

# Note that 'SIL' and '_' refer to blank phonemes.

print(f"The word 'hello' in g2p looks like {g2p['hello']}")
print(f"The word 'memorable' in g2p looks like {g2p['memorable']}")

The following phonemes are in CMU but not in g2p: {'SIL'}
The following phonemes are in g2p but not in CMU: {'_'}
The word 'hello' in g2p looks like {'graphemes': ['h', 'e', 'l|l', 'o'], 'phonemes': ['HH', 'EH0', 'L', 'OW1']}
The word 'memorable' in g2p looks like {'graphemes': ['m', 'e', 'm', 'o|r', 'a', 'b', 'l', 'e'], 'phonemes': ['M', 'EH1', 'M', 'ER0', 'AH0', 'B', 'AH0|L', '_']}


In [9]:
# Converting from graphemes to phonemes is not trivial because it is not a one-to-one mapping (it is many-to-many). For example, the word "hello" has 5 letters but 4 phonemes. In the word "memorable", the letter "l" maps to 2 phonemes - AH and L.
# This algorithm merges the letters aligned using the CTC algorithm into phonemes using the g2p dictionary. It also adds phoneme separators between phonemes and word separators between words. Using this algorithm in production require lots of testing.

def find_phonemes(letter_segments, word_segments, g2p):
    # This function will modify letter_segments, so make a copy if you'd like to keep the original
    segment_index = 0
    letter_segments.append(Segment("|", letter_segments[-1].end, letter_segments[-1].end, 1.0)) # add a word separator at the end of the transcript
    for word in word_segments:
        # print(f"Scanning word {word.label}")
        phonemes_map = g2p[word.label.lower()]
        for idx, grapheme in enumerate(phonemes_map['graphemes']):
            if "|" not in grapheme:
                if grapheme != letter_segments[segment_index].label.lower():
                    # print(f"Error: {grapheme} != {letter_segments[segment_index].label}")
                    break
                letter_segments[segment_index].label = phonemes_map['phonemes'][idx]
                segment_index += 1

                if letter_segments[segment_index].label == "|":
                    # print(f"Found word separator after {word.label}")
                    segment_index += 1
                else:
                    # print(f"Adding phoneme separator after {letter_segments[segment_index-1].label}")
                    letter_segments.insert(segment_index, Segment("|", letter_segments[segment_index-1].end, letter_segments[segment_index].start, 1.0))
                    segment_index += 1
            else:
                # print(f"Found 2 letter phoneme {grapheme}")
                for idx2, letter in enumerate(grapheme.split("|")):               
                    letter_segments[segment_index].label = phonemes_map['phonemes'][idx]
                    # print(f"Letter {idx2}: {letter}. Segment: {letter_segments[segment_index].label}")
                    segment_index += 1
                if letter_segments[segment_index].label != "|":
                    # print(f"Adding phoneme separator after {letter_segments[segment_index-1].label}")
                    letter_segments.insert(segment_index, Segment("|", letter_segments[segment_index-1].end, letter_segments[segment_index].start, 1.0))
                    segment_index += 1
                else:
                    # print(f"Found word separator after {word.label}")
                    segment_index += 1

    # remove silent phonemes
    phoneme_segments = [segment for segment in letter_segments if segment.label != "_"]

    # mitigate double phonemes by splitting them evenly in time, this is an approximation which introduces small errors
    split_segments = []
    import numpy as np
    for segment in phoneme_segments:
        if segment.label == "|":
            split_segments.append(segment)
            continue
        phonemes = segment.label.split("|")
        num_parts = len(phonemes)
        if num_parts > 1:
            # print(f"Found {num_parts} phonemes in {segment.label}")
            parts = np.linspace(segment.start, segment.end, num_parts + 1, dtype=int)
            for idx, phoneme in enumerate(phonemes):
                split_segments.append(Segment(phoneme, parts[idx], parts[idx+1], segment.score))
                split_segments.append(Segment("|", parts[idx+1], parts[idx+1], 1.0))
        else:
            split_segments.append(segment)

    while split_segments[-1].label == "|":
        # print(f"Removing trailing word separator at {split_segments[-1].start}")
        split_segments = split_segments[:-1] # remove last "|" segment
    return split_segments

# Merge phonemes
def merge_repeat_phonemes(segments, separator="|"):
    # segments = phoneme_segments[:]
    phonemes = []
    i1, i2 = 0, 0
    while i1 < len(segments):
        if i2 >= len(segments) or segments[i2].label == separator:
            if i1 != i2:
                segs = segments[i1:i2]
                if all(seg.label == segs[0].label for seg in segs):
                    phoneme = segs[0].label
                else:
                    phoneme = "".join([seg.label for seg in segs])
                score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
                phonemes.append(Segment(phoneme, segments[i1].start, segments[i2 - 1].end, score))
            i1 = i2 + 1
            i2 = i1
        else:
            i2 += 1
    return phonemes

letter_segments_copy = copy.deepcopy(letter_segments) # find_phonemes will modify letter_segments, so make a copy if you'd like to keep the original
phoneme_segments = find_phonemes(letter_segments_copy, word_segments, g2p)
phoneme_segments = merge_repeat_phonemes(phoneme_segments)

ratio = waveform.size()[1] / (trellis.size(0) - 1)
for phoneme in phoneme_segments:
    phoneme.label = re.sub(r'\d', '', phoneme.label) # we can remove the stress markers, but keep in mind they may be useful for future work
    phoneme.start = ratio * phoneme.start / sample_rate # convert to seconds
    phoneme.end = ratio * phoneme.end / sample_rate
    print(f"Phoneme: {phoneme.label}, Time: {phoneme.start:.3f}s - {phoneme.end:.3f}s")

# Save phonemes to JSON
output_file = "output/method_1.json"
save_phonemes_JSON(phoneme_segments, output_file)

Phoneme: B, Time: 0.600s - 0.661s
Phoneme: AH, Time: 0.661s - 0.681s
Phoneme: T, Time: 0.681s - 0.721s
Phoneme: AE, Time: 0.781s - 0.821s
Phoneme: F, Time: 0.821s - 0.881s
Phoneme: T, Time: 0.881s - 0.921s
Phoneme: ER, Time: 0.921s - 1.001s
Phoneme: AO, Time: 1.201s - 1.221s
Phoneme: L, Time: 1.221s - 1.341s
Phoneme: DH, Time: 1.381s - 1.481s
Phoneme: AH, Time: 1.481s - 1.521s
Phoneme: T, Time: 1.521s - 1.561s
Phoneme: K, Time: 1.601s - 1.661s
Phoneme: AH, Time: 1.661s - 1.681s
Phoneme: M, Time: 1.681s - 1.821s
Phoneme: OW, Time: 1.821s - 1.902s
Phoneme: SH, Time: 1.902s - 2.022s
Phoneme: AH, Time: 2.022s - 2.042s
Phoneme: N, Time: 2.042s - 2.122s
Phoneme: W, Time: 2.882s - 2.982s
Phoneme: AH, Time: 2.982s - 3.022s
Phoneme: Z, Time: 3.022s - 3.082s
Phoneme: IH, Time: 3.122s - 3.142s
Phoneme: T, Time: 3.142s - 3.163s
Phoneme: AO, Time: 3.243s - 3.263s
Phoneme: L, Time: 3.263s - 3.363s
Phoneme: W, Time: 3.423s - 3.483s
Phoneme: ER, Time: 3.483s - 3.563s
Phoneme: TH, Time: 3.563s - 3.723s

In [10]:
# Play the first 10 phonemes

for i in range(10):
    display_segment(phoneme_segments, i, waveform, sample_rate)

B (1.00): 0.600 - 0.661 sec


AH (1.00): 0.661 - 0.681 sec


T (0.50): 0.681 - 0.721 sec


AE (0.50): 0.781 - 0.821 sec


F (0.68): 0.821 - 0.881 sec


T (1.00): 0.881 - 0.921 sec


ER (0.51): 0.921 - 1.001 sec


AO (1.00): 1.201 - 1.221 sec


L (0.67): 1.221 - 1.341 sec


DH (0.82): 1.381 - 1.481 sec


# Method 2: CTC Segmentation using wav2vec2-large-english-phoneme-v2

Rather than convert between graphemes and phonemes, it may be better to use an ASR model pre-trained to identify English phonemes. 
There is one available at https://huggingface.co/speech31/wav2vec2-large-english-phoneme-v2 but it doesn't use the ARPABET phoneme set.

In [11]:
# Load the pretrained phoneme recognition model
# Clone the repo from https://huggingface.co/speech31/wav2vec2-large-english-phoneme-v2/tree/main

letters_dictionary = json.load(open("./wav2vec2-large-english-phoneme-v2/vocab.json"))
tokenizer = Wav2Vec2CTCTokenizer("./wav2vec2-large-english-phoneme-v2/vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
model = Wav2Vec2ForCTC.from_pretrained("./wav2vec2-large-english-phoneme-v2")

In [12]:
# Run inference on sample audio file
model.to(device)
input_values = processor(waveform.squeeze(), sampling_rate=sample_rate, return_tensors="pt").input_values.to(device)

with torch.no_grad():
  emissions = model(input_values).logits
emission = emissions[0].cpu().detach()

# We only need the emissions to perform forced alignment, but we can print the transcript generate by the ASR model if we want.
pred_ids = torch.argmax(emissions, dim=-1)
print(f"Letters: {' '.join(processor.tokenizer.convert_ids_to_tokens(pred_ids[0].tolist()))}")
pred_str = processor.batch_decode(pred_ids)[0]
print(f"ASR Model Transcript (CTC Decoded): {pred_str}")

Letters: [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] b ə ə t | | ˈ ˈ [PAD] æ [PAD] f f t t t ə ə r r r | | | [PAD] [PAD] [PAD] ɔ ɔ ɔ l l l l l l | | | ð ð [PAD] ə [PAD] [PAD] [PAD] t | | [PAD] k [PAD] ə ə ˈ ˈ m m [PAD] o o ʊ ʊ [PAD] [PAD] [PAD] ʃ ʃ [PAD] [PAD] ə ə n n [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] | | | [PAD] [PAD] w w w ɑ ɑ [PAD] z z z | ɪ ɪ t | | [PAD] ɔ ɔ ɔ l l l l | | | [PAD] w w [PAD] ə ə r r [PAD] [PAD] θ θ | | | [PAD] w w w a a a a ɪ ɪ ɪ [PAD] l l l [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [

In [13]:
# The phonemes used by this pre-trained model are close, but not identical to the CMU Sphinx phonemes. I did my best to map them but it would be smarter to fine tune a model on the Sphinx phonemes.

sphinx_to_ipa = {
  "'": "'",
  "[PAD]": "[PAD]",
  "[UNK]": "[UNK]",
  "AY": "aɪ",
  "B": "b",
  "C": "c",
  "D": "d",
  "EY": "e",
  "F": "f",
  "G": "g",
  "HH": "h",
  "IY": "i",
  "ZH": "ʒ",
  "K": "k",
  "L": "l",
  "M": "m",
  "N": "n",
  "AW": "o",
  "OW": "oʊ",
  "P": "p",
  "Q": "q",
  "R": "r",
  "S": "s",
  "T": "t",
  "UW": "u",
  "V": "v",
  "W": "w",
  "Y": "y",
  "Z": "z",
  "|": "|",
  "AE": "æ",
  "DH": "ð",
  "NG": "ŋ",
  "AA": "ɑ",
  "AO": "ɔ",
  "AH": "ə",
  "EH": "ɛ",
  "IH": "ɪ",
  "SH": "ʃ",
  "UH": "ʊ",
  "JH": "ʤ",
  "CH": "ʧ",
  "'": "ˈ",
  "TH": "θ",
  "ER": "ər", # pretrained model has no ɝ
  "OY": "ɔɪ"
}

In [14]:
# Now, we convert the ground truth transcript to IPA

# First, convert the CMU dictionary to IPA
CMU = CMUDict(root='./cmudict', download=True)

CMU_dict = {}
for word in CMU:
	if word[0] == 'THE':
		CMU_dict['the'] = 'ðə' # Choose the other pronunciation of "the" to generate the correct IPA transcript for the test example
	else:
		CMU_dict[word[0].lower()] = [sphinx_to_ipa[re.sub(r'\d', '', phoneme.upper())] for phoneme in word[1]] # strip out stress markers

print(f"Hello: {CMU_dict['hello']}")

ipa_transcript = [] # transcript in modified IPA
for word in transcript.split('|'):
  ipa_transcript += CMU_dict[word.lower()]
  ipa_transcript += ['|']

print(f"IPA Transcript: {''.join(ipa_transcript)}")

Hello: ['h', 'ɛ', 'l', 'oʊ']
IPA Transcript: bət|æftər|ɔl|ðət|kəmoʊʃən|wɑz|ɪt|ɔl|wərθwaɪl|æbsəlutli|yɛs|ðə|sɛt|dɪzaɪn|wɑz|brɛθtekɪŋ|ðə|æktərz|wər|ɪnkrɛdəbəl|ænd|ðə|sɔŋz|wər|mɛmərəbəl|


In [15]:
# Next, we proceed as in method 1 using the CTC algorithm to find the most likely alignment timestamps between the wav2vec2 output and the ground truth transcript.
with open("./wav2vec2-large-english-phoneme-v2/vocab.json", 'r', encoding='utf-8') as file:
    letters_dictionary = json.load(file)

tokens = [letters_dictionary[c] for c in ''.join(ipa_transcript)]
# print(tokens)
trellis = get_trellis(emission, tokens)
# print(trellis.shape)

path = backtrack(trellis, emission, tokens)
phoneme_segments = merge_repeats(path, ''.join(ipa_transcript))
phoneme_segments = [segment for segment in phoneme_segments if segment.label != "|"]

ratio = waveform.size()[1] / (trellis.size(0) - 1)
for phoneme in phoneme_segments:
    phoneme.start = ratio * phoneme.start / sample_rate # convert to seconds
    phoneme.end = ratio * phoneme.end / sample_rate
    print(f"Phoneme: {phoneme.label}, Time: {phoneme.start:.3f}s - {phoneme.end:.3f}s")

Phoneme: b, Time: 0.600s - 0.620s
Phoneme: ə, Time: 0.620s - 0.661s
Phoneme: t, Time: 0.661s - 0.701s
Phoneme: æ, Time: 0.781s - 0.841s
Phoneme: f, Time: 0.841s - 0.881s
Phoneme: t, Time: 0.881s - 0.901s
Phoneme: ə, Time: 0.901s - 0.981s
Phoneme: r, Time: 0.981s - 1.061s
Phoneme: ɔ, Time: 1.161s - 1.301s
Phoneme: l, Time: 1.301s - 1.321s
Phoneme: ð, Time: 1.381s - 1.401s
Phoneme: ə, Time: 1.401s - 1.521s
Phoneme: t, Time: 1.521s - 1.541s
Phoneme: k, Time: 1.601s - 1.621s
Phoneme: ə, Time: 1.621s - 1.741s
Phoneme: m, Time: 1.741s - 1.801s
Phoneme: o, Time: 1.801s - 1.821s
Phoneme: ʊ, Time: 1.821s - 1.922s
Phoneme: ʃ, Time: 1.922s - 1.942s
Phoneme: ə, Time: 1.942s - 2.062s
Phoneme: n, Time: 2.062s - 2.822s
Phoneme: w, Time: 2.922s - 2.942s
Phoneme: ɑ, Time: 2.942s - 3.022s
Phoneme: z, Time: 3.022s - 3.062s
Phoneme: ɪ, Time: 3.102s - 3.122s
Phoneme: t, Time: 3.122s - 3.142s
Phoneme: ɔ, Time: 3.223s - 3.323s
Phoneme: l, Time: 3.323s - 3.383s
Phoneme: w, Time: 3.443s - 3.463s
Phoneme: ə, Ti

In [16]:
# Convert the phonemes to Arpabet

ipa_to_sphinx = {v: k for k, v in sphinx_to_ipa.items()}
index = 0
for grapheme in [letter for letter in ipa_transcript if letter != '|']:
    if len(grapheme) == 1:
        phoneme_segments[index].label = ipa_to_sphinx[grapheme]
        index += 1
    else:
        length = len(grapheme)
        phoneme_segments[index].label = ipa_to_sphinx[grapheme]
        phoneme_segments[index].end = phoneme_segments[index + length - 1].end
        del phoneme_segments[index + 1:index + length]
        index += 1

for phoneme in phoneme_segments:
    print(f"Phoneme: {phoneme.label}, Time: {phoneme.start:.3f}s - {phoneme.end:.3f}s")

# Save phonemes to JSON
output_file = "output/method_2.json"
save_phonemes_JSON(phoneme_segments, output_file)

Phoneme: B, Time: 0.600s - 0.620s
Phoneme: AH, Time: 0.620s - 0.661s
Phoneme: T, Time: 0.661s - 0.701s
Phoneme: AE, Time: 0.781s - 0.841s
Phoneme: F, Time: 0.841s - 0.881s
Phoneme: T, Time: 0.881s - 0.901s
Phoneme: ER, Time: 0.901s - 1.061s
Phoneme: AO, Time: 1.161s - 1.301s
Phoneme: L, Time: 1.301s - 1.321s
Phoneme: DH, Time: 1.381s - 1.401s
Phoneme: AH, Time: 1.401s - 1.521s
Phoneme: T, Time: 1.521s - 1.541s
Phoneme: K, Time: 1.601s - 1.621s
Phoneme: AH, Time: 1.621s - 1.741s
Phoneme: M, Time: 1.741s - 1.801s
Phoneme: OW, Time: 1.801s - 1.922s
Phoneme: SH, Time: 1.922s - 1.942s
Phoneme: AH, Time: 1.942s - 2.062s
Phoneme: N, Time: 2.062s - 2.822s
Phoneme: W, Time: 2.922s - 2.942s
Phoneme: AA, Time: 2.942s - 3.022s
Phoneme: Z, Time: 3.022s - 3.062s
Phoneme: IH, Time: 3.102s - 3.122s
Phoneme: T, Time: 3.122s - 3.142s
Phoneme: AO, Time: 3.223s - 3.323s
Phoneme: L, Time: 3.323s - 3.383s
Phoneme: W, Time: 3.443s - 3.463s
Phoneme: ER, Time: 3.463s - 3.603s
Phoneme: TH, Time: 3.603s - 3.743s

In [17]:
# Play the first 10 phonemes
for i in range(10):
    display_segment(phoneme_segments, i, waveform, sample_rate)

B (1401915.25): 0.600 - 0.620 sec


AH (1258843.84): 0.620 - 0.661 sec


T (829433.08): 0.661 - 0.701 sec


AE (612992.91): 0.781 - 0.841 sec


F (541718.73): 0.841 - 0.881 sec


T (3204450.00): 0.881 - 0.901 sec


ER (677875.56): 0.901 - 1.061 sec


AO (572728.64): 1.161 - 1.301 sec


L (3603248.25): 1.301 - 1.321 sec


DH (659742.44): 1.381 - 1.401 sec


# Method 3: Montreal Forced Aligner

MFA (https://montreal-forced-aligner.readthedocs.io/en/latest/getting_started.html) is the most popular and widely used forced aligner and includes phoneme alignment. However, it is not a CTC-based algorithm. I include it here as a possible benchmark to test against.

In [18]:
# MFA must be run in its own conda environment. The command to generate the alignment is:
# "mfa align --clean input/ english_us_arpa english_us_arpa output"
# This creates a textgrid file in the output folder.

# load textgrid file
textgrid_file = 'output/assessment_9.TextGrid'
textgrid = tgt.io.read_textgrid(textgrid_file)
phones = textgrid.get_tier_by_name('phones').intervals

phoneme_segments = []
for phone in phones:
    segment = Segment(re.sub(r'\d', '', phone.text), float(phone.start_time), float(phone.end_time), 1.0)
    phoneme_segments.append(segment)

# phoneme_segments
for phoneme in phoneme_segments:
    print(f"Phoneme: {phoneme.label}, Time: {phoneme.start:.3f}s - {phoneme.end:.3f}s") # note the precision is in hundredths of a second, not milliseconds

save_phonemes_JSON(phoneme_segments, 'output/method_3.json')

Phoneme: B, Time: 0.600s - 0.630s
Phoneme: AH, Time: 0.630s - 0.660s
Phoneme: T, Time: 0.660s - 0.690s
Phoneme: AE, Time: 0.690s - 0.820s
Phoneme: F, Time: 0.820s - 0.870s
Phoneme: T, Time: 0.870s - 0.930s
Phoneme: ER, Time: 0.930s - 1.010s
Phoneme: AO, Time: 1.010s - 1.270s
Phoneme: L, Time: 1.270s - 1.350s
Phoneme: DH, Time: 1.350s - 1.390s
Phoneme: AE, Time: 1.390s - 1.520s
Phoneme: T, Time: 1.520s - 1.600s
Phoneme: K, Time: 1.600s - 1.630s
Phoneme: AH, Time: 1.630s - 1.660s
Phoneme: M, Time: 1.660s - 1.740s
Phoneme: OW, Time: 1.740s - 1.860s
Phoneme: SH, Time: 1.860s - 1.990s
Phoneme: AH, Time: 1.990s - 2.060s
Phoneme: N, Time: 2.060s - 2.200s
Phoneme: W, Time: 2.830s - 2.920s
Phoneme: AH, Time: 2.920s - 3.010s
Phoneme: Z, Time: 3.010s - 3.070s
Phoneme: IH, Time: 3.070s - 3.110s
Phoneme: T, Time: 3.110s - 3.140s
Phoneme: AO, Time: 3.140s - 3.320s
Phoneme: L, Time: 3.320s - 3.390s
Phoneme: W, Time: 3.390s - 3.470s
Phoneme: ER, Time: 3.470s - 3.550s
Phoneme: TH, Time: 3.550s - 3.700s

In [19]:
# Play the first 10 phonemes

for i in range(10):
    display_segment(phoneme_segments, i, waveform, sample_rate)

B (1.00): 0.600 - 0.630 sec


AH (1.00): 0.630 - 0.660 sec


T (1.00): 0.660 - 0.690 sec


AE (1.00): 0.690 - 0.820 sec


F (1.00): 0.820 - 0.870 sec


T (1.00): 0.870 - 0.930 sec


ER (1.00): 0.930 - 1.010 sec


AO (1.00): 1.010 - 1.270 sec


L (1.00): 1.270 - 1.350 sec


DH (1.00): 1.350 - 1.390 sec


# Testing

These methods should be tested on the LJSpeech dataset using the alignment loss metric.
Finally, the best CTC-based method may be the 960-hr wav2vec2 model fine tuned on the LJSpeech dataset with Arpabet phonemes, but fine-tuning wav2vec2 is beyond the scope of this assignment.