# ASR Assignment 2022-23

This notebook has been provided as a template to get you started on the assignment.  Feel free to use it for your development, or do your development directly in Python.

You can find a full description of the assignment [here](http://www.inf.ed.ac.uk/teaching/courses/asr/2022-23/coursework.pdf).

You are provided with two Python modules `observation_model.py` and `wer.py`.  The first was described in [Lab 3](https://github.com/ZhaoZeyu1995/asr_labs/blob/master/asr_lab3_4.ipynb).  The second can be used to compute the number of substitution, deletion and insertion errors between ASR output and a reference text.

It can be used as follows:

```python
import wer

my_refence = 'A B C'
my_output = 'A C C D'

wer.compute_alignment_errors(my_reference, my_output)
```

This produces a tuple $(s,d,i)$ giving counts of substitution,
deletion and insertion errors respectively - in this example (1, 0, 1).  The function accepts either two strings, as in the example above, or two lists.  Matching is case sensitive.

## Template code

Assuming that you have already made a function to generate an WFST, `create_wfst()` and a decoder class, `MyViterbiDecoder`, you can perform recognition on all the audio files as follows:


In [None]:
s2282172
s2473164 

In [None]:
import sys 
sys.path.append('/afs/inf.ed.ac.uk/user/s22/s2282172/.conda/envs/asr_env/lib/python3.7/site-packages') 

In [None]:
import math 
import openfst_python as fst 
import glob
import time 
import os 
import numpy as np 
from copy import deepcopy 
from collections import defaultdict 

In [None]:
class MyViterbiDecoder:
    
    NLL_ZERO = 1e10  # define a constant representing -log(0).  This is really infinite, but approximate
                     # it here with a very large number
    
    def __init__(self, f, audio_file_name, beam_width=None, beam_size=float('inf'), lm_look_ahead=False):
        """Set up the decoder class with an audio file and WFST f
        """
        self.om = observation_model.ObservationModel()
        self.f = f
        self.beam_width = beam_width 
        self.beam_size = beam_size
        self.lm_look_ahead = lm_look_ahead 
        
        if self.beam_width is not None and self.beam_size != float('inf'):
            raise ValueError('Only one of beam width and beam size can be used.') 
        
        if audio_file_name:
            self.om.load_audio(audio_file_name)
        else:
            self.om.load_dummy_audio()
        
        self.initialise_decoding()

        
    def initialise_decoding(self):
        """set up the values for V_j(0) (as negative log-likelihoods)
        
        """
        
        self.V = []   # stores likelihood along best path reaching state j
        self.B = []   # stores identity of best previous state reaching state j
        self.W = []   # stores output labels sequence along arc reaching j - this removes need for 
                      # extra code to read the output sequence along the best path
        self.computations = 0 
        
        for t in range(self.om.observation_length()+1):
            self.V.append([self.NLL_ZERO]*self.f.num_states())
            self.B.append([-1]*self.f.num_states())
            self.W.append([[] for i in range(self.f.num_states())])  #  multiplying the empty list doesn't make multiple
        
        # The above code means that self.V[t][j] for t = 0, ... T gives the Viterbi cost
        # of state j, time t (in negative log-likelihood form)
        # Initialising the costs to NLL_ZERO effectively means zero probability    
        
        # give the WFST start state a probability of 1.0   (NLL = 0.0)
        self.V[0][self.f.start()] = 0.0
        
        # some WFSTs might have arcs with epsilon on the input (you might have already created 
        # examples of these in earlier labs) these correspond to non-emitting states, 
        # which means that we need to process them without stepping forward in time.  
        # Don't worry too much about this!  
        self.traverse_epsilon_arcs(0)        
        
    def traverse_epsilon_arcs(self, t):
        """Traverse arcs with <eps> on the input at time t
        
        These correspond to transitions that don't emit an observation
        
        We've implemented this function for you as it's slightly trickier than
        the normal case.  You might like to look at it to see what's going on, but
        don't worry if you can't fully follow it.
        
        """
        
        states_to_traverse = list(self.f.states()) # traverse all states
        while states_to_traverse:
            
            # Set i to the ID of the current state, the first 
            # item in the list (and remove it from the list)
            i = states_to_traverse.pop(0)   
        
            # don't bother traversing states which have zero probability
            if self.V[t][i] == self.NLL_ZERO:
                    continue
        
            for arc in self.f.arcs(i):
                
                if arc.ilabel == 0:     # if <eps> transition
                  
                    j = arc.nextstate   # ID of next state  
                
                    if self.V[t][j] > self.V[t][i] + float(arc.weight):
                        
                        # this means we've found a lower-cost path to
                        # state j at time t.  We might need to add it
                        # back to the processing queue.
                        self.V[t][j] = self.V[t][i] + float(arc.weight)
                        
                        # save backtrace information.  In the case of an epsilon transition, 
                        # we save the identity of the best state at t-1.  This means we may not
                        # be able to fully recover the best path, but to do otherwise would
                        # require a more complicated way of storing backtrace information
                        self.B[t][j] = self.B[t][i] 
                        
                        # and save the output labels encountered - this is a list, because
                        # there could be multiple output labels (in the case of <eps> arcs)
                        if arc.olabel != 0:
                            self.W[t][j] = self.W[t][i] + [arc.olabel]
                        else:
                            self.W[t][j] = self.W[t][i]
                        
                        if j not in states_to_traverse:
                            states_to_traverse.append(j)

    
    def forward_step(self, t):
          
        for i in self.f.states():
            
            if not self.V[t-1][i] == self.NLL_ZERO:   # no point in propagating states with zero probability
                
                for arc in self.f.arcs(i):
                    
                    if arc.ilabel != 0: # <eps> transitions don't emit an observation
                        j = arc.nextstate
                        tp = float(arc.weight)  # transition prob
                        ep = -self.om.log_observation_probability(self.f.input_symbols().find(arc.ilabel), t)  # emission negative log prob
                        prob = tp + ep + self.V[t-1][i] # they're logs 
                        if self.lm_look_ahead == True:
                            prob = prob - math.log(unigram_lm_look_ahead_prob[j]) 
                        self.computations += 1 
                        if prob < self.V[t][j]:
                            self.V[t][j] = prob
                            self.B[t][j] = i
                            
                            # store the output labels encountered too
                            if arc.olabel !=0:
                                self.W[t][j] = [arc.olabel]
                            else:
                                self.W[t][j] = []
        
        if self.beam_width is not None:
            min_vt_value = min(self.V[t])
            for j in range(len(self.V[t])):
                if self.V[t][j] > -math.log(self.beam_width) + min_vt_value:
                    self.V[t][j] = self.NLL_ZERO 
        
        if self.beam_size != float('inf'): 
            vt_value = sorted(self.V[t])
            beam_size = min(self.beam_size, len(self.V[t]))
            thershold = vt_value[self.beam_size - 1]
            for j in range(len(self.V[t])):
                if self.V[t][j] > thershold:
                    self.V[t][j] = self.NLL_ZERO 
                            
    
    def finalise_decoding(self):
        """ this incorporates the probability of terminating at each state
        """
        
        for state in self.f.states():
            final_weight = float(self.f.final(state))
            if self.V[-1][state] != self.NLL_ZERO:
                if final_weight == math.inf:
                    self.V[-1][state] = self.NLL_ZERO  # effectively says that we can't end in this state
                else:
                    self.V[-1][state] += final_weight
                    
        # get a list of all states where there was a path ending with non-zero probability
        finished = [x for x in self.V[-1] if x < self.NLL_ZERO]
        if not finished:  # if empty
            print("No path got to the end of the observations.")
        
        
    def decode(self):
        self.initialise_decoding()
        t = 1
        while t <= self.om.observation_length():
            self.forward_step(t)
            self.traverse_epsilon_arcs(t)
            t += 1
        self.finalise_decoding()
        return self.computations 
    
    def backtrace(self):
        
        best_final_state = self.V[-1].index(min(self.V[-1])) # argmin
        best_state_sequence = [best_final_state]
        best_out_sequence = []
        
        t = self.om.observation_length()   # ie T
        j = best_final_state
        
        while t >= 0:
            i = self.B[t][j]
            best_state_sequence.append(i)
            best_out_sequence = self.W[t][j] + best_out_sequence  # computer scientists might like
                                                                                # to make this more efficient!

            # continue the backtrace at state i, time t-1
            j = i  
            t-=1
            
        best_state_sequence.reverse()
        
        # convert the best output sequence from FST integer labels into strings
        # best_out_sequence = ' '.join([ self.f.output_symbols().find(label) for label in best_out_sequence]) 
        out_sequence = []
        for label in best_out_sequence:
            out_sequence.append(self.f.output_symbols().find(label))
            
        
        # return (best_state_sequence, best_out_sequence)
        return (best_state_sequence, out_sequence)

In [None]:
def parse_lexicon(lex_file):
    """
    Parse the lexicon file and return it in dictionary form.
    
    Args:
        lex_file (str): filename of lexicon file with structure '<word> <phone1> <phone2>...'
                        eg. peppers p eh p er z

    Returns:
        lex (dict): dictionary mapping words to list of phones
    """
    
    lex = {}  # create a dictionary for the lexicon entries (this could be a problem with larger lexica)
    with open(lex_file, 'r') as f:
        for line in f:
            line = line.split()  # split at each space
            lex[line[0]] = line[1:]  # first field the word, the rest is the phones
    return lex

def generate_symbol_tables(lexicon, n=3):
    '''
    Return word, phone and state symbol tables based on the supplied lexicon
        
    Args:
        lexicon (dict): lexicon to use, created from the parse_lexicon() function
        n (int): number of states for each phone HMM
        
    Returns:
        word_table (fst.SymbolTable): table of words
        phone_table (fst.SymbolTable): table of phones
        state_table (fst.SymbolTable): table of HMM phone-state IDs
    '''
    
    state_table = fst.SymbolTable()
    phone_table = fst.SymbolTable()
    word_table = fst.SymbolTable()
    
    # add empty <eps> symbol to all tables
    state_table.add_symbol('<eps>')
    phone_table.add_symbol('<eps>')
    word_table.add_symbol('<eps>')
    
    for word, phones  in lexicon.items():
        
        word_table.add_symbol(word)
        
        for p in phones: # for each phone
            
            phone_table.add_symbol(p)
            for i in range(1,n+1): # for each state 1 to n
                state_table.add_symbol('{}_{}'.format(p, i))
            
    return word_table, phone_table, state_table


# call these two functions
lex = parse_lexicon('lexicon.txt')
word_table, phone_table, state_table = generate_symbol_tables(lex)
'''
phones_words_dict = {}
for word, phones in lex.items():
    phones_words_dict[tuple(phones)] = word 
'''

def generate_phone_wfst(f, start_state, phone, n, output_phone=False):
    """
    Generate a WFST representating an n-state left-to-right phone HMM
    
    Args:
        f (fst.Fst()): an FST object, assumed to exist already
        start_state (int): the index of the first state, assmed to exist already
        phone (str): the phone label 
        n (int): number of states for each phone HMM
        
    Returns:
        the final state of the FST
    """
    
    current_state = start_state
    
    for i in range(1, n+1):
        
        in_label = state_table.find('{}_{}'.format(phone, i))
        
        sl_weight = fst.Weight('log', -math.log(0.1))  # weight for self-loop
        # self-loop back to current state
        f.add_arc(current_state, fst.Arc(in_label, 0, sl_weight, current_state))
        # f.add_arc(current_state, fst.Arc(in_label, 0, None, current_state))
        
        # transition to next state
        
        # we want to output the phone label on the final state
        # note: if outputting words instead this code should be modified
        
        if output_phone: 
            if i == n:
                out_label = phone_table.find(phone)
            else:
                out_label = 0   # output empty <eps> label
        
        else: 
            out_label = 0 
            
        next_state = f.add_state()
        next_weight = fst.Weight('log', -math.log(0.9)) # weight to next state
        f.add_arc(current_state, fst.Arc(in_label, out_label, next_weight, next_state)) 
        # f.add_arc(current_state, fst.Arc(in_label, out_label, None, next_state))
       
        current_state = next_state
        
    return current_state

def generate_word_wfst(word):
    """ Generate a WFST for any word in the lexicon, composed of 3-state phone WFSTs.
        This will currently output word labels.  
        Exercise: could you modify this function and the one above to output a single phone label instead?
    
    Args:
        word (str): the word to generate
        
    Returns:
        the constructed WFST
    
    """
    f = fst.Fst('log')
    
    # create the start state
    start_state = f.add_state()
    f.set_start(start_state)
    
    current_state = start_state
    
    # iterate over all the phones in the word
    for phone in lex[word]:   # will raise an exception if word is not in the lexicon
        
        current_state = generate_phone_wfst(f, current_state, phone, 3)
    
        # note: new current_state is now set to the final state of the previous phone WFST
        
    f.set_final(current_state)
    
    return f


In [None]:
lex

In [None]:
word_state_dict = defaultdict(list)
# word_state_dict contains word and its correnponding states, including silence states. 
for word, phones in lex.items():
    for phone in phones:
        for i in range(1, 4):
            word_state_dict[word].append('{}_{}'.format(phone, i))
    for j in range(1, 6):
        word_state_dict[word].append('sil_{}'.format(j))

In [None]:
def generate_word_sequence_wfst(n=3):
    """ generate a HMM to recognise any single word sequence for words in the lexicon
    
    Args:
        n (int): states per phone HMM

    Returns:
        the constructed WFST
    
    """
    
    f = fst.Fst('log')
    
    # create a single start state
    start_state = f.add_state()
    f.set_start(start_state)
    
    for word, phones in lex.items():
        current_state = f.add_state()
        # The transition probability to each is equal. 
        f.add_arc(start_state, fst.Arc(0, word_table.find(word),\
                                       fst.Weight('log', -math.log(1 / len(lex.keys()))), current_state)) 
        
        for phone in phones: 
            current_state = generate_phone_wfst(f, current_state, phone, n)
        # note: new current_state is now set to the final state of the previous phone WFST
        
        f.set_final(current_state)
        f.add_arc(current_state, fst.Arc(0, 0, fst.Weight('log', -math.log(1)), start_state))
        
    return f 

In [None]:
phone_index_dict = {}
max_word_phone_length = 0
for phones in lex.values():
    max_word_phone_length = max(max_word_phone_length, len(phones))
    for phone in phones:
        phone_index_dict[phone] = phone_table.find(phone)

In [None]:
state_sil_table = deepcopy(state_table)
# state_sil_table contains all the states from the state_table and silence states. 
for i in range(1, 6):
    state_sil_table.add_symbol('sil_{}'.format(i)) 

In [None]:
state_index_dict = {}
max_word_state_length = 0
for states in word_state_dict.values():
    max_word_state_length = max(max_word_state_length, len(states))
    for state in states:
        state_index_dict[state] = state_sil_table.find(state)

In [None]:
def generate_tree_lexicon(lex):
    
    '''
    Generate tree structure lexicon. 
    '''
    
    words_phones_array = np.zeros((len(lex.keys()), max_word_phone_length + 1))
    # words_phones_array contains the indexes of all the phones for each word in the lexicon, each row 
    # represents one word in the lexicon. 
    states_array = np.zeros((words_phones_array.shape))
    # states_array contains the indexes of the states we have added. 
    word_list = sorted(lex.keys())
    for i in range(len(word_list)):
        for j in range(len(lex[word_list[i]])):
            words_phones_array[i, j] = phone_index_dict[lex[word_list[i]][j]]
    final_states = []
    
    f = fst.Fst('log')
    start_state = f.add_state()
    f.set_start(start_state) 
    
    for i in range(words_phones_array.shape[0]):
        add_state = True 
        for index, phones in enumerate(words_phones_array[0:i, 0]):
            if words_phones_array[i, 0] == phones:
                # If the previous words have the same phone at this time step. 
                add_state = False
                states_array[i, 0] = states_array[index, 0]
                break 
        
        if add_state:
            if words_phones_array[i, 1] == 0:
                # If the word does not have phone at next time step. 
                word_final_state = f.add_state()
                f.set_final(word_final_state) 
                final_states.append(word_final_state)
                states_array[i, 0] = word_final_state 
                f.add_arc(start_state, fst.Arc(words_phones_array[i, 0], \
                              word_table.find(word_list[i]), fst.Weight('log', -math.log(1)), word_final_state)) 
            else:
                state = f.add_state()
                states_array[i, 0] = state
                f.add_arc(start_state, fst.Arc(words_phones_array[i, 0], 0, \
                                                 fst.Weight('log', -math.log(1)), state)) 
    
    for j in range(1, words_phones_array.shape[1] - 1):
        for i in range(words_phones_array.shape[0]):
            if words_phones_array[i, j] != 0:
                # If the word has phone at this time step. 
                add_state = True 
                for index, array in enumerate(words_phones_array[0:i, 0:j+1]):
                    if np.allclose(words_phones_array[i, 0:j+1], array):
                        # If the previous words have the same phones at all the previous time steps 
                        # and this time step. 
                        add_state = False
                        states_array[i, 0:j+1] = states_array[index, 0:j+1]
                        break
                        
                if add_state: 
                    if words_phones_array[i, j+1] == 0:
                        # If the word has phone at next time step. 
                        word_final_state = f.add_state()
                        f.set_final(word_final_state)
                        final_states.append(word_final_state)
                        states_array[i, j] = word_final_state
                        f.add_arc(states_array[i, j-1], fst.Arc(words_phones_array[i, j], \
                                word_table.find(word_list[i]), fst.Weight('log', -math.log(1)), word_final_state))
                    else:
                        state = f.add_state()
                        states_array[i, j] = state
                        f.add_arc(states_array[i, j-1], fst.Arc(words_phones_array[i, j], 0, \
                                                               fst.Weight('log', -math.log(1)), state)) 
    for state in final_states:
        f.add_arc(state, fst.Arc(0, 0, fst.Weight('log', -math.log(1)), start_state)) 
    
    return f, words_phones_array, states_array 

In [None]:
tree_lexicon, words_phones_array, states_array = generate_tree_lexicon(lex)
words_phones_array, states_array

In [None]:
tree_lexicon, _, _ = generate_tree_lexicon(lex)
tree_lexicon.set_input_symbols(phone_table)
tree_lexicon.set_output_symbols(word_table)
from subprocess import check_call
from IPython.display import Image 
tree_lexicon.draw('tmp.dot', portrait=True)
check_call(['dot','-Tpng','-Gdpi=200','tmp.dot','-o','tmp.png'])
Image(filename='tmp.png')

In [None]:
def generate_tree_wfst(word_state_dict):
    
    '''
    Generate tree structure wfst to recognize a sequence of words. 
    '''
    
    # words_phones_array = np.zeros((len(lex.keys()), max_word_phone_length + 1))
    words_states_array = np.zeros((len(word_state_dict.keys()), max_word_state_length + 1))
    # words_states_array contains the indexes of all the states for each word in the lexicon, each row 
    # represents one word in the lexicon. 
    states_array = np.zeros((words_states_array.shape))
    # states_array contains the indexes of the states we have added. 
    word_list = sorted(word_state_dict.keys())
    for i in range(len(word_list)):
        for j in range(len(word_state_dict[word_list[i]])):
            # words_phones_array[i, j] = phone_index_dict[lex[word_list[i]][j]]
            words_states_array[i, j] = state_index_dict[word_state_dict[word_list[i]][j]]
    final_states = []
    word_final_list = []
    
    f = fst.Fst('log')
    start_state = f.add_state()
    f.set_start(start_state) 
    
    for i in range(words_states_array.shape[0]):
        add_state = True 
        for index, states in enumerate(words_states_array[0:i, 0]):
            if words_states_array[i, 0] == states:
                # If the previous words have the same state at this time step. 
                add_state = False
                states_array[i, 0] = states_array[index, 0]
                break 
        
        if add_state:
            if words_states_array[i, 1] == 0:
                # If the word does not have state at next time step. 
                word_final_state = f.add_state()
                f.set_final(word_final_state) 
                final_states.append(word_final_state)
                word_final_list.append(i)
                states_array[i, 0] = word_final_state 
                f.add_arc(start_state, fst.Arc(words_states_array[i, 0], \
                              word_table.find(word_list[i]), fst.Weight('log', \
                                                 -math.log(1)), word_final_state))
            else:
                state = f.add_state()
                states_array[i, 0] = state
                f.add_arc(start_state, fst.Arc(words_states_array[i, 0], 0, \
                                                 fst.Weight('log', \
                                                      -math.log(1)), state)) 
                f.add_arc(state, fst.Arc(words_states_array[i, 0], 0, \
                                        fst.Weight('log', -math.log(0.1)), state))
    
    for j in range(1, words_states_array.shape[1] - 1):
        for i in range(words_states_array.shape[0]):
            if words_states_array[i, j] != 0:
                # If the word has state at this time step. 
                add_state = True 
                for index, array in enumerate(words_states_array[0:i, 0:j+1]):
                    if np.allclose(words_states_array[i, 0:j+1], array):
                        # If the previous words have the same states at all the previous time steps 
                        # and this time step. 
                        add_state = False
                        states_array[i, 0:j+1] = states_array[index, 0:j+1]
                        break

                if add_state: 
                    if words_states_array[i, j+1] == 0:
                        # If the word has phone at next time step. 
                        word_final_state = f.add_state() 
                        f.set_final(word_final_state)
                        final_states.append(word_final_state)
                        word_final_list.append(i)
                        states_array[i, j] = word_final_state
                        f.add_arc(states_array[i, j-1], fst.Arc(words_states_array[i, j], \
                                word_table.find(word_list[i]), fst.Weight('log', -math.log(1)), word_final_state))
                    else:
                        state = f.add_state()
                        states_array[i, j] = state
                        f.add_arc(states_array[i, j-1], fst.Arc(words_states_array[i, j], 0, \
                                                                fst.Weight('log', -math.log(0.9)), state)) 
                        f.add_arc(state, fst.Arc(words_states_array[i, j], 0, \
                                                fst.Weight('log', -math.log(0.1)), state))
    # for state in final_states:
    for i in range(len(final_states)):
        f.add_arc(final_states[i], fst.Arc(0, 0, fst.Weight('log', \
                                     -math.log(1)), start_state))
    
    return f, words_states_array, states_array, word_list 

In [None]:
tree_wfst, words_states_array, states_array, word_list = generate_tree_wfst(word_state_dict)
tree_wfst.set_input_symbols(state_sil_table)
tree_wfst.set_output_symbols(word_table)
from subprocess import check_call
from IPython.display import Image 
tree_wfst.draw('tmp.dot', portrait=True)
check_call(['dot','-Tpng','-Gdpi=200','tmp.dot','-o','tmp.png'])
Image(filename='tmp.png') 

In [None]:
import glob
from collections import defaultdict
def read_transcription(wav_file):
    """
    Get the transcription corresponding to wav_file.
    """
    
    transcription_file = os.path.splitext(wav_file)[0] + '.txt'
    
    with open(transcription_file, 'r') as f:
        transcription = f.readline().strip()
    
    return transcription
word_prob = defaultdict(float)  # word_prob contains the unigram probability for each word. 
total_word = 0 
for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):
    transcription = read_transcription(wav_file)
    for word in transcription.split():
        word_prob[word] += 1
        total_word += 1
for word in word_prob.keys():
    word_prob[word] = word_prob[word] / total_word 

In [None]:
unigram_lm_look_ahead_prob = {}
for i in range(states_array.shape[0]):
    for j in range(states_array.shape[1]):
        if states_array[i, j] not in unigram_lm_look_ahead_prob.keys():
            # unigram_lm_look_ahead_prob[int(states_array[i, j])] = 0
            unigram_lm_look_ahead_prob[int(states_array[i, j])] = word_prob[word_list[i]]
        else:
            unigram_lm_look_ahead_prob[int(states_array[i, j])] = \
                                          max(unigram_lm_look_ahead_prob[int(states_array[i, j])],
                                                                word_prob[word_list[i]]) 

In [None]:
word_sequence_wfst = generate_word_sequence_wfst(3)
word_sequence_wfst.set_input_symbols(state_table)
word_sequence_wfst.set_output_symbols(word_table)
from subprocess import check_call
from IPython.display import Image 
word_sequence_wfst.draw('tmp.dot', portrait=True)
check_call(['dot','-Tpng','-Gdpi=200','tmp.dot','-o','tmp.png'])
Image(filename='tmp.png')

In [None]:
def generate_unigram_word_wfst(n=3):
    
    '''
    Generate a wfst contains unigram grammer to recognize word sequences. 
    '''
    
    f = fst.Fst('log')
    
    # create a single start state
    start_state = f.add_state()
    f.set_start(start_state)
    
    for word, phones in lex.items():
        current_state = f.add_state()
        # The transition probability to each word is unigram probability. 
        f.add_arc(start_state, fst.Arc(0, word_table.find(word),\
                                       fst.Weight('log', -math.log(word_prob[word])), current_state)) 
        
        for phone in phones: 
            current_state = generate_phone_wfst(f, current_state, phone, n)
        # note: new current_state is now set to the final state of the previous phone WFST
        
        f.set_final(current_state)
        f.add_arc(current_state, fst.Arc(0, 0, fst.Weight('log', -math.log(1)), start_state))
        
    return f 

In [None]:
unigram_sequence_wfst = generate_unigram_word_wfst()
unigram_sequence_wfst.set_input_symbols(state_table)
unigram_sequence_wfst.set_output_symbols(word_table)
from subprocess import check_call
from IPython.display import Image 
unigram_sequence_wfst.draw('tmp.dot', portrait=True)
check_call(['dot','-Tpng','-Gdpi=200','tmp.dot','-o','tmp.png'])
Image(filename='tmp.png') 

In [None]:
def generate_word_sil_wfst(n=3):
    
    '''
    Generate word_sil_wfst which contains five silence states at the end of each word. 
    During decoding, the decoder has to go through the slience states. 
    '''
    
    f = fst.Fst('log')
    start_state = f.add_state()
    f.set_start(start_state)
    
    for word, phones in lex.items():
        current_state = f.add_state()
        f.add_arc(start_state, fst.Arc(0, word_table.find(word), fst.Weight('log', -math.log(word_prob[word])), \
                                      current_state))
        for phone in phones:
            current_state = generate_phone_wfst(f, current_state, phone, n)
            
        phone_final_state = current_state 
        last_state = phone_final_state 
        
        for i in range(1, 6):
            current_state = f.add_state()
            f.add_arc(last_state, fst.Arc(state_sil_table.find('sil_{}'.format(i)), 0, \
                                         fst.Weight('log', -math.log(0.9)), current_state))
            f.add_arc(last_state, fst.Arc(state_sil_table.find('sil_{}'.format(i)), 0, \
                                            fst.Weight('log', -math.log(0.1)), last_state))
            last_state = current_state
            
        word_final_state = last_state
        f.set_final(word_final_state)
        
        f.add_arc(word_final_state, fst.Arc(0, 0, fst.Weight('log', -math.log(1)), start_state)) 
        
    return f 

In [None]:
word_sil_wfst = generate_word_sil_wfst()
word_sil_wfst.set_input_symbols(state_sil_table) 
word_sil_wfst.set_output_symbols(word_table)
from subprocess import check_call
from IPython.display import Image 
word_sil_wfst.draw('tmp.dot', portrait=True)
check_call(['dot','-Tpng','-Gdpi=200','tmp.dot','-o','tmp.png'])
Image(filename='tmp.png') 

In [None]:
def generate_sequence_sil_wfst(n=3):
    
    '''
    Generate sequence_sil_wfst which chotains five silence states at the end of each word. 
    During decoding, the decoder can go through the silence states or not depending on the given probabilities. 
    '''
    
    f = fst.Fst('log')
    start_state = f.add_state()
    f.set_start(start_state)
    
    for word, phones in lex.items():
        current_state = f.add_state()
        f.add_arc(start_state, fst.Arc(0, word_table.find(word), fst.Weight('log', -math.log(word_prob[word])), \
                                      current_state))
        for phone in phones:
            current_state = generate_phone_wfst(f, current_state, phone, n)
            
        phone_final_state = current_state 
        current_state = f.add_state()
        
        f.add_arc(phone_final_state, fst.Arc(state_sil_table.find('sil_1'), 0, \
                                            fst.Weight('log', -math.log(0.9*0.8)), current_state))
        f.add_arc(phone_final_state, fst.Arc(state_sil_table.find('sil_1'), 0, \
                                        fst.Weight('log', -math.log(0.1)), phone_final_state))
        last_state = current_state 
        
        for i in range(2, 6):
            current_state = f.add_state()
            f.add_arc(last_state, fst.Arc(state_sil_table.find('sil_{}'.format(i)), 0, \
                                         fst.Weight('log', -math.log(0.9)), current_state))
            f.add_arc(last_state, fst.Arc(state_sil_table.find('sil_{}'.format(i)), 0, \
                                            fst.Weight('log', -math.log(0.1)), last_state))
            last_state = current_state
            
            
        word_final_state = last_state
        f.set_final(word_final_state)
        f.add_arc(phone_final_state, fst.Arc(state_table.find('{}_{}'.format(phone, n)), 0, \
                                      fst.Weight('log', -math.log(0.9*0.2)), word_final_state))
        f.add_arc(word_final_state, fst.Arc(0, 0, fst.Weight('log', -math.log(1)), start_state)) 
        
    return f  

In [None]:
sequence_sil_wfst = generate_sequence_sil_wfst()
sequence_sil_wfst.set_input_symbols(state_sil_table)
sequence_sil_wfst.set_output_symbols(word_table)
from subprocess import check_call
from IPython.display import Image 
sequence_sil_wfst.draw('tmp.dot', portrait=True)
check_call(['dot','-Tpng','-Gdpi=200','tmp.dot','-o','tmp.png'])
Image(filename='tmp.png')  

In [None]:
def generate_unigram_wfst():
    
    '''
    Generate unigram_wfst which contains the unigram grammar of the words. 
    '''
    
    f = fst.Fst('log')
    # f = fst.Fst()
    state = f.add_state()
    f.set_start(state)
    for word in lex.keys():
        f.add_arc(state, fst.Arc(word_table.find(word), word_table.find(word),\
                                fst.Weight('log', -math.log(word_prob[word])), state))
        
    f.set_final(state)
    return f 

In [None]:
unigram_wfst = generate_unigram_wfst()
unigram_wfst.set_input_symbols(word_table)
unigram_wfst.set_output_symbols(word_table)

In [None]:
def generate_hmm_transducer(n=3):
    """ generate a HMM to recognise any single phone sequence in the lexicon
    
    Args:
        n (int): states per phone HMM

    Returns:
        the constructed WFST
    
    """
    
    f = fst.Fst('log')
    
    # create a single start state
    start_state = f.add_state()
    f.set_start(start_state)
    
    phone_set = set()
    
    for pronunciation in lex.values():
        phone_set = phone_set.union(pronunciation)
        
    for phone in phone_set:
        current_state = f.add_state()
        f.add_arc(start_state, fst.Arc(0, 0, fst.Weight('log', -math.log(1)), current_state))
    
        end_state = generate_phone_wfst(f, current_state, phone, n, output_phone=True)
        
        f.add_arc(end_state, fst.Arc(0,0, fst.Weight('log', -math.log(1)), start_state))
        f.set_final(end_state)
            
    return f 

In [None]:
hmm_transducer = generate_hmm_transducer() 
hmm_transducer.set_input_symbols(state_table)
hmm_transducer.set_output_symbols(phone_table)
from subprocess import check_call
from IPython.display import Image 
hmm_transducer.draw('tmp.dot', portrait=True)
check_call(['dot','-Tpng','-Gdpi=200','tmp.dot','-o','tmp.png'])
Image(filename='tmp.png') 

In [None]:
# tree_sequence_wfst = fst.determinize(fst.compose(hmm_transducer, \
                            # fst.determinize(fst.compose(tree_lexicon, unigram_wfst)).minimize())).minimize()
# tree_sequence_wfst = fst.determinize(fst.compose(hmm_transducer, \
                                 # fst.determinize(fst.compose(tree_lexicon, unigram_wfst)).minimize())).minimize()
tree_sequence_wfst = fst.compose(tree_lexicon, unigram_wfst)
# tree_sequence_wfst = fst.determinize(tree_sequence_wfst).minimize()
# tree_sequence_wfst = fst.determinize(fst.compose(tree_lexicon, unigram_wfst)).minimize()
# tree_sequence_wfst = fst.determinize(fst.compose(hmm_transducer, tree_sequence_wfst)).minimize()
tree_sequence_wfst = fst.compose(hmm_transducer, tree_sequence_wfst) 
# tree_sequence_wfst = fst.determinize(tree_sequence_wfst).minimize()
tree_sequence_wfst.set_input_symbols(state_table) 
tree_sequence_wfst.set_output_symbols(word_table) 
from subprocess import check_call 
from IPython.display import Image 
tree_sequence_wfst.draw('tmp.dot', portrait=True) 
check_call(['dot','-Tpng','-Gdpi=200','tmp.dot','-o','tmp.png'])
Image(filename='tmp.png') 

In [None]:
word_index_dict = {}
for word in lex.keys():
    word_index_dict[word] = word_table.find(word)

In [None]:
word_index_dict['#'] = 0 

In [None]:
def compute_bigram():
    bigrams = defaultdict(int)
    unigrams = defaultdict(int)
    bigram_prob = {}
    for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):
        transcription = read_transcription(wav_file)
        transcription = '# ' + transcription + ' #'
        # We add a start symbol to give the first word bigram probability. We add an end symbol 
        # to make sure the probabilities of all sentences sum to one. 
        transcription = transcription.split()
        for i in range(len(transcription) - 1):
            bigrams[tuple(transcription[i:i+2])] += 1
            unigrams[transcription[i]] += 1 
    
    for bigram in bigrams.keys():
        bigram_prob[bigram] = bigrams[bigram] / unigrams[bigram[0]]
        
    return bigram_prob 

In [None]:
bigram_prob = compute_bigram()

In [None]:
bigram_prob_array = np.ones([len(lex.keys()) + 1, len(lex.keys()) + 1])
# bigram_prob_array contains the bigram probabilities. 
# bigram_prob_array[i, j] means p(word[j]|word[i]) 
for bigram in bigram_prob.keys():
    bigram_prob_array[word_index_dict[bigram[0]]][word_index_dict[bigram[1]]] = bigram_prob[bigram] 

In [None]:
def generate_bigram_word_wfst(n=3):
    
    '''
    Generate bigram_sequence_wfst which does not have silence states. 
    '''
    
    f = fst.Fst('log')
    start_state = f.add_state()
    f.set_start(start_state)
    word_start_states = []
    word_final_states = []
    word_index_list = []
    for word, phones in lex.items(): 
        word_index_list.append(word_index_dict[word])
        word_start_state = f.add_state()
        
        word_start_states.append(word_start_state)
        f.add_arc(start_state, fst.Arc(0, 0,\
                                       fst.Weight('log', -math.log(bigram_prob_array[0, word_index_dict[word]])),\
                                       word_start_state))
        current_state = f.add_state()
        f.add_arc(word_start_state, fst.Arc(0, word_table.find(word),\
                                            fst.Weight('log', -math.log(1)), current_state))
        
        for phone in phones:
            current_state = generate_phone_wfst(f, current_state, phone, n)
        word_final_state = current_state
        word_final_states.append(word_final_state) 
        f.set_final(word_final_state)
        
    final_state = f.add_state()
    f.set_final(final_state)
     
    for i in range(len(word_final_states)): 
        f.add_arc(word_final_states[i], fst.Arc(0, 0, \
                     fst.Weight('log', -math.log(bigram_prob_array[word_index_list[i], 0])), final_state))
        for j in range(len(word_start_states)):
            f.add_arc(word_final_states[i], fst.Arc(0, 0, \
                  fst.Weight('log', -math.log(bigram_prob_array[word_index_list[i]][word_index_list[j]])),\
                                                  word_start_states[j]))
     
    return f 

In [None]:
def generate_bigram_sil_wfst(n=3):
    
    '''
    Generate bigram_sil_wfst which contains five silence at the end of each word. 
    '''
    
    f = fst.Fst('log')
    start_state = f.add_state()
    f.set_start(start_state)
    word_start_states = []
    word_final_states = []
    word_index_list = []
    for word, phones in lex.items(): 
        word_index_list.append(word_index_dict[word])
        word_start_state = f.add_state()
        
        word_start_states.append(word_start_state)
        f.add_arc(start_state, fst.Arc(0, 0,\
                                       fst.Weight('log', -math.log(bigram_prob_array[0, word_index_dict[word]])),\
                                       word_start_state))
        current_state = f.add_state()
        f.add_arc(word_start_state, fst.Arc(0, word_table.find(word),\
                                            fst.Weight('log', -math.log(1)), current_state))
        
        for phone in phones:
            current_state = generate_phone_wfst(f, current_state, phone, n)
        
        phone_final_state = current_state 
        last_state = phone_final_state 
        
        for i in range(1, 6):
            current_state = f.add_state()
            f.add_arc(last_state, fst.Arc(state_sil_table.find('sil_{}'.format(i)), 0, \
                                         fst.Weight('log', -math.log(0.9)), current_state))
            f.add_arc(last_state, fst.Arc(state_sil_table.find('sil_{}'.format(i)), 0, \
                                            fst.Weight('log', -math.log(0.1)), last_state))
            last_state = current_state
            
            
        word_final_state = last_state
        
        
        word_final_states.append(word_final_state) 
        f.set_final(word_final_state)
        
    final_state = f.add_state()
    f.set_final(final_state)
     
    for i in range(len(word_final_states)): 
        f.add_arc(word_final_states[i], fst.Arc(0, 0, \
                     fst.Weight('log', -math.log(bigram_prob_array[word_index_list[i], 0])), final_state))
        for j in range(len(word_start_states)):
            f.add_arc(word_final_states[i], fst.Arc(0, 0, \
                  fst.Weight('log', -math.log(bigram_prob_array[word_index_list[i]][word_index_list[j]])),\
                                                  word_start_states[j]))
     
    return f  

In [None]:
bigram_sequence_wfst = generate_bigram_word_wfst()
bigram_sequence_wfst.set_input_symbols(state_table)
bigram_sequence_wfst.set_output_symbols(word_table)
from subprocess import check_call
from IPython.display import Image 
bigram_sequence_wfst.draw('tmp.dot', portrait=True)
check_call(['dot','-Tpng','-Gdpi=200','tmp.dot','-o','tmp.png'])
Image(filename='tmp.png') 

In [None]:
bigram_sil_wfst = generate_bigram_sil_wfst()
bigram_sil_wfst.set_input_symbols(state_sil_table)
bigram_sil_wfst.set_output_symbols(word_table)
from subprocess import check_call
from IPython.display import Image 
bigram_sil_wfst.draw('tmp.dot', portrait=True) 
check_call(['dot','-Tpng','-Gdpi=200','tmp.dot','-o','tmp.png'])
Image(filename='tmp.png') 

In [None]:
unigram_word_wfst = fst.compose(word_sequence_wfst, unigram_wfst) 
from subprocess import check_call
from IPython.display import Image 
unigram_word_wfst.draw('tmp.dot', portrait=True)
check_call(['dot','-Tpng','-Gdpi=200','tmp.dot','-o','tmp.png'])
Image(filename='tmp.png') 

In [None]:
def compute_states_arcs(f):
    
    '''
    Compute num states and num arcs in the transducer. 
    '''
    
    total_states = 0
    total_arcs = 0
    for state in f.states():
        total_states += 1 
        for arc in f.arcs(state):
            total_arcs += 1 
    return total_states, total_arcs 

In [None]:
import glob
import os
import wer
import observation_model
import openfst_python as fst
import time

# ... (add your code to create WFSTs and Viterbi Decoder)

def read_transcription(wav_file):
    """
    Get the transcription corresponding to wav_file.
    """
    
    transcription_file = os.path.splitext(wav_file)[0] + '.txt'
    
    with open(transcription_file, 'r') as f:
        transcription = f.readline().strip()
    
    return transcription


# f = create_wfst()
# f = create_wfst
f = word_sequence_wfst 
f.set_input_symbols(state_table) 
f.set_output_symbols(word_table)
# fdet = fst.determinize(f)

total_decode_time = 0 
total_backtrace_time = 0 
total_errors = 0 
total_words = 0 
total_computations = 0 


for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
                                                                           # audio files
    
    decoder = MyViterbiDecoder(f, wav_file)
    # decoder = MyViterbiDecoder(fdet, wav_file)
    
    decode_start_time = time.time()
    computations = decoder.decode()
    decode_end_time = time.time()
    decode_time = decode_end_time - decode_start_time
    total_decode_time += decode_time 
    print(decode_time) 
    total_computations += computations
    print(computations) 
    # (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
                                               # to return the words along the best path
    backtrace_start_time = time.time()
    # (state_path, phones) = decoder.backtrace()
    (statee_path, words) = decoder.backtrace()
    backtrace_end_time = time.time()
    backtrace_time = backtrace_end_time - backtrace_start_time 
    total_backtrace_time += backtrace_time 
    print(backtrace_time)
    # print(phones) 
    
    
    
    print(words)
    transcription = read_transcription(wav_file)
    # print(transcription)
    error_counts = wer.compute_alignment_errors(transcription, words)
    word_count = len(transcription.split()) 
    total_errors += sum(error_counts)
    total_words += word_count 
        
    print(error_counts, word_count)     # you'll need to accumulate these to produce an overall Word Error Rate

total_wer = total_errors / total_words 


In [None]:
print('word_sequence_wfst')
print('num states: {}, num arcs: {}'.format(*compute_states_arcs(word_sequence_wfst)))
print('total decode time: {}'.format(total_decode_time))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer)) 

word_sequence_wfst
num states: 116, num arcs: 230
total decode time: 974.6350224018097
total backtrace time: 0.16611146926879883
total forward computations: 32513904
total word error rate: 2.064502875924404 

In [None]:
f = word_sequence_wfst 
f.set_input_symbols(state_table) 
f.set_output_symbols(word_table)
# fdet = fst.determinize(f)

total_decode_time = 0 
total_backtrace_time = 0 
total_errors = 0 
total_words = 0 
total_computations = 0 


for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
                                                                           # audio files
    
    decoder = MyViterbiDecoder(f, wav_file)
    # decoder = MyViterbiDecoder(fdet, wav_file)
    
    decode_start_time = time.time()
    computations = decoder.decode()
    decode_end_time = time.time()
    decode_time = decode_end_time - decode_start_time
    total_decode_time += decode_time 
    print(decode_time) 
    total_computations += computations
    print(computations) 
    # (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
                                               # to return the words along the best path
    backtrace_start_time = time.time()
    # (state_path, phones) = decoder.backtrace()
    (statee_path, words) = decoder.backtrace()
    backtrace_end_time = time.time()
    backtrace_time = backtrace_end_time - backtrace_start_time 
    total_backtrace_time += backtrace_time 
    print(backtrace_time) 
    print(words)
    transcription = read_transcription(wav_file)
    print(transcription)
    error_counts = wer.compute_alignment_errors(transcription, words)
    word_count = len(transcription.split()) 
    total_errors += sum(error_counts)
    total_words += word_count 
        
    print(error_counts, word_count)     # you'll need to accumulate these to produce an overall Word Error Rate

total_wer = total_errors / total_words 

In [None]:
print('word_sequence_wfst')
print('num states: {}, num arcs: {}'.format(*compute_states_arcs(word_sequence_wfst)))
print('total decode time: {}'.format(total_decode_time))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer)) 

In [None]:
total_wer

In [None]:
f = unigram_word_wfst 
f.set_input_symbols(state_table) 
f.set_output_symbols(word_table)
# fdet = fst.determinize(f)

total_decode_time = 0 
total_backtrace_time = 0 
total_errors = 0 
total_words = 0 
total_computations = 0 


for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
                                                                           # audio files
    
    decoder = MyViterbiDecoder(f, wav_file)
    # decoder = MyViterbiDecoder(fdet, wav_file)
    
    decode_start_time = time.time()
    computations = decoder.decode()
    decode_end_time = time.time()
    decode_time = decode_end_time - decode_start_time
    total_decode_time += decode_time 
    print(decode_time) 
    total_computations += computations
    print(computations) 
    # (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
                                               # to return the words along the best path
    backtrace_start_time = time.time()
    # (state_path, phones) = decoder.backtrace()
    (statee_path, words) = decoder.backtrace()
    backtrace_end_time = time.time()
    backtrace_time = backtrace_end_time - backtrace_start_time 
    total_backtrace_time += backtrace_time 
    print(backtrace_time)
    
    print(words)
    transcription = read_transcription(wav_file)
    print(transcription)
    error_counts = wer.compute_alignment_errors(transcription, words)
    word_count = len(transcription.split()) 
    total_errors += sum(error_counts)
    total_words += word_count 
        
    print(error_counts, word_count)     # you'll need to accumulate these to produce an overall Word Error Rate

total_wer = total_errors / total_words 

In [None]:
print('unigram_word_wfst')
print('num states: {}, num arcs: {}'.format(*compute_states_arcs(unigram_word_wfst)))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total decode time: {}'.format(total_decode_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer))

unigram_word_wfst
num states: 116, num arcs: 230
total backtrace time: 0.14623546600341797
total decode time: 969.2153050899506
total forward computations: 32513904
total word error rate: 1.3919474116680361

In [None]:
f = unigram_sequence_wfst 
f.set_input_symbols(state_table) 
f.set_output_symbols(word_table)
# fdet = fst.determinize(f)

total_decode_time = 0 
total_backtrace_time = 0 
total_errors = 0 
total_words = 0 
total_computations = 0 


for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
                                                                           # audio files
    
    decoder = MyViterbiDecoder(f, wav_file)
    # decoder = MyViterbiDecoder(fdet, wav_file)
    
    decode_start_time = time.time()
    computations = decoder.decode()
    decode_end_time = time.time()
    decode_time = decode_end_time - decode_start_time
    total_decode_time += decode_time 
    print(decode_time) 
    total_computations += computations
    print(computations) 
    # (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
                                               # to return the words along the best path
    backtrace_start_time = time.time()
    # (state_path, phones) = decoder.backtrace()
    (statee_path, words) = decoder.backtrace()
    backtrace_end_time = time.time()
    backtrace_time = backtrace_end_time - backtrace_start_time 
    total_backtrace_time += backtrace_time 
    print(backtrace_time) 
    print(words)
    transcription = read_transcription(wav_file)
    print(transcription)
    error_counts = wer.compute_alignment_errors(transcription, words)
    word_count = len(transcription.split()) 
    total_errors += sum(error_counts)
    total_words += word_count 
        
    print(error_counts, word_count)     # you'll need to accumulate these to produce an overall Word Error Rate

total_wer = total_errors / total_words 

In [None]:
print('unigram sequence with transition probability 0.5')
print('num states: {}, num arcs: {}'.format(*compute_states_arcs(unigram_sequence_wfst)))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total decode time: {}'.format(total_decode_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer)) 

In [None]:
f = unigram_sequence_wfst 
f.set_input_symbols(state_table) 
f.set_output_symbols(word_table)
# fdet = fst.determinize(f)

total_decode_time = 0 
total_backtrace_time = 0 
total_errors = 0 
total_words = 0 
total_computations = 0 


for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
                                                                           # audio files
    
    decoder = MyViterbiDecoder(f, wav_file)
    # decoder = MyViterbiDecoder(fdet, wav_file)
    
    decode_start_time = time.time()
    computations = decoder.decode()
    decode_end_time = time.time()
    decode_time = decode_end_time - decode_start_time
    total_decode_time += decode_time 
    print(decode_time) 
    total_computations += computations
    print(computations) 
    # (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
                                               # to return the words along the best path
    backtrace_start_time = time.time()
    # (state_path, phones) = decoder.backtrace()
    (statee_path, words) = decoder.backtrace()
    backtrace_end_time = time.time()
    backtrace_time = backtrace_end_time - backtrace_start_time 
    total_backtrace_time += backtrace_time 
    print(backtrace_time) 
    print(words)
    transcription = read_transcription(wav_file)
    print(transcription)
    error_counts = wer.compute_alignment_errors(transcription, words)
    word_count = len(transcription.split()) 
    total_errors += sum(error_counts)
    total_words += word_count 
        
    print(error_counts, word_count)     # you'll need to accumulate these to produce an overall Word Error Rate

total_wer = total_errors / total_words 

In [None]:
print('unigram sequence with transition probability 0.8')
print('num states: {}, num arcs: {}'.format(*compute_states_arcs(unigram_sequence_wfst)))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total decode time: {}'.format(total_decode_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer)) 

In [None]:
print('tree wfst')
print('num states: {}, num arcs: {}'.format(*compute_states_arcs(tree_wfst)))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total decode time: {}'.format(total_decode_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer)) 

In [None]:
print('tree wfst with beam width 0.1.')
print('num states: {}, num arcs: {}'.format(*compute_states_arcs(tree_wfst)))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total decode time: {}'.format(total_decode_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer)) 

unigram word wfst with beam width 10.
num states: 116, num arcs: 230
total backtrace time: 0.13636112213134766
total decode time: 165.58465719223022
total forward computations: 4413984
total word error rate: 2.058340180772391

In [None]:
f = bigram_sequence_wfst 
f.set_input_symbols(state_table) 
f.set_output_symbols(word_table) 

total_decode_time = 0 
total_backtrace_time = 0 
total_errors = 0 
total_words = 0 
total_computations = 0 


for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
                                                                           # audio files
    
    decoder = MyViterbiDecoder(f, wav_file) 
    # decoder = MyViterbiDecoder(fdet, wav_file)
    
    decode_start_time = time.time()
    computations = decoder.decode()
    decode_end_time = time.time()
    decode_time = decode_end_time - decode_start_time
    total_decode_time += decode_time 
    print(decode_time) 
    total_computations += computations
    print(computations) 
    # (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
                                               # to return the words along the best path
    backtrace_start_time = time.time()
    # (state_path, phones) = decoder.backtrace()
    (statee_path, words) = decoder.backtrace()
    backtrace_end_time = time.time()
    backtrace_time = backtrace_end_time - backtrace_start_time 
    total_backtrace_time += backtrace_time 
    print(backtrace_time) 
    print(words)
    transcription = read_transcription(wav_file)
    print(transcription)
    error_counts = wer.compute_alignment_errors(transcription, words)
    word_count = len(transcription.split()) 
    total_errors += sum(error_counts)
    total_words += word_count 
        
    print(error_counts, word_count)     # you'll need to accumulate these to produce an overall Word Error Rate

total_wer = total_errors / total_words 

In [None]:
print('bigram_sequence_wfst')
print('num states: {}, num arcs: {}'.format(*compute_states_arcs(bigram_sequence_wfst)))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total decode time: {}'.format(total_decode_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer))

bigram_sequence_wfst
num states: 127, num arcs: 340
total backtrace time: 0.17756962776184082
total decode time: 1132.2548327445984
total forward computations: 32513904
total word error rate: 1.97493837304848 

unigram sequence wfst with beam width 10
num states: 116, num arcs: 230
total backtrace time: 0.1216437816619873
total decode time: 168.6038625240326
total forward computations: 4413984
total word error rate: 1.9231717337715695 

In [None]:
f = word_sil_wfst 
f.set_input_symbols(state_sil_table) 
f.set_output_symbols(word_table) 

total_decode_time = 0 
total_backtrace_time = 0 
total_errors = 0 
total_words = 0 
total_computations = 0 


for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
                                                                           # audio files
    
    decoder = MyViterbiDecoder(f, wav_file) 
    # decoder = MyViterbiDecoder(fdet, wav_file)
    
    decode_start_time = time.time()
    computations = decoder.decode()
    decode_end_time = time.time()
    decode_time = decode_end_time - decode_start_time
    total_decode_time += decode_time 
    print(decode_time) 
    total_computations += computations
    print(computations) 
    # (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
                                               # to return the words along the best path
    backtrace_start_time = time.time()
    # (state_path, phones) = decoder.backtrace()
    (statee_path, words) = decoder.backtrace()
    backtrace_end_time = time.time()
    backtrace_time = backtrace_end_time - backtrace_start_time 
    total_backtrace_time += backtrace_time 
    print(backtrace_time)  
    print(words)
    transcription = read_transcription(wav_file)
    print(transcription)
    error_counts = wer.compute_alignment_errors(transcription, words)
    word_count = len(transcription.split()) 
    total_errors += sum(error_counts)
    total_words += word_count 
        
    print(error_counts, word_count)     # you'll need to accumulate these to produce an overall Word Error Rate

total_wer = total_errors / total_words 

In [None]:
print('word sil wfst')
print('num states: {}, num arcs: {}'.format(*compute_states_arcs(word_sil_wfst)))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total decode time: {}'.format(total_decode_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer))

word sil wfst
num states: 166, num arcs: 330
total backtrace time: 0.1423349380493164
total decode time: 1387.7171621322632
total forward computations: 47778204
total word error rate: 0.5357436318816763 

In [None]:
f = word_sil_wfst 
f.set_input_symbols(state_sil_table) 
f.set_output_symbols(word_table) 

total_decode_time = 0 
total_backtrace_time = 0 
total_errors = 0 
total_words = 0 
total_computations = 0 


for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
                                                                           # audio files
    
    decoder = MyViterbiDecoder(f, wav_file, beam_size=10) 
    # decoder = MyViterbiDecoder(fdet, wav_file)
    
    decode_start_time = time.time()
    computations = decoder.decode()
    decode_end_time = time.time()
    decode_time = decode_end_time - decode_start_time
    total_decode_time += decode_time 
    print(decode_time) 
    total_computations += computations
    print(computations) 
    # (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
                                               # to return the words along the best path
    backtrace_start_time = time.time()
    # (state_path, phones) = decoder.backtrace()
    (statee_path, words) = decoder.backtrace()
    backtrace_end_time = time.time()
    backtrace_time = backtrace_end_time - backtrace_start_time 
    total_backtrace_time += backtrace_time 
    print(backtrace_time) 
    print(words)
    transcription = read_transcription(wav_file)
    print(transcription)
    error_counts = wer.compute_alignment_errors(transcription, words)
    word_count = len(transcription.split()) 
    total_errors += sum(error_counts)
    total_words += word_count 
        
    print(error_counts, word_count)     # you'll need to accumulate these to produce an overall Word Error Rate

total_wer = total_errors / total_words 

In [None]:
print('word sil wfst with beam size 10.')
print('num states: {}, num arcs: {}'.format(*compute_states_arcs(word_sil_wfst)))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total decode time: {}'.format(total_decode_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer)) 

word sil wfst with beam size 10.
num states: 166, num arcs: 330
total backtrace time: 0.13271570205688477
total decode time: 193.65572690963745
total forward computations: 4826666
total word error rate: 0.8870172555464256 

In [None]:
0.13, 0.15, 0.15, 0.15, 
print(np.array([193.66, 346.09, 507.74, 704.56]) / 1387.72)
print(np.array([4826666, 9777744, 15555670, 21399724]) / 47778204) 

In [None]:
f = word_sil_wfst 
f.set_input_symbols(state_sil_table) 
f.set_output_symbols(word_table) 

total_decode_time = 0 
total_backtrace_time = 0 
total_errors = 0 
total_words = 0 
total_computations = 0 


for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
                                                                           # audio files
    
    decoder = MyViterbiDecoder(f, wav_file, beam_size=30) 
    # decoder = MyViterbiDecoder(fdet, wav_file)
    
    decode_start_time = time.time()
    computations = decoder.decode()
    decode_end_time = time.time()
    decode_time = decode_end_time - decode_start_time
    total_decode_time += decode_time 
    print(decode_time) 
    total_computations += computations
    print(computations) 
    # (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
                                               # to return the words along the best path
    backtrace_start_time = time.time()
    # (state_path, phones) = decoder.backtrace()
    (statee_path, words) = decoder.backtrace()
    backtrace_end_time = time.time()
    backtrace_time = backtrace_end_time - backtrace_start_time 
    total_backtrace_time += backtrace_time 
    print(backtrace_time) 
    print(words)
    transcription = read_transcription(wav_file)
    print(transcription)
    error_counts = wer.compute_alignment_errors(transcription, words)
    word_count = len(transcription.split()) 
    total_errors += sum(error_counts)
    total_words += word_count 
        
    print(error_counts, word_count)     # you'll need to accumulate these to produce an overall Word Error Rate

total_wer = total_errors / total_words 

In [None]:
print('word sil wfst with beam size 30.')
print('num states: {}, num arcs: {}'.format(*compute_states_arcs(word_sil_wfst)))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total decode time: {}'.format(total_decode_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer)) 

word sil wfst with beam size 30.
num states: 166, num arcs: 330
total backtrace time: 0.1462109088897705
total decode time: 346.09289479255676
total forward computations: 9777744
total word error rate: 0.6322925225965489 

In [None]:
f = word_sil_wfst 
f.set_input_symbols(state_sil_table) 
f.set_output_symbols(word_table) 

total_decode_time = 0 
total_backtrace_time = 0 
total_errors = 0 
total_words = 0 
total_computations = 0 


for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
                                                                           # audio files
    
    decoder = MyViterbiDecoder(f, wav_file, beam_size=50) 
    # decoder = MyViterbiDecoder(fdet, wav_file)
    
    decode_start_time = time.time()
    computations = decoder.decode()
    decode_end_time = time.time()
    decode_time = decode_end_time - decode_start_time
    total_decode_time += decode_time 
    print(decode_time) 
    total_computations += computations
    print(computations) 
    # (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
                                               # to return the words along the best path
    backtrace_start_time = time.time()
    # (state_path, phones) = decoder.backtrace()
    (statee_path, words) = decoder.backtrace()
    backtrace_end_time = time.time()
    backtrace_time = backtrace_end_time - backtrace_start_time 
    total_backtrace_time += backtrace_time 
    print(backtrace_time) 
    print(words)
    transcription = read_transcription(wav_file)
    print(transcription)
    error_counts = wer.compute_alignment_errors(transcription, words)
    word_count = len(transcription.split()) 
    total_errors += sum(error_counts)
    total_words += word_count 
        
    print(error_counts, word_count)     # you'll need to accumulate these to produce an overall Word Error Rate

total_wer = total_errors / total_words 

In [None]:
print('word sil wfst with beam size 50.')
print('num states: {}, num arcs: {}'.format(*compute_states_arcs(word_sil_wfst)))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total decode time: {}'.format(total_decode_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer)) 

word sil wfst with beam size 50.
num states: 166, num arcs: 330
total backtrace time: 0.1522197723388672
total decode time: 507.74493885040283
total forward computations: 15555670
total word error rate: 0.5579293344289236 

In [None]:
f = word_sil_wfst 
f.set_input_symbols(state_sil_table) 
f.set_output_symbols(word_table) 

total_decode_time = 0 
total_backtrace_time = 0 
total_errors = 0 
total_words = 0 
total_computations = 0 


for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
                                                                           # audio files
    
    decoder = MyViterbiDecoder(f, wav_file, beam_size=70) 
    # decoder = MyViterbiDecoder(fdet, wav_file)
    
    decode_start_time = time.time()
    computations = decoder.decode()
    decode_end_time = time.time()
    decode_time = decode_end_time - decode_start_time
    total_decode_time += decode_time 
    print(decode_time) 
    total_computations += computations
    print(computations) 
    # (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
                                               # to return the words along the best path
    backtrace_start_time = time.time()
    # (state_path, phones) = decoder.backtrace()
    (statee_path, words) = decoder.backtrace()
    backtrace_end_time = time.time()
    backtrace_time = backtrace_end_time - backtrace_start_time 
    total_backtrace_time += backtrace_time 
    print(backtrace_time) 
    print(words)
    transcription = read_transcription(wav_file)
    print(transcription)
    error_counts = wer.compute_alignment_errors(transcription, words)
    word_count = len(transcription.split()) 
    total_errors += sum(error_counts)
    total_words += word_count 
        
    print(error_counts, word_count)     # you'll need to accumulate these to produce an overall Word Error Rate

total_wer = total_errors / total_words 

In [None]:
print('word sil wfst with beam size 70.')
print('num states: {}, num arcs: {}'.format(*compute_states_arcs(word_sil_wfst)))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total decode time: {}'.format(total_decode_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer)) 

word sil wfst with beam width 0.1.
num states: 166, num arcs: 330
total backtrace time: 0.12229657173156738
total decode time: 118.51288843154907
total forward computations: 2015340
total word error rate: 1.1092851273623665 

In [None]:
f = word_sil_wfst 
f.set_input_symbols(state_sil_table) 
f.set_output_symbols(word_table) 

total_decode_time = 0 
total_backtrace_time = 0 
total_errors = 0 
total_words = 0 
total_computations = 0 


for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
                                                                           # audio files
    
    decoder = MyViterbiDecoder(f, wav_file, beam_width=0.01) 
    # decoder = MyViterbiDecoder(fdet, wav_file)
    
    decode_start_time = time.time()
    computations = decoder.decode()
    decode_end_time = time.time()
    decode_time = decode_end_time - decode_start_time
    total_decode_time += decode_time 
    print(decode_time) 
    total_computations += computations
    print(computations) 
    # (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
                                               # to return the words along the best path
    backtrace_start_time = time.time()
    # (state_path, phones) = decoder.backtrace()
    (statee_path, words) = decoder.backtrace()
    backtrace_end_time = time.time()
    backtrace_time = backtrace_end_time - backtrace_start_time 
    total_backtrace_time += backtrace_time 
    print(backtrace_time) 
    print(words)
    transcription = read_transcription(wav_file)
    print(transcription)
    error_counts = wer.compute_alignment_errors(transcription, words)
    word_count = len(transcription.split()) 
    total_errors += sum(error_counts)
    total_words += word_count 
        
    print(error_counts, word_count)     # you'll need to accumulate these to produce an overall Word Error Rate

total_wer = total_errors / total_words 

In [None]:
print('word sil wfst with beam width 0.01.')
print('num states: {}, num arcs: {}'.format(*compute_states_arcs(word_sil_wfst)))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total decode time: {}'.format(total_decode_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer)) 

In [None]:
print('word sil wfst with beam width 1e-6.')
print('num states: {}, num arcs: {}'.format(*compute_states_arcs(word_sil_wfst)))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total decode time: {}'.format(total_decode_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer)) 

In [None]:
print('word sil wfst with beam width 1e-8.')
print('num states: {}, num arcs: {}'.format(*compute_states_arcs(word_sil_wfst)))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total decode time: {}'.format(total_decode_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer)) 

In [None]:
print('word sil wfst with beam width 1e-10.')
print('num states: {}, num arcs: {}'.format(*compute_states_arcs(word_sil_wfst)))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total decode time: {}'.format(total_decode_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer)) 

In [None]:
f = word_sil_wfst 
f.set_input_symbols(state_sil_table) 
f.set_output_symbols(word_table) 

total_decode_time = 0 
total_backtrace_time = 0 
total_errors = 0 
total_words = 0 
total_computations = 0 


for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
                                                                           # audio files
    
    decoder = MyViterbiDecoder(f, wav_file, beam_width=1e-12) 
    # decoder = MyViterbiDecoder(fdet, wav_file)
    
    decode_start_time = time.time()
    computations = decoder.decode()
    decode_end_time = time.time()
    decode_time = decode_end_time - decode_start_time
    total_decode_time += decode_time 
    print(decode_time) 
    total_computations += computations
    print(computations) 
    # (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
                                               # to return the words along the best path
    backtrace_start_time = time.time()
    # (state_path, phones) = decoder.backtrace()
    (statee_path, words) = decoder.backtrace()
    backtrace_end_time = time.time()
    backtrace_time = backtrace_end_time - backtrace_start_time 
    total_backtrace_time += backtrace_time 
    print(backtrace_time) 
    print(words)
    transcription = read_transcription(wav_file)
    print(transcription)
    error_counts = wer.compute_alignment_errors(transcription, words)
    word_count = len(transcription.split()) 
    total_errors += sum(error_counts)
    total_words += word_count 
        
    print(error_counts, word_count)     # you'll need to accumulate these to produce an overall Word Error Rate

total_wer = total_errors / total_words 

In [None]:
print('word sil wfst with beam width 1e-12.')
print('num states: {}, num arcs: {}'.format(*compute_states_arcs(word_sil_wfst)))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total decode time: {}'.format(total_decode_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer)) 

In [None]:
0.1,0.01, 0.001, 0.0001, 0.00001, 1e-6 
print(np.array([118.51, 131.78, 149.83, 172.63, 188.85,  208.86, 261.79, 300.51, 345.89]) / 1387.72)
print(np.array([2015340, 2641318, 3196708, 3726574, 4328928, 5061720, 6840482, 8167990, 9714586]) / 47778204) 

In [None]:
print('word sil wfst with beam width 0.0001.')
print('num states: {}, num arcs: {}'.format(*compute_states_arcs(word_sil_wfst)))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total decode time: {}'.format(total_decode_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer)) 

In [None]:
print('word sil wfst with beam width 0.00001.')
print('num states: {}, num arcs: {}'.format(*compute_states_arcs(word_sil_wfst)))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total decode time: {}'.format(total_decode_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer)) 

In [None]:
f = tree_sequence_wfst 
f.set_input_symbols(state_sil_table) 
f.set_output_symbols(word_table) 

total_decode_time = 0 
total_backtrace_time = 0 
total_errors = 0 
total_words = 0 
total_computations = 0 


for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
                                                                           # audio files
    
    decoder = MyViterbiDecoder(f, wav_file) 
    # decoder = MyViterbiDecoder(fdet, wav_file)
    
    decode_start_time = time.time()
    computations = decoder.decode()
    decode_end_time = time.time()
    decode_time = decode_end_time - decode_start_time
    total_decode_time += decode_time 
    print(decode_time) 
    total_computations += computations
    print(computations) 
    # (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
                                               # to return the words along the best path
    backtrace_start_time = time.time()
    # (state_path, phones) = decoder.backtrace()
    (statee_path, words) = decoder.backtrace()
    backtrace_end_time = time.time()
    backtrace_time = backtrace_end_time - backtrace_start_time 
    total_backtrace_time += backtrace_time 
    print(backtrace_time) 
    print(words)
    transcription = read_transcription(wav_file)
    print(transcription)
    error_counts = wer.compute_alignment_errors(transcription, words)
    word_count = len(transcription.split()) 
    total_errors += sum(error_counts)
    total_words += word_count 
        
    print(error_counts, word_count)     # you'll need to accumulate these to produce an overall Word Error Rate

total_wer = total_errors / total_words 

In [None]:
print('tree sequence wfst')
print('num states: {}, num arcs: {}'.format(*compute_states_arcs(tree_sequence_wfst)))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total decode time: {}'.format(total_decode_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer)) 

In [None]:
f = bigram_sil_wfst 
f.set_input_symbols(state_sil_table) 
f.set_output_symbols(word_table) 

total_decode_time = 0 
total_backtrace_time = 0 
total_errors = 0 
total_words = 0 
total_computations = 0 


for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
                                                                           # audio files
    
    decoder = MyViterbiDecoder(f, wav_file) 
    # decoder = MyViterbiDecoder(fdet, wav_file)
    
    decode_start_time = time.time()
    computations = decoder.decode()
    decode_end_time = time.time()
    decode_time = decode_end_time - decode_start_time
    total_decode_time += decode_time 
    print(decode_time) 
    total_computations += computations
    print(computations) 
    # (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
                                               # to return the words along the best path
    backtrace_start_time = time.time()
    # (state_path, phones) = decoder.backtrace()
    (statee_path, words) = decoder.backtrace()
    backtrace_end_time = time.time()
    backtrace_time = backtrace_end_time - backtrace_start_time 
    total_backtrace_time += backtrace_time 
    print(backtrace_time) 
    print(words)
    transcription = read_transcription(wav_file)
    print(transcription)
    error_counts = wer.compute_alignment_errors(transcription, words)
    word_count = len(transcription.split()) 
    total_errors += sum(error_counts)
    total_words += word_count 
        
    print(error_counts, word_count)     # you'll need to accumulate these to produce an overall Word Error Rate

total_wer = total_errors / total_words 

In [None]:
print('bigram sil wfst')
print('num state: {}, num arcs: {}'.format(*compute_states_arcs(sequence_sil_wfst)))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total decode time: {}'.format(total_decode_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer)) 

In [None]:
compute_states_arcs(bigram_sil_wfst)

In [None]:
f = sequence_sil_wfst 
f.set_input_symbols(state_sil_table) 
f.set_output_symbols(word_table) 

total_decode_time = 0 
total_backtrace_time = 0 
total_errors = 0 
total_words = 0 
total_computations = 0 


for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
                                                                           # audio files
    
    decoder = MyViterbiDecoder(f, wav_file) 
    # decoder = MyViterbiDecoder(fdet, wav_file)
    
    decode_start_time = time.time()
    computations = decoder.decode()
    decode_end_time = time.time()
    decode_time = decode_end_time - decode_start_time
    total_decode_time += decode_time 
    print(decode_time) 
    total_computations += computations
    print(computations) 
    # (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
                                               # to return the words along the best path
    backtrace_start_time = time.time()
    # (state_path, phones) = decoder.backtrace()
    (statee_path, words) = decoder.backtrace()
    backtrace_end_time = time.time()
    backtrace_time = backtrace_end_time - backtrace_start_time 
    total_backtrace_time += backtrace_time 
    print(backtrace_time) 
    print(words)
    transcription = read_transcription(wav_file)
    print(transcription)
    error_counts = wer.compute_alignment_errors(transcription, words)
    word_count = len(transcription.split()) 
    total_errors += sum(error_counts)
    total_words += word_count 
        
    print(error_counts, word_count)     # you'll need to accumulate these to produce an overall Word Error Rate

total_wer = total_errors / total_words 

In [None]:
print('''sequence sil wfst, the transition probability from the final phone state of the word to the silence 
state: 0.9 * 0.5, to the final state of the word: 0.9 * 0.5. ''')
print('num states: {}, num arcs: {}'.format(*compute_states_arcs(sequence_sil_wfst)))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total decode time: {}'.format(total_decode_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer))

sequence sil wfst, the transition probability from the final phone state of the word to the silence 
state: 0.9 * 0.5, to the final state of the word: 0.9 * 0.5. 
num states: 166, num arcs: 340
total backtrace time: 0.1382451057434082
total decode time: 1413.3285076618195
total forward computations: 49310994
total word error rate: 0.7140509449465899 

In [None]:
f = sequence_sil_wfst 
f.set_input_symbols(state_sil_table) 
f.set_output_symbols(word_table) 

total_decode_time = 0 
total_backtrace_time = 0 
total_errors = 0 
total_words = 0 
total_computations = 0 


for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
                                                                           # audio files
    
    decoder = MyViterbiDecoder(f, wav_file) 
    # decoder = MyViterbiDecoder(fdet, wav_file)
    
    decode_start_time = time.time()
    computations = decoder.decode()
    decode_end_time = time.time()
    decode_time = decode_end_time - decode_start_time
    total_decode_time += decode_time 
    print(decode_time) 
    total_computations += computations
    print(computations) 
    # (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
                                               # to return the words along the best path
    backtrace_start_time = time.time()
    # (state_path, phones) = decoder.backtrace()
    (statee_path, words) = decoder.backtrace()
    backtrace_end_time = time.time()
    backtrace_time = backtrace_end_time - backtrace_start_time 
    total_backtrace_time += backtrace_time 
    print(backtrace_time)
    print(words)
    transcription = read_transcription(wav_file)
    print(transcription)
    error_counts = wer.compute_alignment_errors(transcription, words)
    word_count = len(transcription.split()) 
    total_errors += sum(error_counts)
    total_words += word_count 
        
    print(error_counts, word_count)     # you'll need to accumulate these to produce an overall Word Error Rate

total_wer = total_errors / total_words 

In [None]:
print('''sequence sil wfst, the transition probability from the final phone state of the word to the silence 
state: 0.9 * 0.8, to the final word state: 0.9 * 0.2''')
print('num state: {}, num arcs: {}'.format(*compute_states_arcs(sequence_sil_wfst)))
print('total backtrace time: {}'.format(total_backtrace_time))
print('total decode time: {}'.format(total_decode_time))
print('total forward computations: {}'.format(total_computations))
print('total word error rate: {}'.format(total_wer))

sequence sil wfst, the transition probability from the final phone state of the word to the silence 
state: 0.9 * 0.8, to the final word state: 0.9 * 0.2
num state: 166, num arcs: 340
total backtrace time: 0.1392526626586914
total decode time: 1413.9680876731873
total forward computations: 49310994
total word error rate: 0.6972062448644207 

In [None]:
import glob
import os
import wer
import observation_model
import openfst_python as fst
import time

# ... (add your code to create WFSTs and Viterbi Decoder)

def read_transcription(wav_file):
    """
    Get the transcription corresponding to wav_file.
    """
    
    transcription_file = os.path.splitext(wav_file)[0] + '.txt'
    
    with open(transcription_file, 'r') as f:
        transcription = f.readline().strip()
    
    return transcription

 
# f = create_wfst()
f = create_wfst
f.set_input_symbols(state_table)
f.set_output_symbols(word_table)
fdet = fst.determinize(f)
fmin = fdet.minimize()

total_decode_time = 0 
total_backtrace_time = 0 
total_errors = 0 
total_words = 0 


for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
                                                                           # audio files
    decoder = ViterbiDecoder()
    decoder = MyViterbiDecoder(f, wav_file)
    # decoder = MyViterbiDecoder(fdet, wav_file)
    
    decode_start_time = time.time()
    decoder.decode()
    decode_end_time = time.time()
    decode_time = decode_end_time - decode_start_time
    total_decode_time += decode_time 
    print(decode_time)
    # (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
                                               # to return the words along the best path
    backtrace_start_time = time.time()
    (state_path, words) = decoder.backtrace()
    backtrace_end_time = time.time()
    backtrace_time = backtrace_end_time - backtrace_start_time 
    total_backtrace_time += backtrace_time 
    print(backtrace_time)  
    
    
    print(words)
    transcription = read_transcription(wav_file)
    # print(transcription)
    error_counts = wer.compute_alignment_errors(transcription, words)
    word_count = len(transcription.split()) 
    total_errors += sum(error_counts)
    total_words += word_count 
        
    print(error_counts, word_count)     # you'll need to accumulate these to produce an overall Word Error Rate

total_wer = total_errors / total_words 

In [None]:
# fdet
from subprocess import check_call
from IPython.display import Image
f.draw('tmp.dot', portrait=True)
# fdet.draw('tmp.dot', portrait=True)
check_call(['dot','-Tpng','-Gdpi=200','tmp.dot','-o','tmp.png'])
Image(filename='tmp.png')