In [None]:
%%capture

!add-apt-repository -y ppa:jonathonf/ffmpeg-4
!apt update
!apt install -y ffmpeg

!pip uninstall -y transformers datasets
!pip install audiomentations
!pip install git+https://github.com/huggingface/datasets
!pip install git+https://github.com/huggingface/transformers
!pip install librosa soundfile
!pip install evaluate>=0.3.0
!pip install jiwer
!pip install gradio
!pip install more-itertools
!pip install wandb
!pip install bitsandbytes
!pip install accelerate -U
##more

In [None]:
%set_env WANDB_LOG_MODEL=True
%set_env WANDB_WATCH=all
%set_env WANDB_NOTEBOOK_NAME=whisper_base_en_sb.ipynb

In [None]:
from datasets import Dataset, IterableDatasetDict, load_dataset, interleave_datasets, Audio
import evaluate

import torch
import string
from dataclasses import dataclass
from typing import Any, Dict, List, Union

from transformers import WhisperForConditionalGeneration
from transformers import WhisperProcessor
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer
from transformers import WhisperTokenizer
from transformers import WhisperFeatureExtractor


import wandb
from IPython.display import clear_output
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift
import numpy as np
from huggingface_hub import notebook_login
from transformers import TrainerCallback
from transformers.integrations import WandbCallback
from transformers.trainer_pt_utils import IterableDatasetShard
from torch.utils.data import IterableDataset
from datasets import load_dataset, Audio
from pathlib import Path
import numpy as np
import holoviews as hv
import panel as pn
import tempfile
from bokeh.resources import INLINE
hv.extension("bokeh", logo=False)

from io import StringIO
import pandas as pd
import warnings
import jiwer


warnings.filterwarnings('ignore')

clear_output()
torch.cuda.is_available()

In [None]:
wandb.login()

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
!git clone https://github.com/sunbirdai/leb.git
!pip install -r leb/requirements.txt
!pip install datasets==2.16.1
!pip install mlflow
!pip install wandb

In [None]:
import yaml
import leb.dataset
from leb.utils import DataCollatorCTCWithPadding as dcwp
from datasets import Audio
from datasets import load_dataset, DatasetDict

In [None]:
# languages currently available in SALT multispeaker STT dataset
languages = {
    "english": "eng"
}

yaml_config = '''
common_source: &common_source
  type: speech
  language: [eng]
  preprocessing:
    - set_sample_rate:
        rate: 16_000

common_target: &common_target
  type: text
  language: [eng]
  preprocessing:
    - lower_case
    - clean_and_remove_punctuation

Wav2Vec2ForCTC_args:
    attention_dropout: 0.0
    hidden_dropout: 0.0
    feat_proj_dropout: 0.0
    layerdrop: 0.0
    ctc_loss_reduction: mean
    ignore_mismatched_sizes: True

train:
    huggingface_load:
        # - path: mozilla-foundation/common_voice_13_0
        #   split: train
        #   name: lg
        #   trust_remote_code: True
        - path: Sunbird/salt
          name: multispeaker-eng
          split: train
    source: *common_source

    target: *common_target
    shuffle: True
validation:
    huggingface_load:
        - path: Sunbird/salt
          name: multispeaker-eng
          split: dev
    source: *common_source
    target: *common_target

'''

config = yaml.safe_load(yaml_config)
train_ds = leb.dataset.create(config['train'])
valid_ds = leb.dataset.create(config['validation'])

In [None]:
language = 'eng'

In [None]:
yaml_config = f'''
pretrained_model: openai/whisper-base
pretrained_adapter: {language}
mlflow_experiment_name : stt-whisper-{language}
mlflow_run_name: {language}_from_pretrained

training_args:
    output_dir: stt
    per_device_train_batch_size: 24
    gradient_accumulation_steps: 2
    evaluation_strategy: steps
    max_steps: 1200
    gradient_checkpointing: True
    gradient_checkpointing_kwargs:
      use_reentrant: True
    fp16: True
    save_steps: 100
    eval_steps: 100
    logging_steps: 100
    learning_rate: 3.0e-4
    warmup_steps: 100
    save_total_limit: 2
    # push_to_hub: True
    load_best_model_at_end: True
    metric_for_best_model: loss
    greater_is_better: False
    weight_decay: 0.01

Wav2Vec2ForCTC_args:
    attention_dropout: 0.0
    hidden_dropout: 0.0
    feat_proj_dropout: 0.0
    layerdrop: 0.0
    ctc_loss_reduction: mean
    ignore_mismatched_sizes: True

train:
    huggingface_load:
        - path: Sunbird/salt
          name: multispeaker-{language}
          split: train
    source:
      type: speech
      language: [{language},eng]
      preprocessing:
        - set_sample_rate:
            rate: 16_000
        - augment_audio_noise:
            max_relative_amplitude: 0.5
    target:
      type: text
      language: [{language},eng]
      preprocessing:
        - lower_case
        - clean_and_remove_punctuation:
            allowed_punctuation: "'"
    shuffle: True
validation:
    huggingface_load:
        - path: Sunbird/salt
          name: multispeaker-{language}
          split: dev
        - path: Sunbird/salt
          name: multispeaker-eng
          split: dev
    source:
      type: speech
      language: [{language},eng]
      preprocessing:
        - set_sample_rate:
            rate: 16_000
    target:
      type: text
      language: [{language},eng]
      preprocessing:
        - lower_case
        - clean_and_remove_punctuation:
            allowed_punctuation: "'"
'''

config = yaml.safe_load(yaml_config)
train_ds = leb.dataset.create(config['train'])
valid_ds = leb.dataset.create(config['validation'])

In [None]:
config['pretrained_model']

In [None]:
feature_extractor = WhisperFeatureExtractor.from_pretrained(config['pretrained_model'])
tokenizer = WhisperTokenizer.from_pretrained(config['pretrained_model'], language="english", task="transcribe")

In [None]:
train_iterator = iter(train_ds)
example = next(train_iterator)

In [None]:
example.keys()

In [None]:
input_str = example["target"]

In [None]:
labels = tokenizer(input_str).input_ids
decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
decoded_str = tokenizer.decode(labels, skip_special_tokens=True)

print(f"Input:                 {input_str}")
print(f"Decoded w/ special:    {decoded_with_special}")
print(f"Decoded w/out special: {decoded_str}")
print(f"Are equal:             {input_str == decoded_str}")

In [None]:
processor = WhisperProcessor.from_pretrained(config['pretrained_model'], language=None, task="transcribe")


In [None]:
def prepare_dataset(example):
    # Extract the audio data from the 'source' key
    audio = example["source"]

    # Compute log-Mel input features from the audio array
    input_features = feature_extractor(audio, sampling_rate=16000).input_features[0]

    # Encode target text to label ids
    labels = tokenizer(example["target"]).input_ids

    # Create a new dictionary with the processed data
    processed_example = {
        "input_features": input_features,
        "labels": labels,
        "source.language": example["source.language"],
        "target.language": example["target.language"]
    }

    return processed_example

In [None]:
train_data = train_ds.map(prepare_dataset, remove_columns=["source", "target"])
val_data = valid_ds.map(prepare_dataset, remove_columns=["source", "target"])

In [None]:
train_iterator = iter(train_data)
example = next(train_iterator)


In [None]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch


In [None]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)


In [None]:

metric = evaluate.load("wer")


In [None]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}


In [None]:

model = WhisperForConditionalGeneration.from_pretrained(config['pretrained_model'])


In [None]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
# forced_decoder_ids = processor.tokenizer.get_decoder_prompt_ids(language="Swahili", task="transcribe")


def custom_generate(self, *args, **kwargs):
    kwargs["language"] = "en" # 'en', 'nl'

    return WhisperForConditionalGeneration.generate(self, *args, **kwargs)

model.generate = custom_generate.__get__(model, WhisperForConditionalGeneration)

In [None]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-base-sb-english",  # change to a repo name of your choice
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=1200,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=200,
    eval_steps=200,
    logging_steps=25,
    report_to=["wandb"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=True,
)


In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)


In [None]:
trainer.train()

In [None]:
tokenizer.push_to_hub("akera/whisper-base-sb-english")