In [24]:
# Import libraries
import logging
import os
import sys
import json
import random
import numpy as np
from dataclasses import dataclass, field
from typing import Optional
from huggingface_hub import notebook_login

from datasets import load_dataset, DatasetDict
from transformers import (WhisperFeatureExtractor, 
                          WhisperTokenizer, 
                          WhisperProcessor,
                          WhisperModel,
                          WhisperForConditionalGeneration, 
                          Seq2SeqTrainingArguments, 
                          Seq2SeqTrainer, 
                          TrainerCallback, 
                          TrainingArguments, 
                          TrainerState, 
                          TrainerControl,
                          pipeline)
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from peft import (prepare_model_for_int8_training,
                  LoraConfig, 
                  PeftModel, 
                  LoraModel, 
                  LoraConfig, 
                  TaskType,
                  get_peft_model)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import check_min_version
from tqdm import tqdm
import re

from trainer_utils import AlignmentSeq2SeqTrainer
from data_utils import (DataCollatorSpeechSeq2SeqWithPadding, 
                        load_sd_qa_dataset, 
                        filter_data)
from eval_utils import (evaluate_asr,
                        get_mini_cv)
import csv
import pickle
import evaluate

In [5]:
# Setup 
!pip install -q transformers librosa datasets==2.14.6 evaluate jiwer gradio bitsandbytes==0.37 accelerate geomloss gradio torchaudio
!pip install -q git+https://github.com/huggingface/peft.git@main

In [6]:
class SavePeftCallback(TrainerCallback):
    def on_save(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")

        peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
        kwargs["model"].save_pretrained(peft_model_path)

        pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
        if os.path.exists(pytorch_model_path):
            os.remove(pytorch_model_path)

        # record the losses
        loss_file = os.path.join(args.output_dir, 'loss.csv')
        with open(loss_file, 'a') as f:
            writer = csv.writer(f)
            writer.writerow([state.global_step, state.log_history["loss"][-1]]) # iter, loss

        return control

In [8]:
# log in to huggingface to save model as you go
# notebook_login()
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'

# load whisper feature extractor, tokenizer, processor
model_path = "openai/whisper-base"
task = "transcribe"
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_path)
tokenizer = WhisperTokenizer.from_pretrained(model_path, task=task)
processor = WhisperProcessor.from_pretrained(model_path, task=task)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
 # load pre-trained model checkpoint
model = WhisperForConditionalGeneration.from_pretrained(model_path)
# model.hf_device_map = {" ":0}  # not super sure what to map to here
model.config.forced_decoder_ids = None  # no tokens forced for decoder outputs
model.config.suppress_tokens = []
model = model.to(device)

config.json:   0%|          | 0.00/1.98k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/290M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/3.81k [00:00<?, ?B/s]

In [10]:
# load data
target_dialect = 'usa'
source_dialect = 'ind_n'
sd_qa = filter_data(load_sd_qa_dataset(), source=source_dialect, target=target_dialect)

print(sd_qa['dev'][0])

Downloading readme:   0%|          | 0.00/810 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/376M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/366M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/382M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/373M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/382M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/386M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/377M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/386M [00:00<?, ?B/s]

Generating dev split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1031 [00:00<?, ? examples/s]

{'id': '-1008642825401516622', 'ind_n': {'path': None, 'array': array([ 0.00000000e+00, -3.05175781e-05, -3.05175781e-05, ...,
        3.96728516e-04,  2.13623047e-04,  6.10351562e-05]), 'sampling_rate': 16000}, 'usa': {'path': None, 'array': array([0.        , 0.        , 0.        , ..., 0.00201416, 0.00259399,
       0.00262451]), 'sampling_rate': 16000}}


In [None]:
# prepare data
def prepare_source_data(data):
    # compute log-Mel input features from audio arrays
    data["source_input_features"] = feature_extractor(data[source_dialect]["array"], sampling_rate=data[source_dialect]["sampling_rate"]).input_features[0]
    data["target_input_features"] = feature_extractor(data[target_dialect]["array"], sampling_rate=data[target_dialect]["sampling_rate"]).input_features[0]
    return data

In [None]:
# move to gpu
# run everything at once -> no for loop
def prepare_target_embeddings(data):
    # compute log-Mel input features from target audio array
    # batch_size = 128
    target_embeddings = []
    decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
    decoder_input_ids = decoder_input_ids.to(device)
    # for i in range(0, len(data["target_input_features"]), batch_size):
    input_features = torch.tensor(data["target_input_features"]).unsqueeze(0).to(device)
    print(input_features.shape)
    with torch.no_grad():
        outputs = model(input_features, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
    last_hidden_state = outputs.encoder_hidden_states[-1]
    target_embeddings = [embedding for embedding in last_hidden_state]
    data["target_embeddings"] = target_embeddings
    return data

In [None]:
sd_qa = sd_qa.map(prepare_source_data, desc="Extract features for source dialect"
                    ).map(prepare_target_embeddings, desc="Original hidden embeddings for target dialect")

In [None]:
# data_collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [None]:
#----------LORA PART------------
target_modules = ['k_proj', 'v_proj', 'q_proj', 'out_proj', 'fc1', 'fc2']
config = LoraConfig(r=32, # rank, adjust this
                lora_alpha=64, 
                target_modules = target_modules, 
                lora_dropout=0.05, 
                bias="none",
                task_type=TaskType.FEATURE_EXTRACTION,  # check this???
                )  
model = get_peft_model(model, config)
model.print_trainable_parameters()

In [None]:
# Define training configuration
training_args = Seq2SeqTrainingArguments(
    output_dir="azure-224n/test",  # change to a repo name of your choice
    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-3,
    warmup_steps=50,
    num_train_epochs=3,
    evaluation_strategy="steps",
    fp16=True,
    per_device_eval_batch_size=8,
    generation_max_length=128,
    logging_steps=100,
    max_steps=100, # only for testing purposes, remove this from your final run :)
    remove_unused_columns=False,  # required as the PeftModel forward doesn't have the signature of the wrapped model's forward
    label_names=["labels"],  # same reason as above
)

In [None]:
trainer = AlignmentSeq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=sd_qa['dev'],
    eval_dataset=None,
    data_collator=data_collator,
    tokenizer=processor.feature_extractor,
    callbacks=[SavePeftCallback],
)

In [None]:
trainer.train()
peft_model_id = "azure-224n/whisper-base-alignment"
model.push_to_hub(peft_model_id)

In [None]:
peft_config = PeftConfig.from_pretrained(peft_model_id)
model = WhisperForConditionalGeneration.from_pretrained(
    peft_config.base_model_name_or_path, load_in_8bit=True, device_map="auto"
)
model = PeftModel.from_pretrained(model, peft_model_id) # attaches the PEFT module to the Whisper model
model.config.use_cache = True

In [None]:
dataset = get_mini_cv()
metrics = evaluate_asr(model, processor, dataset, True)
print(metrics)