<a href="https://colab.research.google.com/github/alonbebchuk/audio/blob/master/fine_tune_whisper.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-Tune Whisper For Werewolf

<img src="https://raw.githubusercontent.com/sanchit-gandhi/notebooks/main/whisper_architecture.svg" alt="Trulli" style="width:100%">

## Prepare Environment

In [None]:
!pip install --quiet datasets evaluate jiwer

In [None]:
from huggingface_hub import notebook_login

notebook_login()

## Consts

In [32]:
BATCH_SIZE = 50
BOS_LEN = 2
EOS_LEN = 1
MAX_DURATION = 30
MASK_ID = -100
MAX_LENGTH = 446
SAMPLING_RATE = 16000
WORD_ERROR_PENALTY = 100

## Load Model

In [33]:
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small")
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

## Load Dataset

In [34]:
from datasets import load_dataset, Audio, DatasetDict, Dataset

werewolf_data = DatasetDict()

werewolf_data["train"] = load_dataset("parquet", data_files=["https://huggingface.co/datasets/iohadrubin/werewolf_dialogue_data_10sec/resolve/main/data/train-00002-of-00014-e0e6b0000eedceb4.parquet"])["train"]
werewolf_data["test"] = load_dataset("parquet", data_files=["https://huggingface.co/datasets/iohadrubin/werewolf_dialogue_data_10sec/resolve/main/data/train-00009-of-00014-603088eeb6352a9b.parquet"])["train"]

werewolf_data = werewolf_data.filter(lambda x: x["end"] - x["start"] <= MAX_DURATION and x["dialogue"][-1]["target"] is not None)

## Prepare Dataset

In [35]:
prompt_prefix, prompt_suffix = """Given the following dialogue and audio, assign the last utterance one or more of the following tags (delimited by commas):
'Accusation', 'Defense', 'Evidence', 'Identity Declaration', 'Interrogation', 'No Strategy'

```
""", """
```

Reminder - Assign one or more of the following tags to the last utterance (delimited by commas):
'Accusation', 'Defense', 'Evidence', 'Identity Declaration', 'Interrogation', 'No Strategy'

Assignment:
"""

prompt_prefix_len = len(tokenizer.encode(prompt_prefix)) - EOS_LEN
max_prompt_len = MAX_LENGTH - prompt_prefix_len

In [36]:
import numpy as np

def prepare_decoder_input_ids_and_labels(sample):
  dialogue = "\n".join(f"{x['speaker']}: {x['utterance']}" for x in sample["dialogue"])
  target = sample["dialogue"][-1]["target"]
  text = prompt_prefix + dialogue + prompt_suffix + target

  input_ids = tokenizer.encode(text)
  if len(input_ids) > MAX_LENGTH:
    input_ids = input_ids[:prompt_prefix_len] + input_ids[-max_prompt_len:]

  decoder_input_ids = np.array(input_ids)
  sample["decoder_input_ids"] = decoder_input_ids

  labels = np.array(input_ids)
  target_len = len(tokenizer.encode(target)) - BOS_LEN
  labels[:-target_len] = MASK_ID
  sample["labels"] = labels

  return sample

def prepare_audio(batch):
  audio_arrays = [x["array"] for x in batch["audio"]]

  batch["input_features"] = feature_extractor(audio_arrays, sampling_rate=SAMPLING_RATE, return_tensors="pt").input_features

  return batch

In [37]:
werewolf_data = werewolf_data.map(prepare_decoder_input_ids_and_labels)
werewolf_data = werewolf_data.map(prepare_audio, batched=True, batch_size=BATCH_SIZE)

Map:   0%|          | 0/901 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1053 > 1024). Running this sequence through the model will result in indexing errors


Map:   0%|          | 0/896 [00:00<?, ? examples/s]

Map:   0%|          | 0/901 [00:00<?, ? examples/s]

Map:   0%|          | 0/896 [00:00<?, ? examples/s]

## Data Collator

In [38]:
from dataclasses import dataclass
from transformers import WhisperProcessor

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: WhisperProcessor
    decoder_start_token_id: int

    def __call__(self, features):
      batch = self.processor.feature_extractor.pad([{"input_features": feature["input_features"]} for feature in features], return_tensors="pt")

      decoder_input_ids_batch = self.processor.tokenizer.pad([{"input_ids": feature["decoder_input_ids"]} for feature in features], return_tensors="pt")
      decoder_input_ids = decoder_input_ids_batch["input_ids"].masked_fill(decoder_input_ids_batch.attention_mask.ne(1), MASK_ID)
      batch["decoder_input_ids"] = decoder_input_ids

      labels_batch = self.processor.tokenizer.pad([{"input_ids": feature["labels"]} for feature in features], return_tensors="pt")
      labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), MASK_ID)
      if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
          labels = labels[:, 1:]
      batch["labels"] = labels

      return batch

In [39]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

## Compute Metrics

In [40]:
import evaluate

metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    label_ids[label_ids == MASK_ID] = tokenizer.pad_token_id

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = WORD_ERROR_PENALTY * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

## Training Arguments

In [41]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-small-werewolf",
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=4000,
    gradient_checkpointing=True,
    fp16=True,
    eval_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=True,
)

## Trainer

In [42]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=werewolf_data["train"],
    eval_dataset=werewolf_data["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

### Training

In [43]:
trainer.train()

IndexError: index out of range in self

## New Section

Our best WER is 32.0% - not bad for 8h of training data! We can make our model more accessible on the Hub with appropriate tags and README information.
You can change these values to match your dataset, language and model
name accordingly:

In [None]:
kwargs = {
    "dataset_tags": "mozilla-foundation/common_voice_11_0",
    "dataset": "Common Voice 11.0",  # a 'pretty' name for the training dataset
    "dataset_args": "config: hi, split: test",
    "language": "hi",
    "model_name": "Whisper Small Hi - Sanchit Gandhi",  # a 'pretty' name for our model
    "finetuned_from": "openai/whisper-small",
    "tasks": "automatic-speech-recognition",
}

The training results can now be uploaded to the Hub. To do so, execute the `push_to_hub` command and save the preprocessor object we created:

In [None]:
trainer.push_to_hub(**kwargs)

## Building a Demo

Now that we've fine-tuned our model we can build a demo to show
off its ASR capabilities! We'll make use of 🤗 Transformers
`pipeline`, which will take care of the entire ASR pipeline,
right from pre-processing the audio inputs to decoding the
model predictions.

Running the example below will generate a Gradio demo where we
can record speech through the microphone of our computer and input it to
our fine-tuned Whisper model to transcribe the corresponding text:

In [None]:
from transformers import pipeline
import gradio as gr

pipe = pipeline(model="sanchit-gandhi/whisper-small-hi")  # change to "your-username/the-name-you-picked"

def transcribe(audio):
    text = pipe(audio)["text"]
    return text

iface = gr.Interface(
    fn=transcribe,
    inputs=gr.Audio(source="microphone", type="filepath"),
    outputs="text",
    title="Whisper Small Hindi",
    description="Realtime demo for Hindi speech recognition using a fine-tuned Whisper small model.",
)

iface.launch()