In [1]:
#!pip install evaluate
#!pip install peft==0.10.0

from transformers import WhisperFeatureExtractor, WhisperProcessor, WhisperTokenizer, WhisperForConditionalGeneration
import evaluate

import tqdm
import peft

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Sat Jan 18 19:25:12 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.94                 Driver Version: 560.94         CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 2070 ...  WDDM  |   00000000:01:00.0  On |                  N/A |
| 32%   51C    P0             49W /  215W |     760MiB /   8192MiB |      7%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
import os
import json
from datasets import Dataset, Audio
from tqdm import tqdm
from datasets import Dataset, Audio, DatasetDict
import re

def clean_text(text):
    # Remove non-ASCII characters and extra spaces
    text = re.sub(r'[^\x00-\x7F]+', ' ', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def create_manifest(audio_dir, text_dir, manifest_path):
    manifest = []
    for file_name in tqdm(os.listdir(audio_dir)[:5000]): # too much data, so use less
        if file_name.endswith(".wav"):
            audio_path = os.path.join(audio_dir, file_name)
            transcription_path = os.path.join(text_dir, os.path.splitext(file_name)[0] + ".txt")
            if os.path.exists(transcription_path):
                with open(transcription_path, "r") as f:
                    transcription = clean_text(f.read().strip())
                    if len(transcription.split(" ")) < 7:
                        continue
                manifest.append({
                    "audio": audio_path,
                    "transcription": transcription
                })
    with open(manifest_path, "w") as f:
        json.dump(manifest, f, indent=4)

def load_dataset(manifest_path, test_size=0.2):
    with open(manifest_path, "r") as f:
        data = json.load(f)
    dataset = Dataset.from_list(data)
    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
            
    # Split the dataset into train and validation sets
    train_test_split = dataset.train_test_split(test_size)
    dataset = DatasetDict({
        "train": train_test_split["train"],
        "validation": train_test_split["test"]
    })
            
    return dataset

# Example usage
train_audio_dir = "data/split_audio"
train_text_dir = "data/split_text"
manifest_path = "train_manifest.json"

# Create manifest file
create_manifest(train_audio_dir, train_text_dir, manifest_path)

# Load dataset
dataset = load_dataset(manifest_path, test_size=0.2)
print(dataset)

100%|██████████| 5000/5000 [00:01<00:00, 4306.03it/s]

DatasetDict({
    train: Dataset({
        features: ['audio', 'transcription'],
        num_rows: 2403
    })
    validation: Dataset({
        features: ['audio', 'transcription'],
        num_rows: 601
    })
})





In [4]:
print(dataset["train"][0])

{'audio': {'path': 'data/split_audio\\20190918_FOMC_333.wav', 'array': array([ 1.22070312e-04,  1.22070312e-04,  1.22070312e-04, ...,
       -3.05175781e-05,  0.00000000e+00,  3.96728516e-04]), 'sampling_rate': 16000}, 'transcription': "the appropriate path for monetary policy in that individual person's"}


In [5]:
from transformers import WhisperFeatureExtractor, WhisperProcessor, WhisperTokenizer, WhisperForConditionalGeneration

language = "english"
model_name = "openai/whisper-small"

feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name)
tokenizer = WhisperTokenizer.from_pretrained(model_name, language=language, task="transcribe")
processor = WhisperProcessor.from_pretrained(model_name, language=language, task="transcribe")
model = WhisperForConditionalGeneration.from_pretrained(model_name)


def prepare_dataset(batch):
    # load audio data
    audio = batch["audio"]

    # compute log-Mel input features from input audio array
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids
    batch["labels"] = tokenizer(batch["transcription"]).input_ids
    return batch



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names["train"])
dataset["train"], dataset["validation"]

Map: 100%|██████████| 2403/2403 [00:28<00:00, 84.48 examples/s] 
Map: 100%|██████████| 601/601 [00:06<00:00, 92.06 examples/s] 


(Dataset({
     features: ['input_features', 'labels'],
     num_rows: 2403
 }),
 Dataset({
     features: ['input_features', 'labels'],
     num_rows: 601
 }))

In [7]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [8]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

In [9]:
import evaluate

metric = evaluate.load("wer")

In [10]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [11]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.generation_config.language = "english"

In [12]:
from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model

config = LoraConfig(r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=0.5, bias="none")

model = get_peft_model(model, config)

model.print_trainable_parameters()

trainable params: 3,538,944 || all params: 245,273,856 || trainable%: 1.442854145857274


we are only using ~1% of the trainable parameters

In [13]:
from transformers import Seq2SeqTrainingArguments, EarlyStoppingCallback
#!pip install accelerate==0.28.0

training_args = Seq2SeqTrainingArguments(
    output_dir="output_dir",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,
    learning_rate=2e-5,
    gradient_checkpointing=False,
    fp16=True,
    evaluation_strategy="epoch",
    per_device_eval_batch_size=2,
    predict_with_generate=True,
    generation_max_length=128,
    num_train_epochs=5,
    save_steps=0.2,
    logging_steps=25,
    report_to=["tensorboard"],
    metric_for_best_model="loss",  # Use loss as the metric for early stopping
    greater_is_better=False,
    push_to_hub=True,
    remove_unused_columns=False,
    label_names=["labels"]
)

In [14]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [15]:
trainer.train()

  2%|▏         | 25/1505 [17:40<15:34:27, 37.88s/it]

{'loss': 0.9117, 'grad_norm': 1.068368673324585, 'learning_rate': 1.966777408637874e-05, 'epoch': 0.08}


  3%|▎         | 50/1505 [30:43<12:20:26, 30.53s/it]

{'loss': 0.8347, 'grad_norm': 1.4269933700561523, 'learning_rate': 1.933554817275748e-05, 'epoch': 0.17}


  5%|▍         | 75/1505 [43:46<12:26:10, 31.31s/it]

{'loss': 0.7224, 'grad_norm': 0.924464225769043, 'learning_rate': 1.9003322259136213e-05, 'epoch': 0.25}


  7%|▋         | 100/1505 [57:50<12:37:39, 32.36s/it]

{'loss': 0.6855, 'grad_norm': 1.2925922870635986, 'learning_rate': 1.867109634551495e-05, 'epoch': 0.33}


  8%|▊         | 125/1505 [1:08:55<8:04:33, 21.07s/it] 

{'loss': 0.567, 'grad_norm': 1.1075271368026733, 'learning_rate': 1.833887043189369e-05, 'epoch': 0.42}


 10%|▉         | 150/1505 [1:18:17<8:46:46, 23.33s/it]

{'loss': 0.446, 'grad_norm': 1.0153590440750122, 'learning_rate': 1.8006644518272428e-05, 'epoch': 0.5}


 12%|█▏        | 175/1505 [1:26:59<7:41:59, 20.84s/it]

{'loss': 0.3379, 'grad_norm': 1.3115565776824951, 'learning_rate': 1.7674418604651163e-05, 'epoch': 0.58}


 13%|█▎        | 200/1505 [1:35:45<8:07:15, 22.40s/it]

{'loss': 0.2471, 'grad_norm': 1.0001295804977417, 'learning_rate': 1.73421926910299e-05, 'epoch': 0.66}


 15%|█▍        | 225/1505 [1:44:38<7:45:42, 21.83s/it]

{'loss': 0.1694, 'grad_norm': 0.8210864663124084, 'learning_rate': 1.700996677740864e-05, 'epoch': 0.75}


 17%|█▋        | 250/1505 [1:53:25<7:05:12, 20.33s/it]

{'loss': 0.2246, 'grad_norm': 0.5600518584251404, 'learning_rate': 1.6677740863787378e-05, 'epoch': 0.83}


 18%|█▊        | 275/1505 [2:06:59<13:59:00, 40.93s/it]

{'loss': 0.1978, 'grad_norm': 1.1114120483398438, 'learning_rate': 1.6345514950166113e-05, 'epoch': 0.91}


 20%|█▉        | 300/1505 [2:22:04<11:43:09, 35.01s/it]

{'loss': 0.2146, 'grad_norm': 1.0748969316482544, 'learning_rate': 1.601328903654485e-05, 'epoch': 1.0}


                                                       
 20%|██        | 301/1505 [3:20:00<10:02:22, 30.02s/it]

{'eval_loss': 0.17721404135227203, 'eval_wer': 6.862084456424079, 'eval_runtime': 3457.2495, 'eval_samples_per_second': 0.174, 'eval_steps_per_second': 0.087, 'epoch': 1.0}


 22%|██▏       | 325/1505 [3:34:17<11:44:55, 35.84s/it]   

{'loss': 0.2502, 'grad_norm': 0.4973965883255005, 'learning_rate': 1.568106312292359e-05, 'epoch': 1.08}


 23%|██▎       | 350/1505 [3:49:07<11:21:43, 35.41s/it]

{'loss': 0.2332, 'grad_norm': 1.4782487154006958, 'learning_rate': 1.5348837209302328e-05, 'epoch': 1.16}


 25%|██▍       | 375/1505 [4:02:59<10:16:37, 32.74s/it]

{'loss': 0.1852, 'grad_norm': 0.43938878178596497, 'learning_rate': 1.5016611295681065e-05, 'epoch': 1.25}


 27%|██▋       | 400/1505 [4:18:03<11:01:36, 35.92s/it]

{'loss': 0.1738, 'grad_norm': 0.9841451048851013, 'learning_rate': 1.4684385382059803e-05, 'epoch': 1.33}


 28%|██▊       | 425/1505 [4:31:59<9:55:06, 33.06s/it] 

{'loss': 0.168, 'grad_norm': 0.5903456807136536, 'learning_rate': 1.435215946843854e-05, 'epoch': 1.41}


 30%|██▉       | 450/1505 [4:48:00<11:00:16, 37.55s/it]

{'loss': 0.1836, 'grad_norm': 0.5778101086616516, 'learning_rate': 1.4019933554817278e-05, 'epoch': 1.5}


 32%|███▏      | 475/1505 [5:03:43<10:34:07, 36.94s/it]

{'loss': 0.1692, 'grad_norm': 0.417302668094635, 'learning_rate': 1.3687707641196015e-05, 'epoch': 1.58}


 33%|███▎      | 500/1505 [5:18:21<10:02:37, 35.98s/it]

{'loss': 0.181, 'grad_norm': 0.751311719417572, 'learning_rate': 1.3355481727574753e-05, 'epoch': 1.66}


 35%|███▍      | 525/1505 [5:30:02<8:08:21, 29.90s/it] 

{'loss': 0.1556, 'grad_norm': 1.1488633155822754, 'learning_rate': 1.302325581395349e-05, 'epoch': 1.74}


 37%|███▋      | 550/1505 [5:50:44<12:54:11, 48.64s/it]

{'loss': 0.1903, 'grad_norm': 1.1881126165390015, 'learning_rate': 1.2691029900332228e-05, 'epoch': 1.83}


 38%|███▊      | 575/1505 [6:13:00<13:45:18, 53.25s/it]

{'loss': 0.176, 'grad_norm': 1.3428459167480469, 'learning_rate': 1.2358803986710965e-05, 'epoch': 1.91}


 40%|███▉      | 600/1505 [6:34:44<12:52:51, 51.24s/it]

{'loss': 0.2118, 'grad_norm': 1.015849232673645, 'learning_rate': 1.2026578073089703e-05, 'epoch': 1.99}


                                                       
 40%|████      | 602/1505 [7:57:48<11:07:56, 44.38s/it]

{'eval_loss': 0.16012196242809296, 'eval_wer': 6.435309973045822, 'eval_runtime': 4902.3784, 'eval_samples_per_second': 0.123, 'eval_steps_per_second': 0.061, 'epoch': 2.0}


 42%|████▏     | 625/1505 [8:17:54<12:35:37, 51.52s/it]   

{'loss': 0.1681, 'grad_norm': 0.3739056885242462, 'learning_rate': 1.1694352159468441e-05, 'epoch': 2.08}


 43%|████▎     | 650/1505 [8:40:07<12:39:41, 53.31s/it]

{'loss': 0.1688, 'grad_norm': 1.5053025484085083, 'learning_rate': 1.1362126245847176e-05, 'epoch': 2.16}


 45%|████▍     | 675/1505 [9:01:57<12:06:59, 52.55s/it]

{'loss': 0.1731, 'grad_norm': 1.1402802467346191, 'learning_rate': 1.1043189368770765e-05, 'epoch': 2.24}


 47%|████▋     | 700/1505 [9:24:30<12:11:21, 54.51s/it]

{'loss': 0.1973, 'grad_norm': 1.325312614440918, 'learning_rate': 1.0710963455149504e-05, 'epoch': 2.33}


 48%|████▊     | 725/1505 [9:46:28<11:10:33, 51.58s/it]

{'loss': 0.177, 'grad_norm': 0.7787802219390869, 'learning_rate': 1.037873754152824e-05, 'epoch': 2.41}


 50%|████▉     | 750/1505 [10:08:49<11:12:47, 53.47s/it]

{'loss': 0.1462, 'grad_norm': 0.9179911017417908, 'learning_rate': 1.0046511627906979e-05, 'epoch': 2.49}


 51%|█████▏    | 775/1505 [10:31:12<10:46:21, 53.12s/it]

{'loss': 0.1826, 'grad_norm': 1.0206655263900757, 'learning_rate': 9.714285714285715e-06, 'epoch': 2.57}


 53%|█████▎    | 800/1505 [10:53:29<10:22:01, 52.94s/it]

{'loss': 0.1133, 'grad_norm': 0.8461409211158752, 'learning_rate': 9.382059800664452e-06, 'epoch': 2.66}


 55%|█████▍    | 825/1505 [11:15:16<9:50:57, 52.14s/it] 

{'loss': 0.18, 'grad_norm': 1.217873454093933, 'learning_rate': 9.04983388704319e-06, 'epoch': 2.74}


 56%|█████▋    | 850/1505 [11:37:41<9:41:27, 53.26s/it] 

{'loss': 0.1542, 'grad_norm': 1.7475782632827759, 'learning_rate': 8.717607973421928e-06, 'epoch': 2.82}


 58%|█████▊    | 875/1505 [11:58:22<8:35:09, 49.06s/it]

{'loss': 0.2162, 'grad_norm': 1.30645751953125, 'learning_rate': 8.385382059800665e-06, 'epoch': 2.91}


 60%|█████▉    | 900/1505 [12:19:31<8:26:27, 50.23s/it]

{'loss': 0.1279, 'grad_norm': 0.38487187027931213, 'learning_rate': 8.053156146179403e-06, 'epoch': 2.99}


                                                       
 60%|██████    | 903/1505 [13:49:24<7:20:27, 43.90s/it]

{'eval_loss': 0.15477199852466583, 'eval_wer': 6.210691823899371, 'eval_runtime': 5255.6277, 'eval_samples_per_second': 0.114, 'eval_steps_per_second': 0.057, 'epoch': 3.0}


 61%|██████▏   | 923/1505 [14:04:53<7:45:28, 47.99s/it]    

KeyboardInterrupt: 

In [None]:
model.save_pretrained("output_dir/run4")
processor.save_pretrained("output_dir/run4")

[]

In [28]:
token = "" #specify Huggingface token to upload model to the hub

In [None]:
model.push_to_hub("ashkab/output_dir2")

adapter_model.safetensors: 100%|██████████| 14.2M/14.2M [00:02<00:00, 5.12MB/s]


CommitInfo(commit_url='https://huggingface.co/ashkab/output_dir/commit/cc0f31fc429747a7523674c6958a07035a161558', commit_message='Upload model', commit_description='', oid='cc0f31fc429747a7523674c6958a07035a161558', pr_url=None, repo_url=RepoUrl('https://huggingface.co/ashkab/output_dir', endpoint='https://huggingface.co', repo_type='model', repo_id='ashkab/output_dir'), pr_revision=None, pr_num=None)