#

In [None]:
import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt
from feature_extraction import read_feat_file
import os
from conf import read_conf

In [None]:
test_sets = ["test_hint_office_0_1_3", "test_hint_lecture_0_1_3", "test_hint_stairway_0_1_3_90"]
conf_file_filename = "LSTMModelTransform_presentation2.txt"
conf_file = "conf/fft_mask_dp_fft/"+conf_file_filename
conf_dict = read_conf(conf_file)
model_name = "model1" 


# Visualize features before and after transformation

In [None]:
try:
    from importlib.metadata import *
except ImportError:  # Python < 3.10 (backport)
    from importlib_metadata import *
import umap.umap_ as umap
import json

In [None]:
# from train import get_device
import pickle
import torch
import torch.nn.functional as F
from tqdm import tqdm
from feature_extraction import TorchStandardScaler
import numpy as np
from file_loader_model_transform import read_feat_list,encode_phones
from pathlib import Path
import logging
from net import get_model_type
# import loss_fn
# Labels
from phone_mapping import get_label_encoder

In [None]:
def plot_umap(features, phonemes_from, phoneme_to_number, diagram_name, colour_map=None):
    
    '''
    this is a modification of plot_umap in ../../visualizations/plot_umap.py
    plots a umap vizual

    input:
        features (2D np.array): the features  extracted using feature_extraction/extractTrainingData or feature_extraction/extractTestingData. Each row represents a feature sample.
        phonemes_from (list or 1D np.array): the phoneme number the features correspond to. each value corresponds to a row in features.
        phoneme_to_number (dict): maps phoneme names to numbers that represent it in phonemes_from
        diagram_name (str): name of the output visual
        colour_map (dict): maps label of phoneme classification groups to a set/list of phonemes in that group. for example, the manner of articulation colour_map (from https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=4100697) would be {"vowels":{"iy","ih", 'eh', 'ae', 'aa', 'ah', 'ao', 'uh', 'uw', 'ux', 'ax', 'ax-h', 'ix'},                             "dipthongs":{'ey', 'aw', 'ay', 'oy', 'ow'},"semi-vowels": {'l', 'el', 'r', 'w', 'y', 'er', 'axr'},"stops": {'b', 'd', 'g', 'p', 't', 'k', 'jh', 'ch'},"fricatives": {'s', 'sh', 'z', 'zh', 'f', 'th', 'v', 'dh', 'hh', 'hv'},"nasals": {'m', 'em', 'n', 'nx', 'ng', 'eng', 'en'},"silence": {'dx', 'bcl', 'dcl', 'gcl', 'pcl', 'tcl', 'kcl', 'h', 'pau', 'epi', 'q'},"h#": {"h#"}} 
    '''
    
    features = np.array(features, dtype=float)
    features = np.nan_to_num(features, copy=True, posinf=0, neginf=0)
    features = features[:,:65] # since only the first 65 as per kevin's feature extraction code are teh actual features and the rest are ideal mask values
    phonemes_from = phonemes_from
    print(len(phonemes_from))
    # return
    # copy pasted most of this code from mnist example
    reducer = umap.UMAP(random_state=42, low_memory=True)
    embedding = reducer.fit_transform(features)

    # if issues, may be due to nan values. think how to deal with sparse data? maybe cut off at some point.
    fig, ax = plt.subplots(figsize=(12, 10))
    colors = phonemes_from
    if len(colors) < 5000:
        s = 10
    elif len(colors) < 100000:
        s=3
    else:
        s =1

    
    
    if colour_map:
        number_to_phoneme = {}

        for y,x in phoneme_to_number.items():
            number_to_phoneme[x]=y
        phonemes_from_colour_coded = []
        for phoneme_num in phonemes_from:
            phoneme = number_to_phoneme[phoneme_num]
            added = False
            for indx in range(len(colour_map)):
                if phoneme in list(colour_map.items())[indx][1]:
                    phonemes_from_colour_coded.append(indx)
                    added = True
                    break
            if added:
                continue
            else:
                phonemes_from_colour_coded.append(len(colour_map)+1)
        colors = phonemes_from_colour_coded

    plt.scatter(embedding[:, 0], embedding[:, 1], c=colors, cmap="Spectral", s=s)
    plt.setp(ax, xticks=[], yticks=[])

    plt.title(diagram_name, fontsize=18)
    
    # TODO: show axis in matplotlib
    # plt.axes.Axes.imshow()

    if colour_map:
        try:
            plt.colorbar(boundaries=np.arange(len(set(colors)))+0.5).set_ticks(ticks=np.arange(len(set(colors))+1), labels=list(colour_map.keys())) #+["other"])
            
        except Exception as e:
            plt.title(e, fontsize=18)
            m=plt.colorbar(boundaries=np.arange(len(set(colors))))
            m.set_ticks(ticks=np.arange(len(set(colors)))) #+["other"])
            m.set_ticklabels(list(colour_map.keys()))
    else:
        plt.colorbar().ax.tick_params(labelsize=10)

    output_image_num = 1
    output_image = str(output_image_num)+"_.png"
    while os.path.isfile(output_image):
        output_image_num += 1
        output_image = str(output_image_num)+"_.png"

    plt.show(output_image)


In [None]:
def transform_features(features, phonemes, phoneme_scales_map, phoneme_shifts_map):
    transformed_features = []
    for indx in range(len(features)):
        feature = features[indx]
        phoneme = phonemes[indx]
        scale = phoneme_scales_map[phoneme]
        shift = phoneme_shifts_map[phoneme]
        transformed_feature = (feature * scale) + shift
        transformed_features.append(transformed_features)
    
    return np.array(transformed_features)

In [None]:
# get testing data features files list
file_list = []

for test_set in test_sets:
    test_feat_list = "data/" + conf_dict["mask"] + "/" + test_set + "/" + conf_dict["feature_type"] + ".txt"
    file_list.extend(read_feat_list(test_feat_list))
file_list = [x.replace('/run/user/1000/gvfs/smb-share:server=pse-fs-01.egr.duke.edu,share=lcollins-00/data/data/personal', "Y:/personal") for x in file_list] # do this so code works on my windows laptop

In [None]:
le_bpg = get_label_encoder(conf_dict["bpg"])
phonemes_encoded = encode_phones(le_bpg.classes_, le_bpg)

In [None]:
# colour_maps

manner_of_articulation_map = {"vowels":{"iy","ih", 'eh', 'ae', 'aa', 'ah', 'ao', 'uh', 'uw', 'ux', 'ax', 'ax-h', 'ix'},
                             "dipthongs":{'ey', 'aw', 'ay', 'oy', 'ow'},
                             "semi-vowels": {'l', 'el', 'r', 'w', 'y', 'er', 'axr'},
                             "stops": {'b', 'd', 'g', 'p', 't', 'k', 'jh', 'ch'},
                             "fricatives": {'s', 'sh', 'z', 'zh', 'f', 'th', 'v', 'dh', 'hh', 'hv'},
                             "nasals": {'m', 'em', 'n', 'nx', 'ng', 'eng', 'en'},
                             "silence": {'dx', 'bcl', 'dcl', 'gcl', 'pcl', 'tcl', 'kcl', 'h', 'pau', 'epi', 'q'},
                             "h#": {"h#"}}

vowels_and_consonants_map = {"vowels":{"iy","ih", 'eh', 'ae', 'aa', 'ah', 'ao', 'uh', 'uw', 'ux', 'ax', 'ax-h', 'ix', 'ey', 'aw', 'ay', 'oy', 'ow'},
                             "semi-vowels": {'l', 'el', 'r', 'w', 'y', 'er', 'axr'},
                             "consonants": {'b', 'd', 'g', 'p', 't', 'k', 'jh', 'ch', 's', 'sh', 'z', 'zh', 'f', 'th', 'v', 'dh', 'hh', 'hv', 'm', 'em', 'n', 'nx', 'ng', 'eng', 'en', 'dx', 'bcl', 'dcl', 'gcl', 'pcl', 'tcl', 'kcl', 'h', 'pau', 'epi', 'q'},
                             "h#": {"h#"}}

In [None]:
# run this if using arpanet phonemes (43) instead of timit phonemes (61)
_vowels_and_consonants_map = {}
_manner_of_articulation_map = {}


for phoneme_class,phoneme_set in vowels_and_consonants_map.items():
    _vowels_and_consonants_map[phoneme_class] = set()
    for phoneme in phoneme_set:
        if phoneme in le_bpg.classes_:
            _vowels_and_consonants_map[phoneme_class].add(phoneme)


for phoneme_class,phoneme_set in manner_of_articulation_map.items():
    _manner_of_articulation_map[phoneme_class] = set()
    for phoneme in phoneme_set:
        if phoneme in le_bpg.classes_:
            _manner_of_articulation_map[phoneme_class].add(phoneme)


vowels_and_consonants_map = _vowels_and_consonants_map
manner_of_articulation_map = _manner_of_articulation_map

In [None]:
# get the data
features = []
phonemes = []

# for file in file_list: # read all 720 files
for file in file_list[:5]: # sampling 5 files here for testing purposes
    _features, _, _phonemes = read_feat_file(file, conf_dict)
    features.extend(_features)
    phonemes.extend(_phonemes)

features = np.array(features)
phonemes = np.array(phonemes)
phonemes_from = []
phoneme_to_number = {}

for phoneme in phonemes:
    if phoneme in phoneme_to_number:
        pass
    else:
        phoneme_to_number[phoneme] = len(phoneme_to_number)+1
    phonemes_from.append(phoneme_to_number[phoneme])

In [None]:
# get and save transformation scale and shift values to ../../saved_variables/model_scales_and_weights if not already done
'''
steps:
    load the model.
    run each phoneme encoding through the model.scale_model, model.shift_model
    save the outputs using np.savetxt to the ../../saved_variables folder in format {phoneme}_scale.txt or {phoneme}_shift.txt
'''
pass

In [None]:
# get the transformation scale and shift values if saved to ../../saved_variables/model_scales_and_weights
dir_with_scale_and_shift_vectors = "../../saved_variables/model_scales_and_weights"
scale_vectors = {}
shift_vectors = {}
end_indx = -len("_scale.txt")

for filename in os.listdir(dir_with_scale_and_shift_vectors):
    phoneme = filename[:end_indx]
    if "scale" in filename:
        scale_file = dir_with_scale_and_shift_vectors+"/"+filename
        shift_file = dir_with_scale_and_shift_vectors+"/"+phoneme+"_shift.txt"
        phoneme = filename[:end_indx]
        scale_vectors[phoneme] = np.loadtxt(scale_file)
        shift_vectors[phoneme] = np.loadtxt(shift_file)

In [None]:
# get transformed features
sample_num = 500
transformed_features = transform_features(features[:sample_num], phonemes, scale_vectors, shift_vectors)

In [None]:
# plot umap of features
plot_umap(features[:sample_num], phonemes_from, phoneme_to_number, "Phonemes Before Transformation", manner_of_articulation_map)

In [None]:
# plot umap of transformed features
plot_umap(transformed_features, phonemes_from, phoneme_to_number, "Phonemes Before Transformation", manner_of_articulation_map)

# Show STOI and SRMR performance of model

In [None]:
def read_objective_intelligibility_results(filename):
    with open(filename,"r") as f:
        objective_intelligibility_scores = f.readlines()
    f.close()

    objective_intelligibility_scores = [float(x.strip()) for x in objective_intelligibility_scores]
    return objective_intelligibility_scores

In [None]:
metric = "srmr" # "srmr" or "stoi"

metric_to_barplot_label = {
    "srmr":"SRMR-CI Intelligibility Score",
    "stoi":"STOI Score"
}

In [None]:

objective_intelligibility_dp_scores = []
objective_intelligibility_ideal_scores = []
objective_intelligibility_simplified_phoneme_dependent_scores = []
objective_intelligibility_reverb_scores = []


for test_set in test_sets:
    results_dir = "exp/" + conf_dict["mask"] +"/"+ conf_file_filename[:-4] +"/" + model_name + "/results/" + test_set
    for filename in os.listdir(results_dir):
        file = results_dir + "/" + filename
        if metric in filename:
            if metric == "srmr":
                if "_dp.txt" in filename:
                    objective_intelligibility_dp_scores.extend(read_objective_intelligibility_results(file))
                elif "estimated.txt" in filename:
                    objective_intelligibility_simplified_phoneme_dependent_scores.extend(read_objective_intelligibility_results(file))
                elif "reverb.txt" in filename:
                    objective_intelligibility_reverb_scores.extend(read_objective_intelligibility_results(file))
                elif "ideal.txt" in filename:
                    objective_intelligibility_ideal_scores.extend(read_objective_intelligibility_results(file))
            elif metric == "stoi":
                if "ideal.txt" in filename:
                    objective_intelligibility_dp_scores.extend(read_objective_intelligibility_results(file))
                elif "estimated.txt" in filename:
                    objective_intelligibility_simplified_phoneme_dependent_scores.extend(read_objective_intelligibility_results(file))
                elif "reverb.txt" in filename:
                    objective_intelligibility_reverb_scores.extend(read_objective_intelligibility_results(file))
                elif "ibm.txt" in filename:
                    objective_intelligibility_ideal_scores.extend(read_objective_intelligibility_results(file))

In [None]:
len(objective_intelligibility_dp_scores)==len(objective_intelligibility_ideal_scores)==len(objective_intelligibility_simplified_phoneme_dependent_scores)==len(objective_intelligibility_reverb_scores)

In [None]:
def average(list_of_values):
    return sum(list_of_values)/len(list_of_values)

In [None]:
bars_to_values = {"REV":average(objective_intelligibility_reverb_scores),
"Simplified\nPhoneme-\nSpecific Model":average(objective_intelligibility_simplified_phoneme_dependent_scores),
"Ideal Mask":average(objective_intelligibility_ideal_scores),
"Direct Path":average(objective_intelligibility_dp_scores)
}

In [None]:
# get the outputs from kevin's phoneme independent models
objective_intelligibility_phoneme_independent_scores = []

for test_set in test_sets:
    results_dir = "exp/" + conf_dict["mask"] +"/LSTM_1layer_but_rev_log_fft_8kutts_batch16_sigapprox/model1/results/" + test_set
    for filename in os.listdir(results_dir):
        file = results_dir + "/" + filename
        if metric in filename:
            if metric == "srmr":
                if "updated" in filename: # because the way to calculate srmr has been updated and that's what i'm using
                    pass 
                else:
                    continue
            if "estimated.txt" in filename:
                objective_intelligibility_phoneme_independent_scores.extend(read_objective_intelligibility_results(file))

bars_to_values["Phoneme-\nIndependent\nModel"] = average(objective_intelligibility_phoneme_independent_scores)

In [None]:
# TODO: get the ouput from kevin's moe model
objective_intelligibility_moe_phoneme_dependent_scores = []

In [None]:
fig = plt.figure(figsize = (10, 5))

# creating the bar plot
plt.bar(bars_to_values.keys(), bars_to_values.values(), color ='maroon', 
        width = 0.4)

plt.ylabel(metric_to_barplot_label[metric])
plt.show()