In [2]:
!pip install peft
!pip install bitsandbytes
!pip install accelerate

Collecting peft
  Downloading peft-0.11.1-py3-none-any.whl.metadata (13 kB)
Downloading peft-0.11.1-py3-none-any.whl (251 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m251.6/251.6 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: peft
Successfully installed peft-0.11.1
Collecting bitsandbytes
  Downloading bitsandbytes-0.43.1-py3-none-manylinux_2_24_x86_64.whl.metadata (2.2 kB)
Downloading bitsandbytes-0.43.1-py3-none-manylinux_2_24_x86_64.whl (119.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.8/119.8 MB[0m [31m13.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.43.1


In [64]:
import os
import gc
import torch
import peft
import torchaudio
import accelerate
import numpy as np
import pandas as pd
from typing import Any
import bitsandbytes as bnb
from dataclasses import dataclass
from torch.utils.data import Dataset
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
from transformers import WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer, WhisperTokenizer, WhisperProcessor, DataCollatorForSeq2Seq, BitsAndBytesConfig

processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3", language='en', task = "transcribe")
whisper_tokenizer = WhisperTokenizer.from_pretrained('openai/whisper-large-v3', language = 'en', task = "transcribe")
train_dataset = torchaudio.datasets.LIBRISPEECH('/kaggle/input/librispeech-clean', url='train-clean-360', download=False)
#bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", forced_bos_token_id=0)
#bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large", language='en')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [65]:
class training_dataset(Dataset) :
    def __init__(self, dataset) :
        super().__init__()
        self.data = dataset
    def __len__(self) :
        return self.data.__len__()
    
    def __getitem__(self, idx) :
        data = processor(self.data[idx][0].numpy(), sampling_rate = 16000, truncation=True, padding_size=3000, return_tensors='pt', return_attention_mask=True)
        data['labels'] = whisper_tokenizer(self.data[idx][2], padding='longest', truncation=True, max_length=100, return_tensors='pt').input_ids
        return data
    
dataset = training_dataset(train_dataset)

In [53]:
class CustomWhisperModel(WhisperForConditionalGeneration):
    def __init__(self, model_name) :
        super().__init__(model_name)
        
    def forward(self, input_ids=None,
                    input_features=None,
                    inputs_embeds = None,
                    attention_mask=None,
                    decoder_input_ids=None,
                    decoder_attention_mask=None,
                    labels=None,
                    decoder_inputs_embeds = None,
                    output_attentions=None,
                    output_hidden_states=None,
                    return_dict=None,
                    output_attention= None,
                    task_type =None):
        
        inputs = {"input_features": input_ids, 'decoder_input_ids' : decoder_input_ids, 'attention_mask' : attention_mask, 'decoder_attention_mask' : decoder_attention_mask,
                 'labels' : labels, 'return_dict' : return_dict, 'output_hidden_states' : output_hidden_states, 'output_attentions' : output_attention}
        if input_features != None : 
            inputs['input_features'] = input_features    

        outputs = super().forward(**inputs)
        return outputs
    
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
custom_model = CustomWhisperModel.from_pretrained("openai/whisper-large-v3", quantization_config = bnb_config, device_map='auto')
q_model = prepare_model_for_kbit_training(custom_model)

peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, target_modules=["q_proj", "v_proj"], r=32, lora_alpha=64, lora_dropout=0.1)
final_model = get_peft_model(q_model, peft_config)

In [29]:
def apply_masking(text, mask_rate=0.07):
    mask = torch.rand(text.shape) > mask_rate
    text = text * mask
    return text

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features):
        input_features = [{"input_features": apply_masking(feature["input_features"].squeeze(0))} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        label_features = [{"input_ids": feature["labels"].squeeze(0)} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        return batch
    
collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [30]:
def total_params(model):
    return sum(p.numel() for p in model.parameters())

print(f'Memory used by model: {round(final_model.get_memory_footprint()/1024/1024/1024, 2)} GB')
print(f'total number of parameters is {total_params(final_model)}')
final_model.print_trainable_parameters()

Memory used by model: 1.71 GB
total number of parameters is 1559219200
trainable params: 15,728,640 || all params: 1,559,219,200 || trainable%: 1.0088


In [51]:
from transformers import TrainerState, TrainerCallback, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR

class SavePeftModelCallback(TrainerCallback):
    def on_save(
        self,
        args: Seq2SeqTrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")

        peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
        kwargs["model"].save_pretrained(peft_model_path)

        pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
        if os.path.exists(pytorch_model_path):
            os.remove(pytorch_model_path)
        return control

In [None]:
import warnings
warnings.filterwarnings("ignore")
training_args = Seq2SeqTrainingArguments(
    output_dir="/kaggle/working/",
    report_to="none",
    per_device_train_batch_size=32,
    gradient_accumulation_steps=1, 
    learning_rate=35e-6,
    warmup_steps=50,
    num_train_epochs=1,
    max_steps=3000,
    logging_steps=10,
    fp16=True,
    remove_unused_columns=False,
    label_names=["labels"],
)
trainer = Seq2SeqTrainer(args=training_args, model=final_model, train_dataset=dataset, data_collator=collator, tokenizer=processor.feature_extractor,
                        callbacks=[SavePeftModelCallback])
trainer.train()
state_dict = final_model.state_dict()
torch.save(state_dict, '/kaggle/working/model_weights.pth')

max_steps is given, it will override any value given in num_train_epochs


Step,Training Loss
