In [115]:
import torch
import os
import torchaudio
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor

In [116]:
audio_path = '/home/cslab03/Documents/ASRColab/ASRColab/wav_separate_sounds_mono/'

In [117]:
train_splits = ['ov1s1_wav', 'ov1s2_wav']
val_split = 'ov1s3_wav'

X_train = []
X_val = []
X_test = []

for split in train_splits:
  for file_ in os.listdir(audio_path+split):
      if 'tra' in file_:
          X_train.append(file_)
      elif 'tst' in file_:
          X_test.append(file_)

for file_ in os.listdir(audio_path+val_split):
    if 'tra' in file_:
        X_val.append(file_)
    elif 'tst' in file_:
        X_test.append(file_)

In [118]:
print(len(X_train), len(X_val), len(X_test))

2987 1654 1033


In [119]:
X_train

['tra_075_5_ov1_s1_00.wav',
 'tra_238_8_ov1_s1_04.wav',
 'tra_177_12_ov1_s1_09.wav',
 'tra_085_8_ov1_s1_09.wav',
 'tra_228_5_ov1_s1_04.wav',
 'tra_049_12_ov1_s1_00.wav',
 'tra_196_5_ov1_s1_03.wav',
 'tra_195_0_ov1_s1_09.wav',
 'tra_215_8_ov1_s1_02.wav',
 'tra_139_3_ov1_s1_01.wav',
 'tra_061_7_ov1_s1_04.wav',
 'tra_090_10_ov1_s1_03.wav',
 'tra_154_1_ov1_s1_02.wav',
 'tra_029_14_ov1_s1_09.wav',
 'tra_095_10_ov1_s1_01.wav',
 'tra_099_13_ov1_s1_03.wav',
 'tra_051_9_ov1_s1_09.wav',
 'tra_108_8_ov1_s1_02.wav',
 'tra_156_13_ov1_s1_04.wav',
 'tra_104_0_ov1_s1_07.wav',
 'tra_196_10_ov1_s1_04.wav',
 'tra_076_2_ov1_s1_08.wav',
 'tra_101_4_ov1_s1_07.wav',
 'tra_006_5_ov1_s1_04.wav',
 'tra_157_3_ov1_s1_08.wav',
 'tra_140_15_ov1_s1_08.wav',
 'tra_154_2_ov1_s1_09.wav',
 'tra_159_5_ov1_s1_00.wav',
 'tra_156_8_ov1_s1_08.wav',
 'tra_146_7_ov1_s1_06.wav',
 'tra_132_4_ov1_s1_01.wav',
 'tra_071_6_ov1_s1_00.wav',
 'tra_010_3_ov1_s1_05.wav',
 'tra_137_10_ov1_s1_04.wav',
 'tra_213_0_ov1_s1_07.wav',
 'tra_198_

In [120]:
audio_paths1 = '/home/cslab03/Documents/ASRColab/ASRColab/wav_separate_sounds_mono/ov1s1_wav/'
audio_paths2 = '/home/cslab03/Documents/ASRColab/ASRColab/wav_separate_sounds_mono/ov1s2_wav/'
audio_paths3 = '/home/cslab03/Documents/ASRColab/ASRColab/wav_separate_sounds_mono/ov1s3_wav/'

In [121]:

def load_data(X_train):
    x_train = []
    for index in range(len(X_train)):
        if "ov1_s1" in X_train[index]:
            feats = audio_paths1 + X_train[index]
        elif "ov1_s2" in X_train[index]:
            feats = audio_paths2 + X_train[index]
        elif "ov1_s3" in X_train[index]:
            feats = audio_paths3 + X_train[index]

        target = torch.tensor(int(X_train[index][-6:-4])).item()
        
        x_train.append({'paths':feats, 'labels':target})
    
    return x_train

In [122]:
x_train = load_data(X_train)

x_val = load_data(X_val)

x_test = load_data(X_test)

In [123]:
import pandas as pd

df_train = pd.DataFrame(x_train)
df_train.head()

Unnamed: 0,paths,labels
0,/home/cslab03/Documents/ASRColab/ASRColab/wav_...,0
1,/home/cslab03/Documents/ASRColab/ASRColab/wav_...,4
2,/home/cslab03/Documents/ASRColab/ASRColab/wav_...,9
3,/home/cslab03/Documents/ASRColab/ASRColab/wav_...,9
4,/home/cslab03/Documents/ASRColab/ASRColab/wav_...,4


In [124]:
print("LABELS: ", df_train['labels'].unique())

label_list = df_train['labels'].unique().tolist()
label_list.sort()
num_labels = len(label_list)

print(len(df_train['labels']))

df_train.groupby('labels').count()


LABELS:  [ 0  4  9  3  2  1  7  8  6  5 10]
2987


Unnamed: 0_level_0,paths
labels,Unnamed: 1_level_1
0,250
1,250
2,280
3,245
4,418
5,207
6,266
7,262
8,257
9,321


In [125]:
df_val = pd.DataFrame(x_val)
df_val.head()

Unnamed: 0,paths,labels
0,/home/cslab03/Documents/ASRColab/ASRColab/wav_...,2
1,/home/cslab03/Documents/ASRColab/ASRColab/wav_...,9
2,/home/cslab03/Documents/ASRColab/ASRColab/wav_...,4
3,/home/cslab03/Documents/ASRColab/ASRColab/wav_...,0
4,/home/cslab03/Documents/ASRColab/ASRColab/wav_...,10


In [126]:
from transformers import AutoConfig, Wav2Vec2Processor

In [127]:
model = "facebook/wav2vec2-base-960h"

In [128]:
config = AutoConfig.from_pretrained(model, num_labels=num_labels, label2id = {label: i for i, label in enumerate(label_list)}, id2label = {i: label for i, label in enumerate(label_list)}, finetuning_task="wav2vec2_clf", cache_dir=None)

In [129]:
processor = Wav2Vec2Processor.from_pretrained(model,)

In [130]:
targer_sampling = processor.feature_extractor.sampling_rate
targer_sampling

16000

In [131]:
def audio_file_to_array_fn(path):
    audio_array, sampling_rate = torchaudio.load(path)
    resampler = torchaudio.transforms.Resample(sampling_rate, targer_sampling)
    audio = resampler(audio_array).squeeze().numpy()
    return audio

def label_to_id(label):
    if len(label_list) > 0:
        return label_list.index(label) if label in label_list else -1
    return label
    
def preprocess_data(datalist):
    input_list = [audio_file_to_array_fn(input) for input in datalist['paths']]
    label_list = [label_to_id(label) for label in datalist['labels']]
    
    result = processor(input_list, sampling_rate=targer_sampling, return_tensors="pt", padding=True)
    result['labels'] = torch.tensor(label_list)
    return result

In [132]:
train_dataset = preprocess_data(df_train)
train_dataset

{'input_values': tensor([[-2.0865e-02, -7.0939e-02, -3.3383e-02,  ...,  1.6690e-03,
          1.6690e-03,  1.6690e-03],
        [ 1.5131e-01,  2.3052e-02, -1.2658e-01,  ...,  1.6760e-03,
          1.6760e-03,  1.6760e-03],
        [-6.2875e-03, -1.8330e-02, -5.4458e-02,  ..., -2.2732e-03,
         -2.2732e-03, -2.2732e-03],
        ...,
        [-6.6569e-02, -3.6330e-02, -6.6569e-02,  ...,  9.0271e-03,
          9.0271e-03,  9.0271e-03],
        [ 3.6107e-02, -1.5254e-01, -1.7948e-01,  ...,  1.7473e-04,
          1.7473e-04,  1.7473e-04],
        [-6.7423e-03, -9.2959e-03,  9.1852e-04,  ...,  9.1852e-04,
          9.1852e-04,  9.1852e-04]]), 'labels': tensor([ 0,  4,  9,  ...,  6,  7, 10])}

In [133]:
eval_dataset = preprocess_data(df_val)
eval_dataset

{'input_values': tensor([[ 0.0289, -0.0138,  0.0253,  ...,  0.0004,  0.0004,  0.0004],
        [ 0.0186,  0.0044,  0.0398,  ...,  0.0009,  0.0009,  0.0009],
        [ 0.0083,  0.0323,  0.0643,  ...,  0.0002,  0.0002,  0.0002],
        ...,
        [-0.0329,  0.0036, -0.0132,  ...,  0.0008,  0.0008,  0.0008],
        [-0.0700, -0.0460,  0.2290,  ...,  0.0018,  0.0018,  0.0018],
        [-0.0274, -0.0325, -0.0687,  ...,  0.0036,  0.0036,  0.0036]]), 'labels': tensor([ 2,  9,  4,  ...,  2,  4, 10])}