# ==================================================

# Task 4

In [1]:
### import os
import wer
import openfst_python as fst
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from vetrbi import MyViterbiDecoder
from utils import parse_lexicon, generate_symbol_tables
from hmm import generate_word_sequence_recognition_wfst
from hmm import generate_word_sequence_recognition_wfst_with_silance
from hmm import generate_bigram_wfst

from utils import draw
import numpy as np
import pickle


def create_wfst():
    f = generate_word_sequence_recognition_wfst(3)
    f.set_input_symbols(state_table)
    f.set_output_symbols(phone_table)
    return f

def create_wfst_with_silance():
    f = generate_word_sequence_recognition_wfst_with_silance(3, use_unigram_probs=False)
    f.set_input_symbols(state_table)
    f.set_output_symbols(phone_table)
    return f




def read_transcription(wav_file):
    """
    Get the transcription corresponding to wav_file.
    """
    
    transcription_file = os.path.splitext(wav_file)[0] + '.txt'
    
    with open(transcription_file, 'r') as f:
        transcription = f.readline().strip()
    
    return transcription

def memory_of_wfst(f):
    '''
    Compute a measure of the memory required for your decoder by providing counts
    of number of states and arcs in the WFST.
    '''
    all_states = []
    all_arcs = []
    for state in f.states():
        all_states.append(state)
        for arc in f.arcs(state):
            all_arcs.append(arc)
    return len(all_states), len(all_arcs)
    
def get_avg_wer(all_losses, verbose=False):
    all_wer = []
    for error_counts, word_count in all_losses:
        all_wer.append(sum(error_counts) / word_count)
    
    if verbose :
        print(f'The average WER is {np.mean(all_wer):.2%}')    
    return np.mean(all_wer)

def get_avg_effciency(efficancy_measures, verbose=False):
    decoding_time = np.mean(efficancy_measures[0])
    backtrace_time = np.mean(efficancy_measures[1])
    number_of_computions = np.mean(efficancy_measures[2])
    if verbose:
        print(f'The average decoding time is {decoding_time:.2f} seconds')
        print(f'The average backtrace time is {backtrace_time:.2f} seconds')
        print(f'The average number of computations is {number_of_computions:.2f}')
    return decoding_time, backtrace_time, number_of_computions


def decoding_loop(f, train_set=True, train_split=0.5, use_pruning=False, determinized=False, verbose=False, prune_threshold= None, bigram = False):
    all_losses = []
    decoding_time = []
    backtrace_time = []
    number_of_computations = []
    all_files = glob.glob('/group/teaching/asr/labs/recordings/*.wav')
    train_files = all_files[:(int(train_split*len(all_files)))]
    test_files = all_files[(int(train_split*len(all_files))):]
    
    if train_set:
        files= train_files
    else:
        files = test_files
    
    for wav_file in tqdm(files):    
        decoder  = MyViterbiDecoder(f, wav_file, verbose=verbose, use_pruning=use_pruning, determinized=determinized, bigram=bigram)
        if use_pruning and prune_threshold!=None:
            decoder.prune_threshold = prune_threshold
        decoder.decode()
        (state_path, words) = decoder.backtrace()  
        transcription = read_transcription(wav_file)
        error_counts = wer.compute_alignment_errors(transcription, words)
        word_count = len(transcription.split())

        all_losses.append((error_counts, word_count))
        decoding_time.append(decoder.decode_time)
        backtrace_time.append(decoder.backtrace_time)
        number_of_computations.append(decoder.number_of_computiations)
        if verbose:
            print(f'Transcription: {transcription} || Prediction: {words} || (nsub, ndel, nin) :{error_counts}')
    
    efficancy_measures = (decoding_time, backtrace_time, number_of_computations)
    return all_losses, efficancy_measures


lex = parse_lexicon('lexicon.txt')
word_table, phone_table, state_table = generate_symbol_tables(lex)
f = create_wfst()
f_silence = create_wfst_with_silance()
f_det = fst.determinize(f)

 85%|████████████████████████████████▏     | 269/318 [00:00<00:00, 27026.32it/s]


In [None]:
exp_dict= {
    'loss' : [],
    'efficancy':[],
    'acc': [],
    'm1': [],
    'm2': [],
    'm3': [],
    'all_states': [],
    'all_arcs': [],
    'det': False,
    'n': 0   
}


def create_bigram_lexical(n):
    f = generate_bigram_wfst(n)
    f.set_input_symbols(state_table)
    f.set_output_symbols(phone_table)
    return f

ns = [1,2,3]

for n in ns:
    print(f'N = {n}')
    f_bigram = create_bigram_lexical(n)
    exp_dict['n'] = n
    det = True
    if det:
        f_bigram = fst.determinize(f_bigram)
    exp_dict['det'] = det
    exp_dict['all_states'], exp_dict['all_arcs'] = memory_of_wfst(f_bigram)
    verbose = False
    print(f'det = {det}')
    print(f'All states: {exp_dict["all_states"]}, all arcs: {exp_dict["all_arcs"]}')
    all_losses, efficancy_measures = decoding_loop(f_bigram, train_set=True, train_split=0.5, determinized=det, verbose=verbose, bigram=True)
    avg_wer = get_avg_wer(all_losses, verbose=True)
    m1,m2,m3 = get_avg_effciency(efficancy_measures, verbose=verbose)
    exp_dict['loss'].append(all_losses)
    exp_dict['efficancy'].append(efficancy_measures)
    exp_dict['acc'].append(avg_wer)
    exp_dict['m1'].append(m1)
    exp_dict['m2'].append(m2)
    exp_dict['m3'].append(m3)

    print('\n\n\n')
    file_name = f'exp_dict_baseline_det_{det}_bigram_{n}.pkl'
    with open(file_name, 'wb') as handler:
        pickle.dump(exp_dict, handler)
    print(f'saved to {file_name}')



N = 1
det = True
All states: 1915, all arcs: 4034


 21%|████████▌                               | 34/159 [28:38<1:24:27, 40.54s/it]

# ======================================================