# ASR Assignment 2022-23

This notebook has been provided as a template to get you started on the assignment.  Feel free to use it for your development, or do your development directly in Python.

You can find a full description of the assignment [here](http://www.inf.ed.ac.uk/teaching/courses/asr/2022-23/coursework.pdf).

You are provided with two Python modules `observation_model.py` and `wer.py`.  The first was described in [Lab 3](https://github.com/ZhaoZeyu1995/asr_labs/blob/master/asr_lab3_4.ipynb).  The second can be used to compute the number of substitution, deletion and insertion errors between ASR output and a reference text.

It can be used as follows:

```python
import wer

my_refence = 'A B C'
my_output = 'A C C D'

wer.compute_alignment_errors(my_reference, my_output)
```

This produces a tuple $(s,d,i)$ giving counts of substitution,
deletion and insertion errors respectively - in this example (1, 0, 1).  The function accepts either two strings, as in the example above, or two lists.  Matching is case sensitive.

## Template code

Assuming that you have already made a function to generate an WFST, `create_wfst()` and a decoder class, `MyViterbiDecoder`, you can perform recognition on all the audio files as follows:


In [1]:
import glob
import os
import wer
import observation_model
import openfst_python as fst
import math

In [2]:
def parse_lexicon(lex_file):
    """
    Parse the lexicon file and return it in dictionary form.
    
    Args:
        lex_file (str): filename of lexicon file with structure '<word> <phone1> <phone2>...'
                        eg. peppers p eh p er z

    Returns:
        lex (dict): dictionary mapping words to list of phones
    """
    
    lex = {}  # create a dictionary for the lexicon entries
    with open(lex_file, 'r') as f:
        for line in f:
            line = line.split()  # split at each space
            lex[line[0]] = line[1:]  # first field the word, the rest is the phones
    return lex

lex = parse_lexicon('lexicon.txt')

In [3]:
def generate_symbol_tables(lexicon, n=3):
    '''
    Return word, phone and state symbol tables based on the supplied lexicon
    
    Args:
        lexicon (dict): lexicon to use, created from the parse_lexicon() function
        n (int): number of states for each phone HMM
        
    Returns:
        word_table (fst.SymbolTable): table of words
        phone_table (fst.SymbolTable): table of phones
        state_table (fst.SymbolTable): table of HMM phone-state IDs
    '''
    state_table = fst.SymbolTable()
    phone_table = fst.SymbolTable()
    word_table = fst.SymbolTable()
    
    # your code here
    # .add_symbol('<eps>')
    word_table.add_symbol('<eps>')
    for word in lexicon:
        word_table.add_symbol(word)
        
    phone_table.add_symbol('<eps>')
    for word in list(lexicon):
        for phone in lexicon[word]:
            phone_table.add_symbol(phone)
    
    state_table.add_symbol('<eps>')
    for word in list(lexicon):
        for phone in lexicon[word]:
            for i in range(n):
                state_table.add_symbol(f"{phone}_{i+1}")
    

    return word_table, phone_table, state_table

word_table, phone_table, state_table = generate_symbol_tables(lex)

In [4]:
def generate_phone_wfst(f, start_state, phone, n):
    """
    Generate a WFST representing an n-state left-to-right phone HMM.
    
    Args:
        f (fst.Fst()): an FST object, assumed to exist already
        start_state (int): the index of the first state, assumed to exist already
        phone (str): the phone label 
        n (int): number of emitting states of the HMM
        
    Returns:
        the final state of the FST
    """
    
    current_state = start_state
    eps = phone_table.find('<eps>')
    out = phone_table.find(phone)
    
    for i in range(1, n+1):
    
        in_label = state_table.find('{}_{}'.format(phone, i))
        
        next_state = f.add_state()
        f.add_arc(current_state, fst.Arc(in_label, eps, None, current_state))
        if (i < n):
            f.add_arc(current_state, fst.Arc(in_label, eps, None, next_state))
        else:
            f.add_arc(current_state, fst.Arc(in_label, out, None, next_state))
        
        current_state = next_state
    
    return current_state

In [5]:
def generate_word_wfst(f, start_state, word, n):
    """ Generate a WFST for any word in the lexicon, composed of n-state phone WFSTs.
        This will currently output phone labels.  
    
    Args:
        f (fst.Fst()): an FST object, assumed to exist already
        start_state (int): the index of the first state, assumed to exist already
        word (str): the word to generate
        n (int): states per phone HMM
        
    Returns:
        the constructed WFST
    
    """

    current_state = start_state
    phone_list = lex[word]
    for phone in phone_list:
        current_state = generate_phone_wfst(f, current_state, phone, n)
    f.set_final(current_state)
    
    return current_state

In [6]:
def generate_word_sequence_recognition_wfst(n = 3):
    """ generate a HMM to recognise any sequence of words in the lexicon
    
    Args:
        n (int): states per phone HMM

    Returns:
        the constructed WFST
    
    """
    
    f = fst.Fst()
    
    # create a single start state
    start_state = f.add_state()
    f.set_start(start_state)
    
    for word in lex.keys():
        new_start_state = f.add_state()
        f.add_arc(start_state, fst.Arc(state_table.find("<eps>"), phone_table.find("<eps>"), None, new_start_state))
        last_state = generate_word_wfst(f, new_start_state, word, n)
        f.set_final(last_state)
        f.add_arc(last_state, fst.Arc(state_table.find("<eps>"), phone_table.find("<eps>"), None, start_state))
        
    return f


In [7]:
def create_wfst(n = 3):
    f = generate_word_sequence_recognition_wfst(n)
    f.set_input_symbols(state_table)
    f.set_output_symbols(phone_table)
    
    return f

In [8]:
class MyViterbiDecoder:
    
    NLL_ZERO = 1e10  # define a constant representing -log(0).  This is not really infinite, but approximate
                     # it here with a very large number
    
    def __init__(self, f, audio_file_name):
        """Set up the decoder class with an audio file and WFST f
        """
        self.om = observation_model.ObservationModel()
        self.f = f
        
        if audio_file_name:
            self.om.load_audio(audio_file_name)
        else:
            self.om.load_dummy_audio()
        
        self.initialise_decoding()
    
    def initialise_decoding(self):
        """set up the values for V_j(0) (as negative log-likelihoods)
        
        """
        
        self.V = []
        self.B = [] # B[t][i]
        self.W = []
        for t in range(self.om.observation_length()+1):
            self.V.append([self.NLL_ZERO]*self.f.num_states())
            self.B.append([-1]*self.f.num_states())
            self.W.append(['']*self.f.num_states())
        
        # The above code means that self.V[t][j] for t = 0, ... T gives the Viterbi cost
        # of state j, time t (in negative log-likelihood form)
        # Initialising the costs to NLL_ZERO effectively means zero probability    
        
        # give the WFST start state a probability of 1.0   (NLL = 0.0)
        self.V[0][f.start()] = 0.0
        
        # some WFSTs might have arcs with epsilon on the input (you might have already created 
        # examples of these in earlier labs) these correspond to non-emitting states, 
        # which means that we need to process them without stepping forward in time.  
        # Don't worry too much about this!  
        self.traverse_epsilon_arcs(0)
        
    def traverse_epsilon_arcs(self, t):
        """Traverse arcs with <eps> on the input at time t
        
        These correspond to transitions that don't emit an observation
        
        We've implemented this function for you as it's slightly trickier than
        the normal case.  You might like to look at it to see what's going on, but
        don't worry if you can't fully follow it.
        
        """
        
        states_to_traverse = list(range(self.f.num_states())) # traverse all states
        while states_to_traverse:
            
            # Set i to the ID of the current state, the first 
            # item in the list (and remove it from the list)
            i = states_to_traverse.pop(0)   
        
            # don't bother traversing states which have zero probability
            if self.V[t][i] == self.NLL_ZERO:
                    continue
        
            for arc in self.f.arcs(i):
                
                if arc.ilabel == 0:     # if <eps> transition
                  
                    j = arc.nextstate   # ID of next state  
                
                    if self.V[t][j] > self.V[t][i] + float(arc.weight):
                        
                        # this means we've found a lower-cost path to
                        # state j at time t.  We might need to add it
                        # back to the processing queue.
                        self.V[t][j] = self.V[t][i] + float(arc.weight)
                  
                        if j not in states_to_traverse:
                            states_to_traverse.append(j)

    
    def forward_step(self, t):
        for i in self.f.states():
            
            if not self.V[t-1][i] == self.NLL_ZERO:   # no point in propagating states with zero probability
                
                for arc in self.f.arcs(i):
                    
                    if arc.ilabel != 0: # <eps> transitions don't emit an observation
                        j = arc.nextstate
                        tp = float(arc.weight)  # transition prob
                        ep = -self.om.log_observation_probability(self.f.input_symbols().find(arc.ilabel), t)  # emission negative log prob
                        prob = tp + ep + self.V[t-1][i] # they're logs
                        if prob < self.V[t][j]:
                            self.V[t][j] = prob
                            self.B[t][j] = i
                            
                            # store the output labels encountered too
                            olbl = arc.olabel
                            olbl = self.f.output_symbols().find(olbl)
                            self.W[t][j] = olbl
    
    def finalise_decoding(self):
        
        # TODO - exercise
        states = list(range(self.f.num_states()))
        for i in states:
            prob_final = float(self.f.final(i)) # probablity of being the end state (0 for non finals, and upwards of 1 if one final state, or split between all final states)
            if (self.V[-1][i] < self.NLL_ZERO):
                if (prob_final == math.inf): # not a final state
                    self.V[-1][i] = self.NLL_ZERO
                else: # is a final state
                    self.V[-1][i] += prob_final # includes the weighting of ending at each of the final states in the path towards them in the last step
        
    def decode(self):
        
        self.initialise_decoding()
        t = 1
        while t <= self.om.observation_length():
            self.forward_step(t)
            self.traverse_epsilon_arcs(t)
            t += 1
        
        self.finalise_decoding()
    
    def backtrace(self):
        
        # TODO - exercise 
        
        # complete code to trace back through the
        # best state sequence
        
        # You'll need to create a structure B_j(t) to store the 
        # back-pointers (see lectures), and amend the functions above to fill it.
        
#         for w in self.W:
#             print(dict(enumerate(w)))

        T = self.om.observation_length()
        current = -1
        for i in range(self.f.num_states()):
            if (float(self.f.final(i)) != math.inf):
                if (self.B[T][i] != -1):
                    current = i
                    break
        if (current == -1):
            raise Exception('No valid path')
        seq = [current]
        currentStr = ''
        strSeq = ''
        
        for t in range(T,-1, -1):
            tmpStr = self.W[t][current]
            if not(tmpStr in ['', '<eps>',  currentStr]):
                currentStr = tmpStr
                strSeq = f'{currentStr} {strSeq}'
        
       
            current = self.B[t][current]
            seq.insert(0, current)
        
        
        strSeq = strSeq.strip()
        best_state_sequence = (seq, strSeq)
        
        return best_state_sequence

In [17]:
def read_transcription(wav_file):
    """
    Get the transcription corresponding to wav_file.
    """
    
    transcription_file = os.path.splitext(wav_file)[0] + '.txt'
    
    with open(transcription_file, 'r') as f:
        transcription = f.readline().strip()
    
    return transcription

def run():
    f = create_wfst()
    
    files = list(glob.glob('/group/teaching/asr/labs/recordings/*.wav'))
    for wav_file in files[:5]:    # replace path if using your own
                                                                               # audio files

        decoder = MyViterbiDecoder(f, wav_file)

        decoder.decode()
        (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
                                                   # to return the words along the best path


        transcription = read_transcription(wav_file)
        print(f"\n\n[{words}]\n[{transcription}]")
        error_counts = wer.compute_alignment_errors(transcription, words)
        word_count = len(transcription.split())

        print(error_counts, word_count)     # you'll need to accumulate these to produce an overall Word Error Rate

In [18]:
run()

[ey]
[peter piper picked a peck of pickled peppers where's the peck of pickled peppers peter piper picked]


(1, 16, 0) 17
[ey]
[peter picked a peck of pickled peppers]


(1, 6, 0) 7
[ey]
[a peck of pickled peppers picked]


(1, 5, 0) 6
[ey]
[peter piper peter piper]


(1, 3, 0) 4
[ey]
[where's the peck of pickled peppers picked]


(1, 6, 0) 7


In [15]:
wer_counts = """(1, 16, 0) 17
(1, 6, 0) 7
(1, 5, 0) 6
(1, 3, 0) 4
(1, 6, 0) 7
(1, 3, 0) 4
(1, 1, 0) 2
(1, 4, 0) 5
(1, 5, 0) 6
(1, 7, 0) 8
(1, 0, 0) 1
(1, 0, 0) 1
(1, 0, 0) 1
(1, 0, 0) 1
(1, 0, 0) 1
(1, 5, 0) 6
(1, 2, 0) 3
(1, 7, 0) 8
(1, 5, 0) 6
(1, 2, 0) 3
(1, 6, 0) 7
(1, 9, 0) 10
(1, 7, 0) 8
(1, 3, 0) 4
(1, 5, 0) 6
(1, 6, 0) 7
(1, 7, 0) 8
(1, 7, 0) 8
(1, 4, 0) 5
(1, 6, 0) 7
(1, 8, 0) 9
(1, 4, 0) 5
(1, 4, 0) 5
(1, 4, 0) 5
(1, 4, 0) 5
(1, 3, 0) 4
(1, 2, 0) 3
(1, 7, 0) 8
(1, 7, 0) 8
(1, 6, 0) 7
(1, 7, 0) 8
(1, 6, 0) 7
(1, 10, 0) 11
(1, 6, 0) 7
(1, 7, 0) 8
(1, 5, 0) 6
(1, 10, 0) 11
(1, 16, 0) 17
(1, 7, 0) 8
(1, 8, 0) 9
(1, 4, 0) 5
(1, 10, 0) 11
(1, 7, 0) 8
(1, 5, 0) 6
(1, 3, 0) 4
(1, 8, 0) 9
(1, 6, 0) 7
(1, 7, 0) 8
(1, 7, 0) 8
(1, 5, 0) 6
(1, 4, 0) 5
(1, 4, 0) 5
(1, 16, 0) 17
(1, 6, 0) 7
(1, 11, 0) 12
(1, 7, 0) 8
(1, 7, 0) 8
(1, 7, 0) 8
(1, 4, 0) 5
(1, 8, 0) 9
(1, 3, 0) 4
(1, 4, 0) 5
(1, 8, 0) 9
(1, 7, 0) 8
(1, 7, 0) 8
(1, 8, 0) 9
(1, 8, 0) 9
(1, 8, 0) 9
(1, 8, 0) 9
(1, 5, 0) 6
(1, 7, 0) 8
(1, 4, 0) 5
(1, 2, 0) 3
(1, 2, 0) 3
(1, 8, 0) 9
(1, 10, 0) 11
(1, 7, 0) 8
(1, 8, 0) 9
(1, 16, 0) 17
(1, 16, 0) 17
(1, 16, 0) 17
(1, 3, 0) 4
(1, 4, 0) 5
(1, 4, 0) 5
(1, 2, 0) 3
(1, 4, 0) 5
(1, 5, 0) 6
(1, 7, 0) 8
(1, 6, 0) 7
(1, 10, 0) 11
(1, 9, 0) 10
(1, 13, 0) 14
(1, 3, 0) 4
(1, 6, 0) 7
(1, 5, 0) 6
(1, 8, 0) 9
(1, 7, 0) 8
(1, 7, 0) 8
(1, 6, 0) 7
(1, 5, 0) 6
(1, 7, 0) 8
(1, 5, 0) 6
(1, 4, 0) 5
(1, 6, 0) 7
(1, 7, 0) 8
(1, 2, 0) 3
(1, 8, 0) 9
(1, 3, 0) 4
(1, 7, 0) 8
(1, 3, 0) 4
(1, 2, 0) 3
(1, 3, 0) 4
(1, 3, 0) 4
(1, 5, 0) 6
(1, 4, 0) 5
(1, 2, 0) 3
(1, 3, 0) 4
(1, 4, 0) 5
(1, 2, 0) 3
(1, 7, 0) 8
(1, 8, 0) 9
(1, 4, 0) 5
(1, 1, 0) 2
(1, 2, 0) 3
(1, 4, 0) 5
(1, 6, 0) 7
(1, 5, 0) 6
(1, 4, 0) 5
(1, 5, 0) 6
(1, 7, 0) 8
(1, 7, 0) 8
(1, 7, 0) 8
(1, 8, 0) 9
(1, 7, 0) 8
(1, 7, 0) 8
(1, 8, 0) 9
(1, 2, 0) 3
(1, 2, 0) 3
(1, 3, 0) 4
(1, 5, 0) 6
(1, 5, 0) 6
(1, 8, 0) 9
(1, 4, 0) 5
(1, 7, 0) 8
(1, 7, 0) 8
(1, 4, 0) 5
(1, 3, 0) 4
(1, 8, 0) 9
(1, 7, 0) 8
(1, 8, 0) 9
(1, 7, 0) 8
(1, 8, 0) 9
(1, 6, 0) 7
(1, 7, 0) 8
(1, 4, 0) 5
(1, 2, 0) 3
(1, 7, 0) 8
(1, 5, 0) 6
(1, 5, 0) 6
(1, 9, 0) 10
(1, 4, 0) 5
(1, 7, 0) 8
(1, 8, 0) 9
(1, 7, 0) 8
(1, 16, 0) 17
(1, 1, 0) 2
(1, 1, 0) 2
(1, 8, 0) 9
(1, 1, 0) 2
(1, 5, 0) 6
(1, 7, 0) 8
(1, 7, 0) 8
(1, 3, 0) 4
(1, 5, 0) 6
(1, 4, 0) 5
(1, 5, 0) 6
(1, 5, 0) 6
(1, 2, 0) 3
(1, 4, 0) 5
(1, 5, 0) 6
(1, 16, 0) 17
(1, 16, 0) 17
(1, 16, 0) 17
(1, 16, 0) 17
(1, 16, 0) 17
(1, 7, 0) 8
(1, 8, 0) 9
(1, 16, 0) 17
(1, 16, 0) 17
(1, 16, 0) 17
(1, 16, 0) 17
(1, 16, 0) 17
(1, 9, 0) 10
(1, 16, 0) 17
(1, 6, 0) 7
(1, 6, 0) 7
(1, 7, 0) 8
(1, 7, 0) 8
(1, 8, 0) 9
(1, 5, 0) 6
(1, 7, 0) 8
(1, 8, 0) 9
(1, 16, 0) 17
(1, 15, 0) 16
(1, 13, 0) 14
(1, 15, 0) 16
(1, 15, 0) 16
(1, 10, 0) 11
(1, 4, 0) 5
(1, 11, 0) 12
(1, 8, 0) 9
(1, 16, 0) 17
(1, 5, 0) 6
(1, 6, 0) 7
(1, 5, 0) 6
(1, 6, 0) 7
(1, 5, 0) 6
(1, 6, 0) 7
(1, 6, 0) 7
(1, 4, 0) 5
(1, 5, 0) 6
(1, 6, 0) 7
(1, 16, 0) 17
(1, 16, 0) 17
(1, 5, 0) 6
(1, 3, 0) 4
(1, 7, 0) 8
(1, 2, 0) 3
(1, 4, 0) 5
(1, 7, 0) 8
(1, 8, 0) 9
(1, 5, 0) 6
(1, 6, 0) 7
(1, 6, 0) 7
(1, 6, 0) 7
(1, 4, 0) 5
(1, 2, 0) 3
(1, 2, 0) 3
(1, 3, 0) 4
(1, 5, 0) 6
(1, 17, 0) 18
(1, 10, 0) 11
(1, 2, 0) 3
(1, 9, 0) 10
(1, 16, 0) 17
(1, 16, 0) 17
(1, 16, 0) 17
(1, 16, 0) 17
(1, 16, 0) 17
(1, 16, 0) 17
(1, 5, 0) 6
(1, 4, 0) 5
(1, 4, 0) 5
(1, 5, 0) 6
(1, 3, 0) 4
(1, 2, 0) 3
(1, 5, 0) 6
(1, 5, 0) 6
(1, 5, 0) 6
(1, 4, 0) 5
(1, 2, 0) 3
(1, 3, 0) 4
(1, 5, 0) 6
(1, 6, 0) 7
(1, 3, 0) 4
(1, 2, 0) 3
(1, 6, 0) 7
(1, 3, 0) 4
(1, 3, 0) 4
(1, 3, 0) 4
(1, 4, 0) 5
(1, 6, 0) 7
(1, 4, 0) 5
(1, 16, 0) 17
(1, 1, 0) 2
(1, 3, 0) 4
(1, 2, 0) 3
(1, 7, 0) 8
(1, 5, 0) 6
(1, 2, 0) 3
(1, 3, 0) 4
(1, 2, 0) 3
(1, 3, 0) 4
(1, 4, 0) 5
(1, 4, 0) 5
(1, 5, 0) 6
(1, 3, 0) 4
(1, 8, 0) 9
(1, 4, 0) 5
(1, 6, 0) 7
(1, 2, 0) 3
(1, 3, 0) 4
(1, 7, 0) 8
(1, 4, 0) 5
(1, 4, 0) 5
(1, 4, 0) 5
(1, 6, 0) 7
(1, 5, 0) 6
(1, 5, 0) 6
(1, 15, 0) 16
(1, 16, 0) 17
(1, 4, 0) 5
(1, 2, 0) 3
(1, 3, 0) 4
(1, 5, 0) 6
(1, 2, 0) 3
(1, 4, 0) 5
(1, 2, 0) 3
(1, 2, 0) 3
(1, 3, 0) 4
(1, 2, 0) 3
(1, 3, 0) 4
(1, 2, 0) 3
(1, 2, 0) 3
(1, 16, 0) 17
(1, 7, 0) 8
(1, 11, 0) 12
(1, 7, 0) 8
(1, 8, 0) 9"""

S, D, I, N = 0, 0, 0, 0
wer_counts = wer_counts.split("\n")
IndexCount = len(wer_counts)
for wer_elem in wer_counts:
    wer_elem = wer_elem.replace("(", "").replace(")", "").replace(",", "").split(" ")
    S += int(wer_elem[0])
    D += int(wer_elem[1])
    I += int(wer_elem[2])
    N += int(wer_elem[3])

WER = (S + D + I) / N
print(IndexCount, WER, S, D, I, N)

329 1.0 329 2093 0 2422
