# Phoneme Classifier Stability Train and Test
### Author: Cathal Ó Faoláin
### 15:51, 02/08/2024

The goal of this work is to understand how we can use predicted IHC potentials, such as those predicted by WavIHC, introduced in the paper "WaveNet-based approximation of a cochlear filtering and hair cell transduction model".  Feature encoders designed to use these predicted IHC potentials are evaluated against other state-of-the-art feature encoders in order to understand how discriminating they are, and over a range of different Signal-to-Noise Ratios (SNRs).

This notebook checks how stable the feature encoders are, and what kind of variation we should expect to see when we evaluate them. We have 9 feature encoders:

- Contrastive Predictive Coding (CPC) 
- Wav2vec2.0
- Autoregressive Predictive Coding (APC)
- IHC CPC
- IHC CPC 80
- IHC Wav2vec2
- IHC Wav2vec2 80
- IHC Extract
- IHC Extract 512 

The first three feature encoders, CPC, Wav2vec2.0 and APC are based on the designs used in each of the papers. Any context encoders that tries to model longer-term dependencies have been removed - so no transformers or Recurrent Neural Networks (RNN). This is to allow for us to evaluate how discriminating the features themselves are. 

IHC CPC and Wav2vec2 are adapted feature encoders that take predicted IHC potentials as input rather than the signal alone. Each is inspired by their namesake models.

## Imports

In [1]:
import torch
from torch import nn
import librosa
import time
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from torch.utils.data import DataLoader, Dataset, IterableDataset
import torchaudio
import pandas as pd
import numpy as np
import time
import sys
import yaml
import math
import pathlib as Path
import scipy.signal as signal
from dataclasses import dataclass, field
from typing import List, Tuple
import torch.nn.functional as F
import pickle

In [2]:
sys.path.append('./IHCApproxNH/')
from classes import WaveNet
from utils import utils
from Encoders import FeatureEncoders 
from TIMIT_utils import TIMIT_utils
from Train_TestFunctions import TrainEvalFunctions

## Set Global Learning Settings

In [3]:
#Empty prior cuda cache
torch.cuda.empty_cache()

EPOCHS=100
learning_rate=0.01

#And save location 
dir_results=Path.Path('Results/k-Fold Stability')
dir_results.mkdir(parents=True, exist_ok=True)

## Combined Train and Test For Original Model Classifiers

In [4]:
original_models=["Wav2vec2", "Wav2vec2_80", "CPC_80", "CPC", "MelSimple",  "MelSimple_MLP", "SIG_Extract", "SIG_Extract_512", "SIG_Extract_2.0", "Whisper", "Whisper_80", "SIG_Extract_3.0"]

#Reload any old results so that we can continue training if required
if Path.Path('Results/k-Fold Stability/original_models.pkl').is_file():
    with open('Results/k-Fold Stability/original_models.pkl', 'rb') as f:
        test_accuracies = pickle.load(f)
        print(test_accuracies)
else:
    test_accuracies={}


for model in original_models:
    print("=============================")
    print("Starting k-Validation Testing on: %s" %model)
    for k in range(5):
        test_accu, test_loss, unique_phonemes, time =TrainEvalFunctions.train_epochs(model, EPOCHS, learning_rate=learning_rate, distributed=False, Kfold_eval=True, kInt=k)

        test_accuracies["{}-kFold-{}".format(model, k+1)]=test_accu
        #test_loss["{}-kFold-{}".format(model, k+1)]=np.average(test_loss)

        #When a model and kfold is trained, save it
        with open('Results/k-Fold Stability/original_models.pkl', 'wb') as f:
            pickle.dump(test_accuracies, f)
    
    print("==============================")
    print("")
    print("")

with open('Results/k-Fold Stability/original_models.pkl', 'wb') as f:
    pickle.dump(test_accuracies, f)

{'Wav2vec2-kFold-1': 0.5742884322943068, 'Wav2vec2-kFold-2': 0.5788991271621312, 'Wav2vec2-kFold-3': 0.5795726774666176, 'Wav2vec2-kFold-4': 0.5775589703707305, 'Wav2vec2-kFold-5': 0.5810655982446029, 'Wav2vec2_80-kFold-1': 0.5915646504135044, 'Wav2vec2_80-kFold-2': 0.5933353238943706, 'Wav2vec2_80-kFold-3': 0.5933353238943706, 'Wav2vec2_80-kFold-4': 0.5908633248387298, 'Wav2vec2_80-kFold-5': 0.588960718823995, 'CPC_80-kFold-1': 0.5955615221197296, 'CPC_80-kFold-2': 0.5973710979938965, 'CPC_80-kFold-3': 0.5944370246835837, 'CPC_80-kFold-4': 0.5948833636659309, 'CPC_80-kFold-5': 0.5950113833740459, 'CPC-kFold-1': 0.5846452470088368, 'CPC-kFold-2': 0.559650263997398, 'CPC-kFold-3': 0.5808842355839429, 'CPC-kFold-4': 0.5885446581181795, 'CPC-kFold-5': 0.5662623089219356, 'MelSimple-kFold-1': 0.4609932991450633, 'MelSimple-kFold-2': 0.4616209654335208, 'MelSimple-kFold-3': 0.4609691581339688, 'MelSimple-kFold-4': 0.46089673510068524, 'MelSimple-kFold-5': 0.46078292747695393, 'MelSimple_MLP

In [5]:
with open('Results/k-Fold Stability/original_models.pkl', 'rb') as f:
    loaded_test = pickle.load(f)

print(loaded_test)
#print(test_accuracies)

{'Wav2vec2-kFold-1': 0.5742884322943068, 'Wav2vec2-kFold-2': 0.5788991271621312, 'Wav2vec2-kFold-3': 0.5795726774666176, 'Wav2vec2-kFold-4': 0.5775589703707305, 'Wav2vec2-kFold-5': 0.5810655982446029, 'Wav2vec2_80-kFold-1': 0.5915646504135044, 'Wav2vec2_80-kFold-2': 0.5933353238943706, 'Wav2vec2_80-kFold-3': 0.5933353238943706, 'Wav2vec2_80-kFold-4': 0.5908633248387298, 'Wav2vec2_80-kFold-5': 0.588960718823995, 'CPC_80-kFold-1': 0.5955615221197296, 'CPC_80-kFold-2': 0.5973710979938965, 'CPC_80-kFold-3': 0.5944370246835837, 'CPC_80-kFold-4': 0.5948833636659309, 'CPC_80-kFold-5': 0.5950113833740459, 'CPC-kFold-1': 0.5846452470088368, 'CPC-kFold-2': 0.559650263997398, 'CPC-kFold-3': 0.5808842355839429, 'CPC-kFold-4': 0.5885446581181795, 'CPC-kFold-5': 0.5662623089219356, 'MelSimple-kFold-1': 0.4609932991450633, 'MelSimple-kFold-2': 0.4616209654335208, 'MelSimple-kFold-3': 0.4609691581339688, 'MelSimple-kFold-4': 0.46089673510068524, 'MelSimple-kFold-5': 0.46078292747695393, 'MelSimple_MLP

## Combined Train and Test IHC Feature Encoder Models

In [6]:
IHC_models=["IHC_Cpc", "IHC_Cpc_80","IHC_Wav2vec2_80", "IHC_Wav2vec2","IHC_Extract", "IHC_Extract_512", "IHC_Extract_2.0", "IHC_Extract_3.0"]

#Reload any old results so that we can continue training if required
if Path.Path('Results/k-Fold Stability/IHC_models.pkl').is_file():
    with open('Results/k-Fold Stability/IHC_models.pkl', 'rb') as f:
        IHC_test_accuracies = pickle.load(f)
        print(IHC_test_accuracies)
else:
    IHC_test_accuracies={}


for model in IHC_models:
    print("=============================")
    print("Starting k-Validation Testing on: %s" %model)
    for k in range(5):
        test_accu, test_loss, unique_phonemes, time =TrainEvalFunctions.train_epochs(model, EPOCHS, learning_rate=learning_rate, distributed=False, Kfold_eval=True, kInt=k)

        IHC_test_accuracies["{}-kFold-{}".format(model, k+1)]=test_accu
        #Update the results pickle for each one
        with open('Results/k-Fold Stability/IHC_models.pkl', 'wb') as f:
            pickle.dump(IHC_test_accuracies, f)

    print("==============================")
    print("")
    print("")
    
#Update the results with the final results
with open('Results/k-Fold Stability/IHC_models.pkl', 'wb') as f:
    pickle.dump(IHC_test_accuracies, f)

{'IHC_Extract_2.0-kFold-1': 0.63725447561064, 'IHC_Extract_2.0-kFold-2': 0.6341740958179314, 'IHC_Extract_2.0-kFold-3': 0.6381532217148655, 'IHC_Extract_2.0-kFold-4': 0.6319924621294485, 'IHC_Extract_2.0-kFold-5': 0.6350945857795173, 'IHC_Cpc-kFold-1': 0.581603284715025, 'IHC_Cpc-kFold-2': 0.5776707268190338, 'IHC_Cpc-kFold-3': 0.5741317858290691, 'IHC_Cpc-kFold-4': 0.574503735749443, 'IHC_Cpc-kFold-5': 0.5780751772178868, 'IHC_Cpc_80-kFold-1': 0.5780137874252037, 'IHC_Cpc_80-kFold-2': 0.5777501724330941, 'IHC_Cpc_80-kFold-3': 0.5782737912530379, 'IHC_Cpc_80-kFold-4': 0.5794257526569141, 'IHC_Cpc_80-kFold-5': 0.5755004170894739, 'IHC_Wav2vec2_80-kFold-1': 0.5954410379067914, 'IHC_Wav2vec2_80-kFold-2': 0.5963470319634703, 'IHC_Wav2vec2_80-kFold-3': 0.5955932449083134, 'IHC_Wav2vec2_80-kFold-4': 0.5955135174313256, 'IHC_Wav2vec2_80-kFold-5': 0.6000362397622672, 'IHC_Wav2vec2-kFold-1': 0.5890700877002247, 'IHC_Wav2vec2-kFold-2': 0.525592520113068, 'IHC_Wav2vec2-kFold-3': 0.583315213452199

In [7]:
with open('Results/k-Fold Stability/IHC_models.pkl', 'rb') as f:
    loaded_test = pickle.load(f)

print(loaded_test)

{'IHC_Extract_2.0-kFold-1': 0.63725447561064, 'IHC_Extract_2.0-kFold-2': 0.6341740958179314, 'IHC_Extract_2.0-kFold-3': 0.6381532217148655, 'IHC_Extract_2.0-kFold-4': 0.6319924621294485, 'IHC_Extract_2.0-kFold-5': 0.6350945857795173, 'IHC_Cpc-kFold-1': 0.581603284715025, 'IHC_Cpc-kFold-2': 0.5776707268190338, 'IHC_Cpc-kFold-3': 0.5741317858290691, 'IHC_Cpc-kFold-4': 0.574503735749443, 'IHC_Cpc-kFold-5': 0.5780751772178868, 'IHC_Cpc_80-kFold-1': 0.5780137874252037, 'IHC_Cpc_80-kFold-2': 0.5777501724330941, 'IHC_Cpc_80-kFold-3': 0.5782737912530379, 'IHC_Cpc_80-kFold-4': 0.5794257526569141, 'IHC_Cpc_80-kFold-5': 0.5755004170894739, 'IHC_Wav2vec2_80-kFold-1': 0.5954410379067914, 'IHC_Wav2vec2_80-kFold-2': 0.5963470319634703, 'IHC_Wav2vec2_80-kFold-3': 0.5955932449083134, 'IHC_Wav2vec2_80-kFold-4': 0.5955135174313256, 'IHC_Wav2vec2_80-kFold-5': 0.6000362397622672, 'IHC_Wav2vec2-kFold-1': 0.5890700877002247, 'IHC_Wav2vec2-kFold-2': 0.525592520113068, 'IHC_Wav2vec2-kFold-3': 0.583315213452199

## K-Fold Stability Test 
This is for testing alone. Training takes by far the majority of the time - this loads the already trained k-fold models and tests them on the clean TIMIT test dataset. This allows for rapid re-testing of saved trained models. The test dataset is the same, but the k-fold models were optimised to perform best on different validation sets, so the performance should be quite different. This should clearly show how robust the models are to different training data and regimes. Ideally, the stability should be low, with small standard deviation in performance.

In [8]:
original_models=["Wav2vec2", "Wav2vec2_80", "CPC_80", "CPC", "MelSimple", "MelSimple_MLP",  "SIG_Extract", "SIG_Extract_512", "Whisper", "Whisper_80", "SIG_Extract_3.0"]

for model in IHC_models:
    print("=============================")
    print("Starting k-Validation Testing on: %s" %model)
    for k in range(5):
        test_accu, test_loss, unique_phonemes, time =TrainEvalFunctions.test_best(model, distributed=False, Kfold_eval=True, kInt=k)

        IHC_test_accuracies["{}-kFold-{}".format(model, k+1)]=test_accu
        #Update the results pickle for each one
        with open('Results/k-Fold Stability/IHC_models.pkl', 'wb') as f:
            pickle.dump(IHC_test_accuracies, f)

    print("==============================")
    print("")
    print("")
    
#Update the results with the final results
with open('Results/k-Fold Stability/IHC_models.pkl', 'wb') as f:
    pickle.dump(IHC_test_accuracies, f)

Starting k-Validation Testing on: IHC_Extract_3.0
> Initialising model: IHC_Extract_3.0
**************************************************************
Testing Best Model found at: Model Checkpoints/IHC_Extract_3.0 Checkpoints/kFold Eval 0/best_IHC_Extract_3.0_checkpoint.pth.tar
**************************************************************
> Initialising model: IHC_Extract_3.0
> Setting: Test Mode
Loading model checkpoint
+---------------------------------------------+
Using Test Data. Test Accuracy :
Testing For: | Batchsize 4 | Steps: 237
Evaluation accuracy:  0.5205, Phoneme Error Rate:  0.4795, Loss :  1.2447, Time:  122.8010s, Time per sample:  0.5181s
+---------------------------------------------+
> Initialising model: IHC_Extract_3.0
**************************************************************
Testing Best Model found at: Model Checkpoints/IHC_Extract_3.0 Checkpoints/kFold Eval 1/best_IHC_Extract_3.0_checkpoint.pth.tar
*********************************************************

In [9]:
IHC_models=["IHC_Cpc", "IHC_Cpc_80","IHC_Wav2vec2_80", "IHC_Wav2vec2", "IHC_Extract", "IHC_Extract_512", "IHC_Extract_2.0", "IHC_Extract_3.0"]

for model in original_models:
    print("=============================")
    print("Starting k-Validation Testing on: %s" %model)
    for k in range(5):
        test_accu, test_loss, unique_phonemes, time =TrainEvalFunctions.test_best(model, distributed=False, Kfold_eval=True, kInt=k)

        test_accuracies["{}-kFold-{}".format(model, k+1)]=test_accu
        #test_loss["{}-kFold-{}".format(model, k+1)]=np.average(test_loss)

    print("==============================")
    print("")
    print("")

with open('Results/k-Fold Stability/original_models.pkl', 'wb') as f:
    pickle.dump(test_accuracies, f)

Starting k-Validation Testing on: Wav2vec2
> Initialising model: Wav2Vec2.0
**************************************************************
Testing Best Model found at: Model Checkpoints/Wav2Vec2.0_Encoder Checkpoints/kFold Eval 0/best_Wav2Vec2.0_Encoder_checkpoint.pth.tar
**************************************************************
> Initialising model: Wav2Vec2.0
> Setting: Test Mode
Loading model checkpoint
+---------------------------------------------+
Using Test Data. Test Accuracy :
Testing For: | Batchsize 4 | Steps: 237
Evaluation accuracy:  0.5743, Phoneme Error Rate:  0.4257, Loss :  1.1171, Time:  34.8878s, Time per sample:  0.1472s
+---------------------------------------------+
> Initialising model: Wav2Vec2.0
**************************************************************
Testing Best Model found at: Model Checkpoints/Wav2Vec2.0_Encoder Checkpoints/kFold Eval 1/best_Wav2Vec2.0_Encoder_checkpoint.pth.tar
**************************************************************
> Ini