In [1]:
import os
import platform
from typing import Optional

import pytorch_lightning as pl
import soundfile as sf
import torch
import torchaudio
from datasets import Dataset as HFDataset, Audio
from torch.utils.data import DataLoader, Dataset
from torchmetrics.text import WordErrorRate, CharErrorRate
from transformers import WhisperProcessor, WhisperForConditionalGeneration


In [2]:
class Config:
    MODEL = "openai/whisper-small"
    SAMPLE_RATE = 16000
    BATCH_TRAIN = 8
    BATCH_VAL = 4
    EPOCHS = 6
    LR = 1e-5
    EXCLUDE_IDS = {
        'toronto_7', 'toronto_9', 'toronto_21', 'toronto_27', 'toronto_37', 'toronto_42', 'toronto_43',
        'toronto_46', 'toronto_54', 'toronto_58', 'toronto_62', 'toronto_67', 'toronto_81', 'toronto_89',
        'toronto_123', 'toronto_134', 'toronto_135', 'toronto_148', 'toronto_156', 'toronto_157', 'toronto_166'
    }
    WAV_DIR = r"C:\\Users\\Samurai\\dataset\\data"

In [3]:
# Utils
def parse_transcript_from_filename(filename: str) -> str:
    base = os.path.splitext(os.path.basename(filename))[0]
    parts = base.split("_")
    return f"Say the word {parts[1]}" if len(parts) > 1 else ""
def extract_transcript(file_name):
    parts = os.path.splitext(os.path.basename(file_name))[0].split("_")
    return f"Say the word {parts[1]}" if len(parts) > 1 else ""

def is_valid_audio(path):
    try:
        torchaudio.load(path)
        return True
    except Exception as e:
        print(f"⛔ torchaudio failed on {path}: {e}")
        return False

def load_data_entries(audio_dir, test_ids, sample_rate):
    data = []
    for root, _, files in os.walk(audio_dir):
        print(f"Found files in folder: {files}")
        for file in files:
            if file.endswith(".wav"):
                path = os.path.join(root, file)
                if os.path.exists(path) and is_valid_audio(path):
                    uid = os.path.splitext(file)[0]
                    data.append({"id": uid, "path": path, "text": extract_transcript(file)})
    train, val, test = [], [], []
    for i, d in enumerate(data):
        if d["id"] in test_ids:
            test.append(d)
        elif i % 10 == 0:
            val.append(d)
        else:
            train.append(d)
    return (
        HFDataset.from_list(train),
        HFDataset.from_list(val),
        HFDataset.from_list(test),
    )

In [4]:
def get_dataloader(hf_dataset, processor, batch_size, sample_rate, shuffle=False):
    def preprocess(batch):
        waveform, sr = torchaudio.load(batch["path"])
        if sr != sample_rate:
            waveform = torchaudio.functional.resample(waveform, sr, sample_rate)
        waveform = waveform.mean(0).numpy()
        inputs = processor.feature_extractor(waveform, sampling_rate=sample_rate, return_tensors="pt").input_features.squeeze(0)
        labels = processor.tokenizer(batch["text"], return_tensors="pt").input_ids.squeeze(0)
        return {"input_features": inputs, "labels": labels}

    class CustomDataset(Dataset):
        def __init__(self):
            self.data = hf_dataset
        def __len__(self):
            return len(self.data)
        def __getitem__(self, idx):
            return preprocess(self.data[idx])

    def collate_fn(batch):
        inputs = [b["input_features"] for b in batch]
        labels = [b["labels"] for b in batch]
        x = torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True)
        y = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=processor.tokenizer.pad_token_id)
        y[y == processor.tokenizer.pad_token_id] = -100
        return {"input_features": x, "labels": y}

    return DataLoader(CustomDataset(), batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)

In [5]:
class WhisperTrainer(pl.LightningModule):
    def __init__(self, model, processor, lr):
        super().__init__()
        self.model = model
        self.processor = processor
        self.lr = lr
        self.metrics = torch.nn.ModuleDict({
            "wer": WordErrorRate(),
            "cer": CharErrorRate()
        })

    def forward(self, input_features, labels=None):
        return self.model(input_features=input_features, labels=labels)

    def training_step(self, batch, _):
        out = self(**batch)
        self.log("train_loss", out.loss, prog_bar=True)
        return out.loss

    def validation_step(self, batch, _):
        self.model.generate(batch["input_features"], max_new_tokens=100)
        ref = batch["labels"].clone()
        ref[ref == -100] = self.processor.tokenizer.pad_token_id

    def test_step(self, batch, _):
        self.model.generate(batch["input_features"], max_new_tokens=100)
        ref = batch["labels"].clone()
        ref[ref == -100] = self.processor.tokenizer.pad_token_id

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr)

In [None]:
def main():
    cfg = Config()
    torch.set_float32_matmul_precision("high")

    processor = WhisperProcessor.from_pretrained(cfg.MODEL)
    model = WhisperForConditionalGeneration.from_pretrained(cfg.MODEL)
    model.gradient_checkpointing_enable()

    train_set, val_set, test_set = load_data_entries(cfg.WAV_DIR, cfg.EXCLUDE_IDS, cfg.SAMPLE_RATE)
    print(f"Train: {len(train_set)}, Val: {len(val_set)}, Test: {len(test_set)}")
    print(cfg.WAV_DIR)
    print(os.path.exists(cfg.WAV_DIR))

    train_loader = get_dataloader(train_set, processor, cfg.BATCH_TRAIN, cfg.SAMPLE_RATE, shuffle=True)
    val_loader = get_dataloader(val_set, processor, cfg.BATCH_VAL, cfg.SAMPLE_RATE)
    test_loader = get_dataloader(test_set, processor, cfg.BATCH_VAL, cfg.SAMPLE_RATE)

    pl_model = WhisperTrainer(model=model, processor=processor, lr=cfg.LR)

    trainer = pl.Trainer(
        accelerator="gpu",
        devices=1,
        precision=32,
        max_epochs=cfg.EPOCHS,
        gradient_clip_val=1.0,
        log_every_n_steps=10
    )

    trainer.fit(pl_model, train_dataloaders=train_loader, val_dataloaders=val_loader)
    trainer.test(pl_model, dataloaders=test_loader)

if __name__ == "__main__":
    main()

Found files in folder: ['toronto_11_0.wav', 'toronto_11_1.wav', 'toronto_11_10.wav', 'toronto_11_100.wav', 'toronto_11_101.wav', 'toronto_11_102.wav', 'toronto_11_103.wav', 'toronto_11_104.wav', 'toronto_11_105.wav', 'toronto_11_106.wav', 'toronto_11_107.wav', 'toronto_11_108.wav', 'toronto_11_109.wav', 'toronto_11_11.wav', 'toronto_11_110.wav', 'toronto_11_111.wav', 'toronto_11_112.wav', 'toronto_11_113.wav', 'toronto_11_114.wav', 'toronto_11_115.wav', 'toronto_11_116.wav', 'toronto_11_117.wav', 'toronto_11_118.wav', 'toronto_11_119.wav', 'toronto_11_12.wav', 'toronto_11_120.wav', 'toronto_11_121.wav', 'toronto_11_122.wav', 'toronto_11_123.wav', 'toronto_11_124.wav', 'toronto_11_125.wav', 'toronto_11_126.wav', 'toronto_11_127.wav', 'toronto_11_128.wav', 'toronto_11_129.wav', 'toronto_11_13.wav', 'toronto_11_130.wav', 'toronto_11_131.wav', 'toronto_11_132.wav', 'toronto_11_133.wav', 'toronto_11_134.wav', 'toronto_11_135.wav', 'toronto_11_136.wav', 'toronto_11_137.wav', 'toronto_11_138.

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Train: 4959, Val: 552, Test: 0
C:\\Users\\Samurai\\dataset\\data
True



  | Name    | Type                            | Params | Mode 
--------------------------------------------------------------------
0 | model   | WhisperForConditionalGeneration | 241 M  | eval 
1 | metrics | ModuleDict                      | 0      | train
--------------------------------------------------------------------
240 M     Trainable params
1.2 M     Non-trainable params
241 M     Total params
966.940   Total estimated model params size (MB)
3         Modules in train mode
350       Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

C:\Users\Samurai\miniconda3\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
C:\Users\Samurai\miniconda3\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:42

Training: |          | 0/? [00:00<?, ?it/s]

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
