In [2]:
import pandas as pd
import torch
import torchaudio
from pathlib import Path
import transformers
device = torch.device("cuda")

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from transformers import AutoFeatureExtractor, WhisperForAudioClassification
import os

token = os.getenv('HF_TOKEN')
model_id = "openai/whisper-tiny"

feature_extractor = AutoFeatureExtractor.from_pretrained(model_id, token=token)
model = WhisperForAudioClassification.from_pretrained(model_id, token=token, num_labels=7)

feature_extractor
model.to(device)

Some weights of WhisperForAudioClassification were not initialized from the model checkpoint at openai/whisper-tiny and are newly initialized: ['model.classifier.bias', 'model.classifier.weight', 'model.projector.bias', 'model.projector.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


WhisperForAudioClassification(
  (encoder): WhisperEncoder(
    (conv1): Conv1d(80, 384, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(384, 384, kernel_size=(3,), stride=(2,), padding=(1,))
    (embed_positions): Embedding(1500, 384)
    (layers): ModuleList(
      (0-3): 4 x WhisperEncoderLayer(
        (self_attn): WhisperSdpaAttention(
          (k_proj): Linear(in_features=384, out_features=384, bias=False)
          (v_proj): Linear(in_features=384, out_features=384, bias=True)
          (q_proj): Linear(in_features=384, out_features=384, bias=True)
          (out_proj): Linear(in_features=384, out_features=384, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (activation_fn): GELUActivation()
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (final_layer_norm): LayerNorm((384,), eps=1e-05, elementwise_

In [5]:
data_path = ""
df = pd.read_csv("data/ground_truth.csv")

In [6]:
import os
from functools import lru_cache
from subprocess import CalledProcessError, run
from typing import Optional, Union

import numpy as np
import torch
import torch.nn.functional as F

def exact_div(x, y):
    assert x % y == 0
    return x // y


class AudioUtil():
  # hard-coded audio hyperparameters
  SAMPLE_RATE = 16000
  N_FFT = 400
  HOP_LENGTH = 160
  CHUNK_LENGTH = 30
  N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE  # 480000 samples in a 30-second chunk
  N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH)  # 3000 frames in a mel spectrogram input

  N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2  # the initial convolutions has stride 2
  FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH)  # 10ms per audio frame
  TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN)  # 20ms per audio token


  def load_audio(file: str, sr: int = SAMPLE_RATE):
      """
      Open an audio file and read as mono waveform, resampling as necessary

      Parameters
      ----------
      file: str
          The audio file to open

      sr: int
          The sample rate to resample the audio if necessary

      Returns
      -------
      A NumPy array containing the audio waveform, in float32 dtype.
      """

      # This launches a subprocess to decode audio while down-mixing
      # and resampling as necessary.  Requires the ffmpeg CLI in PATH.
      # fmt: off
      cmd = [
          "ffmpeg",
          "-nostdin",
          "-threads", "0",
          "-i", file,
          "-f", "s16le",
          "-ac", "1",
          "-acodec", "pcm_s16le",
          "-ar", str(sr),
          "-"
      ]
      # fmt: on
      try:
          out = run(cmd, capture_output=True, check=True).stdout
      except CalledProcessError as e:
          raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e

      return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0


  def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
      """
      Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
      """
      if torch.is_tensor(array):
          if array.shape[axis] > length:
              array = array.index_select(
                  dim=axis, index=torch.arange(length, device=array.device)
              )

          if array.shape[axis] < length:
              pad_widths = [(0, 0)] * array.ndim
              pad_widths[axis] = (0, length - array.shape[axis])
              array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
      else:
          if array.shape[axis] > length:
              array = array.take(indices=range(length), axis=axis)

          if array.shape[axis] < length:
              pad_widths = [(0, 0)] * array.ndim
              pad_widths[axis] = (0, length - array.shape[axis])
              array = np.pad(array, pad_widths)

      return array
  
  @lru_cache(maxsize=None)
  def mel_filters(device, n_mels: int) -> torch.Tensor:
        """
        load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
        Allows decoupling librosa dependency; saved using:

            np.savez_compressed(
                "mel_filters.npz",
                mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
                mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
            )
        """
        assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"

        with np.load("data/mel_filters.npz", allow_pickle=False) as f:
            return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)

  def log_mel_spectrogram(
        audio: Union[str, np.ndarray, torch.Tensor],
        n_mels: int = 80,
        padding: int = 0,
        device: Optional[Union[str, torch.device]] = None,
    ):
        if not torch.is_tensor(audio):
            if isinstance(audio, str):
                audio = AudioUtil.load_audio(audio)
            audio = torch.from_numpy(audio)

        if device is not None:
            audio = audio.to(device)
        if padding > 0:
            audio = F.pad(audio, (0, padding))
        window = torch.hann_window(AudioUtil.N_FFT).to(audio.device)
        stft = torch.stft(audio, AudioUtil.N_FFT, AudioUtil.HOP_LENGTH, window=window, return_complex=True)
        magnitudes = stft[..., :-1].abs() ** 2

        filters = AudioUtil.mel_filters(audio.device, n_mels)
        mel_spec = filters @ magnitudes

        log_spec = torch.clamp(mel_spec, min=1e-10).log10()
        log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
        log_spec = (log_spec + 4.0) / 4.0
        return log_spec

In [7]:
from torch.utils.data import DataLoader, Dataset
import torchaudio

# ----------------------------
# Sound Dataset
# ----------------------------
class SoundDS(Dataset):
  def __init__(self, df, data_path):
    self.df = df
    self.data_path = str(data_path)
            
  # ----------------------------
  # Number of items in dataset
  # ----------------------------
  def __len__(self):
    return len(self.df)    
    
  # ----------------------------
  # Get i'th item in dataset
  # ----------------------------
  def __getitem__(self, idx):
    # Absolute file path of the audio file - concatenate the audio directory with
    # the relative path
    audio_file = self.data_path + self.df.loc[idx, 'relative_path']
    # Get the Class ID
    class_id = self.df.loc[idx, 'classID']

    audio = AudioUtil.load_audio(audio_file)
    audio = AudioUtil.pad_or_trim(audio)

    # make log-Mel spectrogram and move to the same device as the model
    mel = AudioUtil.log_mel_spectrogram(audio).to(model.device)

    return {
        "input_features": mel,  # Adjust this key according to your model's input name
        "labels": torch.tensor(class_id)
    }

In [8]:
from torch.utils.data import random_split

myds = SoundDS(df, data_path)

# Split the dataset into training, validation, and test sets
num_items = len(myds)
num_train = round(num_items * 0.7)
num_val = round(num_items * 0.15)
num_test = num_items - num_train - num_val
train_ds, val_ds, test_ds = random_split(myds, [num_train, num_val, num_test])


In [9]:
from transformers import TrainingArguments, Trainer

batch_size = 32

args = TrainingArguments(
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    output_dir="./"
)



In [10]:
import numpy as np
from datasets import load_metric
metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

  metric = load_metric("accuracy")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
Downloading builder script: 4.21kB [00:00, 8.50MB/s]                   


In [11]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics
)

In [12]:
trainer.train()

  attn_output = torch.nn.functional.scaled_dot_product_attention(
  8%|▊         | 10/130 [01:55<22:32, 11.27s/it]

{'loss': 1.8446, 'grad_norm': 4.654596328735352, 'learning_rate': 2.307692307692308e-05, 'epoch': 0.38}


 15%|█▌        | 20/130 [03:47<20:37, 11.25s/it]

{'loss': 1.4343, 'grad_norm': 3.2920188903808594, 'learning_rate': 2.8205128205128207e-05, 'epoch': 0.77}


                                                
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50358, 50359, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


{'eval_loss': 1.1777087450027466, 'eval_accuracy': 0.653954802259887, 'eval_runtime': 55.1039, 'eval_samples_per_second': 12.848, 'eval_steps_per_second': 0.417, 'epoch': 1.0}


 23%|██▎       | 30/130 [06:27<26:22, 15.83s/it]

{'loss': 1.1554, 'grad_norm': 1.2386064529418945, 'learning_rate': 2.564102564102564e-05, 'epoch': 1.15}


 31%|███       | 40/130 [08:06<14:35,  9.72s/it]

{'loss': 1.0806, 'grad_norm': 2.7539684772491455, 'learning_rate': 2.307692307692308e-05, 'epoch': 1.54}


 38%|███▊      | 50/130 [09:45<13:12,  9.91s/it]

{'loss': 1.0777, 'grad_norm': 4.981454849243164, 'learning_rate': 2.0512820512820515e-05, 'epoch': 1.92}


                                                
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50358, 50359, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


{'eval_loss': 1.0701020956039429, 'eval_accuracy': 0.6581920903954802, 'eval_runtime': 46.5453, 'eval_samples_per_second': 15.211, 'eval_steps_per_second': 0.494, 'epoch': 2.0}


 46%|████▌     | 60/130 [12:09<12:55, 11.09s/it]

{'loss': 1.035, 'grad_norm': 3.3870556354522705, 'learning_rate': 1.7948717948717948e-05, 'epoch': 2.31}


 54%|█████▍    | 70/130 [13:47<09:54,  9.92s/it]

{'loss': 1.0112, 'grad_norm': 6.031352519989014, 'learning_rate': 1.5384615384615384e-05, 'epoch': 2.69}


                                                
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50358, 50359, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


{'eval_loss': 1.033631443977356, 'eval_accuracy': 0.672316384180791, 'eval_runtime': 47.0122, 'eval_samples_per_second': 15.06, 'eval_steps_per_second': 0.489, 'epoch': 3.0}


 62%|██████▏   | 80/130 [16:10<16:13, 19.47s/it]

{'loss': 0.9697, 'grad_norm': 9.146903991699219, 'learning_rate': 1.282051282051282e-05, 'epoch': 3.08}


 69%|██████▉   | 90/130 [17:50<06:53, 10.33s/it]

{'loss': 0.9158, 'grad_norm': 7.488309860229492, 'learning_rate': 1.0256410256410258e-05, 'epoch': 3.46}


 77%|███████▋  | 100/130 [19:33<05:10, 10.35s/it]

{'loss': 0.9247, 'grad_norm': 4.136931419372559, 'learning_rate': 7.692307692307692e-06, 'epoch': 3.85}


                                                 
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50358, 50359, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


{'eval_loss': 1.007013201713562, 'eval_accuracy': 0.6638418079096046, 'eval_runtime': 48.1773, 'eval_samples_per_second': 14.696, 'eval_steps_per_second': 0.477, 'epoch': 4.0}


 85%|████████▍ | 110/130 [21:46<04:02, 12.11s/it]

{'loss': 0.8603, 'grad_norm': 4.615314483642578, 'learning_rate': 5.128205128205129e-06, 'epoch': 4.23}


 92%|█████████▏| 120/130 [23:36<01:49, 10.97s/it]

{'loss': 0.8675, 'grad_norm': 4.85215425491333, 'learning_rate': 2.5641025641025644e-06, 'epoch': 4.62}


100%|██████████| 130/130 [25:15<00:00,  8.39s/it]

{'loss': 0.8817, 'grad_norm': 6.624349594116211, 'learning_rate': 0.0, 'epoch': 5.0}


                                                 
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50358, 50359, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}
100%|██████████| 130/130 [25:53<00:00, 11.95s/it]

{'eval_loss': 0.9969074726104736, 'eval_accuracy': 0.6709039548022598, 'eval_runtime': 38.4079, 'eval_samples_per_second': 18.434, 'eval_steps_per_second': 0.599, 'epoch': 5.0}
{'train_runtime': 1553.5563, 'train_samples_per_second': 10.637, 'train_steps_per_second': 0.084, 'train_loss': 1.0814221895658054, 'epoch': 5.0}





TrainOutput(global_step=130, training_loss=1.0814221895658054, metrics={'train_runtime': 1553.5563, 'train_samples_per_second': 10.637, 'train_steps_per_second': 0.084, 'total_flos': 1.84008352428e+17, 'train_loss': 1.0814221895658054, 'epoch': 5.0})

In [13]:
trainer.evaluate()

100%|██████████| 23/23 [00:40<00:00,  1.75s/it]


{'eval_loss': 1.033631443977356,
 'eval_accuracy': 0.672316384180791,
 'eval_runtime': 43.2531,
 'eval_samples_per_second': 16.369,
 'eval_steps_per_second': 0.532,
 'epoch': 5.0}