<a href="https://colab.research.google.com/github/OhadRubin/audio/blob/master/fine_tune_whisper.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-Tune Whisper For Werewolf

<img src="https://raw.githubusercontent.com/sanchit-gandhi/notebooks/main/whisper_architecture.svg" alt="Trulli" style="width:100%">

## Prepare Environment

In [29]:
!pip install --upgrade --quiet pip
!pip install --upgrade --quiet datasets[audio] transformers accelerate evaluate jiwer tensorboard gradio

In [None]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## Load Dataset

In [30]:
SAMPLING_RATE = 16000
MAX_DURATION = 30

In [32]:
from datasets import load_dataset, Audio, DatasetDict

werewolf_data = DatasetDict()

werewolf_data["train"] = load_dataset("iohadrubin/werewolf_dialogue_data_10sec", split="train+validation")
werewolf_data["test"] = load_dataset("iohadrubin/werewolf_dialogue_data_10sec", split="test")

werewolf_data = werewolf_data.cast_column("audio", Audio(sampling_rate=SAMPLING_RATE))
werewolf_data = werewolf_data.filter(lambda x: x["end"] - x["start"] <= MAX_DURATION)

print(werewolf_data)

(…)-00000-of-00014-c6df01c2da117cd1.parquet:   0%|          | 0.00/441M [00:00<?, ?B/s]

(…)-00001-of-00014-ccd91d2ed48e4e87.parquet:   0%|          | 0.00/435M [00:00<?, ?B/s]

(…)-00002-of-00014-e0e6b0000eedceb4.parquet:   0%|          | 0.00/402M [00:00<?, ?B/s]

(…)-00003-of-00014-85d4a47ac309edc9.parquet:   0%|          | 0.00/427M [00:00<?, ?B/s]

(…)-00004-of-00014-ac10f74d0c33a6e0.parquet:   0%|          | 0.00/437M [00:00<?, ?B/s]

(…)-00005-of-00014-be20ab528f0f3002.parquet:   0%|          | 0.00/453M [00:00<?, ?B/s]

(…)-00006-of-00014-69034c722d260ec0.parquet:   0%|          | 0.00/443M [00:00<?, ?B/s]

(…)-00007-of-00014-0d411041c6fb4863.parquet:   0%|          | 0.00/445M [00:00<?, ?B/s]

(…)-00008-of-00014-7b142f54a9bb142f.parquet:   0%|          | 0.00/438M [00:00<?, ?B/s]

(…)-00009-of-00014-603088eeb6352a9b.parquet:   0%|          | 0.00/427M [00:00<?, ?B/s]

(…)-00010-of-00014-97c2e82d30fa973f.parquet:   0%|          | 0.00/447M [00:00<?, ?B/s]

(…)-00011-of-00014-4bf6d60b503a6606.parquet:   0%|          | 0.00/461M [00:00<?, ?B/s]

(…)-00012-of-00014-27f5cc18d436c30e.parquet:   0%|          | 0.00/419M [00:00<?, ?B/s]

(…)-00013-of-00014-f90ad44683420d6c.parquet:   0%|          | 0.00/437M [00:00<?, ?B/s]

(…)-00000-of-00004-6d15a1b1bed2ac20.parquet:   0%|          | 0.00/445M [00:00<?, ?B/s]

(…)-00001-of-00004-98ccb7b58f2a251f.parquet:   0%|          | 0.00/431M [00:00<?, ?B/s]

(…)-00002-of-00004-f5ac33ea5fd1c3e7.parquet:   0%|          | 0.00/434M [00:00<?, ?B/s]

(…)-00003-of-00004-c09e5a9e71ffcfa7.parquet:   0%|          | 0.00/419M [00:00<?, ?B/s]

(…)-00000-of-00003-1081b98bc3c75256.parquet:   0%|          | 0.00/397M [00:00<?, ?B/s]

(…)-00001-of-00003-d989cdf59ecdbeca.parquet:   0%|          | 0.00/326M [00:00<?, ?B/s]

(…)-00002-of-00003-1beb6058afe441df.parquet:   0%|          | 0.00/389M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/12746 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3602 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/2473 [00:00<?, ? examples/s]

Filter:   0%|          | 0/15219 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3602 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['audio', 'dialogue', 'start', 'end', 'idx', 'Game_ID', 'file_name', 'video_name', 'startRoles', 'startTime', 'endRoles', 'playerNames'],
        num_rows: 15098
    })
    test: Dataset({
        features: ['audio', 'dialogue', 'start', 'end', 'idx', 'Game_ID', 'file_name', 'video_name', 'startRoles', 'startTime', 'endRoles', 'playerNames'],
        num_rows: 3579
    })
})


In [33]:
print(werewolf_data)

DatasetDict({
    train: Dataset({
        features: ['audio', 'dialogue', 'start', 'end', 'idx', 'Game_ID', 'file_name', 'video_name', 'startRoles', 'startTime', 'endRoles', 'playerNames'],
        num_rows: 15098
    })
    test: Dataset({
        features: ['audio', 'dialogue', 'start', 'end', 'idx', 'Game_ID', 'file_name', 'video_name', 'startRoles', 'startTime', 'endRoles', 'playerNames'],
        num_rows: 3579
    })
})


## Prepare Feature Extractor, Tokenizer and Data

### Load WhisperFeatureExtractor

In [34]:
from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")

preprocessor_config.json:   0%|          | 0.00/185k [00:00<?, ?B/s]

### Load WhisperTokenizer

In [35]:
from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small")

### Combine To Create A WhisperProcessor

In [36]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-small")

# Prepare Data

In [94]:
tokenizer.bos_token

'<|endoftext|>'

In [143]:
import numpy as np

def _process_sample(sample, seq_length, tokenizer):
    tokens = tokenizer.encode(sample['prompt'] + sample['completion'])
    truncated = False
    if len(tokens) > seq_length:
        tokens = tokens[:seq_length]
        truncated = True
    tokens = [tokenizer.bos_token_id] + tokens + [tokenizer.eos_token_id]
    # tokens = tokens + [tokenizer.eos_token_id]
    prompt_len = len(tokenizer.encode(sample['prompt'])) + 1  # add bos token
    loss_masks = ([0.0] * prompt_len) + ([1.0] * (len(tokens) - prompt_len))
    # trunacte and pad everything out
    if len(tokens) > seq_length:
        tokens = tokens[:seq_length]
        loss_masks = loss_masks[:seq_length]
    # before padding, account for shifting
    input_tokens = tokens[:-1]
    loss_masks = loss_masks[1:]
    target_tokens = tokens[1:]
    attention_mask = [1] * len(input_tokens) + [0] * (seq_length - len(input_tokens))
    input_tokens = input_tokens + [tokenizer.pad_token_id] * (seq_length - len(input_tokens))
    target_tokens = target_tokens + [tokenizer.pad_token_id] * (seq_length - len(target_tokens))
    loss_masks = loss_masks + [0.0] * (seq_length - len(loss_masks))
    return {
            "input_tokens": np.array(input_tokens, dtype=np.int32),
            "target_tokens": np.array(target_tokens, dtype=np.int32),
            "loss_masks": np.array(loss_masks, dtype=np.float32),
            "attention_mask": np.array(attention_mask, dtype=np.int32),
            "truncated": truncated,
        }


In [None]:
dataset = werewolf_data["train"]

In [None]:
  # prev_before = '<|startofprev|>'
  # current_before = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
  # current_after = "<|endoftext|>"

  # prev_text = 'Hello and'
  # current_text = 'welcome. My name is...'

  # prompt_and_text = prev_before + prev_text + current_before + current_text + current_after
  # prompt_and_text

In [73]:
sample = dataset.select(range(10))

In [129]:
tokenizer.decode(tokenizer.encode(target)[-3:])

'No Strategy<|endoftext|>'

In [156]:

all_text_len, promot_len

(543, 519)

In [186]:


def get_prompt(batch):
  dialogue = "\n".join(f"{x['speaker']}: {x['utterance']}" for x in batch['dialogue'])
  prompt = f"""Given the following dialogue and audio, assign the last utterance one or more of the following tags (delimited by commas):

'Accusation', 'Defense', 'Evidence', 'Identity Declaration', 'Interrogation', 'No Strategy'

```
{dialogue}
```

Reminder: Assign one or more of the following tags (delimited by commas): 'Accusation', 'Defense', 'Evidence', 'Identity Declaration', 'Interrogation', 'No Strategy'.

Assignment:
"""
  return prompt

def tokenize_prompt(batch):
  text=get_prompt(batch)
  target = batch['dialogue'][-1]["target"]

  input_ids = np.array(tokenizer.encode(text + target))
  all_text = tokenizer.decode(input_ids)
  print(f"{all_text[-100:]=}")
  print("--"*10)
  all_text_len = len(input_ids)

  prompt_input_ids = tokenizer.encode(text)[:-1]
  prompt = tokenizer.decode(prompt_input_ids)
  print(f"{prompt[-100:]=}")
  print("--"*10)
  promot_len = len(prompt_input_ids)
  # tokenizer.encode(target)
  labels = np.array(input_ids)
  labels[:promot_len] = 0
  print(tokenizer.decode(labels[promot_len:]))
  print("--"*10)
  return input_ids, labels
  # tokenizer.decode(labels)

In [187]:
tokenizer.decode([198, 20892, 41134,    25,   883, 40915, 50257])

'\nAssignment: No Strategy<|endoftext|>'

In [195]:
i=55

In [202]:
_ = tokenize_prompt(dataset[i])
i=i+1

all_text[-100:]="ence', 'Identity Declaration', 'Interrogation', 'No Strategy'.\n\nAssignment:\nNo Strategy<|endoftext|>"
--------------------
prompt[-100:]="ation', 'Defense', 'Evidence', 'Identity Declaration', 'Interrogation', 'No Strategy'.\n\nAssignment:\n"
--------------------
No Strategy<|endoftext|>
--------------------


In [None]:
prompt_input_ids

In [86]:
text

"Given the following dialogue and audio, assign the last utterance one or more of the following tags (delimited by commas):\n\n'Accusation', 'Defense', 'Evidence', 'Identity Declaration', 'Interrogation', 'No Strategy'\n\n```\nMitchell: What is that?\nJames: Can't find my ginger ale with my eyes closed.\n```\n\nReminder: Assign one or more of the following tags (delimited by commas): 'Accusation', 'Defense', 'Evidence', 'Identity Declaration', 'Interrogation', 'No Strategy'.\n\nAssignment:"

In [145]:
tok_batch = tokenize_prompt(sample[0], seq_length=1000)

In [146]:

tokenizer.decode(tok_batch["target_tokens"][tok_batch["loss_masks"].astype(bool)])


' Strategy<|endoftext|><|endoftext|>'

In [139]:
# tokenizer.convert_ids_to_tokens(tokenizer.encode(text))
tokenizer.decode(tok_batch["input_tokens"])

"<|startoftranscript|><|notimestamps|>Given the following dialogue and audio, assign the last utterance one or more of the following tags (delimited by commas):\n\n'Accusation', 'Defense', 'Evidence', 'Identity Declaration', 'Interrogation', 'No Strategy'\n\n```\nMitchell: What is that?\nJames: Can't find my ginger ale with my eyes closed.\n```\n\nReminder: Assign one or more of the following tags (delimited by commas): 'Accusation', 'Defense', 'Evidence', 'Identity Declaration', 'Interrogation', 'No Strategy'.\n\nAssignment:No Strategy<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|e

In [None]:
tok_batch

In [125]:

messages = [

          {"role":"user","content":text},
          {"role":"assistant","content":target},
            ]
el = dict(messages=messages)
process_sample_chatml(el, tokenizer, "messages", 1000, chat_template)

[50363, 50258, 50363, 38, 5709, 264, 3480, 10221, 293, 6278, 11, 6269, 264, 1036, 17567, 719, 472, 420, 544, 295, 264, 3480, 18632, 522, 18105, 332, 1226, 538, 800, 296, 4507, 198, 198, 6, 43705, 1149, 399, 6098, 922, 35, 5666, 1288, 6098, 922, 36, 85, 2778, 6098, 922, 42739, 317, 507, 40844, 6098, 922, 13406, 6675, 399, 6098, 922, 4540, 40915, 6, 198, 198, 63, 63, 63, 198, 44, 1549, 898, 25, 708, 307, 300, 30, 198, 32263, 25, 1664, 380, 915, 452, 14966, 6775, 365, 452, 2575, 5395, 13, 198, 63, 63, 63, 198, 198, 39170, 5669, 25, 6281, 788, 472, 420, 544, 295, 264, 3480, 18632, 522, 18105, 332, 1226, 538, 800, 296, 4507, 922, 43705, 1149, 399, 6098, 922, 35, 5666, 1288, 6098, 922, 36, 85, 2778, 6098, 922, 42739, 317, 507, 40844, 6098, 922, 13406, 6675, 399, 6098, 922, 4540, 40915, 5004, 198, 198, 20892, 41134, 25, 4540, 40915, 50257, 50257]
[50363, 50258, 50363, 38, 5709, 264, 3480, 10221, 293, 6278, 11, 6269, 264, 1036, 17567, 719, 472, 420, 544, 295, 264, 3480, 18632, 522, 18105, 332,

UnboundLocalError: cannot access local variable 'input_tokens' where it is not associated with a value

In [82]:
_process_sample()

'No Strategy'

In [37]:


def prepare_dataset(batch):
  audio = batch["audio"]

  batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

  batch["prompt"] = get_prompt(batch)
  batch["target"] = batch['dialogue'][-1]["target"]
  batch["sentence"] = batch["prompt"] + batch["target"]

  # prompt_mask = [0 for _ in tokenizer(batch["prompt"]).input_ids]
  # batch["labels"] = prompt_mask + tokenizer(batch["target"]).input_ids

  batch["labels"] = tokenizer(batch["sentence"]).input_ids

  return batch

In [38]:
werewolf_data = werewolf_data.map(prepare_dataset, num_proc=4)

Map (num_proc=4):   0%|          | 0/15098 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1143 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1052 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1051 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1043 > 1024). Running this sequence through the model will result in indexing errors


TypeError: can only concatenate str (not "NoneType") to str

## Training and Evaluation

### Load a Pre-Trained Checkpoint

In [None]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

### Define a Data Collator

In [None]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [None]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

### Evaluation Metrics

In [None]:
import evaluate

metric = evaluate.load("wer")

In [None]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

### Define the Training Configuration

In [None]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-small-werewolf",
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=4000,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=True,
)

**Note**: if one does not want to upload the model checkpoints to the Hub,
set `push_to_hub=False`.

We can forward the training arguments to the 🤗 Trainer along with our model,
dataset, data collator and `compute_metrics` function:

In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=common_voice["train"],
    eval_dataset=common_voice["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

We'll save the processor object once before starting training. Since the processor is not trainable, it won't change over the course of training:

In [None]:
processor.save_pretrained(training_args.output_dir)

### Training

Training will take approximately 5-10 hours depending on your GPU or the one
allocated to this Google Colab. If using this Google Colab directly to
fine-tune a Whisper model, you should make sure that training isn't
interrupted due to inactivity. A simple workaround to prevent this is
to paste the following code into the console of this tab (_right mouse click_
-> _inspect_ -> _Console tab_ -> _insert code_).

```javascript
function ConnectButton(){
    console.log("Connect pushed");
    document.querySelector("#top-toolbar > colab-connect-button").shadowRoot.querySelector("#connect").click()
}
setInterval(ConnectButton, 60000);
```

The peak GPU memory for the given training configuration is approximately 15.8GB.
Depending on the GPU allocated to the Google Colab, it is possible that you will encounter a CUDA `"out-of-memory"` error when you launch training.
In this case, you can reduce the `per_device_train_batch_size` incrementally by factors of 2
and employ [`gradient_accumulation_steps`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments.gradient_accumulation_steps)
to compensate.

To launch training, simply execute:

In [None]:
trainer.train()

Our best WER is 32.0% - not bad for 8h of training data! We can make our model more accessible on the Hub with appropriate tags and README information.
You can change these values to match your dataset, language and model
name accordingly:

In [None]:
kwargs = {
    "dataset_tags": "mozilla-foundation/common_voice_11_0",
    "dataset": "Common Voice 11.0",  # a 'pretty' name for the training dataset
    "dataset_args": "config: hi, split: test",
    "language": "hi",
    "model_name": "Whisper Small Hi - Sanchit Gandhi",  # a 'pretty' name for our model
    "finetuned_from": "openai/whisper-small",
    "tasks": "automatic-speech-recognition",
}

The training results can now be uploaded to the Hub. To do so, execute the `push_to_hub` command and save the preprocessor object we created:

In [None]:
trainer.push_to_hub(**kwargs)

## Building a Demo

Now that we've fine-tuned our model we can build a demo to show
off its ASR capabilities! We'll make use of 🤗 Transformers
`pipeline`, which will take care of the entire ASR pipeline,
right from pre-processing the audio inputs to decoding the
model predictions.

Running the example below will generate a Gradio demo where we
can record speech through the microphone of our computer and input it to
our fine-tuned Whisper model to transcribe the corresponding text:

In [None]:
from transformers import pipeline
import gradio as gr

pipe = pipeline(model="sanchit-gandhi/whisper-small-hi")  # change to "your-username/the-name-you-picked"

def transcribe(audio):
    text = pipe(audio)["text"]
    return text

iface = gr.Interface(
    fn=transcribe,
    inputs=gr.Audio(source="microphone", type="filepath"),
    outputs="text",
    title="Whisper Small Hindi",
    description="Realtime demo for Hindi speech recognition using a fine-tuned Whisper small model.",
)

iface.launch()

## Closing Remarks

In this blog, we covered a step-by-step guide on fine-tuning Whisper for multilingual ASR
using 🤗 Datasets, Transformers and the Hugging Face Hub. For more details on the Whisper model, the Common Voice dataset and the theory behind fine-tuning, refere to the accompanying [blog post](https://huggingface.co/blog/fine-tune-whisper). If you're interested in fine-tuning other
Transformers models, both for English and multilingual ASR, be sure to check out the
examples scripts at [examples/pytorch/speech-recognition](https://github.com/huggingface/transformers/tree/main/examples/pytorch/speech-recognition).