# Prepare environment
Load libraries, etc.

In [5]:
%pip install -q transformers librosa datasets==2.14.6 evaluate jiwer gradio bitsandbytes==0.37 accelerate geomloss gradio torchaudio
%pip install -q git+https://github.com/huggingface/peft.git@main

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [None]:
from huggingface_hub import notebook_login

notebook_login() #huggingface-cli login workaround

In [2]:
from datasets import load_dataset, DatasetDict
from transformers import (WhisperFeatureExtractor, 
                          WhisperTokenizer, 
                          WhisperProcessor,
                          WhisperModel,
                          WhisperForConditionalGeneration, 
                          Seq2SeqTrainingArguments, 
                          Seq2SeqTrainer, 
                          TrainerCallback, 
                          TrainingArguments, 
                          TrainerState, 
                          TrainerControl)
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from geomloss import SamplesLoss
from peft import (prepare_model_for_int8_training,
                  LoraConfig, 
                  PeftModel, 
                  LoraModel, 
                  LoraConfig, 
                  TaskType,
                  get_peft_model)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
import re


# Load dataset

In [3]:
sd_qa = DatasetDict()

sd_qa["dev"] = load_dataset("WillHeld/SD-QA", split="dev", token=True)
sd_qa["test"] = load_dataset("WillHeld/SD-QA", split="test", token=True)

  table = cls._concat_blocks(blocks, axis=0)


In [4]:
saved_data = sd_qa
print(sd_qa)
print(sd_qa['dev'])

DatasetDict({
    dev: Dataset({
        features: ['id', 'aus', 'gbr', 'ind_n', 'ind_s', 'irl', 'kenya', 'nga', 'nzl', 'phl', 'usa', 'zaf', 'answers', 'question'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['id', 'aus', 'gbr', 'ind_n', 'ind_s', 'irl', 'kenya', 'nga', 'nzl', 'phl', 'usa', 'zaf', 'answers', 'question'],
        num_rows: 1031
    })
})
Dataset({
    features: ['id', 'aus', 'gbr', 'ind_n', 'ind_s', 'irl', 'kenya', 'nga', 'nzl', 'phl', 'usa', 'zaf', 'answers', 'question'],
    num_rows: 1000
})


# Prepare data

In [5]:
# select only target and source dialect
target_dialect = 'usa'
source_dialect = 'ind_n'
sd_qa = sd_qa.select_columns(['id', source_dialect, target_dialect, 'question'])
print(sd_qa['dev'][0])

{'id': '-1008642825401516622', 'ind_n': {'path': None, 'array': array([ 0.00000000e+00, -3.05175781e-05, -3.05175781e-05, ...,
        3.96728516e-04,  2.13623047e-04,  6.10351562e-05]), 'sampling_rate': 16000}, 'usa': {'path': None, 'array': array([0.        , 0.        , 0.        , ..., 0.00201416, 0.00259399,
       0.00262451]), 'sampling_rate': 16000}, 'question': None}


In [6]:
# load whisper feature extractor, tokenizer, processor
model_path = "openai/whisper-base"
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_path)
task = "transcribe"
tokenizer = WhisperTokenizer.from_pretrained(model_path, task=task)
processor = WhisperProcessor.from_pretrained(model_path, task=task)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
# define prepare dataset function to extract audio features
def prepare_dataset(batch):
    # compute log-Mel input features from input audio array
    batch["source_input_features"] = feature_extractor(batch[source_dialect]["array"], sampling_rate=batch[source_dialect]["sampling_rate"]).input_features[0]
    batch["target_input_features"] = feature_extractor(batch[target_dialect]["array"], sampling_rate=batch[target_dialect]["sampling_rate"]).input_features[0]
    return batch

In [8]:
# map feature extracter to sd_qa
# sd_qa = sd_qa.map(prepare_dataset, remove_columns=sd_qa.column_names["dev"], num_proc=2)
sd_qa = sd_qa.map(prepare_dataset, remove_columns=[source_dialect, target_dialect], num_proc=2)

Map (num_proc=2):   0%|          | 0/1000 [00:00<?, ? examples/s]

  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  table = cls._concat_blocks(blocks, axis=0)


Map (num_proc=2):   0%|          | 0/1031 [00:00<?, ? examples/s]

  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)


In [None]:
print(sd_qa['dev'])

# Training 

In [None]:
# Define a data collator
@dataclass
class DataCollatorSpeechSeq2Seq:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # convert source inputs to pytorch tensors
        source_input_features = [{"source_input_features": feature["source_input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(source_input_features, return_tensors="pt")

        # convert target inputs to pytorch tensors
        target_input_features = [{"target_input_features": feature["target_input_features"]} for feature in features]
        batch["target_input_features"] = self.processor.feature_extractor.pad(target_input_features, return_tensors="pt")["target_input_features"]

        return batch

# Initialize a data collator
data_collator = DataCollatorSpeechSeq2Seq(processor=processor)

In [10]:
sd_qa['dev']

Dataset({
    features: ['id', 'question', 'source_input_features', 'target_input_features'],
    num_rows: 1000
})

In [None]:
# Define evaluation metrics
sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2)

# Define a function to compute metrics
def compute_metrics(pred):
    loss = sinkhorn_loss(pred.predictions, pred.target)
    return {"loss": loss}

In [None]:
example = [saved_data['dev'][0]['aus']['array'], saved_data['dev'][0]['usa']['array']]


## Messing around and testing

In [None]:
sd_qa['dev'][0]['ind_n']

In [None]:
# example of generating last encoder hidden state
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")

example = [saved_data['dev'][0]['aus']['array'], saved_data['dev'][0]['usa']['array']]
example = feature_extractor(example, return_tensors="pt", sampling_rate=16000)
decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id

with torch.no_grad():
  outputs = model(example.input_features, decoder_input_ids=decoder_input_ids, output_hidden_states=True)

last_hidden_state = outputs.encoder_hidden_states[-1]
print(last_hidden_state)

# loss = sinkhorn_loss(last_hidden_state[0], last_hidden_state[1])
# print(loss)

# with torch.no_grad():
#   outputs = model.generate(example.input_features, output_hidden_states=True)

# decoded=processor.tokenizer.batch_decode(outputs, skip_special_tokens=True)
# print(decoded)


In [None]:
print(outputs)

## Training

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 
model = WhisperForConditionalGeneration.from_pretrained(model_path,load_in_8bit=True, device_map="auto")


In [None]:
# Load pre-trained checkpoint in 8b
model = WhisperForConditionalGeneration.from_pretrained(model_path,load_in_8bit=True, device_map="auto")

# Post-processing steps on the model
model = prepare_model_for_int8_training(model, output_embedding_layer_name="proj_out")

# # Make inputs trainable
# def make_inputs_require_grad(module, input, output):
#     output.requires_grad_(True)

# model.model.encoder.conv1.register_forward_hook(make_inputs_require_grad)

In [None]:
print(model)
def get_target_modules(model):
    model_modules = str(model.modules)
    pattern = r'\((\w+)\): Linear'
    linear_layer_names = re.findall(pattern, model_modules)

    names = []
    # Print the names of the Linear layers
    for name in linear_layer_names:
        names.append(name)
    target_modules = list(set(names))

In [None]:
# Apply LoRA to model, targeting all layers of encoder
target_modules = ['k_proj', 'v_proj', 'q_proj', 'out_proj', 'fc1', 'fc2']
config = LoraConfig(r=32, # rank, adjust this
                    lora_alpha=64, 
                    target_modules = target_modules, 
                    lora_dropout=0.05, 
                    bias="none",
                    task_type=TaskType.FEATURE_EXTRACTION,
                    ) # task_type= ????? 

model = get_peft_model(model, config)
model.print_trainable_parameters()

# Define training configuration
training_args = Seq2SeqTrainingArguments(
    output_dir="reach-vb/test",  # change to a repo name of your choice
    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-3,
    warmup_steps=50,
    num_train_epochs=3,
    evaluation_strategy="steps",
    fp16=True,
    per_device_eval_batch_size=8,
    generation_max_length=128,
    logging_steps=100,
#    max_steps=100, # only for testing purposes, remove this from your final run :)
    remove_unused_columns=False,  # required as the PeftModel forward doesn't have the signature of the wrapped model's forward
    label_names=["labels"],  # same reason as above
)

In [None]:
# Additional PEFT things
# This callback helps to save only the adapter weights and remove the base model weights.
class SavePeftModelCallback(TrainerCallback):
    def on_save(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")

        peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
        kwargs["model"].save_pretrained(peft_model_path)

        pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
        if os.path.exists(pytorch_model_path):
            os.remove(pytorch_model_path)
        return control


trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=sd_qa["dev"],
    eval_dataset=sd_qa["test"],
    data_collator=data_collator,
    tokenizer=processor.feature_extractor,
    callbacks=[SavePeftModelCallback],
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!

In [None]:
# Run train
trainer.train()

In [None]:
# Save to hub
peft_model_id = "azure-224n/whisper-base-100steps"
model.push_to_hub(peft_model_id)

# Evaluation

# Inference

## Demo

In [2]:
model_path = "openai/whisper-base"

In [1]:
import gradio as gr
from transformers import pipeline
import numpy as np

transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en")

def transcribe(stream, new_chunk):
    sr, y = new_chunk
    y = y.astype(np.float32)
    y /= np.max(np.abs(y))

    if stream is not None:
        stream = np.concatenate([stream, y])
    else:
        stream = y
    return stream, transcriber({"sampling_rate": sr, "raw": stream})["text"]


demo = gr.Interface(
    transcribe,
    ["state", gr.Audio(sources=["microphone"], streaming=True)],
    ["state", "text"],
    live=True,
)

demo.launch()

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


