In [178]:
import IPython.display as ipd
import librosa
import torch
import torchaudio
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import (
    AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC
)
%matplotlib inline

In [179]:
DEVICE = "cuda"

In [180]:
import pandas as pd
import os

train_data = pd.read_csv("train_data.csv").dropna()
test_data = pd.read_csv("test_data.csv").dropna()


In [181]:
train_data.head()

Unnamed: 0,index,test_or_train,dialect_region,speaker_id,filename,path_from_data_dir,path_from_data_dir_windows,is_converted_audio,is_audio,is_word_file,is_phonetic_file,is_sentence_file
0,1.0,TRAIN,DR4,MMDM0,SI681.WAV.wav,TRAIN/DR4/MMDM0/SI681.WAV.wav,TRAIN\\DR4\\MMDM0\\SI681.WAV.wav,True,True,False,False,False
1,2.0,TRAIN,DR4,MMDM0,SI1311.PHN,TRAIN/DR4/MMDM0/SI1311.PHN,TRAIN\\DR4\\MMDM0\\SI1311.PHN,False,False,False,True,False
2,3.0,TRAIN,DR4,MMDM0,SI1311.WRD,TRAIN/DR4/MMDM0/SI1311.WRD,TRAIN\\DR4\\MMDM0\\SI1311.WRD,False,False,True,False,False
3,4.0,TRAIN,DR4,MMDM0,SX321.PHN,TRAIN/DR4/MMDM0/SX321.PHN,TRAIN\\DR4\\MMDM0\\SX321.PHN,False,False,False,True,False
4,5.0,TRAIN,DR4,MMDM0,SX321.WRD,TRAIN/DR4/MMDM0/SX321.WRD,TRAIN\\DR4\\MMDM0\\SX321.WRD,False,False,True,False,False


In [182]:
data_path = 'data/'

def combine_files(df):
    data = {}

    for idx, row in tqdm(df.iterrows()):
        path = row['path_from_data_dir']
        entry_id = path.split('.')[0]

        if entry_id not in data:
            data[entry_id] = {}

        if row['is_audio'] is True:
            data[entry_id]['audio_file'] = os.path.join(data_path, path)
        elif row['is_word_file'] is True:
            data[entry_id]['word_file'] = os.path.join(data_path, path)
        elif row['is_phonetic_file'] is True:
            data[entry_id]['phonetic_file'] = os.path.join(data_path, path)
    return data

In [183]:
combined_train_data = combine_files(train_data)
combined_test_data = combine_files(test_data)


0it [00:00, ?it/s][A
8400it [00:00, 58224.86it/s][A

0it [00:00, ?it/s][A
8400it [00:00, 59616.79it/s][A


In [184]:
from sklearn.model_selection import train_test_split
import pandas as pd

def dicts_to_splits(train_dict: dict, test_dict: dict, val_ratio: float = 0.5, seed: int = 42):
    train_df = pd.DataFrame.from_dict(train_dict, orient="index")
    test_full_df = pd.DataFrame.from_dict(test_dict, orient="index")
    val_df, test_df = train_test_split(test_full_df, test_size=val_ratio, random_state=seed, shuffle=True)
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)

train_df, val_df, test_df = dicts_to_splits(combined_train_data, combined_test_data)


In [185]:
train_df.head()

Unnamed: 0,audio_file,phonetic_file,word_file
0,data/TRAIN/DR4/MMDM0/SI681.WAV,data/TRAIN/DR4/MMDM0/SI681.PHN,data/TRAIN/DR4/MMDM0/SI681.WRD
1,data/TRAIN/DR4/MMDM0/SI1311.WAV,data/TRAIN/DR4/MMDM0/SI1311.PHN,data/TRAIN/DR4/MMDM0/SI1311.WRD
2,data/TRAIN/DR4/MMDM0/SX321.WAV,data/TRAIN/DR4/MMDM0/SX321.PHN,data/TRAIN/DR4/MMDM0/SX321.WRD
3,data/TRAIN/DR4/MMDM0/SX51.WAV,data/TRAIN/DR4/MMDM0/SX51.PHN,data/TRAIN/DR4/MMDM0/SX51.WRD
4,data/TRAIN/DR4/MMDM0/SX231.WAV.wav,data/TRAIN/DR4/MMDM0/SX231.PHN,data/TRAIN/DR4/MMDM0/SX231.WRD


In [186]:
phon61_map39 = {
    'iy':'iy',  'ih':'ih',   'eh':'eh',  'ae':'ae',    'ix':'ih',  'ax':'ah',   'ah':'ah',  'uw':'uw',
    'ux':'uw',  'uh':'uh',   'ao':'aa',  'aa':'aa',    'ey':'ey',  'ay':'ay',   'oy':'oy',  'aw':'aw',
    'ow':'ow',  'l':'l',     'el':'l',  'r':'r',      'y':'y',    'w':'w',     'er':'er',  'axr':'er',
    'm':'m',    'em':'m',     'n':'n',    'nx':'n',     'en':'n',  'ng':'ng',   'eng':'ng', 'ch':'ch',
    'jh':'jh',  'dh':'dh',   'b':'b',    'd':'d',      'dx':'dx',  'g':'g',     'p':'p',    't':'t',
    'k':'k',    'z':'z',     'zh':'sh',  'v':'v',      'f':'f',    'th':'th',   's':'s',    'sh':'sh',
    'hh':'hh',  'hv':'hh',   'pcl':'h#', 'tcl':'h#', 'kcl':'h#', 'qcl':'h#','bcl':'h#','dcl':'h#',
    'gcl':'h#','h#':'h#',  '#h':'h#',  'pau':'h#', 'epi': 'h#','nx':'n',   'ax-h':'ah','q':'h#' 
}

In [187]:
from tqdm import tqdm


def extract_file_data(df):
    data = []
    for i, row in tqdm(df.iterrows(), total=len(df)):
        try:
            waveform, _ = torchaudio.load(row["audio_file"])

            with open(row["phonetic_file"]) as f:
                phonetic_transcription = " ".join(
                    phon61_map39[line.split()[2].strip()] for line in f if line.strip()
                )

            with open(row["word_file"]) as f:
                word_transcription = " ".join(
                    line.split()[2].strip() for line in f if line.strip()
                )

            data.append([waveform, phonetic_transcription, word_transcription])
            print(f"[{i}] Success: {row['audio_file']}")
        except Exception as e:
            print(f"[{i}] Error: {row['audio_file']} — {e}")
    return data


In [188]:
train_data = extract_file_data(train_df)


  0%|          | 0/1680 [00:00<?, ?it/s][A

[0] Success: data/TRAIN/DR4/MMDM0/SI681.WAV
[1] Success: data/TRAIN/DR4/MMDM0/SI1311.WAV
[2] Success: data/TRAIN/DR4/MMDM0/SX321.WAV
[3] Success: data/TRAIN/DR4/MMDM0/SX51.WAV
[4] Success: data/TRAIN/DR4/MMDM0/SX231.WAV.wav
[5] Success: data/TRAIN/DR4/MMDM0/SX141.WAV
[6] Success: data/TRAIN/DR4/MMDM0/SI1941.WAV
[7] Success: data/TRAIN/DR4/MMDM0/SA1.WAV.wav
[8] Success: data/TRAIN/DR4/MMDM0/SA2.WAV.wav
[9] Success: data/TRAIN/DR4/MMDM0/SX411.WAV.wav
[10] Success: data/TRAIN/DR4/MCSS0/SX30.WAV
[11] Success: data/TRAIN/DR4/MCSS0/SI750.WAV
[12] Success: data/TRAIN/DR4/MCSS0/SI688.WAV
[13] Success: data/TRAIN/DR4/MCSS0/SX120.WAV
[14] Success: data/TRAIN/DR4/MCSS0/SX300.WAV
[15] Success: data/TRAIN/DR4/MCSS0/SX210.WAV.wav
[16] Success: data/TRAIN/DR4/MCSS0/SX390.WAV
[17] Success: data/TRAIN/DR4/MCSS0/SI1380.WAV.wav
[18] Success: data/TRAIN/DR4/MCSS0/SA1.WAV.wav
[19] Success: data/TRAIN/DR4/MCSS0/SA2.WAV.wav
[20] Success: data/TRAIN/DR4/MCDR0/SX164.WAV
[21] Success: data/TRAIN/DR4/MCDR0/SI524


 13%|█▎        | 217/1680 [00:00<00:00, 2154.61it/s][A

[160] Success: data/TRAIN/DR4/MJRH0/SI1125.WAV.wav
[161] Success: data/TRAIN/DR4/MJRH0/SX315.WAV.wav
[162] Success: data/TRAIN/DR4/MJRH0/SX225.WAV
[163] Success: data/TRAIN/DR4/MJRH0/SX135.WAV
[164] Success: data/TRAIN/DR4/MJRH0/SI1840.WAV
[165] Success: data/TRAIN/DR4/MJRH0/SX45.WAV.wav
[166] Success: data/TRAIN/DR4/MJRH0/SI1755.WAV.wav
[167] Success: data/TRAIN/DR4/MJRH0/SA1.WAV.wav
[168] Success: data/TRAIN/DR4/MJRH0/SA2.WAV.wav
[169] Success: data/TRAIN/DR4/MJRH0/SX405.WAV.wav
[170] Success: data/TRAIN/DR4/MLSH0/SX247.WAV
[171] Success: data/TRAIN/DR4/MLSH0/SX337.WAV
[172] Success: data/TRAIN/DR4/MLSH0/SI1417.WAV.wav
[173] Success: data/TRAIN/DR4/MLSH0/SX67.WAV
[174] Success: data/TRAIN/DR4/MLSH0/SI787.WAV.wav
[175] Success: data/TRAIN/DR4/MLSH0/SX427.WAV
[176] Success: data/TRAIN/DR4/MLSH0/SA1.WAV.wav
[177] Success: data/TRAIN/DR4/MLSH0/SA2.WAV.wav
[178] Success: data/TRAIN/DR4/MLSH0/SI2047.WAV
[179] Success: data/TRAIN/DR4/MLSH0/SX157.WAV.wav
[180] Success: data/TRAIN/DR4/MTAS0/S


 26%|██▌       | 436/1680 [00:00<00:00, 2169.70it/s][A

[433] Success: data/TRAIN/DR4/MBMA0/SX412.WAV
[434] Success: data/TRAIN/DR4/MBMA0/SX322.WAV
[435] Success: data/TRAIN/DR4/MBMA0/SI592.WAV.wav
[436] Success: data/TRAIN/DR4/MBMA0/SX232.WAV.wav
[437] Success: data/TRAIN/DR4/MBMA0/SA1.WAV.wav
[438] Success: data/TRAIN/DR4/MBMA0/SA2.WAV.wav
[439] Success: data/TRAIN/DR4/MBMA0/SI1222.WAV.wav
[440] Success: data/TRAIN/DR4/MRGM0/SI532.WAV.wav
[441] Success: data/TRAIN/DR4/MRGM0/SX442.WAV.wav
[442] Success: data/TRAIN/DR4/MRGM0/SX262.WAV.wav
[443] Success: data/TRAIN/DR4/MRGM0/SX172.WAV.wav
[444] Success: data/TRAIN/DR4/MRGM0/SX416.WAV.wav
[445] Success: data/TRAIN/DR4/MRGM0/SI1792.WAV
[446] Success: data/TRAIN/DR4/MRGM0/SA1.WAV.wav
[447] Success: data/TRAIN/DR4/MRGM0/SA2.WAV.wav
[448] Success: data/TRAIN/DR4/MRGM0/SX82.WAV
[449] Success: data/TRAIN/DR4/MRGM0/SI1162.WAV
[450] Success: data/TRAIN/DR4/FDKN0/SX271.WAV.wav
[451] Success: data/TRAIN/DR4/FDKN0/SI1711.WAV
[452] Success: data/TRAIN/DR4/FDKN0/SI1081.WAV.wav
[453] Success: data/TRAIN/DR


 39%|███▉      | 660/1680 [00:00<00:00, 2195.98it/s][A

[607] Success: data/TRAIN/DR4/MLBC0/SA2.WAV.wav
[608] Success: data/TRAIN/DR4/MLBC0/SI609.WAV.wav
[609] Success: data/TRAIN/DR4/MLBC0/SX69.WAV.wav
[610] Success: data/TRAIN/DR4/MPRT0/SX310.WAV.wav
[611] Success: data/TRAIN/DR4/MPRT0/SX130.WAV
[612] Success: data/TRAIN/DR4/MPRT0/SI495.WAV.wav
[613] Success: data/TRAIN/DR4/MPRT0/SI1210.WAV
[614] Success: data/TRAIN/DR4/MPRT0/SX400.WAV
[615] Success: data/TRAIN/DR4/MPRT0/SX40.WAV
[616] Success: data/TRAIN/DR4/MPRT0/SA1.WAV.wav
[617] Success: data/TRAIN/DR4/MPRT0/SI580.WAV.wav
[618] Success: data/TRAIN/DR4/MPRT0/SA2.WAV.wav
[619] Success: data/TRAIN/DR4/MPRT0/SX220.WAV.wav
[620] Success: data/TRAIN/DR4/MNET0/SI1446.WAV
[621] Success: data/TRAIN/DR4/MNET0/SX276.WAV.wav
[622] Success: data/TRAIN/DR4/MNET0/SI816.WAV.wav
[623] Success: data/TRAIN/DR4/MNET0/SI2076.WAV.wav
[624] Success: data/TRAIN/DR4/MNET0/SX6.WAV.wav
[625] Success: data/TRAIN/DR4/MNET0/SX366.WAV.wav
[626] Success: data/TRAIN/DR4/MNET0/SX186.WAV.wav
[627] Success: data/TRAIN/D




[818] Success: data/TRAIN/DR3/FSLS0/SA2.WAV.wav
[819] Success: data/TRAIN/DR3/FSLS0/SX156.WAV.wav
[820] Success: data/TRAIN/DR3/MRTJ0/SX52.WAV
[821] Success: data/TRAIN/DR3/MRTJ0/SX142.WAV
[822] Success: data/TRAIN/DR3/MRTJ0/SI2032.WAV.wav
[823] Success: data/TRAIN/DR3/MRTJ0/SX412.WAV
[824] Success: data/TRAIN/DR3/MRTJ0/SX322.WAV
[825] Success: data/TRAIN/DR3/MRTJ0/SI772.WAV.wav
[826] Success: data/TRAIN/DR3/MRTJ0/SX232.WAV.wav
[827] Success: data/TRAIN/DR3/MRTJ0/SI1551.WAV.wav
[828] Success: data/TRAIN/DR3/MRTJ0/SA1.WAV.wav
[829] Success: data/TRAIN/DR3/MRTJ0/SA2.WAV.wav
[830] Success: data/TRAIN/DR3/MTLB0/SI504.WAV
[831] Success: data/TRAIN/DR3/MTLB0/SI1764.WAV
[832] Success: data/TRAIN/DR3/MTLB0/SX324.WAV
[833] Success: data/TRAIN/DR3/MTLB0/SX234.WAV
[834] Success: data/TRAIN/DR3/MTLB0/SI1134.WAV.wav
[835] Success: data/TRAIN/DR3/MTLB0/SX414.WAV.wav
[836] Success: data/TRAIN/DR3/MTLB0/SX144.WAV.wav
[837] Success: data/TRAIN/DR3/MTLB0/SA1.WAV.wav
[838] Success: data/TRAIN/DR3/MTLB0/S

 52%|█████▏    | 881/1680 [00:00<00:00, 2188.77it/s][A

[881] Success: data/TRAIN/DR3/FLTM0/SI1070.WAV
[882] Success: data/TRAIN/DR3/FLTM0/SX440.WAV
[883] Success: data/TRAIN/DR3/FLTM0/SI1700.WAV
[884] Success: data/TRAIN/DR3/FLTM0/SX80.WAV
[885] Success: data/TRAIN/DR3/FLTM0/SX350.WAV.wav
[886] Success: data/TRAIN/DR3/FLTM0/SX170.WAV.wav
[887] Success: data/TRAIN/DR3/FLTM0/SI2330.WAV
[888] Success: data/TRAIN/DR3/FLTM0/SA1.WAV.wav
[889] Success: data/TRAIN/DR3/FLTM0/SA2.WAV.wav
[890] Success: data/TRAIN/DR3/MCDD0/SX253.WAV.wav
[891] Success: data/TRAIN/DR3/MCDD0/SI883.WAV
[892] Success: data/TRAIN/DR3/MCDD0/SI2143.WAV.wav
[893] Success: data/TRAIN/DR3/MCDD0/SI1513.WAV
[894] Success: data/TRAIN/DR3/MCDD0/SX73.WAV.wav
[895] Success: data/TRAIN/DR3/MCDD0/SX163.WAV.wav
[896] Success: data/TRAIN/DR3/MCDD0/SX343.WAV
[897] Success: data/TRAIN/DR3/MCDD0/SA1.WAV.wav
[898] Success: data/TRAIN/DR3/MCDD0/SA2.WAV.wav
[899] Success: data/TRAIN/DR3/MCDD0/SX433.WAV
[900] Success: data/TRAIN/DR3/MDWM0/SI2176.WAV.wav
[901] Success: data/TRAIN/DR3/MDWM0/SX10


 65%|██████▌   | 1100/1680 [00:00<00:00, 2184.59it/s][A

[1058] Success: data/TRAIN/DR3/MVJH0/SA1.WAV.wav
[1059] Success: data/TRAIN/DR3/MVJH0/SA2.WAV.wav
[1060] Success: data/TRAIN/DR3/FGRW0/SX252.WAV.wav
[1061] Success: data/TRAIN/DR3/FGRW0/SI1152.WAV
[1062] Success: data/TRAIN/DR3/FGRW0/SI1990.WAV.wav
[1063] Success: data/TRAIN/DR3/FGRW0/SX72.WAV.wav
[1064] Success: data/TRAIN/DR3/FGRW0/SI1782.WAV.wav
[1065] Success: data/TRAIN/DR3/FGRW0/SX342.WAV
[1066] Success: data/TRAIN/DR3/FGRW0/SA1.WAV.wav
[1067] Success: data/TRAIN/DR3/FGRW0/SA2.WAV.wav
[1068] Success: data/TRAIN/DR3/FGRW0/SX162.WAV.wav
[1069] Success: data/TRAIN/DR3/FGRW0/SX432.WAV
[1070] Success: data/TRAIN/DR3/MRJB1/SI2021.WAV
[1071] Success: data/TRAIN/DR3/MRJB1/SX30.WAV
[1072] Success: data/TRAIN/DR3/MRJB1/SX120.WAV
[1073] Success: data/TRAIN/DR3/MRJB1/SX300.WAV
[1074] Success: data/TRAIN/DR3/MRJB1/SI1413.WAV
[1075] Success: data/TRAIN/DR3/MRJB1/SX210.WAV.wav
[1076] Success: data/TRAIN/DR3/MRJB1/SX390.WAV
[1077] Success: data/TRAIN/DR3/MRJB1/SA1.WAV.wav
[1078] Success: data/TR


 79%|███████▊  | 1319/1680 [00:00<00:00, 2178.67it/s][A

[1295] Success: data/TRAIN/DR3/MWGR0/SI1606.WAV
[1296] Success: data/TRAIN/DR3/MWGR0/SX346.WAV
[1297] Success: data/TRAIN/DR3/MWGR0/SA1.WAV.wav
[1298] Success: data/TRAIN/DR3/MWGR0/SA2.WAV.wav
[1299] Success: data/TRAIN/DR3/MWGR0/SX76.WAV.wav
[1300] Success: data/TRAIN/DR3/FCMG0/SX252.WAV.wav
[1301] Success: data/TRAIN/DR3/FCMG0/SI1872.WAV.wav
[1302] Success: data/TRAIN/DR3/FCMG0/SI1242.WAV.wav
[1303] Success: data/TRAIN/DR3/FCMG0/SI1142.WAV
[1304] Success: data/TRAIN/DR3/FCMG0/SX72.WAV.wav
[1305] Success: data/TRAIN/DR3/FCMG0/SX342.WAV
[1306] Success: data/TRAIN/DR3/FCMG0/SA1.WAV.wav
[1307] Success: data/TRAIN/DR3/FCMG0/SA2.WAV.wav
[1308] Success: data/TRAIN/DR3/FCMG0/SX162.WAV.wav
[1309] Success: data/TRAIN/DR3/FCMG0/SX432.WAV
[1310] Success: data/TRAIN/DR3/MAKB0/SX26.WAV
[1311] Success: data/TRAIN/DR3/MAKB0/SX116.WAV
[1312] Success: data/TRAIN/DR3/MAKB0/SX296.WAV
[1313] Success: data/TRAIN/DR3/MAKB0/SI1646.WAV.wav
[1314] Success: data/TRAIN/DR3/MAKB0/SX386.WAV
[1315] Success: data/T


100%|██████████| 1680/1680 [00:00<00:00, 2182.93it/s][A

[1502] Success: data/TRAIN/DR2/MEFG0/SI465.WAV
[1503] Success: data/TRAIN/DR2/MEFG0/SI491.WAV.wav
[1504] Success: data/TRAIN/DR2/MEFG0/SX15.WAV.wav
[1505] Success: data/TRAIN/DR2/MEFG0/SI598.WAV.wav
[1506] Success: data/TRAIN/DR2/MEFG0/SX375.WAV.wav
[1507] Success: data/TRAIN/DR2/MEFG0/SA1.WAV.wav
[1508] Success: data/TRAIN/DR2/MEFG0/SA2.WAV.wav
[1509] Success: data/TRAIN/DR2/MEFG0/SX195.WAV.wav
[1510] Success: data/TRAIN/DR2/FMMH0/SI1537.WAV
[1511] Success: data/TRAIN/DR2/FMMH0/SI2167.WAV.wav
[1512] Success: data/TRAIN/DR2/FMMH0/SI907.WAV.wav
[1513] Success: data/TRAIN/DR2/FMMH0/SX367.WAV.wav
[1514] Success: data/TRAIN/DR2/FMMH0/SX7.WAV.wav
[1515] Success: data/TRAIN/DR2/FMMH0/SX187.WAV.wav
[1516] Success: data/TRAIN/DR2/FMMH0/SX420.WAV.wav
[1517] Success: data/TRAIN/DR2/FMMH0/SA1.WAV.wav
[1518] Success: data/TRAIN/DR2/FMMH0/SA2.WAV.wav
[1519] Success: data/TRAIN/DR2/FMMH0/SX97.WAV
[1520] Success: data/TRAIN/DR2/FPJF0/SI1676.WAV
[1521] Success: data/TRAIN/DR2/FPJF0/SX326.WAV
[1522] Su




In [189]:
test_data = extract_file_data(test_df)


  0%|          | 0/840 [00:00<?, ?it/s][A

[0] Success: data/TEST/DR8/MJLN0/SX189.WAV
[1] Success: data/TEST/DR3/MHPG0/SX10.WAV.wav
[2] Success: data/TEST/DR4/MLJB0/SX230.WAV.wav
[3] Success: data/TEST/DR4/MPLB0/SA2.WAV.wav
[4] Success: data/TEST/DR5/FMAH0/SA1.WAV.wav
[5] Success: data/TEST/DR3/MHPG0/SX190.WAV
[6] Success: data/TEST/DR8/MAJC0/SA1.WAV.wav
[7] Success: data/TEST/DR7/MNLS0/SA1.WAV.wav
[8] Success: data/TEST/DR8/FJSJ0/SX44.WAV.wav
[9] Success: data/TEST/DR3/MRTK0/SA1.WAV.wav
[10] Success: data/TEST/DR7/MDLF0/SA1.WAV.wav
[11] Success: data/TEST/DR5/FJSA0/SA1.WAV.wav
[12] Success: data/TEST/DR3/MBDG0/SX383.WAV.wav
[13] Success: data/TEST/DR3/MGJF0/SI641.WAV.wav
[14] Success: data/TEST/DR3/MBWM0/SX314.WAV.wav
[15] Success: data/TEST/DR4/MROA0/SX137.WAV
[16] Success: data/TEST/DR4/MJRF0/SX11.WAV.wav
[17] Success: data/TEST/DR3/MJMP0/SX275.WAV.wav
[18] Success: data/TEST/DR3/MLNT0/SX192.WAV
[19] Success: data/TEST/DR3/MBDG0/SI833.WAV
[20] Success: data/TEST/DR2/MGWT0/SX9.WAV.wav
[21] Success: data/TEST/DR6/MJFC0/SX403.W


 25%|██▌       | 214/840 [00:00<00:00, 2124.26it/s][A

[184] Success: data/TEST/DR2/MCEM0/SA1.WAV.wav
[185] Success: data/TEST/DR2/FSLB1/SX14.WAV.wav
[186] Success: data/TEST/DR4/MROA0/SI1307.WAV
[187] Success: data/TEST/DR7/FSXA0/SI1108.WAV.wav
[188] Success: data/TEST/DR6/FLNH0/SX404.WAV.wav
[189] Success: data/TEST/DR7/MGRT0/SX190.WAV
[190] Success: data/TEST/DR3/MKCH0/SX28.WAV.wav
[191] Success: data/TEST/DR2/FSLB1/SI644.WAV.wav
[192] Success: data/TEST/DR6/MJDH0/SX4.WAV
[193] Success: data/TEST/DR7/MKJL0/SX290.WAV
[194] Success: data/TEST/DR2/MDLD0/SI1543.WAV.wav
[195] Success: data/TEST/DR2/FJAS0/SX320.WAV
[196] Success: data/TEST/DR4/FREW0/SX380.WAV
[197] Success: data/TEST/DR3/MGLB0/SA1.WAV.wav
[198] Success: data/TEST/DR4/FMAF0/SX379.WAV
[199] Success: data/TEST/DR4/FMCM0/SX10.WAV.wav
[200] Success: data/TEST/DR3/MCSH0/SI919.WAV
[201] Success: data/TEST/DR2/MPDF0/SI1542.WAV.wav
[202] Success: data/TEST/DR7/FCAU0/SX137.WAV
[203] Success: data/TEST/DR8/MAJC0/SI2095.WAV
[204] Success: data/TEST/DR1/MREB0/SI1375.WAV
[205] Success: dat


 52%|█████▏    | 433/840 [00:00<00:00, 2151.25it/s][A


[419] Success: data/TEST/DR4/MPCS0/SX189.WAV
[420] Success: data/TEST/DR2/MMDB1/SX275.WAV.wav
[421] Success: data/TEST/DR7/FISB0/SA1.WAV.wav
[422] Success: data/TEST/DR2/FSLB1/SX374.WAV.wav
[423] Success: data/TEST/DR5/FCAL1/SA1.WAV.wav
[424] Success: data/TEST/DR6/MCMJ0/SI1094.WAV.wav
[425] Success: data/TEST/DR3/FKMS0/SX50.WAV
[426] Success: data/TEST/DR4/FLBW0/SI1849.WAV.wav
[427] Success: data/TEST/DR2/MRCZ0/SI1541.WAV.wav
[428] Success: data/TEST/DR6/FLNH0/SI1214.WAV
[429] Success: data/TEST/DR4/FGJD0/SI549.WAV
[430] Success: data/TEST/DR7/FDHC0/SI929.WAV.wav
[431] Success: data/TEST/DR8/FCMH1/SX233.WAV
[432] Success: data/TEST/DR6/MESD0/SX102.WAV.wav
[433] Success: data/TEST/DR4/FCRH0/SX278.WAV
[434] Success: data/TEST/DR8/MJTC0/SX20.WAV
[435] Success: data/TEST/DR6/FDRW0/SX293.WAV
[436] Success: data/TEST/DR7/FISB0/SX49.WAV
[437] Success: data/TEST/DR5/MCRC0/SX372.WAV
[438] Success: data/TEST/DR3/FCMH0/SI1454.WAV.wav
[439] Success: data/TEST/DR4/MKCL0/SX11.WAV.wav
[440] Success:

100%|██████████| 840/840 [00:00<00:00, 2126.07it/s][A

[649] Success: data/TEST/DR2/MWVW0/SX36.WAV
[650] Success: data/TEST/DR4/MPCS0/SI1359.WAV.wav
[651] Success: data/TEST/DR2/FJAS0/SA2.WAV.wav
[652] Success: data/TEST/DR8/FCMH1/SX143.WAV
[653] Success: data/TEST/DR4/FCRH0/SA1.WAV.wav
[654] Success: data/TEST/DR4/MPLB0/SX314.WAV.wav
[655] Success: data/TEST/DR8/FJSJ0/SI854.WAV
[656] Success: data/TEST/DR4/MTEB0/SI503.WAV
[657] Success: data/TEST/DR2/MPGL0/SX199.WAV
[658] Success: data/TEST/DR2/FCMR0/SX25.WAV.wav
[659] Success: data/TEST/DR7/MNJM0/SX320.WAV
[660] Success: data/TEST/DR1/MDAB0/SX319.WAV
[661] Success: data/TEST/DR5/FHES0/SA1.WAV.wav
[662] Success: data/TEST/DR4/MDRM0/SX203.WAV
[663] Success: data/TEST/DR4/MPCS0/SX99.WAV
[664] Success: data/TEST/DR8/FMLD0/SX205.WAV.wav
[665] Success: data/TEST/DR5/FJCS0/SI1833.WAV.wav
[666] Success: data/TEST/DR3/MKCH0/SA2.WAV.wav
[667] Success: data/TEST/DR3/MTDT0/SI1994.WAV
[668] Success: data/TEST/DR8/FCMH1/SI1493.WAV.wav
[669] Success: data/TEST/DR8/MDAW1/SX283.WAV
[670] Success: data/TE




In [190]:
val_data = extract_file_data(val_df)


  0%|          | 0/840 [00:00<?, ?it/s][A

[0] Success: data/TEST/DR4/FJMG0/SX11.WAV.wav
[1] Success: data/TEST/DR1/MREB0/SX385.WAV
[2] Success: data/TEST/DR5/MRJM3/SX98.WAV
[3] Success: data/TEST/DR2/MWEW0/SX281.WAV
[4] Success: data/TEST/DR4/MJRF0/SX281.WAV
[5] Success: data/TEST/DR2/MBJK0/SX275.WAV.wav
[6] Success: data/TEST/DR6/MRJR0/SI2313.WAV.wav
[7] Success: data/TEST/DR5/MDAC2/SX369.WAV
[8] Success: data/TEST/DR5/MDWK0/SA1.WAV.wav
[9] Success: data/TEST/DR4/MPLB0/SI764.WAV
[10] Success: data/TEST/DR6/MESD0/SA1.WAV.wav
[11] Success: data/TEST/DR3/MTHC0/SI1015.WAV
[12] Success: data/TEST/DR4/FLBW0/SI2253.WAV
[13] Success: data/TEST/DR5/FHES0/SX389.WAV.wav
[14] Success: data/TEST/DR7/MPAB0/SX318.WAV
[15] Success: data/TEST/DR1/MREB0/SX295.WAV
[16] Success: data/TEST/DR5/MRWS1/SI500.WAV
[17] Success: data/TEST/DR3/MJJG0/SX283.WAV
[18] Success: data/TEST/DR5/MDWK0/SA2.WAV.wav
[19] Success: data/TEST/DR8/MDAW1/SX373.WAV
[20] Success: data/TEST/DR2/MTMR0/SX133.WAV
[21] Success: data/TEST/DR3/MBWM0/SX134.WAV
[22] Success: data/


 24%|██▍       | 201/840 [00:00<00:00, 1995.03it/s][A

[140] Success: data/TEST/DR2/MGWT0/SX369.WAV
[141] Success: data/TEST/DR8/MRES0/SX227.WAV
[142] Success: data/TEST/DR7/MTWH0/SA1.WAV.wav
[143] Success: data/TEST/DR3/MBDG0/SA1.WAV.wav
[144] Success: data/TEST/DR2/MDBB0/SX25.WAV.wav
[145] Success: data/TEST/DR5/MCRC0/SA2.WAV.wav
[146] Success: data/TEST/DR2/FJWB0/SX365.WAV.wav
[147] Success: data/TEST/DR5/MCRC0/SI1092.WAV.wav
[148] Success: data/TEST/DR5/FGMD0/SA1.WAV.wav
[149] Success: data/TEST/DR5/MDAC2/SI2259.WAV
[150] Success: data/TEST/DR8/MSLB0/SX23.WAV
[151] Success: data/TEST/DR7/FTLH0/SA2.WAV.wav
[152] Success: data/TEST/DR6/MPAM1/SX36.WAV
[153] Success: data/TEST/DR3/MMJR0/SI2166.WAV.wav
[154] Success: data/TEST/DR6/MPAM1/SX396.WAV
[155] Success: data/TEST/DR5/FHEW0/SX133.WAV
[156] Success: data/TEST/DR7/FTLH0/SA1.WAV.wav
[157] Success: data/TEST/DR4/FADG0/SI649.WAV
[158] Success: data/TEST/DR2/MTMR0/SA2.WAV.wav
[159] Success: data/TEST/DR5/MSFH1/SX280.WAV
[160] Success: data/TEST/DR5/FJCS0/SX409.WAV
[161] Success: data/TEST/


 48%|████▊     | 403/840 [00:00<00:00, 2007.77it/s][A

[390] Success: data/TEST/DR3/MRTK0/SX13.WAV.wav
[391] Success: data/TEST/DR2/FDRD1/SX14.WAV.wav
[392] Success: data/TEST/DR8/MAJC0/SA2.WAV.wav
[393] Success: data/TEST/DR1/MWBT0/SA2.WAV.wav
[394] Success: data/TEST/DR7/FGWR0/SX318.WAV
[395] Success: data/TEST/DR2/MWEW0/SX11.WAV.wav
[396] Success: data/TEST/DR4/MKCL0/SX191.WAV
[397] Success: data/TEST/DR2/MMDB1/SX5.WAV
[398] Success: data/TEST/DR5/MLIH0/SX373.WAV
[399] Success: data/TEST/DR2/FJRE0/SX306.WAV.wav
[400] Success: data/TEST/DR4/MKCL0/SA2.WAV.wav
[401] Success: data/TEST/DR2/MTAS1/SI2098.WAV.wav
[402] Success: data/TEST/DR8/FCMH1/SX323.WAV
[403] Success: data/TEST/DR2/FDRD1/SA1.WAV.wav
[404] Success: data/TEST/DR3/MCTW0/SX383.WAV.wav
[405] Success: data/TEST/DR4/MBNS0/SA1.WAV.wav
[406] Success: data/TEST/DR2/MTAS1/SI838.WAV.wav
[407] Success: data/TEST/DR5/FMAH0/SI1289.WAV
[408] Success: data/TEST/DR2/FRAM1/SX10.WAV.wav
[409] Success: data/TEST/DR5/MCRC0/SX102.WAV.wav
[410] Success: data/TEST/DR4/MPLB0/SI1394.WAV.wav
[411] Su


 72%|███████▏  | 604/840 [00:00<00:00, 1993.48it/s][A

[574] Success: data/TEST/DR8/MAJC0/SX115.WAV.wav
[575] Success: data/TEST/DR8/MJTC0/SA1.WAV.wav
[576] Success: data/TEST/DR3/FKMS0/SI860.WAV.wav
[577] Success: data/TEST/DR5/FASW0/SA2.WAV.wav
[578] Success: data/TEST/DR3/MHPG0/SA1.WAV.wav
[579] Success: data/TEST/DR1/FELC0/SA2.WAV.wav
[580] Success: data/TEST/DR7/MRPC0/SX313.WAV.wav
[581] Success: data/TEST/DR4/FMCM0/SX370.WAV.wav
[582] Success: data/TEST/DR8/MRES0/SX47.WAV.wav
[583] Success: data/TEST/DR3/MMJR0/SI1648.WAV
[584] Success: data/TEST/DR5/MCMB0/SX188.WAV
[585] Success: data/TEST/DR8/MJTC0/SI1460.WAV
[586] Success: data/TEST/DR2/MTAS1/SX388.WAV.wav
[587] Success: data/TEST/DR6/MRJS0/SX274.WAV.wav
[588] Success: data/TEST/DR3/MBWM0/SI674.WAV
[589] Success: data/TEST/DR7/MNJM0/SA2.WAV.wav
[590] Success: data/TEST/DR4/FEDW0/SX184.WAV.wav
[591] Success: data/TEST/DR7/MRCS0/SI1223.WAV.wav
[592] Success: data/TEST/DR7/MRJM4/SX409.WAV
[593] Success: data/TEST/DR2/MTMR0/SA1.WAV.wav
[594] Success: data/TEST/DR4/FMCM0/SI1180.WAV
[595


100%|██████████| 840/840 [00:00<00:00, 2006.54it/s][A

[795] Success: data/TEST/DR2/MJAR0/SX368.WAV
[796] Success: data/TEST/DR5/MCMB0/SX368.WAV
[797] Success: data/TEST/DR4/FEDW0/SA2.WAV.wav
[798] Success: data/TEST/DR5/MDAC2/SX9.WAV.wav
[799] Success: data/TEST/DR2/FPAS0/SA2.WAV.wav
[800] Success: data/TEST/DR1/MWBT0/SX383.WAV.wav
[801] Success: data/TEST/DR3/MRTK0/SI1093.WAV.wav
[802] Success: data/TEST/DR5/MLIH0/SX283.WAV
[803] Success: data/TEST/DR3/MJES0/SX394.WAV
[804] Success: data/TEST/DR8/FMLD0/SX295.WAV
[805] Success: data/TEST/DR5/MDRB0/SI2109.WAV
[806] Success: data/TEST/DR2/MMDB1/SA1.WAV.wav
[807] Success: data/TEST/DR4/FCFT0/SX188.WAV
[808] Success: data/TEST/DR4/FNMR0/SI2029.WAV.wav
[809] Success: data/TEST/DR7/MDVC0/SX216.WAV.wav
[810] Success: data/TEST/DR3/MMDH0/SI2286.WAV
[811] Success: data/TEST/DR7/MCHH0/SX14.WAV.wav
[812] Success: data/TEST/DR4/MRKO0/SX47.WAV.wav
[813] Success: data/TEST/DR5/MCMB0/SX98.WAV
[814] Success: data/TEST/DR7/FCAU0/SX47.WAV.wav
[815] Success: data/TEST/DR3/FCMH0/SI2084.WAV
[816] Success: dat




In [191]:
train_data[5]

[tensor([[ 1.2207e-04,  6.1035e-05,  6.1035e-05,  ..., -6.1035e-05,
           3.0518e-05, -6.1035e-05]]),
 'h# p l eh sh h# t ah h# p er h# t ih s h# p ey dx ih n ah v aa dx ih z ih h# k w aa dx ih h# k ah m h# p h# t ih sh ih n h#',
 "pledge to participate in nevada's aquatic competition"]

In [192]:
def build_phoneme_vocab(dataset):
    phonemes = set()
    for sample in dataset:
        if isinstance(sample[1], str):
            phonemes.update(sample[1].split())
        else:
            phonemes.update(sample[1])  # already a list
    phoneme_list = sorted(phonemes)
    phone_to_index = {ph: i for i, ph in enumerate(phoneme_list)}
    index_to_phone = {i: ph for ph, i in phone_to_index.items()}
    return phone_to_index, index_to_phone


def map_phonemes(dataset, phone_to_index):
    mapped = []
    for waveform, phonemes, words in dataset:
        if isinstance(phonemes, str):
            phoneme_tokens = phonemes.split()
        else:
            phoneme_tokens = phonemes
        phoneme_indices = [phone_to_index[p] for p in phoneme_tokens]
        mapped.append([waveform, phoneme_indices, words])
    return mapped


phone_to_index, index_to_phone = build_phoneme_vocab(train_data)
train_data = map_phonemes(train_data, phone_to_index)
val_data = map_phonemes(val_data, phone_to_index)
test_data = map_phonemes(test_data, phone_to_index)


In [193]:
from torch.utils.data import Dataset
import torch

class PhoneticDataset(Dataset):
    def __init__(self, data: list, num_phoneme_classes: int, name: str = ""):
        self.data = data
        self.num_phoneme_classes = num_phoneme_classes
        self.name = name

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, list[int], str]:
        waveform, phoneme_indices, word_text = self.data[idx]
        return waveform, phoneme_indices, word_text


In [194]:
train_ds = PhoneticDataset(train_data, len(phonetics))
test_ds = PhoneticDataset(test_data, len(phonetics))
val_ds = PhoneticDataset(val_data, len(phonetics))

In [195]:
from torch.nn.utils.rnn import pad_sequence


def collate_batch(batch):
    audios, phonemes, words = zip(*batch)
    max_len = max(w.shape[1] for w in audios)

    audio_padded = [torch.nn.functional.pad(w, (0, max_len - w.shape[1])) for w in audios]
    mask = [torch.nn.functional.pad(torch.ones_like(w, dtype=torch.long), (0, max_len - w.shape[1])) for w in audios]

    audio_tensor = torch.stack(audio_padded)
    mask_tensor = torch.stack(mask)
    label_lengths = torch.tensor([len(seq) for seq in phonemes])
    phoneme_tensor = pad_sequence([torch.tensor(seq) for seq in phonemes], batch_first=True, padding_value=-100)

    return audio_tensor, mask_tensor, phoneme_tensor, label_lengths, words


In [196]:
BATCH_SIZE = 4

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_batch,
)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    collate_fn=collate_batch,
)

test_loader = DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
    collate_fn=collate_batch,
)

In [197]:
model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [198]:
from torchaudio.transforms import Resample
import torchaudio
import torch

def load_audio_with_torchaudio(paths, target_sr=16000):
    audio_list = []
    masks = []
    max_len = 0

    for path in paths:
        waveform, sr = torchaudio.load(path)
        if sr != target_sr:
            waveform = Resample(orig_freq=sr, new_freq=target_sr)(waveform)

        waveform = waveform.squeeze(0)
        audio_list.append(waveform)
        max_len = max(max_len, waveform.size(0))

    padded_waveforms = []
    padded_masks = []

    for waveform in audio_list:
        length = waveform.size(0)
        pad_len = max_len - length
        padded_waveforms.append(torch.nn.functional.pad(waveform, (0, pad_len)))
        padded_masks.append(torch.cat([torch.ones(length), torch.zeros(pad_len)]))

    return torch.stack(padded_waveforms), torch.stack(padded_masks)


In [199]:

example_path = train_df.iloc[0]['audio_file']
speaker_audio, speaker_pad_mask = load_audio_with_torchaudio([example_path])
ipd.Audio(speaker_audio, rate = 16000)


In [200]:
input_values = feature_extractor(speaker_audio.numpy().tolist(), return_tensors="pt").input_values

It is strongly recommended to pass the `sampling_rate` argument to `Wav2Vec2FeatureExtractor()`. Failing to do so can result in silent errors that might be hard to debug.


In [201]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

input_values = input_values.to(device)
model = model.to(device)

with torch.no_grad():
    model_output = model(input_values)

In [202]:
pred_ids = torch.argmax(model_output.logits[0], axis=-1)
outputs = tokenizer.decode(pred_ids, output_word_offsets=True)
outputs.text

## EXAMPLE

'WOULD SUCH AN ACT OF REFUSAL BE USEFUL'

In [203]:

import torch.nn as nn
from transformers import Wav2Vec2Model

class PhoneticRecognition(nn.Module):
    def __init__(self, output_dim, model_name='facebook/wav2vec2-base-960h', use_rnn=False, rnn_units=512, layers=2):
        super().__init__()

        self.extractor = Wav2Vec2Model.from_pretrained(model_name, output_hidden_states=True)
        for param in self.extractor.parameters():
            param.requires_grad = False

        feature_size = self.extractor.config.hidden_size
        self.use_rnn = use_rnn

        if use_rnn:
            self.rnn = nn.LSTM(
                input_size=feature_size,
                hidden_size=rnn_units,
                num_layers=layers,
                batch_first=False,
                dropout=0.5
            )
            self.output_layer = nn.Linear(rnn_units, output_dim)
        else:
            self.output_layer = nn.Sequential(
                nn.Linear(feature_size, feature_size // 2),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Linear(feature_size // 2, output_dim)
            )

    def forward(self, inputs, mask=None):
        inputs = inputs.squeeze()
        mask = mask.squeeze() if mask is not None else None

        result = self.extractor(inputs, attention_mask=mask)
        x = result.hidden_states[9]

        if self.use_rnn:
            x, _ = self.rnn(x)

        logits = self.output_layer(x)
        probs = torch.log_softmax(logits, dim=-1)
        return probs.permute(1, 0, 2)



In [204]:
import torch
import pytorch_lightning as pl


class LightningPhonemeTrainer(pl.LightningModule):
    def __init__(self, backbone, lr=1e-3):
        super().__init__()
        self.backbone = backbone
        self.lr = lr
        self.loss_fn = torch.nn.CTCLoss(blank=0)

    def forward(self, x, mask):
        return self.backbone(x, mask)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def shared_step(self, batch):
        x, mask, targets, target_lens, _ = batch
        preds = self(x, mask)
        bs = x.size(0)
        pred_lens = torch.full((bs,), preds.size(0), dtype=torch.long)
        return self.loss_fn(preds, targets, pred_lens, target_lens)

    def training_step(self, batch, idx):
        loss = self.shared_step(batch)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, idx):
        loss = self.shared_step(batch)
        self.log("val_loss", loss, on_epoch=True, prog_bar=True)
        return loss

    def test_step(self, batch, idx):
        loss = self.shared_step(batch)
        self.log("test_loss", loss)
        return loss


In [206]:
pr_model_mlp = PhoneticRecognition(len(phonetics)).to(DEVICE)

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [207]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

lightning_model_mlp = LightningPhonemeTrainer(pr_model_mlp)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='checkpoints/',
    filename='phonetic-recognition-{epoch:02d}-{val_loss:.4f}',
    save_top_k=1,
    mode='min'
)

logger = TensorBoardLogger("lightning_logs", name="phonetic_recognition")

trainer = pl.Trainer(
    max_epochs=10,
    callbacks=[checkpoint_callback],
    logger=logger,
    accelerator='cuda' if torch.cuda.is_available() else 'cpu',
    devices=1
)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [209]:
trainer.fit(lightning_model_mlp, train_dataloader, val_dataloader)

C:\Users\Samurai\miniconda3\Lib\site-packages\pytorch_lightning\callbacks\model_checkpoint.py:654: Checkpoint directory C:\Users\Samurai\PycharmProjects\Lab2\checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type                | Params | Mode 
---------------------------------------------------------
0 | backbone | PhoneticRecognition | 94.7 M | train
1 | loss_fn  | CTCLoss             | 0      | train
---------------------------------------------------------
310 K     Trainable params
94.4 M    Non-trainable params
94.7 M    Total params
378.728   Total estimated model params size (MB)
7         Modules in train mode
220       Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

C:\Users\Samurai\miniconda3\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
C:\Users\Samurai\miniconda3\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


In [211]:
from typing import List

def edit_distance(a: List[int], b: List[int]) -> int:
    m, n = len(a), len(b)
    dp = [[i + j if i * j == 0 else 0 for j in range(n + 1)] for i in range(m + 1)]  # init borders
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            cost = 0 if a[i - 1] == b[j - 1] else 1
            dp[i][j] = min(dp[i - 1][j] + 1,         # delete
                           dp[i][j - 1] + 1,         # insert
                           dp[i - 1][j - 1] + cost)  # substitute
    return dp[m][n]

def calculate_per(hyp: List[int], ref: List[int]) -> float:
    return 0.0 if not ref else edit_distance(hyp, ref) / len(ref)

def greedy_decode(log_probs: torch.Tensor, blank: int = 0) -> List[List[int]]:
    decoded = []
    for row in log_probs.transpose(0, 1):
        seq = row.argmax(-1)
        tokens = []
        prev = None
        for t in seq:
            if t != blank and t != prev:
                tokens.append(t.item())
            prev = t
        decoded.append(tokens)
    return decoded


def evaluate_model(model, dataloader) -> float:
    model.eval()
    total_per, total_samples = 0.0, 0

    with torch.no_grad():
        for x, m, y, y_lens, _ in dataloader:
            x, m = x.to(DEVICE), m.to(DEVICE)

            log_probs = model(x, m)
            hyps = greedy_decode(log_probs)

            refs = [seq[:l] for seq, l in zip(y.tolist(), y_lens.tolist())]  # cut padding

            total_per += sum(calculate_per(h, r) for h, r in zip(hyps, refs))
            total_samples += len(refs)

    avg = total_per / total_samples if total_samples > 0 else 0.0
    print(f"PER on test dataset: {avg*100:.2f}%")
    return avg


In [212]:
per = evaluate_model(pr_model_mlp.to('cuda'), test_dataloader)

PER on test dataset: 14.70%


In [220]:
phonetic_recognition = PhoneticRecognition(len(phonetics)).to(DEVICE)
lightning_model = LightningPhonemeTrainer(phonetic_recognition)

trainer = pl.Trainer(
    max_epochs=10,
    callbacks=[checkpoint_callback],
    logger=logger,
    accelerator='cuda' if torch.cuda.is_available() else 'cpu',
    devices=1
)

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [221]:
trainer.fit(lightning_model, train_dataloader, val_dataloader)

C:\Users\Samurai\miniconda3\Lib\site-packages\pytorch_lightning\callbacks\model_checkpoint.py:654: Checkpoint directory C:\Users\Samurai\PycharmProjects\Lab2\checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type                | Params | Mode 
---------------------------------------------------------
0 | backbone | PhoneticRecognition | 94.7 M | train
1 | loss_fn  | CTCLoss             | 0      | train
---------------------------------------------------------
310 K     Trainable params
94.4 M    Non-trainable params
94.7 M    Total params
378.728   Total estimated model params size (MB)
7         Modules in train mode
220       Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

C:\Users\Samurai\miniconda3\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
C:\Users\Samurai\miniconda3\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


In [222]:
per = evaluate_model(phonetic_recognition.to(DEVICE), test_dataloader)

PER on test dataset: 14.36%


We tested two versions of models: one using a simple MLP head and another with an LSTM. Both gave very similar results on the test set, so the extra complexity of the LSTM didn’t really help in this case.

They both provide more or less the same PER 14% and I did only 10 epochs cause after they start to overfit a bit