# ASR Assignment 2022-23

This notebook has been provided as a template to get you started on the assignment.  Feel free to use it for your development, or do your development directly in Python.

You can find a full description of the assignment [here](http://www.inf.ed.ac.uk/teaching/courses/asr/2022-23/coursework.pdf).

You are provided with two Python modules `observation_model.py` and `wer.py`.  The first was described in [Lab 3](https://github.com/ZhaoZeyu1995/asr_labs/blob/master/asr_lab3_4.ipynb).  The second can be used to compute the number of substitution, deletion and insertion errors between ASR output and a reference text.

It can be used as follows:

```python
import wer

my_refence = 'A B C'
my_output = 'A C C D'

wer.compute_alignment_errors(my_reference, my_output)
```

This produces a tuple $(s,d,i)$ giving counts of substitution,
deletion and insertion errors respectively - in this example (1, 0, 1).  The function accepts either two strings, as in the example above, or two lists.  Matching is case sensitive.

## Template code

Assuming that you have already made a function to generate an WFST, `create_wfst()` and a decoder class, `MyViterbiDecoder`, you can perform recognition on all the audio files as follows:


In [1]:
import observation_model
import math
import openfst_python as fst

from subprocess import check_call
from IPython.display import Image

import glob
import os
import wer


In [93]:
class MyViterbiDecoder:
    
    NLL_ZERO = 1e10  # define a constant representing -log(0).  This is really infinite, but approximate
                     # it here with a very large number
    
    def __init__(self, f, audio_file_name):
        """Set up the decoder class with an audio file and WFST f
        """
        self.om = observation_model.ObservationModel()
        self.f = f
        
        if audio_file_name:
            self.om.load_audio(audio_file_name)
        else:
            self.om.load_dummy_audio()
        
        self.initialise_decoding()

        
    def initialise_decoding(self):
        """set up the values for V_j(0) (as negative log-likelihoods)
        
        """
        
        self.V = []   # stores likelihood along best path reaching state j
        self.B = []   # stores identity of best previous state reaching state j
        self.W = []   # stores output labels sequence along arc reaching j - this removes need for 
                      # extra code to read the output sequence along the best path
        
        for t in range(self.om.observation_length()+1):
            self.V.append([self.NLL_ZERO]*self.f.num_states())
            self.B.append([-1]*self.f.num_states())
            self.W.append([[] for i in range(self.f.num_states())])  #  multiplying the empty list doesn't make multiple
        
        # The above code means that self.V[t][j] for t = 0, ... T gives the Viterbi cost
        # of state j, time t (in negative log-likelihood form)
        # Initialising the costs to NLL_ZERO effectively means zero probability    
        
        # give the WFST start state a probability of 1.0   (NLL = 0.0)
        self.V[0][self.f.start()] = 0.0
        
        # some WFSTs might have arcs with epsilon on the input (you might have already created 
        # examples of these in earlier labs) these correspond to non-emitting states, 
        # which means that we need to process them without stepping forward in time.  
        # Don't worry too much about this!  
        self.traverse_epsilon_arcs(0)        
        
    def traverse_epsilon_arcs(self, t):
        """Traverse arcs with <eps> on the input at time t
        
        These correspond to transitions that don't emit an observation
        
        We've implemented this function for you as it's slightly trickier than
        the normal case.  You might like to look at it to see what's going on, but
        don't worry if you can't fully follow it.
        
        """
        
        states_to_traverse = list(self.f.states()) # traverse all states
        while states_to_traverse:
            
            # Set i to the ID of the current state, the first 
            # item in the list (and remove it from the list)
            i = states_to_traverse.pop(0)   
        
            # don't bother traversing states which have zero probability
            if self.V[t][i] == self.NLL_ZERO:
                    continue
        
            for arc in self.f.arcs(i):
                
                if arc.ilabel == 0:     # if <eps> transition
                  
                    j = arc.nextstate   # ID of next state  
                
                    if self.V[t][j] > self.V[t][i] + float(arc.weight):
                        
                        # this means we've found a lower-cost path to
                        # state j at time t.  We might need to add it
                        # back to the processing queue.
                        self.V[t][j] = self.V[t][i] + float(arc.weight)
                        
                        # save backtrace information.  In the case of an epsilon transition, 
                        # we save the identity of the best state at t-1.  This means we may not
                        # be able to fully recover the best path, but to do otherwise would
                        # require a more complicated way of storing backtrace information
                        self.B[t][j] = self.B[t][i] 
                        
                        # and save the output labels encountered - this is a list, because
                        # there could be multiple output labels (in the case of <eps> arcs)
                        if arc.olabel != 0:
                            self.W[t][j] = self.W[t][i] + [arc.olabel]
                        else:
                            self.W[t][j] = self.W[t][i]
                        
                        if j not in states_to_traverse:
                            states_to_traverse.append(j)

    
    def forward_step(self, t):
          
        for i in self.f.states():
            
            if not self.V[t-1][i] == self.NLL_ZERO:   # no point in propagating states with zero probability
                
                for arc in self.f.arcs(i):
                    
                    if arc.ilabel != 0: # <eps> transitions don't emit an observation
                        j = arc.nextstate
                        tp = float(arc.weight)  # transition prob
                        ep = -self.om.log_observation_probability(self.f.input_symbols().find(arc.ilabel), t)  # emission negative log prob
                        prob = tp + ep + self.V[t-1][i] # they're logs
                        if prob < self.V[t][j]:
                            self.V[t][j] = prob
                            self.B[t][j] = i
                            
                            # store the output labels encountered too
                            if arc.olabel !=0:
                                self.W[t][j] = [arc.olabel]
                            else:
                                self.W[t][j] = []
                            
    
    def finalise_decoding(self):
        """ this incorporates the probability of terminating at each state
        """
        
        for state in self.f.states():
            final_weight = float(self.f.final(state))
            if self.V[-1][state] != self.NLL_ZERO:
                if final_weight == math.inf:
                    self.V[-1][state] = self.NLL_ZERO  # effectively says that we can't end in this state
                else:
                    self.V[-1][state] += final_weight
                    
        # get a list of all states where there was a path ending with non-zero probability
        finished = [x for x in self.V[-1] if x < self.NLL_ZERO]
        if not finished:  # if empty
            print("No path got to the end of the observations.")
        
        
    def decode(self):
        self.initialise_decoding()
        t = 1
        while t <= self.om.observation_length():
            self.forward_step(t)
            self.traverse_epsilon_arcs(t)
            t += 1
        self.finalise_decoding()
    
    def backtrace(self):
        
        best_final_state = self.V[-1].index(min(self.V[-1])) # argmin
        best_state_sequence = [best_final_state]
        best_out_sequence = []
        
        t = self.om.observation_length()   # ie T
        j = best_final_state
        
        while t >= 0:
            i = self.B[t][j]
            best_state_sequence.append(i)
            best_out_sequence = self.W[t][j] + best_out_sequence  # computer scientists might like
                                                                                # to make this more efficient!

            # continue the backtrace at state i, time t-1
            j = i  
            t-=1
            
        best_state_sequence.reverse()
        
        # convert the best output sequence from FST integer labels into strings
        best_out_sequence = ' '.join([ self.f.output_symbols().find(label) for label in best_out_sequence])
        
        return (best_state_sequence, best_out_sequence)
    


In [122]:
def show_wfst(f):
    f.draw('tmp.dot', portrait=True)
    check_call(['dot','-Tpng','-Gdpi=500','tmp.dot','-o','tmp.png'])
    Image(filename='tmp.png')

In [217]:
def parse_lexicon(lex_file):
    """
    Parse the lexicon file and return it in dictionary form.
    
    Args:
        lex_file (str): filename of lexicon file with structure '<word> <phone1> <phone2>...'
                        eg. peppers p eh p er z

    Returns:
        lex (dict): dictionary mapping words to list of phones
    """
    
    lex = {}  # create a dictionary for the lexicon entries (this could be a problem with larger lexica)
    with open(lex_file, 'r') as f:
        for line in f:
            line = line.split()  # split at each space
            if line[0] in lex.keys():
                lex[line[0] + "_"] = line[1:] 
            else:
                lex[line[0]] = line[1:]  # first field the word, the rest is the phones
    return lex

def generate_symbol_tables(lexicon, n=3):
    '''
    Return word, phone and state symbol tables based on the supplied lexicon
        
    Args:
        lexicon (dict): lexicon to use, created from the parse_lexicon() function
        n (int): number of states for each phone HMM
        
    Returns:
        word_table (fst.SymbolTable): table of words
        phone_table (fst.SymbolTable): table of phones
        state_table (fst.SymbolTable): table of HMM phone-state IDs
    '''
    
    state_table = fst.SymbolTable()
    phone_table = fst.SymbolTable()
    word_table = fst.SymbolTable()
    
    # add empty <eps> symbol to all tables
    state_table.add_symbol('<eps>')
    phone_table.add_symbol('<eps>')
    word_table.add_symbol('<eps>')
    
    for word, phones  in lexicon.items():
        
        word_table.add_symbol(word)
        
        for p in phones: # for each phone
            
            phone_table.add_symbol(p)
            for i in range(1,n+1): # for each state 1 to n
                state_table.add_symbol('{}_{}'.format(p, i))
            
    return word_table, phone_table, state_table


# call these two functions
lex = parse_lexicon('lexicon.txt')
word_table, phone_table, state_table = generate_symbol_tables(lex)

def generate_phone_wfst(f, start_state, phone, n):
    """
    Generate a WFST representating an n-state left-to-right phone HMM
    
    Args:
        f (fst.Fst()): an FST object, assumed to exist already
        start_state (int): the index of the first state, assmed to exist already
        phone (str): the phone label 
        n (int): number of states for each phone HMM
        
    Returns:
        the final state of the FST
    """
    
    current_state = start_state
    sl_weight = fst.Weight('log', 10)  # weight for self-loop
    next_weight = fst.Weight('log', 15) # weight to next state
    
    for i in range(1, n+1):
        
        in_label = state_table.find('{}_{}'.format(phone, i))
        
        # self-loop back to current state
        f.add_arc(current_state, fst.Arc(in_label, 0, sl_weight, current_state))
        
        # transition to next state
        
        # we want to output the phone label on the final state
        # note: if outputting words instead this code should be modified
        if i == n:
            out_label = phone_table.find(phone)
        else:
            out_label = 0   # output empty <eps> label
            
        next_state = f.add_state()
        f.add_arc(current_state, fst.Arc(in_label, 0, next_weight, next_state))    # changed to 0 ! 
       
        current_state = next_state
        
    return current_state

def generate_word_wfst(word):
    """ Generate a WFST for any word in the lexicon, composed of 3-state phone WFSTs.
        This will currently output word labels.  
        Exercise: could you modify this function and the one above to output a single phone label instead?
    
    Args:
        word (str): the word to generate
        
    Returns:
        the constructed WFST
    
    """
    f = fst.Fst('log')
    
    # create the start state
    start_state = f.add_state()
    f.set_start(start_state)
    certain_weight =  fst.Weight('log', -math.log(1))
    
    current_state = start_state
    
    # iterate over all the phones in the word
    for (i,phone) in enumerate(lex[word]):   # will raise an exception if word is not in the lexicon
        
        current_state = generate_phone_wfst(f, current_state, phone, 3)
    
        if i == len(lex[word]) - 1:
            next_state = f.add_state()
            f.add_arc(current_state, fst.Arc(in_label, word_table.find(word), certain_weight, current_state))
            
        # note: new current_state is now set to the final state of the previous phone WFST
        
    f.set_final(current_state)
    
    return f

def generate_word_sequence_recognition_wfst(n, probs):
    """ generate a HMM to recognise any single word sequence for words in the lexicon
    
    Args:
        n (int): states per phone HMM

    Returns:
        the constructed WFST
    
    """
    
    f = fst.Fst('log')

    #even_weight = fst.Weight('log', -math.log(1/len(lex)))
    
    #even_weight = fst.Weight('log', -math.log(1/len(lex)))
    #reduced_weight = fst.Weight('log', -math.log(1/(5*len(lex))))
    next_weight = fst.Weight('log', -math.log(0.1))
    certain_weight =  fst.Weight('log', -math.log(1))
    
    # create a single start state
    start_state = f.add_state()
    #f.add_arc(start_state, fst.Arc(0, 0, fst.Weight('log', -math.log(0.3)), start_state))
    
    f.set_start(start_state)
    
    for word, phones in lex.items():
        current_state = f.add_state()

        f.add_arc(start_state, fst.Arc(0, 0, fst.Weight('log', -math.log(probs[w])), current_state))
        
        for (i, phone) in enumerate(phones): 
            current_state = generate_phone_wfst(f, current_state, phone, n)
        # note: new current_state is now set to the final state of the previous phone WFST
            if i == len(lex[word]) - 1:

                next_state = f.add_state()
                f.add_arc(current_state, fst.Arc(0, word_table.find(word.replace("_", "")), certain_weight, next_state))
                current_state= next_state
                
        f.set_final(current_state)
        f.add_arc(current_state, fst.Arc(0, 0, next_weight, start_state))
        
    return f



In [218]:
show_wfst(f)

In [219]:
def create_wfst(n, state_table, phone_table, word_probabilities):
    # word probabilities: a dictionary, to adjust weights. 
    f = generate_word_sequence_recognition_wfst(n, word_probabilities)
    f.set_input_symbols(state_table)
    f.set_output_symbols(word_table)
    return f

In [220]:
even_dict = {}
for word, _ in lex.items():
    even_dict[w] = 0.1

In [221]:
def read_transcription(wav_file):
    """
    Get the transcription corresponding to wav_file.
    """
    

    transcription_file = os.path.splitext(wav_file)[0] + '.txt'
    
    with open(transcription_file, 'r') as f:
        transcription = f.readline().strip()
    
    return transcription


c = 0

word_table, phone_table, state_table = generate_symbol_tables(lex)
f = create_wfst(3, state_table, phone_table, even_dict)
errors_sum = 0
for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
    if c < 10:                                                                       # audio files
        c+=1
        decoder = MyViterbiDecoder(f, wav_file)
    
        
        decoder.decode()
        %time (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
        transcription = read_transcription(wav_file)                                           # to return the words along the best path
        print (words)
        print(transcription)
    
        error_counts = wer.compute_alignment_errors(transcription, words)
        word_count = len(transcription.split())
    
        print(error_counts, word_count)     # you'll need to accumulate these
        errors_sum += sum(error_counts)
print(errors_sum)

CPU times: user 402 µs, sys: 6 µs, total: 408 µs
Wall time: 413 µs
the pickled of of peter the
a pickled piper of peter
(2, 0, 1) 5
CPU times: user 264 µs, sys: 0 ns, total: 264 µs
Wall time: 273 µs
the where's peter the
where's peter
(0, 0, 2) 2
CPU times: user 440 µs, sys: 0 ns, total: 440 µs
Wall time: 448 µs
the peter peck peck the
peter picked a peck
(1, 1, 2) 4
CPU times: user 294 µs, sys: 5 µs, total: 299 µs
Wall time: 302 µs
the where's the peppers the
where's the peppers
(0, 0, 2) 3
CPU times: user 316 µs, sys: 5 µs, total: 321 µs
Wall time: 324 µs
the piper pickled peppers the
the piper pickled peppers
(0, 0, 1) 4
CPU times: user 0 ns, sys: 958 µs, total: 958 µs
Wall time: 965 µs
the a of of pickled peck of pickled a where's the
peter piper picked a peck of pickled peppers
(5, 0, 3) 8
CPU times: user 352 µs, sys: 5 µs, total: 357 µs
Wall time: 361 µs
the peck pickled picked the
a peck of pickled peppers peter piper picked
(1, 4, 1) 8
CPU times: user 395 µs, sys: 6 µs, total: 

In [132]:
f = create_wfst(3, state_table, phone_table, even_dict)
WER_estimate(f, 15)

(11.41127450980392, 15)

Number of states for assesing memory

In [None]:
len(list(f.states()))

In [164]:
a = 1
for word, phones in lex.items():
        a += 3*len(phones) + 2
a

139

## Task 1

wfst with unigram probabilities based on counts, instead of even probabilities for all words

In [222]:
c = {}
for word in lex.keys():
    c[word] = 0
c["SUM"] = 0
for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):
    transcript=read_transcription(wav_file)
    for word in lex.keys():
        count = transcript.count(word)
        c[word] += count
        c["SUM"] += count
        
        
unigram_probs = {}
for w, count in c.items():
    if "_" not in w:
        unigram_probs[w] = count /c["SUM"]
        
unigram_probs

{'a': 0.05834018077239113,
 'of': 0.10230073952341824,
 'peck': 0.11133935907970419,
 'peppers': 0.13475760065735415,
 'peter': 0.12900575184880855,
 'picked': 0.11298274445357437,
 'pickled': 0.11750205423171733,
 'piper': 0.11503697617091208,
 'the': 0.06409202958093672,
 "where's": 0.05464256368118324,
 'SUM': 1.0}

In [223]:
def generate_word_sequence_recognition_wfst_unigram(n, probs):
    """ generate a HMM to recognise any single word sequence for words in the lexicon
    
    Args:
        n (int): states per phone HMM

    Returns:
        the constructed WFST
    
    """
    
    f = fst.Fst('log')

    #even_weight = fst.Weight('log', -math.log(1/len(lex)))
    
    #even_weight = fst.Weight('log', -math.log(1/len(lex)))
    #reduced_weight = fst.Weight('log', -math.log(1/(5*len(lex))))
    next_weight = fst.Weight('log', -math.log(0.05))
    certain_weight =  fst.Weight('log', -math.log(1))
    
    # create a single start state
    start_state = f.add_state()
    #f.add_arc(start_state, fst.Arc(0, 0, fst.Weight('log', -math.log(0.3)), start_state))
    
    f.set_start(start_state)
    
    for word, phones in lex.items():
        current_state = f.add_state()

        f.add_arc(start_state, fst.Arc(0, 0, fst.Weight('log', -math.log(probs[w])), current_state))
        
        for (i, phone) in enumerate(phones): 
            current_state = generate_phone_wfst(f, current_state, phone, n)
        # note: new current_state is now set to the final state of the previous phone WFST
            if i == len(lex[word]) - 1:

                next_state = f.add_state()
                f.add_arc(current_state, fst.Arc(0, word_table.find(word.replace("_", "")), certain_weight, next_state))
                current_state= next_state
                
        f.set_final(current_state)
        f.add_arc(current_state, fst.Arc(0, 0, next_weight, start_state))
        
    return f


def create_wfst_unigram(n, state_table, phone_table, word_probabilities):
    # word probabilities: a dictionary, to adjust weights. 
    f = generate_word_sequence_recognition_wfst_unigram(n, word_probabilities)
    f.set_input_symbols(state_table)
    f.set_output_symbols(word_table)
    return f

In [224]:
c = 0
word_table, phone_table, state_table = generate_symbol_tables(lex)
f = create_wfst_unigram(3, state_table, phone_table, unigram_probs)
errors_sum = 0
for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
    if c < 10:                                                                       # audio files
        c+=1
        decoder = MyViterbiDecoder(f, wav_file)
    
        
        decoder.decode()
        %time (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
        transcription = read_transcription(wav_file)                                           # to return the words along the best path
        print (words)
        print(transcription)
    
        error_counts = wer.compute_alignment_errors(transcription, words)
        word_count = len(transcription.split())
    
        print(error_counts, word_count)     # you'll need to accumulate these
        errors_sum += sum(error_counts)
print(errors_sum)

CPU times: user 0 ns, sys: 1.19 ms, total: 1.19 ms
Wall time: 1.2 ms
the pickled of of peter the
a pickled piper of peter
(2, 0, 1) 5
CPU times: user 242 µs, sys: 4 µs, total: 246 µs
Wall time: 248 µs
the where's peter the
where's peter
(0, 0, 2) 2
CPU times: user 326 µs, sys: 5 µs, total: 331 µs
Wall time: 337 µs
the peter peck peck the
peter picked a peck
(1, 1, 2) 4
CPU times: user 297 µs, sys: 5 µs, total: 302 µs
Wall time: 305 µs
the where's the peppers the
where's the peppers
(0, 0, 2) 3
CPU times: user 408 µs, sys: 6 µs, total: 414 µs
Wall time: 421 µs
the piper pickled peppers the
the piper pickled peppers
(0, 0, 1) 4
CPU times: user 1.84 ms, sys: 0 ns, total: 1.84 ms
Wall time: 1.85 ms
the a of of pickled peck of pickled a where's the
peter piper picked a peck of pickled peppers
(5, 0, 3) 8
CPU times: user 364 µs, sys: 5 µs, total: 369 µs
Wall time: 374 µs
the peck pickled picked the
a peck of pickled peppers peter piper picked
(1, 4, 1) 8
CPU times: user 692 µs, sys: 0 ns, to

Adding silence states at the start and between words.

In [236]:
state_table.add_symbol("sil_1")
state_table.add_symbol("sil_2")
state_table.add_symbol("sil_3")
state_table.add_symbol("sil_4")
state_table.add_symbol("sil_5")

56

In [239]:
def generate_sil_wfst(n, word_probabilities):
    """ generate a HMM to recognise any single word sequence for words in the lexicon
    
    Args:
        n (int): states per phone HMM

    Returns:
        the constructed WFST
    
    """
    
    f = fst.Fst('log')

    #even_weight = fst.Weight('log', -math.log(1/len(lex)))
    
    #even_weight = fst.Weight('log', -math.log(1/len(lex)))
    #reduced_weight = fst.Weight('log', -math.log(1/(5*len(lex))))
    next_weight = fst.Weight('log', -math.log(0.1))
    certain_weight =  fst.Weight('log', -math.log(1))
    
    # create a single start state
    start_state = f.add_state()
    f.set_start(start_state)
    
    first_silent = f.add_state()
    last_silent = f.add_state()
    f.set_final(last_silent)
    
    
    f.add_arc(start_state, fst.Arc(state_table.find("sil_1"), 0, fst.Weight('log', -math.log(1)),first_silent))
    f.add_arc(last_silent, fst.Arc(0, 0, fst.Weight('log', -math.log(1)),start_state))
    curr_state = first_silent
    for i in range(2, 5):
        nxt_state = f.add_state()
        f.add_arc(nxt_state, fst.Arc(state_table.find(("sil_" + str(i))), 0, 
                                      fst.Weight('log', -math.log(1)),nxt_state))
        f.add_arc(curr_state, fst.Arc(state_table.find(("sil_" + str(i))), 0, 
                                      fst.Weight('log', -math.log(1)),nxt_state))
        f.add_arc(first_silent, fst.Arc(state_table.find(("sil_" + str(i))), 0, 
                                        fst.Weight('log', -math.log(1)),nxt_state))
        f.add_arc(nxt_state, fst.Arc(state_table.find(("sil_5" )), 0, 
                                      fst.Weight('log', -math.log(1)),last_silent))
        curr_state = nxt_state
    
    
    for word, phones in lex.items():
        current_state = f.add_state()

        f.add_arc(start_state, fst.Arc(0, 0, fst.Weight('log', -math.log(word_probabilities[w])), current_state))
        
        for (i, phone) in enumerate(phones): 
            current_state = generate_phone_wfst(f, current_state, phone, n)
        # note: new current_state is now set to the final state of the previous phone WFST
            if i == len(lex[word]) - 1:

                next_state = f.add_state()
                f.add_arc(current_state, fst.Arc(0, word_table.find(word.replace("_", "")), certain_weight, next_state))
                current_state= next_state
                
        f.set_final(current_state)
        f.add_arc(current_state, fst.Arc(0, 0, next_weight, start_state))
        
    return f


In [240]:
s = generate_sil_wfst(3, even_dict)
s.set_input_symbols(state_table)
s.set_output_symbols(word_table)

errors_sum = 0
utterances = 0
words_no = 0
for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
    if utterances < 10:                                                                       # audio files
        utterances += 1
        decoder = MyViterbiDecoder(s, wav_file)
    
        
        decoder.decode()
        (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
        transcription = read_transcription(wav_file)                                           # to return the words along the best path
        print (words)
        print(transcription)
    
        error_counts = wer.compute_alignment_errors(transcription, words)
        word_count = len(transcription.split())
    
        if utterances < 10:
            print(error_counts, word_count)     # you'll need to accumulate these
        errors_sum += sum(error_counts)
        words_no += word_count
print(errors_sum)
print(words_no)

CPU times: user 2.74 s, sys: 0 ns, total: 2.74 s
Wall time: 2.74 s
CPU times: user 285 µs, sys: 0 ns, total: 285 µs
Wall time: 288 µs

a pickled piper of peter
(0, 5, 0) 5
CPU times: user 1.79 s, sys: 0 ns, total: 1.79 s
Wall time: 1.79 s
CPU times: user 211 µs, sys: 0 ns, total: 211 µs
Wall time: 214 µs

where's peter
(0, 2, 0) 2
CPU times: user 2.2 s, sys: 0 ns, total: 2.2 s
Wall time: 2.2 s
CPU times: user 246 µs, sys: 0 ns, total: 246 µs
Wall time: 249 µs

peter picked a peck
(0, 4, 0) 4
CPU times: user 2.26 s, sys: 0 ns, total: 2.26 s
Wall time: 2.26 s
CPU times: user 250 µs, sys: 0 ns, total: 250 µs
Wall time: 252 µs

where's the peppers
(0, 3, 0) 3
CPU times: user 2.52 s, sys: 0 ns, total: 2.52 s
Wall time: 2.53 s
CPU times: user 847 µs, sys: 0 ns, total: 847 µs
Wall time: 850 µs

the piper pickled peppers
(0, 4, 0) 4
CPU times: user 4.7 s, sys: 0 ns, total: 4.7 s
Wall time: 4.7 s
CPU times: user 490 µs, sys: 0 ns, total: 490 µs
Wall time: 493 µs

peter piper picked a peck of pi

In [230]:
class PruningViterbiDecoder:
    
    NLL_ZERO = 1e10  # define a constant representing -log(0).  This is really infinite, but approximate
                     # it here with a very large number
    
    def __init__(self, f, audio_file_name, pruning_threshold = 500):
        """Set up the decoder class with an audio file and WFST f
        """
        self.om = observation_model.ObservationModel()
        self.f = f
        
        if audio_file_name:
            self.om.load_audio(audio_file_name)
        else:
            self.om.load_dummy_audio()
        
        self.initialise_decoding()
        self.threshold = pruning_threshold

        
    def initialise_decoding(self):
        """set up the values for V_j(0) (as negative log-likelihoods)
        
        """
        
        self.V = []   # stores likelihood along best path reaching state j
        self.B = []   # stores identity of best previous state reaching state j
        self.W = []   # stores output labels sequence along arc reaching j - this removes need for 
                      # extra code to read the output sequence along the best path
        
        for t in range(self.om.observation_length()+1):
            self.V.append([self.NLL_ZERO]*self.f.num_states())
            self.B.append([-1]*self.f.num_states())
            self.W.append([[] for i in range(self.f.num_states())])  #  multiplying the empty list doesn't make multiple
        
        # The above code means that self.V[t][j] for t = 0, ... T gives the Viterbi cost
        # of state j, time t (in negative log-likelihood form)
        # Initialising the costs to NLL_ZERO effectively means zero probability    
        
        # give the WFST start state a probability of 1.0   (NLL = 0.0)
        self.V[0][self.f.start()] = 0.0
        
        # some WFSTs might have arcs with epsilon on the input (you might have already created 
        # examples of these in earlier labs) these correspond to non-emitting states, 
        # which means that we need to process them without stepping forward in time.  
        # Don't worry too much about this!  
        self.traverse_epsilon_arcs(0)        
        
    def traverse_epsilon_arcs(self, t):
        """Traverse arcs with <eps> on the input at time t
        
        These correspond to transitions that don't emit an observation
        
        We've implemented this function for you as it's slightly trickier than
        the normal case.  You might like to look at it to see what's going on, but
        don't worry if you can't fully follow it.
        
        """
        
        states_to_traverse = list(self.f.states()) # traverse all states
        while states_to_traverse:
            
            # Set i to the ID of the current state, the first 
            # item in the list (and remove it from the list)
            i = states_to_traverse.pop(0)   
        
            # don't bother traversing states which have zero probability
            if self.V[t][i] == self.NLL_ZERO:
                    continue
        
            for arc in self.f.arcs(i):
                
                if arc.ilabel == 0:     # if <eps> transition
                  
                    j = arc.nextstate   # ID of next state  
                
                    if self.V[t][j] > self.V[t][i] + float(arc.weight):
                        
                        # this means we've found a lower-cost path to
                        # state j at time t.  We might need to add it
                        # back to the processing queue.
                        self.V[t][j] = self.V[t][i] + float(arc.weight)
                        
                        # save backtrace information.  In the case of an epsilon transition, 
                        # we save the identity of the best state at t-1.  This means we may not
                        # be able to fully recover the best path, but to do otherwise would
                        # require a more complicated way of storing backtrace information
                        self.B[t][j] = self.B[t][i] 
                        
                        # and save the output labels encountered - this is a list, because
                        # there could be multiple output labels (in the case of <eps> arcs)
                        if arc.olabel != 0:
                            self.W[t][j] = self.W[t][i] + [arc.olabel]
                        else:
                            self.W[t][j] = self.W[t][i]
                        
                        if j not in states_to_traverse:
                            states_to_traverse.append(j)

    
    def forward_step(self, t):
        #find the best V[t-1]
        best = max(0.1, min(self.V[t-1]))
        for i in self.f.states():
            
            #if not self.V[t-1][i] == self.NLL_ZERO:   # no point in propagating states with zero probability
            if self.V[t-1][i] < best* self.threshold:   # bigger value means lower probability ! 
                #print(self.V[t-1][i])
                for arc in self.f.arcs(i):
                    
                    if arc.ilabel != 0: # <eps> transitions don't emit an observation
                        j = arc.nextstate
                        tp = float(arc.weight)  # transition prob
                        ep = -self.om.log_observation_probability(self.f.input_symbols().find(arc.ilabel), t)  # emission negative log prob
                        prob = tp + ep + self.V[t-1][i] # they're logs
                        if prob < self.V[t][j]:
                            self.V[t][j] = prob
                            self.B[t][j] = i
                            
                            # store the output labels encountered too
                            if arc.olabel !=0:
                                self.W[t][j] = [arc.olabel]
                            else:
                                self.W[t][j] = []
                            
    
    def finalise_decoding(self):
        """ this incorporates the probability of terminating at each state
        """
        
        for state in self.f.states():
            final_weight = float(self.f.final(state))
            if self.V[-1][state] != self.NLL_ZERO:
                if final_weight == math.inf:
                    self.V[-1][state] = self.NLL_ZERO  # effectively says that we can't end in this state
                else:
                    self.V[-1][state] += final_weight
                    
        # get a list of all states where there was a path ending with non-zero probability
        finished = [x for x in self.V[-1] if x < self.NLL_ZERO]
        if not finished:  # if empty
            print("No path got to the end of the observations.")
        
        
    def decode(self):
        self.initialise_decoding()
        t = 1
        while t <= self.om.observation_length():
            self.forward_step(t)
            self.traverse_epsilon_arcs(t)
            t += 1
        self.finalise_decoding()
    
    def backtrace(self):
        
        best_final_state = self.V[-1].index(min(self.V[-1])) # argmin
        best_state_sequence = [best_final_state]
        best_out_sequence = []
        
        t = self.om.observation_length()   # ie T
        j = best_final_state
        
        while t >= 0:
            i = self.B[t][j]
            best_state_sequence.append(i)
            best_out_sequence = self.W[t][j] + best_out_sequence  # computer scientists might like
                                                                                # to make this more efficient!

            # continue the backtrace at state i, time t-1
            j = i  
            t-=1
            
        best_state_sequence.reverse()
        
        # convert the best output sequence from FST integer labels into strings
        best_out_sequence = ' '.join([ self.f.output_symbols().find(label) for label in best_out_sequence])
        
        return (best_state_sequence, best_out_sequence)
    


In [244]:
f = generate_word_sequence_recognition_wfst(3, even_dict)
f.set_input_symbols(state_table)
f.set_output_symbols(word_table)

errors_sum = 0
c = 0
for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
    if c < 10:                                                                       # audio files
        c+=1
        decoder = PruningViterbiDecoder(f, wav_file, pruning_threshold =25)
    
        
        %time decoder.decode()
        %time (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
        transcription = read_transcription(wav_file)                                           # to return the words along the best path
        print (words)
        print(transcription)
    
        error_counts = wer.compute_alignment_errors(transcription, words)
        word_count = len(transcription.split())
    
        print(error_counts, word_count)     # you'll need to accumulate these
        errors_sum += sum(error_counts)
print(errors_sum)

CPU times: user 2.62 s, sys: 15 µs, total: 2.62 s
Wall time: 2.62 s
CPU times: user 335 µs, sys: 0 ns, total: 335 µs
Wall time: 338 µs
the pickled of of peter the
a pickled piper of peter
(2, 0, 1) 5
CPU times: user 1.71 s, sys: 4 µs, total: 1.71 s
Wall time: 1.71 s
CPU times: user 242 µs, sys: 4 µs, total: 246 µs
Wall time: 248 µs
the where's peter the
where's peter
(0, 0, 2) 2
CPU times: user 2.1 s, sys: 0 ns, total: 2.1 s
Wall time: 2.1 s
CPU times: user 303 µs, sys: 0 ns, total: 303 µs
Wall time: 307 µs
the peter peck peck the
peter picked a peck
(1, 1, 2) 4
CPU times: user 2.15 s, sys: 0 ns, total: 2.15 s
Wall time: 2.15 s
CPU times: user 291 µs, sys: 0 ns, total: 291 µs
Wall time: 296 µs
the where's the peppers the
where's the peppers
(0, 0, 2) 3
CPU times: user 2.42 s, sys: 0 ns, total: 2.42 s
Wall time: 2.42 s
CPU times: user 316 µs, sys: 4 µs, total: 320 µs
Wall time: 323 µs
the piper pickled peppers the
the piper pickled peppers
(0, 0, 1) 4
CPU times: user 4.49 s, sys: 0 ns, 