# New keywords

### Create dataset from current audio folder

In [1]:
import os
import tqdm
from pathlib import Path
import pandas as pd
import torchaudio
import librosa
import IPython.display as ipd
import numpy as np
from sklearn.model_selection import train_test_split
from datasets import load_dataset, load_metric

PATH_TO_AUDIO = "google_speech_recognition_v2"

  from .autonotebook import tqdm as notebook_tqdm


source: https://colab.research.google.com/github/m3hrdadfi/soxan/blob/main/notebooks/Emotion_recognition_in_Greek_speech_using_Wav2Vec2.ipynb#scrollTo=-gh7fQ1XEpC7

In [2]:
data = []

for subdir, dirs, files in os.walk(PATH_TO_AUDIO):
    for file in files:
        if file.endswith(".wav"):
            name = file.split(".")[0]
            label = subdir.split("/")[-1] # label = subdir.split("\\")[-1] on windows
            path = os.path.join(subdir, file)
            data.append({
                "name": name,
                "path": path,
                "keyword": label,
            })

In [3]:
df = pd.DataFrame(data)
print("Labels: ", df["keyword"].unique())
print()
#df.groupby("keyword").count()[["path"]]

Labels:  ['go' 'six' 'up' 'happy' 'sheila' 'follow' 'wow' 'four' 'learn' 'forward'
 'house' 'zero' 'on' 'left' 'no' 'backward' 'right' 'bird' 'eight'
 'visual' 'marvin' 'bed' 'stop' 'nine' 'seven' 'five' 'yes' 'one'
 '_background_noise_' 'off' 'dog' 'two' 'three' 'tree' 'down' 'cat']



In [4]:
# we only choose the keywords we want
desired_keywords = ["follow", "go", "happy", "marvin", "stop", "down"]
df = df[df["keyword"].isin(desired_keywords)]
print("Labels: ", df["keyword"].unique())
print()
df.groupby("keyword").count()[["path"]]

Labels:  ['go' 'happy' 'follow' 'marvin' 'stop' 'down']



Unnamed: 0_level_0,path
keyword,Unnamed: 1_level_1
down,3917
follow,1579
go,3880
happy,2054
marvin,2100
stop,3872


In [5]:
idx = np.random.randint(0, len(df))
sample = df.iloc[idx]
path = sample["path"]
label = sample["keyword"]


print(f"ID Location: {idx}")
print(f"      Label: {label}")
print()

speech, sr = torchaudio.load(path)
speech = speech[0].numpy().squeeze()
# speech = librosa.resample(np.asarray(speech), sr, 16_000) # audio is already at 16kHz
ipd.Audio(data=np.asarray(speech), autoplay=True, rate=16000)

ID Location: 11171
      Label: stop



In [6]:
# create csv files to be used to load data
save_path = "gsr_v2_cleaned"

train_df, test_df = train_test_split(df, test_size=0.2, random_state=101, stratify=df["keyword"])

train_df = train_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

train_df.to_csv(f"{save_path}/train.csv", sep="\t", encoding="utf-8", index=False)
test_df.to_csv(f"{save_path}/test.csv", sep="\t", encoding="utf-8", index=False)


print(train_df.shape)
print(test_df.shape)

(13921, 3)
(3481, 3)


In [7]:
# Loading the created dataset using datasets

data_files = {
    "train": save_path+"/train.csv", 
    "validation": save_path+"/test.csv",
}

dataset = load_dataset("csv", data_files=data_files, delimiter="\t", )
train_dataset = dataset["train"]
eval_dataset = dataset["validation"]

print(train_dataset)
print(eval_dataset)

Downloading and preparing dataset csv/default to /home/guillaume/.cache/huggingface/datasets/csv/default-1847f1d03778b3e7/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1...


Downloading data files: 100%|██████████| 2/2 [00:00<00:00, 15307.68it/s]
Extracting data files: 100%|██████████| 2/2 [00:00<00:00, 2432.19it/s]
                                                             

Dataset csv downloaded and prepared to /home/guillaume/.cache/huggingface/datasets/csv/default-1847f1d03778b3e7/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1. Subsequent calls will reuse this data.


100%|██████████| 2/2 [00:00<00:00, 1244.05it/s]

Dataset({
    features: ['name', 'path', 'keyword'],
    num_rows: 13921
})
Dataset({
    features: ['name', 'path', 'keyword'],
    num_rows: 3481
})





In [8]:
labels = df["keyword"].unique()
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

id2label["0"]

'go'

In [9]:
from transformers import AutoFeatureExtractor


model_checkpoint = "facebook/wav2vec2-base"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)



In [10]:
input_column = "path"
output_column = "keyword"
max_duration = 1.0  # seconds
target_sampling_rate = feature_extractor.sampling_rate

def speech_file_to_array(path):
    speech_array, sampling_rate = torchaudio.load(path)
    resampler = torchaudio.transforms.Resample(sampling_rate, target_sampling_rate)
    speech = resampler(speech_array).squeeze().numpy()
    return speech

def label_to_id(label, label_list):

    label_list = list(label_list)
    if len(label_list) > 0:
        return label_list.index(label) if label in label_list else -1
    return label

def preprocess_function(examples):
    audio_arrays = [speech_file_to_array(path) for path in examples[input_column]]
    target_list = [label_to_id(label, labels) for label in examples[output_column]]

    result = feature_extractor(
        audio_arrays, 
        sampling_rate=feature_extractor.sampling_rate, 
        max_length=int(feature_extractor.sampling_rate * max_duration), 
        truncation=True,
        padding=True 
    )

    # print(result)
    result["label"] = list(target_list)

    return result

In [11]:
encoded_dataset = dataset.map(preprocess_function, remove_columns=["path", "keyword", "name"], batched=True)
encoded_dataset

                                                                   

DatasetDict({
    train: Dataset({
        features: ['input_values', 'label'],
        num_rows: 13921
    })
    validation: Dataset({
        features: ['input_values', 'label'],
        num_rows: 3481
    })
})

In [12]:
from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer

num_labels = len(id2label)
model = AutoModelForAudioClassification.from_pretrained(
    model_checkpoint, 
    num_labels=num_labels,
    label2id=label2id,
    id2label=id2label,
)


Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2ForSequenceClassification: ['project_hid.bias', 'quantizer.weight_proj.bias', 'quantizer.codevectors', 'project_q.bias', 'project_q.weight', 'quantizer.weight_proj.weight', 'project_hid.weight']
- This IS expected if you are initializing Wav2Vec2ForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.bias', 'projector.weight', 'projector.

In [13]:
batch_size = 16

In [14]:
model_name = model_checkpoint.split("/")[-1]

args = TrainingArguments(
    f"{model_name}-finetuned-ks",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)

  return torch._C._cuda_getDeviceCount() > 0


In [15]:
metric = load_metric("accuracy")

  metric = load_metric("accuracy")


In [16]:
import numpy as np

def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

In [17]:
trainer = Trainer(
    model,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics
)

In [18]:
trainer.train()

  1%|          | 10/1085 [00:22<37:17,  2.08s/it] 

{'loss': 1.7921, 'learning_rate': 2.7522935779816517e-06, 'epoch': 0.05}


  2%|▏         | 20/1085 [00:42<36:31,  2.06s/it]

{'loss': 1.7822, 'learning_rate': 5.504587155963303e-06, 'epoch': 0.09}


  3%|▎         | 30/1085 [01:03<36:02,  2.05s/it]

{'loss': 1.7637, 'learning_rate': 8.256880733944954e-06, 'epoch': 0.14}


  4%|▎         | 40/1085 [01:23<35:41,  2.05s/it]

{'loss': 1.7361, 'learning_rate': 1.1009174311926607e-05, 'epoch': 0.18}


  5%|▍         | 50/1085 [01:44<35:16,  2.05s/it]

{'loss': 1.6832, 'learning_rate': 1.3761467889908258e-05, 'epoch': 0.23}


  6%|▌         | 60/1085 [02:04<35:00,  2.05s/it]

{'loss': 1.6506, 'learning_rate': 1.651376146788991e-05, 'epoch': 0.28}


  6%|▋         | 70/1085 [02:25<34:44,  2.05s/it]

{'loss': 1.4694, 'learning_rate': 1.9266055045871563e-05, 'epoch': 0.32}


  7%|▋         | 80/1085 [02:45<34:23,  2.05s/it]

{'loss': 1.2488, 'learning_rate': 2.2018348623853213e-05, 'epoch': 0.37}


  8%|▊         | 90/1085 [03:06<34:08,  2.06s/it]

{'loss': 1.0949, 'learning_rate': 2.4770642201834864e-05, 'epoch': 0.41}


  9%|▉         | 100/1085 [03:27<33:46,  2.06s/it]

{'loss': 0.9155, 'learning_rate': 2.7522935779816515e-05, 'epoch': 0.46}


 10%|█         | 110/1085 [03:47<33:26,  2.06s/it]

{'loss': 0.8014, 'learning_rate': 2.996926229508197e-05, 'epoch': 0.51}


 11%|█         | 120/1085 [04:08<33:03,  2.06s/it]

{'loss': 0.6882, 'learning_rate': 2.966188524590164e-05, 'epoch': 0.55}


 12%|█▏        | 130/1085 [04:28<32:35,  2.05s/it]

{'loss': 0.6526, 'learning_rate': 2.9354508196721315e-05, 'epoch': 0.6}


 13%|█▎        | 140/1085 [04:49<32:23,  2.06s/it]

{'loss': 0.6363, 'learning_rate': 2.9047131147540983e-05, 'epoch': 0.64}


 14%|█▍        | 150/1085 [05:10<32:10,  2.06s/it]

{'loss': 0.5695, 'learning_rate': 2.8739754098360657e-05, 'epoch': 0.69}


 15%|█▍        | 160/1085 [05:30<31:41,  2.06s/it]

{'loss': 0.5369, 'learning_rate': 2.8432377049180328e-05, 'epoch': 0.73}


 16%|█▌        | 170/1085 [05:51<31:16,  2.05s/it]

{'loss': 0.4756, 'learning_rate': 2.8125e-05, 'epoch': 0.78}


 17%|█▋        | 180/1085 [06:11<30:58,  2.05s/it]

{'loss': 0.4743, 'learning_rate': 2.7817622950819674e-05, 'epoch': 0.83}


 18%|█▊        | 190/1085 [06:32<30:47,  2.06s/it]

{'loss': 0.4438, 'learning_rate': 2.7510245901639345e-05, 'epoch': 0.87}


 18%|█▊        | 200/1085 [06:52<30:28,  2.07s/it]

{'loss': 0.4453, 'learning_rate': 2.720286885245902e-05, 'epoch': 0.92}


 19%|█▉        | 210/1085 [07:13<30:08,  2.07s/it]

{'loss': 0.3988, 'learning_rate': 2.6895491803278687e-05, 'epoch': 0.96}


                                                  
 20%|██        | 217/1085 [08:03<29:52,  2.07s/it]

{'eval_loss': 0.2934631407260895, 'eval_accuracy': 0.9721344441252514, 'eval_runtime': 34.5309, 'eval_samples_per_second': 100.808, 'eval_steps_per_second': 6.313, 'epoch': 1.0}


 20%|██        | 220/1085 [08:09<1:44:07,  7.22s/it]

{'loss': 0.3972, 'learning_rate': 2.658811475409836e-05, 'epoch': 1.01}


 21%|██        | 230/1085 [08:29<31:22,  2.20s/it]  

{'loss': 0.3834, 'learning_rate': 2.6280737704918032e-05, 'epoch': 1.06}


 22%|██▏       | 240/1085 [08:50<29:03,  2.06s/it]

{'loss': 0.338, 'learning_rate': 2.5973360655737707e-05, 'epoch': 1.1}


 23%|██▎       | 250/1085 [09:11<28:41,  2.06s/it]

{'loss': 0.3482, 'learning_rate': 2.5665983606557378e-05, 'epoch': 1.15}


 24%|██▍       | 260/1085 [09:31<28:23,  2.06s/it]

{'loss': 0.3042, 'learning_rate': 2.5358606557377052e-05, 'epoch': 1.19}


 25%|██▍       | 270/1085 [09:52<27:59,  2.06s/it]

{'loss': 0.3171, 'learning_rate': 2.5051229508196723e-05, 'epoch': 1.24}


 26%|██▌       | 280/1085 [10:12<27:34,  2.06s/it]

{'loss': 0.2699, 'learning_rate': 2.4743852459016394e-05, 'epoch': 1.29}


 27%|██▋       | 290/1085 [10:33<27:15,  2.06s/it]

{'loss': 0.279, 'learning_rate': 2.4436475409836065e-05, 'epoch': 1.33}


 28%|██▊       | 300/1085 [10:53<26:41,  2.04s/it]

{'loss': 0.2679, 'learning_rate': 2.4129098360655736e-05, 'epoch': 1.38}


 29%|██▊       | 310/1085 [11:14<26:15,  2.03s/it]

{'loss': 0.2259, 'learning_rate': 2.382172131147541e-05, 'epoch': 1.42}


 29%|██▉       | 320/1085 [11:34<25:56,  2.03s/it]

{'loss': 0.2558, 'learning_rate': 2.3514344262295082e-05, 'epoch': 1.47}


 30%|███       | 330/1085 [11:55<25:38,  2.04s/it]

{'loss': 0.2437, 'learning_rate': 2.3206967213114756e-05, 'epoch': 1.52}


 31%|███▏      | 340/1085 [12:15<25:16,  2.04s/it]

{'loss': 0.2555, 'learning_rate': 2.2899590163934424e-05, 'epoch': 1.56}


 32%|███▏      | 350/1085 [12:35<24:52,  2.03s/it]

{'loss': 0.2614, 'learning_rate': 2.25922131147541e-05, 'epoch': 1.61}


 33%|███▎      | 360/1085 [12:55<24:33,  2.03s/it]

{'loss': 0.2835, 'learning_rate': 2.228483606557377e-05, 'epoch': 1.65}


 34%|███▍      | 370/1085 [13:16<24:13,  2.03s/it]

{'loss': 0.2534, 'learning_rate': 2.1977459016393444e-05, 'epoch': 1.7}


 35%|███▌      | 380/1085 [13:36<23:52,  2.03s/it]

{'loss': 0.2376, 'learning_rate': 2.1670081967213115e-05, 'epoch': 1.75}


 36%|███▌      | 390/1085 [13:56<23:29,  2.03s/it]

{'loss': 0.2496, 'learning_rate': 2.136270491803279e-05, 'epoch': 1.79}


 37%|███▋      | 400/1085 [14:17<23:12,  2.03s/it]

{'loss': 0.2269, 'learning_rate': 2.105532786885246e-05, 'epoch': 1.84}


 38%|███▊      | 410/1085 [14:37<23:04,  2.05s/it]

{'loss': 0.2726, 'learning_rate': 2.074795081967213e-05, 'epoch': 1.88}


 39%|███▊      | 420/1085 [14:58<22:46,  2.06s/it]

{'loss': 0.2309, 'learning_rate': 2.0440573770491803e-05, 'epoch': 1.93}


 40%|███▉      | 430/1085 [15:18<22:28,  2.06s/it]

{'loss': 0.2065, 'learning_rate': 2.0133196721311477e-05, 'epoch': 1.97}


                                                  
 40%|████      | 435/1085 [16:04<22:18,  2.06s/it]

{'eval_loss': 0.12088477611541748, 'eval_accuracy': 0.9810399310542948, 'eval_runtime': 34.5447, 'eval_samples_per_second': 100.768, 'eval_steps_per_second': 6.311, 'epoch': 2.0}


 41%|████      | 440/1085 [16:14<49:15,  4.58s/it]  

{'loss': 0.1868, 'learning_rate': 1.9825819672131148e-05, 'epoch': 2.02}


 41%|████▏     | 450/1085 [16:35<22:28,  2.12s/it]

{'loss': 0.1559, 'learning_rate': 1.951844262295082e-05, 'epoch': 2.07}


 42%|████▏     | 460/1085 [16:55<21:22,  2.05s/it]

{'loss': 0.2069, 'learning_rate': 1.9211065573770493e-05, 'epoch': 2.11}


 43%|████▎     | 470/1085 [17:16<21:05,  2.06s/it]

{'loss': 0.2294, 'learning_rate': 1.8903688524590165e-05, 'epoch': 2.16}


 44%|████▍     | 480/1085 [17:36<20:48,  2.06s/it]

{'loss': 0.2095, 'learning_rate': 1.8596311475409836e-05, 'epoch': 2.2}


 45%|████▌     | 490/1085 [17:57<20:27,  2.06s/it]

{'loss': 0.2028, 'learning_rate': 1.8288934426229507e-05, 'epoch': 2.25}


 46%|████▌     | 500/1085 [18:18<20:09,  2.07s/it]

{'loss': 0.1699, 'learning_rate': 1.798155737704918e-05, 'epoch': 2.3}


 47%|████▋     | 510/1085 [18:38<19:46,  2.06s/it]

{'loss': 0.1634, 'learning_rate': 1.7674180327868852e-05, 'epoch': 2.34}


 48%|████▊     | 520/1085 [18:59<19:28,  2.07s/it]

{'loss': 0.2046, 'learning_rate': 1.7366803278688527e-05, 'epoch': 2.39}


 49%|████▉     | 530/1085 [19:19<19:04,  2.06s/it]

{'loss': 0.1619, 'learning_rate': 1.7059426229508198e-05, 'epoch': 2.43}


 50%|████▉     | 540/1085 [19:40<18:42,  2.06s/it]

{'loss': 0.1694, 'learning_rate': 1.6752049180327872e-05, 'epoch': 2.48}


 51%|█████     | 550/1085 [20:00<18:11,  2.04s/it]

{'loss': 0.1876, 'learning_rate': 1.644467213114754e-05, 'epoch': 2.53}


 52%|█████▏    | 560/1085 [20:21<17:50,  2.04s/it]

{'loss': 0.1919, 'learning_rate': 1.6137295081967214e-05, 'epoch': 2.57}


 53%|█████▎    | 570/1085 [20:41<17:30,  2.04s/it]

{'loss': 0.1483, 'learning_rate': 1.5829918032786885e-05, 'epoch': 2.62}


 53%|█████▎    | 580/1085 [21:02<17:09,  2.04s/it]

{'loss': 0.1896, 'learning_rate': 1.552254098360656e-05, 'epoch': 2.66}


 54%|█████▍    | 590/1085 [21:22<16:49,  2.04s/it]

{'loss': 0.1634, 'learning_rate': 1.521516393442623e-05, 'epoch': 2.71}


 55%|█████▌    | 600/1085 [21:42<16:29,  2.04s/it]

{'loss': 0.1412, 'learning_rate': 1.4907786885245902e-05, 'epoch': 2.76}


 56%|█████▌    | 610/1085 [22:03<16:08,  2.04s/it]

{'loss': 0.1966, 'learning_rate': 1.4600409836065574e-05, 'epoch': 2.8}


 57%|█████▋    | 620/1085 [22:23<15:48,  2.04s/it]

{'loss': 0.1856, 'learning_rate': 1.4293032786885247e-05, 'epoch': 2.85}


 58%|█████▊    | 630/1085 [22:44<15:28,  2.04s/it]

{'loss': 0.1769, 'learning_rate': 1.3985655737704918e-05, 'epoch': 2.89}


 59%|█████▉    | 640/1085 [23:04<15:07,  2.04s/it]

{'loss': 0.154, 'learning_rate': 1.3678278688524591e-05, 'epoch': 2.94}


 60%|█████▉    | 650/1085 [23:24<14:46,  2.04s/it]

{'loss': 0.1593, 'learning_rate': 1.3370901639344264e-05, 'epoch': 2.99}


                                                  
 60%|██████    | 653/1085 [24:05<14:40,  2.04s/it]

{'eval_loss': 0.08418291807174683, 'eval_accuracy': 0.9827635736857225, 'eval_runtime': 33.9037, 'eval_samples_per_second': 102.673, 'eval_steps_per_second': 6.43, 'epoch': 3.0}


 61%|██████    | 660/1085 [24:19<22:58,  3.24s/it]  

{'loss': 0.1639, 'learning_rate': 1.3063524590163935e-05, 'epoch': 3.03}


 62%|██████▏   | 670/1085 [24:40<14:25,  2.08s/it]

{'loss': 0.1141, 'learning_rate': 1.2756147540983606e-05, 'epoch': 3.08}


 63%|██████▎   | 680/1085 [25:00<13:54,  2.06s/it]

{'loss': 0.1742, 'learning_rate': 1.2448770491803279e-05, 'epoch': 3.12}


 64%|██████▎   | 690/1085 [25:21<13:36,  2.07s/it]

{'loss': 0.1588, 'learning_rate': 1.2141393442622951e-05, 'epoch': 3.17}


 65%|██████▍   | 700/1085 [25:41<13:13,  2.06s/it]

{'loss': 0.1284, 'learning_rate': 1.1834016393442622e-05, 'epoch': 3.21}


 65%|██████▌   | 710/1085 [26:02<12:49,  2.05s/it]

{'loss': 0.0954, 'learning_rate': 1.1526639344262295e-05, 'epoch': 3.26}


 66%|██████▋   | 720/1085 [26:23<12:29,  2.05s/it]

{'loss': 0.1176, 'learning_rate': 1.1219262295081968e-05, 'epoch': 3.31}


 67%|██████▋   | 730/1085 [26:43<12:10,  2.06s/it]

{'loss': 0.1436, 'learning_rate': 1.0911885245901639e-05, 'epoch': 3.35}


 68%|██████▊   | 740/1085 [27:04<11:49,  2.06s/it]

{'loss': 0.1427, 'learning_rate': 1.0604508196721312e-05, 'epoch': 3.4}


 69%|██████▉   | 750/1085 [27:24<11:31,  2.06s/it]

{'loss': 0.1446, 'learning_rate': 1.0297131147540984e-05, 'epoch': 3.44}


 70%|███████   | 760/1085 [27:45<11:13,  2.07s/it]

{'loss': 0.1284, 'learning_rate': 9.989754098360657e-06, 'epoch': 3.49}


 71%|███████   | 770/1085 [28:06<10:52,  2.07s/it]

{'loss': 0.1194, 'learning_rate': 9.682377049180328e-06, 'epoch': 3.54}


 72%|███████▏  | 780/1085 [28:27<10:34,  2.08s/it]

{'loss': 0.0956, 'learning_rate': 9.375000000000001e-06, 'epoch': 3.58}


 73%|███████▎  | 790/1085 [28:47<10:07,  2.06s/it]

{'loss': 0.1117, 'learning_rate': 9.067622950819674e-06, 'epoch': 3.63}


 74%|███████▎  | 800/1085 [29:08<09:48,  2.06s/it]

{'loss': 0.077, 'learning_rate': 8.760245901639343e-06, 'epoch': 3.67}


 75%|███████▍  | 810/1085 [29:28<09:28,  2.07s/it]

{'loss': 0.1392, 'learning_rate': 8.452868852459016e-06, 'epoch': 3.72}


 76%|███████▌  | 820/1085 [29:49<09:07,  2.07s/it]

{'loss': 0.131, 'learning_rate': 8.145491803278688e-06, 'epoch': 3.77}


 76%|███████▋  | 830/1085 [30:10<08:41,  2.04s/it]

{'loss': 0.1268, 'learning_rate': 7.838114754098361e-06, 'epoch': 3.81}


 77%|███████▋  | 840/1085 [30:30<08:20,  2.04s/it]

{'loss': 0.1152, 'learning_rate': 7.530737704918032e-06, 'epoch': 3.86}


 78%|███████▊  | 850/1085 [30:50<07:59,  2.04s/it]

{'loss': 0.1512, 'learning_rate': 7.223360655737705e-06, 'epoch': 3.9}


 79%|███████▉  | 860/1085 [31:11<07:38,  2.04s/it]

{'loss': 0.101, 'learning_rate': 6.915983606557377e-06, 'epoch': 3.95}


 80%|████████  | 870/1085 [31:31<07:18,  2.04s/it]

{'loss': 0.1183, 'learning_rate': 6.60860655737705e-06, 'epoch': 4.0}


                                                  
 80%|████████  | 871/1085 [32:07<06:48,  1.91s/it]

{'eval_loss': 0.06703688204288483, 'eval_accuracy': 0.9844872163171502, 'eval_runtime': 33.9042, 'eval_samples_per_second': 102.672, 'eval_steps_per_second': 6.43, 'epoch': 4.0}


 81%|████████  | 880/1085 [32:26<09:00,  2.64s/it]

{'loss': 0.0736, 'learning_rate': 6.3012295081967215e-06, 'epoch': 4.04}


 82%|████████▏ | 890/1085 [32:46<06:40,  2.06s/it]

{'loss': 0.0797, 'learning_rate': 5.993852459016394e-06, 'epoch': 4.09}


 83%|████████▎ | 900/1085 [33:07<06:17,  2.04s/it]

{'loss': 0.1486, 'learning_rate': 5.686475409836066e-06, 'epoch': 4.13}


 84%|████████▍ | 910/1085 [33:27<05:56,  2.04s/it]

{'loss': 0.1064, 'learning_rate': 5.379098360655737e-06, 'epoch': 4.18}


 85%|████████▍ | 920/1085 [33:48<05:36,  2.04s/it]

{'loss': 0.1065, 'learning_rate': 5.07172131147541e-06, 'epoch': 4.23}


 86%|████████▌ | 930/1085 [34:08<05:16,  2.04s/it]

{'loss': 0.086, 'learning_rate': 4.764344262295082e-06, 'epoch': 4.27}


 87%|████████▋ | 940/1085 [34:28<04:55,  2.04s/it]

{'loss': 0.1208, 'learning_rate': 4.4569672131147546e-06, 'epoch': 4.32}


 88%|████████▊ | 950/1085 [34:49<04:35,  2.04s/it]

{'loss': 0.1174, 'learning_rate': 4.1495901639344265e-06, 'epoch': 4.36}


 88%|████████▊ | 960/1085 [35:09<04:14,  2.04s/it]

{'loss': 0.0959, 'learning_rate': 3.842213114754098e-06, 'epoch': 4.41}


 89%|████████▉ | 970/1085 [35:30<03:54,  2.04s/it]

{'loss': 0.1031, 'learning_rate': 3.5348360655737707e-06, 'epoch': 4.45}


 90%|█████████ | 980/1085 [35:50<03:34,  2.04s/it]

{'loss': 0.1342, 'learning_rate': 3.2274590163934426e-06, 'epoch': 4.5}


 91%|█████████ | 990/1085 [36:10<03:13,  2.04s/it]

{'loss': 0.1451, 'learning_rate': 2.920081967213115e-06, 'epoch': 4.55}


 92%|█████████▏| 1000/1085 [36:31<02:53,  2.04s/it]

{'loss': 0.1189, 'learning_rate': 2.6127049180327868e-06, 'epoch': 4.59}


 93%|█████████▎| 1010/1085 [36:51<02:34,  2.06s/it]

{'loss': 0.0841, 'learning_rate': 2.305327868852459e-06, 'epoch': 4.64}


 94%|█████████▍| 1020/1085 [37:12<02:14,  2.06s/it]

{'loss': 0.115, 'learning_rate': 1.9979508196721314e-06, 'epoch': 4.68}


 95%|█████████▍| 1030/1085 [37:33<01:53,  2.06s/it]

{'loss': 0.1379, 'learning_rate': 1.6905737704918033e-06, 'epoch': 4.73}


 96%|█████████▌| 1040/1085 [37:53<01:32,  2.06s/it]

{'loss': 0.1244, 'learning_rate': 1.3831967213114756e-06, 'epoch': 4.78}


 97%|█████████▋| 1050/1085 [38:14<01:12,  2.06s/it]

{'loss': 0.141, 'learning_rate': 1.0758196721311475e-06, 'epoch': 4.82}


 98%|█████████▊| 1060/1085 [38:34<00:51,  2.06s/it]

{'loss': 0.0982, 'learning_rate': 7.684426229508197e-07, 'epoch': 4.87}


 99%|█████████▊| 1070/1085 [38:55<00:30,  2.06s/it]

{'loss': 0.1306, 'learning_rate': 4.610655737704918e-07, 'epoch': 4.91}


100%|█████████▉| 1080/1085 [39:16<00:10,  2.06s/it]

{'loss': 0.0733, 'learning_rate': 1.5368852459016392e-07, 'epoch': 4.96}


                                                   
100%|██████████| 1085/1085 [40:01<00:00,  2.06s/it]

{'eval_loss': 0.06069877743721008, 'eval_accuracy': 0.9864981327204826, 'eval_runtime': 34.9516, 'eval_samples_per_second': 99.595, 'eval_steps_per_second': 6.237, 'epoch': 4.98}


100%|██████████| 1085/1085 [40:02<00:00,  2.21s/it]

{'train_runtime': 2402.5691, 'train_samples_per_second': 28.971, 'train_steps_per_second': 0.452, 'train_loss': 0.3366053492242839, 'epoch': 4.98}





TrainOutput(global_step=1085, training_loss=0.3366053492242839, metrics={'train_runtime': 2402.5691, 'train_samples_per_second': 28.971, 'train_steps_per_second': 0.452, 'train_loss': 0.3366053492242839, 'epoch': 4.98})

In [32]:
trainer.evaluate()

100%|██████████| 218/218 [00:34<00:00,  6.29it/s]


{'eval_loss': 0.06069877743721008,
 'eval_accuracy': 0.9864981327204826,
 'eval_runtime': 34.8562,
 'eval_samples_per_second': 99.867,
 'eval_steps_per_second': 6.254,
 'epoch': 4.98}

In [None]:
from transformers import AutoModelForAudioClassification, AutoFeatureExtractor

feature_extractor = AutoFeatureExtractor.from_pretrained("anton-l/my-awesome-model")
model = AutoModelForAudioClassification.from_pretrained("anton-l/my-awesome-model")

### Prediction

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from transformers import AutoConfig, Wav2Vec2Processor

import librosa
import IPython.display as ipd
import numpy as np
import pandas as pd

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_extractor_checkpoint = "facebook/wav2vec2-base"
audio_classification_checkpoint = "wav2vec2-base-finetuned-ks/checkpoint-1085"
feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_checkpoint)
config = AutoConfig.from_pretrained(feature_extractor_checkpoint)
sampling_rate = feature_extractor.sampling_rate
model = AutoModelForAudioClassification.from_pretrained(audio_classification_checkpoint).to(device)

In [40]:
def speech_file_to_array_fn(path, sampling_rate):
    speech_array, _sampling_rate = torchaudio.load(path)
    resampler = torchaudio.transforms.Resample(_sampling_rate)
    speech = resampler(speech_array).squeeze().numpy()
    return speech


def predict(path, sampling_rate):
    speech = speech_file_to_array_fn(path, sampling_rate)
    features = feature_extractor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)

    input_values = features.input_values.to(device)
    # attention_mask = features.attention_mask.to(device)

    with torch.no_grad():
        logits = model(input_values).logits

    scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
    outputs = scores
    # outputs = [{"Emotion": config.id2label[i], "Score": f"{round(score * 100, 3):.1f}%"} for i, score in enumerate(scores)]
    return outputs


STYLES = """
<style>
div.display_data {
    margin: 0 auto;
    max-width: 500px;
}
table.xxx {
    margin: 50px !important;
    float: right !important;
    clear: both !important;
}
table.xxx td {
    min-width: 300px !important;
    text-align: center !important;
}
</style>
""".strip()

def prediction(df_row):
    path, emotion = df_row["path"], df_row["emotion"]
    df = pd.DataFrame([{"Emotion": emotion, "Sentence": "    "}])
    setup = {
        'border': 2,
        'show_dimensions': True,
        'justify': 'center',
        'classes': 'xxx',
        'escape': False,
    }
    ipd.display(ipd.HTML(STYLES + df.to_html(**setup) + "<br />"))
    speech, sr = torchaudio.load(path)
    speech = speech[0].numpy().squeeze()
    speech = librosa.resample(np.asarray(speech), sr, sampling_rate)
    ipd.display(ipd.Audio(data=np.asarray(speech), autoplay=True, rate=sampling_rate))

    outputs = predict(path, sampling_rate)
    r = pd.DataFrame(outputs)
    ipd.display(ipd.HTML(STYLES + r.to_html(**setup) + "<br />"))

In [41]:
# with pyaudio
import pyaudio
import wave
import tempfile
import os

CHUNK = 320  # number of audio samples per frame
FORMAT = pyaudio.paInt16  # audio format
CHANNELS = 1  # mono audio
RATE = 48000  # sampling rate in Hz
RECORD_SECONDS = 1  # duration of each recording in seconds
FILE_NAME = f"temp.wav"

def record_audio():
    p = pyaudio.PyAudio()

    stream = p.open(format=FORMAT,
                    channels=CHANNELS,
                    rate=RATE,
                    input=True,
                    frames_per_buffer=CHUNK,
                    input_device_index=4)

    try:
        while True:
            frames = []  # to store audio frames

            for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
                data = stream.read(CHUNK)
                frames.append(data)

            # write frames to temporary WAV file
            
            wav_filename =  FILE_NAME
            wf = wave.open(wav_filename, 'wb')
            wf.setnchannels(CHANNELS)
            wf.setsampwidth(p.get_sample_size(FORMAT))
            wf.setframerate(RATE)
            wf.writeframes(b''.join(frames))
            wf.close()

            # read contents of WAV file a

            yield wav_filename

    except KeyboardInterrupt:
        pass

    stream.stop_stream()
    stream.close()
    p.terminate()

In [44]:
labels

array(['go', 'happy', 'follow', 'marvin', 'stop', 'down'], dtype=object)

In [48]:
for wav_data in record_audio():
    # pass the WAV data to your keyword spotter here
    label = predict(wav_data, 16000)
    max = np.argmax(label)
    if max > 0.99:
        print(id2label[str(max)])
    else:
        print("not confident enough")
    

not confident enough
down
not confident enough
happy
happy
happy
not confident enough
not confident enough
happy
not confident enough
follow
happy
happy
happy
happy
not confident enough
happy
happy
happy
not confident enough
happy
happy
down
stop
not confident enough
not confident enough
stop
happy
not confident enough
down
not confident enough
not confident enough
happy
happy
happy
happy
happy
happy
happy
happy
not confident enough
happy
not confident enough
happy
happy
happy
happy
happy
happy
happy
not confident enough
happy
happy
not confident enough
not confident enough
not confident enough
not confident enough
not confident enough
not confident enough
not confident enough
not confident enough
not confident enough
not confident enough
happy
not confident enough
happy
not confident enough
happy
not confident enough
not confident enough
happy
not confident enough
not confident enough
not confident enough
not confident enough
happy
down
marvin
down
marvin
not confident enough
marvin
n