This code finds all mislabelled data, lets you hear it and then you can relabel it yourself. The output file is used in embed_wav.ipynb to remove all unwanted files

In [2]:
import tensorflow as tf
from tensorflow.keras import layers
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from glob import glob
from cv import RepeatedStratifiedGroupKFold
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, plot_confusion_matrix, ConfusionMatrixDisplay
import s3fs

from phonemes import get_phoneme_prediction
from embed import embed_from_local, embed_from_s3
import operator
from utils import *

from IPython.display import Audio

INFO:tensorflow:Saver not created because there are no variables in the graph to restore
Embedding model loaded, embedding shape: (None, None, 1, 96)


In [3]:
# connect to s3
# must first enter AWS creds in cmd/terminal
s3 = s3fs.S3FileSystem()

In [4]:
consonant_list = ['b','p','k','g','m','n','t','d']  # other consonants excluded b/c of small sample size
NUM_CLASSES = len(consonant_list)  # 8 consonants
CONTEXT_SIZE = NUM_CLASSES * 2  # no of dims to extract from embeddings
RATE = 16_000

In [5]:
org_df = pd.read_csv('embeddings_org_ravel.txt')  # unaltered data
aug_df = pd.read_csv('embeddings_aug_ravel.txt')  # augmented data

In [6]:
all_df = pd.concat([org_df, aug_df], ignore_index=True)
all_df.shape

(5535, 1539)

In [19]:
preds_list = []
nrows = all_df.shape[0]
for idx in range(nrows):
    wav_m = all_df.iloc[idx,:1536].to_numpy()
    wav_m = np.reshape(wav_m,(1, 16,1,96))
    wav_m_tens = tf.convert_to_tensor(wav_m, np.float32)
    ph_pred = get_phoneme_prediction(wav_m_tens)
    # get the consonant with the highest score as a str
    max_con = max(ph_pred.items(), key=operator.itemgetter(1))[0]
    preds_list.append(max_con)

KeyboardInterrupt: 

In [15]:
preds_vs_actuals_df = pd.concat([pd.Series(preds_list),all_df['consonants']], axis=1)
preds_vs_actuals_df = preds_vs_actuals_df.iloc[:nrows,:]
mismatches_df = preds_vs_actuals_df[preds_vs_actuals_df[0] != preds_vs_actuals_df['consonants']]

In [16]:
mislabel_df = pd.DataFrame()
num_mismatches = mismatches_df.shape[0]

for idx, row in mismatches_df.iterrows():
    print(f'File {idx}/{num_mismatches}')
    filepath = org_df['filepath'][idx]
    wav = load_wav(filepath)
    display(Audio(wav, rate=RATE))
    is_correct = input(f"Is this a {row['consonants']} (y/n)?")
    if is_correct == 'n':
        consonant = input('What consonant is this?')
        correction = [[filepath, row['consonants'], consonant]]
        mislabel_df = mislabel_df.append(correction, ignore_index=True)
        
mislabel_df.columns = ['filepath','prev','correction']

Is this a d (y/n)? y


Is this a t (y/n)? y


Is this a t (y/n)? y


Is this a t (y/n)? y


Is this a b (y/n)? y


Is this a b (y/n)? y


KeyboardInterrupt: Interrupted by user

In [None]:
mislabel_df.to_csv('mislabelled_data_phase2.txt', index=False)