In [1]:
import torchaudio
import os
import matplotlib.pyplot as plt
import pandas as pd

In [3]:
metadata = pd.read_csv("../data/darpa_timit_dataset/train_data.csv")

In [4]:
metadata = metadata[~metadata["filename"].isnull()]
metadata = metadata[metadata["filename"].str.contains("WAV.wav")]

In [5]:
metadata

Unnamed: 0,index,test_or_train,dialect_region,speaker_id,filename,path_from_data_dir,path_from_data_dir_windows,is_converted_audio,is_audio,is_word_file,is_phonetic_file,is_sentence_file
0,1.0,TRAIN,DR4,MMDM0,SI681.WAV.wav,TRAIN/DR4/MMDM0/SI681.WAV.wav,TRAIN\\DR4\\MMDM0\\SI681.WAV.wav,True,True,False,False,False
6,7.0,TRAIN,DR4,MMDM0,SI1311.WAV.wav,TRAIN/DR4/MMDM0/SI1311.WAV.wav,TRAIN\\DR4\\MMDM0\\SI1311.WAV.wav,True,True,False,False,False
10,11.0,TRAIN,DR4,MMDM0,SX141.WAV.wav,TRAIN/DR4/MMDM0/SX141.WAV.wav,TRAIN\\DR4\\MMDM0\\SX141.WAV.wav,True,True,False,False,False
13,14.0,TRAIN,DR4,MMDM0,SX51.WAV.wav,TRAIN/DR4/MMDM0/SX51.WAV.wav,TRAIN\\DR4\\MMDM0\\SX51.WAV.wav,True,True,False,False,False
23,24.0,TRAIN,DR4,MMDM0,SX411.WAV.wav,TRAIN/DR4/MMDM0/SX411.WAV.wav,TRAIN\\DR4\\MMDM0\\SX411.WAV.wav,True,True,False,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...
23077,23078.0,TRAIN,DR8,MRDM0,SX155.WAV.wav,TRAIN/DR8/MRDM0/SX155.WAV.wav,TRAIN\\DR8\\MRDM0\\SX155.WAV.wav,True,True,False,,False
23080,23081.0,TRAIN,DR8,MRDM0,SI965.WAV.wav,TRAIN/DR8/MRDM0/SI965.WAV.wav,TRAIN\\DR8\\MRDM0\\SI965.WAV.wav,True,True,False,,False
23088,23089.0,TRAIN,DR8,MRDM0,SA1.WAV.wav,TRAIN/DR8/MRDM0/SA1.WAV.wav,TRAIN\\DR8\\MRDM0\\SA1.WAV.wav,True,True,False,,False
23095,23096.0,TRAIN,DR8,MRDM0,SI1044.WAV.wav,TRAIN/DR8/MRDM0/SI1044.WAV.wav,TRAIN\\DR8\\MRDM0\\SI1044.WAV.wav,True,True,False,,False


In [9]:
def load_text(file_path):
    file_path = os.path.join("../data/darpa_timit_dataset/data", file_path)
    file_path = file_path.replace("WAV.wav", "WRD")
    with open(file_path, 'r') as f:
        text = f.read().strip()
        
    text = text.split('\n')
    text = [x.strip().split(' ', 2) for x in text if len(x.split(' ', 2)) == 3]
    return [int(x[0]) for x in text], [int(x[1]) for x in text], [x[2] for x in text]

In [10]:
def dur(x):
    start, end, _ = load_text(x)
    return -1 * (start[0] - end[-1])

metadata["durations"] = metadata["path_from_data_dir"].apply(dur)

In [11]:
metadata["sex"] = metadata.speaker_id.str[0]

In [12]:
list_obj = metadata.groupby(["sex", "dialect_region", "speaker_id"]).agg({
    "filename": list,
    "path_from_data_dir": list,
    "durations": "sum"
})

In [13]:
list_obj = list_obj.reset_index()

In [14]:
list_obj["count"] = list_obj["filename"].apply(len)
list_obj = list_obj.sort_values(by="count", ascending=False)

stats = list_obj[["sex", "dialect_region", "speaker_id", "durations"]]

In [15]:
stats.to_csv("timit_speaker_stats.csv", index=False)

In [16]:
speaker_id = set()

for s in ["M", "F"]:
    for dr in range(1, 9):
        dr = f"DR{dr}"
        subset = list_obj[(list_obj["sex"] == s) & (list_obj["dialect_region"] == dr)]
        subset = subset.sort_values(by="durations", ascending=False)
        speaker_id.add(subset.iloc[0, 2])

In [17]:
speaker_id

{'FAPB0',
 'FECD0',
 'FGRW0',
 'FJEN0',
 'FKLH0',
 'FLKM0',
 'FSCN0',
 'FSJG0',
 'MCAE0',
 'MJDE0',
 'MPMB0',
 'MRRE0',
 'MRSO0',
 'MRTJ0',
 'MTAB0',
 'MTQC0'}

In [18]:
full_data = metadata[metadata["speaker_id"].isin(speaker_id)].reset_index(drop=True)

In [19]:
full_data["unq_id"] = full_data["speaker_id"] + "_" + full_data["filename"].str.replace(".WAV.wav", "")

In [20]:
speaker_file = full_data.groupby("speaker_id")['unq_id'].apply(list).reset_index()


In [21]:
import random

train_ = []
val_ = []
for speaker in speaker_file.to_dict(orient='records'):
    speaker_id = speaker['speaker_id']
    filenames = speaker['unq_id']
    _t = random.sample(filenames, k=int(len(filenames)*0.8))
    train_.extend(_t)
    val_.extend([f for f in filenames if f not in _t])

In [22]:
len(train_), len(val_)

(128, 32)

In [23]:
val_

['FAPB0_SI1693',
 'FAPB0_SX343',
 'FECD0_SX158',
 'FECD0_SX338',
 'FGRW0_SI1782',
 'FGRW0_SX72',
 'FJEN0_SI1677',
 'FJEN0_SI2307',
 'FKLH0_SX87',
 'FKLH0_SX267',
 'FLKM0_SI686',
 'FLKM0_SX440',
 'FSCN0_SI705',
 'FSCN0_SA1',
 'FSJG0_SI2200',
 'FSJG0_SA1',
 'MCAE0_SX97',
 'MCAE0_SI2077',
 'MJDE0_SX310',
 'MJDE0_SA1',
 'MPMB0_SI1501',
 'MPMB0_SX151',
 'MRRE0_SA1',
 'MRRE0_SA2',
 'MRSO0_SX39',
 'MRSO0_SI1206',
 'MRTJ0_SI772',
 'MRTJ0_SX322',
 'MTAB0_SX222',
 'MTAB0_SA2',
 'MTQC0_SI2071',
 'MTQC0_SA1']

In [24]:
import random

val_data = full_data[full_data['unq_id'].isin(val_)].reset_index(drop=True)
train_data = full_data[full_data['unq_id'].isin(train_)].reset_index(drop=True)

In [25]:
val_data.shape, train_data.shape

((32, 15), (128, 15))

In [26]:
train_data.to_csv("../data/darpa_timit_dataset/train_data_top20.csv", index=False)
val_data.to_csv("../data/darpa_timit_dataset/val_data_top20.csv", index=False)

In [27]:
sinc_net_trained = pd.read_csv(
    "../data/darpa_timit_dataset/train_data_unseen.csv"
)

FileNotFoundError: [Errno 2] No such file or directory: '../data/darpa_timit_dataset/train_data_unseen.csv'

In [124]:
set1 = set(sinc_net_trained["speaker_id"])
set2 = set(train_data["speaker_id"])

In [125]:
set1.intersection(set2)

set()