# ASR Assignment 2020-21

This notebook has been provided as a template to get you started on the assignment.  Feel free to use it for your development, or do your development directly in Python.

You can find a full description of the assignment [here](http://www.inf.ed.ac.uk/teaching/courses/asr/2020-21/coursework.pdf).

You are provided with two Python modules `observation_model.py` and `wer.py`.  The first was described in [Lab 3](https://github.com/Ore-an/asr_labs/blob/master/asr_lab3_4.ipynb).  The second can be used to compute the number of substitution, deletion and insertion errors between ASR output and a reference text.

It can be used as follows:

```python
import wer

my_refence = 'A B C'
my_output = 'A C C D'

wer.compute_alignment_errors(my_reference, my_output)
```

This produces a tuple $(s,d,i)$ giving counts of substitution,
deletion and insertion errors respectively - in this example (1, 0, 1).  The function accepts either two strings, as in the example above, or two lists.  Matching is case sensitive.

## Template code

Assuming that you have already made a function to generate an WFST, `create_wfst()` and a decoder class, `MyViterbiDecoder`, you can perform recognition on all the audio files as follows:


In [20]:
import openfst_python as fst
from subprocess import check_call
from IPython.display import Image
from helper_functions import *
import time
import pandas as pd
import numpy as np

In [3]:
# f, word_table = generate_word_sequence_recognition_wfst(3, 'lexicon.txt',original=False)
# print (f'word table :{list(word_table)}')
# f.draw('tmp.dot', portrait=True)
# check_call(['dot','-Tpng','-Gdpi=300','tmp.dot','-o','tmp.png'])
# Image(filename='tmp.png')

In [4]:
import logging
log = logging.getLogger('root')
FORMAT = "%(message)s"
logging.basicConfig(format=FORMAT)
log.setLevel(logging.INFO)
log.warning('logger is used')

logger is used


In [5]:
import observation_model
import math

class MyViterbiDecoder:
    
    NLL_ZERO = 1e10  # define a constant representing -log(0).  This is really infinite, but approximate
                     # it here with a very large number
    
    def __init__(self, f, audio_file_name, word_table):
        """Set up the decoder class with an audio file and WFST f
        """
        self.om = observation_model.ObservationModel()
        self.f = f
        self.word_table = word_table
        self.forward_computations = 0
        
        if audio_file_name:
            self.om.load_audio(audio_file_name)
        else:
            self.om.load_dummy_audio()
        
        self.initialise_decoding()

        
    def initialise_decoding(self):
        """set up the values for V_j(0) (as negative log-likelihoods)
        
        """
        
        self.V = []   # stores likelihood along best path reaching state j
        self.B = []   # stores identity of best previous state reaching state j
        self.W = []   # stores output labels sequence along arc reaching j - this removes need for 
                      # extra code to read the output sequence along the best path
        
        for t in range(self.om.observation_length()+1):
            self.V.append([self.NLL_ZERO]*self.f.num_states())
            self.B.append([-1]*self.f.num_states())
            self.W.append([[] for i in range(self.f.num_states())])  #  multiplying the empty list doesn't make multiple
        
        # The above code means that self.V[t][j] for t = 0, ... T gives the Viterbi cost
        # of state j, time t (in negative log-likelihood form)
        # Initialising the costs to NLL_ZERO effectively means zero probability    
        
        # give the WFST start state a probability of 1.0   (NLL = 0.0)
        self.V[0][self.f.start()] = 0.0
        
        # some WFSTs might have arcs with epsilon on the input (you might have already created 
        # examples of these in earlier labs) these correspond to non-emitting states, 
        # which means that we need to process them without stepping forward in time.  
        # Don't worry too much about this!  
        self.traverse_epsilon_arcs(0)        
        
    def traverse_epsilon_arcs(self, t):
        """Traverse arcs with <eps> on the input at time t
        
        These correspond to transitions that don't emit an observation
        
        We've implemented this function for you as it's slightly trickier than
        the normal case.  You might like to look at it to see what's going on, but
        don't worry if you can't fully follow it.
        
        """
        
        states_to_traverse = list(self.f.states()) # traverse all states
        while states_to_traverse:
            
            # Set i to the ID of the current state, the first 
            # item in the list (and remove it from the list)
            i = states_to_traverse.pop(0)   
        
            # don't bother traversing states which have zero probability
            if self.V[t][i] == self.NLL_ZERO:
                    continue
        
            for arc in self.f.arcs(i):
                
                if arc.ilabel == 0:     # if <eps> transition
                  
                    j = arc.nextstate   # ID of next state  
                
                    if self.V[t][j] > self.V[t][i] + float(arc.weight):
                        
                        # this means we've found a lower-cost path to
                        # state j at time t.  We might need to add it
                        # back to the processing queue.
                        self.V[t][j] = self.V[t][i] + float(arc.weight)
                        
                        # save backtrace information.  In the case of an epsilon transition, 
                        # we save the identity of the best state at t-1.  This means we may not
                        # be able to fully recover the best path, but to do otherwise would
                        # require a more complicated way of storing backtrace information
                        self.B[t][j] = self.B[t][i] 
                        
                        # and save the output labels encountered - this is a list, because
                        # there could be multiple output labels (in the case of <eps> arcs)
                        if arc.olabel != 0:
                            self.W[t][j] = self.W[t][i] + [arc.olabel]
                        else:
                            self.W[t][j] = self.W[t][i]
                        
                        if j not in states_to_traverse:
                            states_to_traverse.append(j)

    
    def forward_step(self, t):
          
        for i in self.f.states():
            
            if not self.V[t-1][i] == self.NLL_ZERO:   # no point in propagating states with zero probability
                
                for arc in self.f.arcs(i):
                    
                    if arc.ilabel != 0: # <eps> transitions don't emit an observation
                        j = arc.nextstate
                        tp = float(arc.weight)  # transition prob
                        ep = -self.om.log_observation_probability(self.f.input_symbols().find(arc.ilabel), t)  # emission negative log prob
                        prob = tp + ep + self.V[t-1][i] # they're logs
                        if prob < self.V[t][j]:
                            self.V[t][j] = prob
                            self.B[t][j] = i
                            
                            # store the output labels encountered too
                            if arc.olabel !=0:
                                self.W[t][j] = [arc.olabel]
                            else:
                                self.W[t][j] = []
                                
                        # update number of forward computations
                        self.forward_computations += 1
                            
    
    def finalise_decoding(self):
        """ this incorporates the probability of terminating at each state
        """
        
        for state in self.f.states():
            final_weight = float(self.f.final(state))
            if self.V[-1][state] != self.NLL_ZERO:
                if final_weight == math.inf:
                    self.V[-1][state] = self.NLL_ZERO  # effectively says that we can't end in this state
                else:
                    self.V[-1][state] += final_weight
                    
        # get a list of all states where there was a path ending with non-zero probability
        finished = [x for x in self.V[-1] if x < self.NLL_ZERO]
        if not finished:  # if empty
            print("No path got to the end of the observations.")
        
        
    def decode(self):
        self.initialise_decoding()
        t = 1
        while t <= self.om.observation_length():
            self.forward_step(t)
            self.traverse_epsilon_arcs(t)
            t += 1
        self.finalise_decoding()
    
    def backtrace(self):
        
        best_final_state = self.V[-1].index(min(self.V[-1])) # argmin
        best_state_sequence = [best_final_state]
        best_out_sequence = []
        
        t = self.om.observation_length()   # ie T
        j = best_final_state
        
        while t >= 0:
            i = self.B[t][j]
            best_state_sequence.append(i)
            best_out_sequence = self.W[t][j] + best_out_sequence  # computer scientists might like
                                                                                # to make this more efficient!
            log.debug(f"W[t][j]: {self.W[t][j]}")
            # continue the backtrace at state i, time t-1
            j = i  
            t-=1
            
        best_state_sequence.reverse()
        
        # convert the best output sequence from FST integer labels into strings
#         best_out_sequence = ' '.join([ self.f.output_symbols().find(label) for label in best_out_sequence if self.word_table.find(label)])
        best_out_sequence_str = []
        word_strs = [x[1] for x in list(self.word_table)]
        log.debug(f"word_strs: {word_strs}")
        for label in best_out_sequence:
            label_str = self.f.output_symbols().find(label)
            log.debug(f"label_str: {label_str}")
            if (label_str in word_strs):
                best_out_sequence_str += [f'{label_str}']
        best_out_sequence_str = ' '.join([x for x in best_out_sequence_str])
        return (best_state_sequence, best_out_sequence_str)

In [6]:
import glob
import os
import wer
import observation_model
import openfst_python as fst

def read_transcription(wav_file):
    """
    Get the transcription corresponding to wav_file.
    """
    
    transcription_file = os.path.splitext(wav_file)[0] + '.txt'
    
    with open(transcription_file, 'r') as f:
        transcription = f.readline().strip()
    
    return transcription

In [7]:
import os
folder = '/group/teaching/asr/labs/individual_recordings/s1645821'
folder = '/group/teaching/asr/labs/recordings'
wavs_txt = [os.path.join(folder,x) for x in os.listdir(folder)]
wavs = [wav for wav in wavs_txt if ('.wav' in wav)]
txts = [wav for wav in wavs_txt if ('.txt' in wav)]
# wavs

In [8]:
# original_lex = True
# f, word_table = generate_word_sequence_recognition_wfst(3, 'lexicon.txt',False, False)
# print (f'word table :{list(word_table)}')
# f.draw('tmp.dot', portrait=True)
# check_call(['dot','-Tpng','-Gdpi=300','tmp.dot','-o','tmp.png'])
# Image(filename='tmp.png')

In [9]:
# wav_file = wavs[0]
# decoder = MyViterbiDecoder(f, wav_file, word_table)
# decoder.decode()
# (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
#                                            # to return the words along the best path
# log.debug(f"State path: {state_path}\nWords: {words}")
# transcription = read_transcription(wav_file)
# log.info(f"Words: {words}")
# log.info(f'Transcription: {transcription}')
# error_counts = wer.compute_alignment_errors(transcription, words) # num_subs, num_del, num_ins
# word_count = len(transcription.split())

# log.info(f"Error_counts: {error_counts}, word_count: {word_count}")     # you'll need to accumulate these to produce an overall Word Error Rate

In [10]:
# WFTS that uses 0.1 for self trans, and 0.9 for next trans, without different phones for same word
WFST_1_9_O = generate_word_sequence_recognition_wfst(3, 'lexicon.txt',True, True)

# WFTS that uses 0.1 for self trans, and 0.9 for next trans, with different phones for same word
WFST_1_9 = generate_word_sequence_recognition_wfst(3, 'lexicon.txt',False, True)

# WFST that doesnt use probs all trans are None but is original
WFST_O = generate_word_sequence_recognition_wfst(3, 'lexicon.txt',True, False)

# WFST that doesnt use probs all trans are None but is not original
WFST = generate_word_sequence_recognition_wfst(3, 'lexicon.txt',False, False)

In [37]:
total_error_counts = []
total_word_counts = []

total_average_decoder_times = []
total_average_backtrace_times = []

total_forward_computations = []

different_fs_word_table = [WFST_1_9_O, WFST_1_9, WFST_O, WFST]

for f, word_table in different_fs_word_table[:]:
    f_error_counts = [0,0,0]
    f_word_counts = 0
    decoder_times = []
    backtrace_times = []
    forward_computations = []
    # -- Dataframe
    
    
    for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav')[:]:    # replace path if using your own
        
        decoder = MyViterbiDecoder(f, wav_file, word_table)
        
        decoder_start_time = time.time()
        decoder.decode()
        decoder_end_time  = time.time()
        (state_path, words) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
                                           # to return the words along the best path
        backtrace_end_time = time.time()
        
        transcription = read_transcription(wav_file)
        error_counts = wer.compute_alignment_errors(transcription, words) #num_subs, num_del, num_ins
        word_count = len(transcription.split())
        
        # add up output
        f_word_counts += word_count
        f_error_counts[0] += error_counts[0]
        f_error_counts[1] += error_counts[1]
        f_error_counts[2] += error_counts[2]
#         print(error_counts, word_count)     # you'll need to accumulate these to produce an overall Word Error Rate
        
        # add up times
        decoder_time = decoder_end_time - decoder_start_time
        backtrace_time = backtrace_end_time - decoder_end_time
        
        decoder_times.append(decoder_time)
        backtrace_times.append(backtrace_time)
        
        # add up conputations
        forward_computations.append(decoder.forward_computations)
        
        # -- add to DataFrame
        wav_name = wav_file.split('\\')[-1]
        pd_row = []
        
    total_error_counts.append(f_error_counts)
    total_word_counts.append(f_word_counts)
    
    total_average_decoder_times.append(sum(decoder_times)/len(decoder_times))
    total_average_backtrace_times.append(sum(backtrace_times)/len(backtrace_times))
    
    total_forward_computations.append(sum(forward_computations)/len(forward_computations))

### Dataframe

In [43]:
columns=['WFST', 'S', 'D', 'I', 'Accuracy','Word Counts', 'Decoder Times', 'Backtrace Times', 'Forward Computations']
WFSTs = ['baseline log', 'multi word log', 'None weight original', 'None weight multi word']
total_error_counts = np.array(total_error_counts)
subs = total_error_counts[:,0]
deletions = total_error_counts[:,1]
insertions = total_error_counts[:,2]
accuracies = ((subs + deletions + insertions)/total_word_counts)*100
df = pd.DataFrame((WFSTs, subs, deletions, insertions, accuracies, total_word_counts, total_average_decoder_times, total_average_backtrace_times, total_forward_computations),index=columns).T
df.to_excel('wfsts.xlsx')

### Accuracy

In [41]:
# accuracy
for idx, (f, word_table) in enumerate(different_fs_word_table[:]):
    error = total_error_counts[idx]
    word_count = total_word_counts[idx]
    
    accuracy = (error[0] + error[1] + error[2])/word_count
    accuracy = accuracy * 100
    
    print(f'---- {idx} ----')
    print(f"Errors: {error}, word count: {word_count}")
    print(f"Accuracy for F ({idx}): {accuracy}\n")

---- 0 ----
Errors: [ 5  0 17], word count: 21
Accuracy for F (0): 104.76190476190477

---- 1 ----
Errors: [ 3  0 23], word count: 21
Accuracy for F (1): 123.80952380952381

---- 2 ----
Errors: [4 1 4], word count: 21
Accuracy for F (2): 42.857142857142854

---- 3 ----
Errors: [3 1 4], word count: 21
Accuracy for F (3): 38.095238095238095



### Speed

In [36]:
# speed
for idx, (f, word_table) in enumerate(different_fs_word_table[:3]):
    decode_time = total_average_decoder_times[idx]
    backtrace_time = total_average_backtrace_times[idx]
    
    print(f'---- {idx} ----')
    print(f"Average Decode time: {decode_time}")
    print(f"Average backtrace time: {backtrace_time}\n")

---- 0 ----
Average Decode time: 1.9447619120279949
Average backtrace time: 0.0006755193074544271

---- 1 ----
Average Decode time: 2.135376214981079
Average backtrace time: 0.0008070468902587891

---- 2 ----
Average Decode time: 1.9157028992970784
Average backtrace time: 0.0006573994954427084



### Memory 

In [46]:
# memory
for idx, (f, word_table) in enumerate(different_fs_word_table[:1]):
    num_states = f.num_states()
    # why 1 + f.num_arcs(s)
    num_arcs = sum([f.num_arcs(s) for s in f.states()])
    
    print(f'---- {idx} ----')
    print(f"Number of states: {num_states}")
    print(f"Number of arcs: {num_arcs}")

---- 0 ----
Number of states: 116
Number of arcs: 230


### Forward Computations

In [47]:
# forward_computations
for f, word_table in enumerate(different_fs_word_table[:1]):
    forward_computations = total_forward_computations[idx]
    
    print(f'---- {idx} ----')
    print(f"Number of forward compuations: {forward_computations}")

---- 0 ----
Number of forward compuations: [68328, 68328]


In [44]:
print(total_forward_computations)

[[68328], [68328], [68328, 68328], [68328, 68328]]


### Save results