# ==================================================

# Task 4

In [4]:
### import os
import wer
import openfst_python as fst
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from vetrbi import MyViterbiDecoder
from utils import parse_lexicon, generate_symbol_tables
from hmm import generate_word_sequence_recognition_wfst
from hmm import generate_word_sequence_recognition_wfst_with_silance
from utils import draw
import numpy as np
import pickle


def create_wfst():
    f = generate_word_sequence_recognition_wfst(3)
    f.set_input_symbols(state_table)
    f.set_output_symbols(phone_table)
    return f

def create_wfst_with_silance():
    f = generate_word_sequence_recognition_wfst_with_silance(3, use_unigram_probs=False)
    f.set_input_symbols(state_table)
    f.set_output_symbols(phone_table)
    return f


def read_transcription(wav_file):
    """
    Get the transcription corresponding to wav_file.
    """
    
    transcription_file = os.path.splitext(wav_file)[0] + '.txt'
    
    with open(transcription_file, 'r') as f:
        transcription = f.readline().strip()
    
    return transcription

def memory_of_wfst(f):
    '''
    Compute a measure of the memory required for your decoder by providing counts
    of number of states and arcs in the WFST.
    '''
    all_states = []
    all_arcs = []
    for state in f.states():
        all_states.append(state)
        for arc in f.arcs(state):
            all_arcs.append(arc)
    return len(all_states), len(all_arcs)
    
def get_avg_wer(all_losses, verbose=False):
    all_wer = []
    for error_counts, word_count in all_losses:
        all_wer.append(sum(error_counts) / word_count)
    
    if verbose :
        print(f'The average WER is {np.mean(all_wer):.2%}')    
    return np.mean(all_wer)

def get_avg_effciency(efficancy_measures, verbose=False):
    decoding_time = np.mean(efficancy_measures[0])
    backtrace_time = np.mean(efficancy_measures[1])
    number_of_computions = np.mean(efficancy_measures[2])
    if verbose:
        print(f'The average decoding time is {decoding_time:.2f} seconds')
        print(f'The average backtrace time is {backtrace_time:.2f} seconds')
        print(f'The average number of computations is {number_of_computions:.2f}')
    return decoding_time, backtrace_time, number_of_computions


def decoding_loop(f, train_set=True, train_split=0.85, use_pruning=False, determinized=False, verbose=False, prune_threshold= None):
    all_losses = []
    decoding_time = []
    backtrace_time = []
    number_of_computations = []
    all_files = glob.glob('/group/teaching/asr/labs/recordings/*.wav')
    train_files = all_files[:(int(train_split*len(all_files)))]
    test_files = all_files[(int(train_split*len(all_files))):]
    
    if train_set:
        files= train_files
    else:
        files = test_files
    
    for wav_file in tqdm(files):    
        decoder  = MyViterbiDecoder(f, wav_file, verbose=verbose, use_pruning=use_pruning, determinized=determinized)
        if use_pruning and prune_threshold!=None:
            decoder.prune_threshold = prune_threshold
        decoder.decode()
        (state_path, words) = decoder.backtrace()  
        transcription = read_transcription(wav_file)
        error_counts = wer.compute_alignment_errors(transcription, words)
        word_count = len(transcription.split())

        all_losses.append((error_counts, word_count))
        decoding_time.append(decoder.decode_time)
        backtrace_time.append(decoder.backtrace_time)
        number_of_computations.append(decoder.number_of_computiations)
        if verbose:
            print(f'Transcription: {transcription} || Prediction: {words} || (nsub, ndel, nin) :{error_counts}')
    
    efficancy_measures = (decoding_time, backtrace_time, number_of_computations)
    return all_losses, efficancy_measures


lex = parse_lexicon('lexicon.txt')
word_table, phone_table, state_table = generate_symbol_tables(lex)
f = create_wfst()
f_silence = create_wfst_with_silance()
f_det = fst.determinize(f)

In [6]:
memory_of_wfst(f), memory_of_wfst(f_det)

((116, 230), (84, 183))

In [None]:
exp_dict= {
    'loss' : [],
    'efficancy':[],
    'acc': [],
    'm1': [],
    'm2': [],
    'm3': []
    
}
prune_thresholds = [i*5 for i in range(1,21)]
verbose = False
for prune_threshold in prune_thresholds:
    print(f'Threshold = {prune_threshold}')
    all_losses, efficancy_measures = decoding_loop(f, train_set=True, train_split=0.5, use_pruning=True, prune_threshold=prune_threshold, verbose=verbose)
    avg_wer = get_avg_wer(all_losses, verbose=True)
    m1,m2,m3 = get_avg_effciency(efficancy_measures, verbose=verbose)
    exp_dict['loss'].append(all_losses)
    exp_dict['efficancy'].append(efficancy_measures)
    exp_dict['acc'].append(avg_wer)
    exp_dict['m1'].append(m1)
    exp_dict['m2'].append(m2)
    exp_dict['m3'].append(m3)
    print('\n\n\n')
    file_name = f'exp_dict_baseline_ours_{prune_threshold}.pkl'
    with open(file_name, 'wb') as handler:
        pickle.dump(exp_dict, handler)
    print(f'saved to {file_name}')





# ======================================================

In [None]:
import os
import wer
import glob
from tqdm import tqdm
from vetrbi import MyViterbiDecoder
from utils import parse_lexicon, generate_symbol_tables


lex = parse_lexicon('lexicon.txt')
word_table, phone_table, state_table = generate_symbol_tables(lex)


def read_transcription(wav_file):
    """
    Get the transcription corresponding to wav_file.
    """
    
    transcription_file = os.path.splitext(wav_file)[0] + '.txt'
    
    with open(transcription_file, 'r') as f:
        transcription = f.readline().strip()
    
    return transcription


In [None]:
from hmm import generate_word_sequence_recognition_wfst
from hmm import generate_word_sequence_recognition_wfst_with_silance
from hmm import generate_lexical_hmm

def create_wfst():
    f = generate_word_sequence_recognition_wfst(3)
    f.set_input_symbols(state_table)
    f.set_output_symbols(phone_table)
    return f

def create_wfst_with_silance():
    f = generate_word_sequence_recognition_wfst_with_silance(3, use_unigram_probs=False)
    f.set_input_symbols(state_table)
    f.set_output_symbols(phone_table)
    return f

def create_lexical():
    f = generate_lexical_hmm(1, use_unigram_probs=False)
    f.set_input_symbols(state_table)
    f.set_output_symbols(phone_table)
    return f


In [None]:
f = create_wfst()
f_silence = create_wfst_with_silance()

In [None]:
import openfst_python as fst
from subprocess import check_call
from IPython.display import Image
f_silence = fst.determinize(f_silence)
f_silence.draw('tmp.dot', portrait=True)
check_call(['dot','-Tpng','-Gdpi=400','tmp.dot','-o','tmp.png'])
Image(filename='tmp.png')

In [None]:
all_losses = []
decoding_time = []
backtrace_time = []
number_of_computations = []
i = 0
train_split = int(0.85 * len(glob.glob('/group/teaching/asr/labs/recordings/*.wav')))  # replace path if using your own audio files
print(train_split)

all_transcriptions = ''
for wav_file in tqdm(glob.glob('/group/teaching/asr/labs/recordings/*.wav')):    # replace path if using your own
                                                                           # audio files
    
    decoder  = MyViterbiDecoder(f_silence, wav_file, verbose=False, use_pruning=False, determinized=True)
    decoder.decode()
    (state_path, words) = decoder.backtrace()  
    
    transcription = read_transcription(wav_file)
    all_transcriptions += transcription + ' '
    error_counts = wer.compute_alignment_errors(transcription, words)
    word_count = len(transcription.split())

    all_losses.append((error_counts, word_count))
    decoding_time.append(decoder.decode_time)
    backtrace_time.append(decoder.backtrace_time)
    number_of_computations.append(decoder.number_of_computiations)
    i += 1
    if i == 10:
        break

In [None]:
# calculate the WER fo each file
import numpy as np
all_wer = []
for error_counts, word_count in all_losses:
    all_wer.append(sum(error_counts) / word_count)



print(f'The average WER is {np.mean(all_wer):.2%}')

In [None]:
#Print dcoding time statistics
print(f'The average decoding time is {np.mean(decoding_time):.2f} seconds')
print(f'The average backtrace time is {np.mean(backtrace_time):.2f} seconds')
print(f'The average number of computations is {np.mean(number_of_computations):.2f}')



In [None]:
# count unigram counts in all_transcriptions
unigram_counts = {}
for word in all_transcriptions.split():
    if word in unigram_counts:
        unigram_counts[word] += 1
    else:
        unigram_counts[word] = 1



print('Unigram counts:' ,unigram_counts)
unigram_probs = {}
for word, count in unigram_counts.items():
    unigram_probs[word] = count / sum(unigram_counts.values())

print('Unigram probs:' ,unigram_probs)
# save unigram probs to pickle file
import pickle
with open('unigram_probs.pickle', 'wb') as handle:
    pickle.dump(unigram_probs, handle)
    # load unigram probs from pickle file
import pickle
with open('unigram_probs.pickle', 'rb') as handle:
    unigram_probs = pickle.load(handle)

print('Unigram probs:' ,unigram_probs)

In [None]:
import os
import wer
import glob
from tqdm import tqdm
from vetrbi import MyViterbiDecoder
from utils import parse_lexicon, generate_symbol_tables, draw
import openfst_python as fst



lex = parse_lexicon('lexicon.txt')
word_table, phone_table, state_table = generate_symbol_tables(lex)


def read_transcription(wav_file):
    """
    Get the transcription corresponding to wav_file.
    """
    
    transcription_file = os.path.splitext(wav_file)[0] + '.txt'
    
    with open(transcription_file, 'r') as f:
        transcription = f.readline().strip()
    
    return transcription


In [None]:
from hmm import generate_word_sequence_recognition_wfst
from hmm import generate_word_sequence_recognition_wfst_with_silance
from hmm import generate_bigram_wfst

def create_wfst():
    f = generate_word_sequence_recognition_wfst(3)
    f.set_input_symbols(state_table)
    f.set_output_symbols(phone_table)
    return f

def create_wfst_with_silance():
    f = generate_word_sequence_recognition_wfst_with_silance(3, use_unigram_probs=False)
    f.set_input_symbols(state_table)
    f.set_output_symbols(phone_table)
    return f

def create_lexical():
    f = generate_lexical_hmm(1, use_unigram_probs=False)
    f.set_input_symbols(state_table)
    f.set_output_symbols(phone_table)
    return f

def create_bigram_lexical():
    f = generate_bigram_wfst(1)
    f.set_input_symbols(state_table)
    f.set_output_symbols(phone_table)
    return f


In [None]:
all_losses = []
decoding_time = []
backtrace_time = []
number_of_computations = []
i = 0
train_split = int(0.85 * len(glob.glob('/group/teaching/asr/labs/recordings/*.wav')))  # replace path if using your own audio files
print(train_split)

all_transcriptions = ''
for wav_file in tqdm(glob.glob('/group/teaching/asr/labs/recordings/*.wav')):    # replace path if using your own
                                                                           # audio files
    
    decoder  = MyViterbiDecoder(f, wav_file, verbose=False, use_pruning=False, determinized=True, bigram = True)
    decoder.decode()
    (state_path, words) = decoder.backtrace()  
    
    transcription = read_transcription(wav_file)
    all_transcriptions += transcription + ' '
    error_counts = wer.compute_alignment_errors(transcription, words)
    word_count = len(transcription.split())

    all_losses.append((error_counts, word_count))
    decoding_time.append(decoder.decode_time)
    backtrace_time.append(decoder.backtrace_time)
    number_of_computations.append(decoder.number_of_computiations)
    print(words)
    print(transcription)
    i += 1
    if i == 10:
        break

In [None]:
# calculate the WER fo each file
import numpy as np
all_wer = []
for error_counts, word_count in all_losses:
    all_wer.append(sum(error_counts) / word_count)



print(f'The average WER is {np.mean(all_wer):.2%}')

In [None]:
#Print dcoding time statistics
print(f'The average decoding time is {np.mean(decoding_time):.2f} seconds')
print(f'The average backtrace time is {np.mean(backtrace_time):.2f} seconds')
print(f'The average number of computations is {np.mean(number_of_computations):.2f}')

