<a target="_blank" href="https://colab.research.google.com/github/Deep-unlearning/notebooks/blob/main/finetune_phi4mm.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Fine-tuning Phi-4 for ASR

In this notebook, we will fine-tune pretrained Phi-4-Multimodal-Instruct on a small split of earnings22 test set.

Let's get started by installing necessary libraries.

In [4]:
!pip install huggingface_hub==0.25.2
!pip install scipy==1.15.1
!pip install peft==0.13.2
!pip install backoff==2.2.1
!pip install transformers==4.46.1
!pip install accelerate==1.3.0
!pip install sacrebleu
!pip install torchvision
!pip install hf_transfer
!pip install transformers
!pip install librosa
!pip install soundfile
!pip install datasets
!pip install jiwer

Collecting huggingface_hub==0.25.2
  Using cached huggingface_hub-0.25.2-py3-none-any.whl.metadata (13 kB)
Using cached huggingface_hub-0.25.2-py3-none-any.whl (436 kB)
Installing collected packages: huggingface_hub
  Attempting uninstall: huggingface_hub
    Found existing installation: huggingface-hub 0.29.2
    Uninstalling huggingface-hub-0.29.2:
      Successfully uninstalled huggingface-hub-0.29.2
Successfully installed huggingface_hub-0.25.2
Collecting jiwer
  Downloading jiwer-3.1.0-py3-none-any.whl.metadata (2.6 kB)
Collecting click>=8.1.8 (from jiwer)
  Downloading click-8.1.8-py3-none-any.whl.metadata (2.3 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.12.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading jiwer-3.1.0-py3-none-any.whl (22 kB)
Downloading click-8.1.8-py3-none-any.whl (98 kB)
Downloading rapidfuzz-3.12.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━

In [2]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [5]:
import json
import os
from pathlib import Path

import torch
import jiwer
from accelerate import Accelerator
from accelerate.utils import gather_object
from datasets import load_dataset
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoProcessor,
    BatchFeature,
    Trainer,
    TrainingArguments,
    StoppingCriteria,
    StoppingCriteriaList,
)

# Fixed ASR instruction and other constants
INSTRUCTION = "Transcribe the audio."
ANSWER_SUFFIX = "<|end|><|endoftext|>"
_IGNORE_INDEX = -100


In [3]:
class ESBEarnings22Dataset(Dataset):
    def __init__(self, processor, dataset, training=True):
        """
        processor: the AutoProcessor instance
        dataset: a Hugging Face Dataset (already split into train/validation)
        training: whether this dataset is for training (affects concatenation of target tokens)
        """
        self.data = dataset
        self.training = training
        self.processor = processor
        self.instruction = INSTRUCTION

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx]
        # The dataset contains an "audio" dict and a "text" field for transcription.
        user_message = {
            'role': 'user',
            'content': '<|audio_1|>\n' + self.instruction,
        }
        prompt = self.processor.tokenizer.apply_chat_template(
            [user_message], tokenize=False, add_generation_prompt=True
        )
        inputs = self.processor(
            text=prompt,
            audios=[(data["audio"]["array"], data["audio"]["sampling_rate"])],
            return_tensors='pt'
        )
        
        answer = f"{data['text']}{ANSWER_SUFFIX}"
        answer_ids = self.processor.tokenizer(answer, return_tensors='pt').input_ids
        if self.training:
            # Concatenate prompt and answer, but mask all tokens except the answer.
            input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
            labels = torch.full_like(input_ids, _IGNORE_INDEX)
            labels[:, -answer_ids.shape[1]:] = answer_ids
        else:
            input_ids = inputs.input_ids
            labels = answer_ids

        return {
            'input_ids': input_ids,
            'labels': labels,
            'input_audio_embeds': inputs.input_audio_embeds,
            'audio_embed_sizes': inputs.audio_embed_sizes,
        }


In [4]:
def pad_sequence(sequences, padding_side='right', padding_value=0):
    assert padding_side in ['right', 'left']
    max_size = sequences[0].size()
    trailing_dims = max_size[1:]
    max_len = max(len(seq) for seq in sequences)
    batch_size = len(sequences)
    output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
    for i, seq in enumerate(sequences):
        length = seq.size(0)
        if padding_side == 'right':
            output.data[i, :length] = seq
        else:
            output.data[i, -length:] = seq
    return output

def cat_with_pad(tensors, dim, padding_value=0):
    ndim = tensors[0].dim()
    assert all(t.dim() == ndim for t in tensors[1:]), 'All tensors must have the same number of dimensions'
    out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
    out_size[dim] = sum(t.shape[dim] for t in tensors)
    output = tensors[0].new_full(out_size, padding_value)
    index = 0
    for t in tensors:
        slices = [slice(0, t.shape[d]) for d in range(ndim)]
        slices[dim] = slice(index, index + t.shape[dim])
        output[slices] = t
        index += t.shape[dim]
    return output


In [5]:
def esb_collate_fn(batch):
    input_ids_list = []
    labels_list = []
    input_audio_embeds_list = []
    audio_embed_sizes_list = []
    audio_attention_mask_list = []
    for inputs in batch:
        input_ids_list.append(inputs['input_ids'][0])
        labels_list.append(inputs['labels'][0])
        input_audio_embeds_list.append(inputs['input_audio_embeds'])
        audio_embed_sizes_list.append(inputs['audio_embed_sizes'])
        audio_attention_mask_list.append(
            inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
        )
    try:
        input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
        labels = pad_sequence(labels_list, padding_side='left', padding_value=0)
        audio_attention_mask = (
            pad_sequence(audio_attention_mask_list, padding_side='right', padding_value=False)
            if len(audio_attention_mask_list) > 1 else None
        )
    except Exception as e:
        print(e)
        print(input_ids_list)
        print(labels_list)
        raise
    attention_mask = (input_ids != 0).long()
    input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0)
    audio_embed_sizes = torch.cat(audio_embed_sizes_list)
    return BatchFeature({
        'input_ids': input_ids,
        'labels': labels,
        'attention_mask': attention_mask,
        'input_audio_embeds': input_audio_embeds,
        'audio_embed_sizes': audio_embed_sizes,
        'audio_attention_mask': audio_attention_mask,
        'input_mode': 2,  # speech mode
    })


In [6]:
class MultipleTokenBatchStoppingCriteria(StoppingCriteria):
    def __init__(self, stop_tokens: torch.LongTensor, batch_size: int = 1) -> None:
        self.stop_tokens = stop_tokens
        self.max_stop_tokens = stop_tokens.shape[-1]
        self.stop_tokens_idx = torch.zeros(batch_size, dtype=torch.long, device=stop_tokens.device)

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        generated_inputs = torch.eq(input_ids[:, -self.max_stop_tokens :].unsqueeze(1), self.stop_tokens)
        equal_generated_inputs = torch.all(generated_inputs, dim=2)
        sequence_idx = torch.any(equal_generated_inputs, dim=1)
        sequence_set_mask = self.stop_tokens_idx == 0
        self.stop_tokens_idx[sequence_idx & sequence_set_mask] = input_ids.shape[-1]
        return torch.all(self.stop_tokens_idx)


In [7]:
def create_model(model_name_or_path, use_flash_attention=False):
    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        torch_dtype=torch.bfloat16 if use_flash_attention else torch.float32,
        _attn_implementation='flash_attention_2' if use_flash_attention else 'sdpa',
        trust_remote_code=True,
    ).to('cuda')
    return model


In [8]:
@torch.no_grad()
def evaluate(model, processor, eval_dataset, save_path=None, disable_tqdm=False, eval_batch_size=1):
    rank = int(os.environ.get('RANK', 0))
    local_rank = int(os.environ.get('LOCAL_RANK', 0))

    model.eval()
    all_generated_texts = []
    all_labels = []

    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=eval_batch_size,
        collate_fn=esb_collate_fn,
        shuffle=False,
        drop_last=False,
        num_workers=8,
        prefetch_factor=2,
        pin_memory=True,
    )
    stop_tokens = ["<|end|>", processor.tokenizer.eos_token]
    stop_tokens_ids = processor.tokenizer(stop_tokens, add_special_tokens=False, padding="longest", return_tensors="pt")["input_ids"]
    stop_tokens_ids = stop_tokens_ids.to(f'cuda:{local_rank}')

    for inputs in tqdm(eval_dataloader, disable=(rank != 0) or disable_tqdm, desc='running eval'):
        stopping_criteria = StoppingCriteriaList([MultipleTokenBatchStoppingCriteria(stop_tokens_ids, batch_size=inputs.input_ids.size(0))])
        inputs = inputs.to(f'cuda:{local_rank}')
        generated_ids = model.generate(
            **inputs, eos_token_id=processor.tokenizer.eos_token_id, max_new_tokens=64,
            stopping_criteria=stopping_criteria,
        )

        stop_tokens_idx = stopping_criteria[0].stop_tokens_idx.reshape(inputs.input_ids.size(0), -1)[:, 0]
        stop_tokens_idx = torch.where(
            stop_tokens_idx > 0,
            stop_tokens_idx - stop_tokens_ids.shape[-1],
            generated_ids.shape[-1],
        )
        generated_text = [
            processor.decode(_pred_ids[inputs["input_ids"].shape[1] : _stop_tokens_idx],
                               skip_special_tokens=True,
                               clean_up_tokenization_spaces=False)
            for _pred_ids, _stop_tokens_idx in zip(generated_ids, stop_tokens_idx)
        ]
        all_generated_texts.extend(generated_text)
        labels = [processor.decode(_label_ids[_label_ids != 0]).rstrip(ANSWER_SUFFIX) for _label_ids in inputs["labels"]]
        all_labels.extend(labels)

    all_generated_texts = gather_object(all_generated_texts)
    all_labels = gather_object(all_labels)
    
    if rank == 0:
        wer = jiwer.wer(all_labels, all_generated_texts)
        print("WER:", wer)
        if save_path:
            with open(save_path, 'w') as f:
                save_dict = {
                    'all_generated_texts': all_generated_texts,
                    'all_labels': all_labels,
                    'wer': wer,
                }
                json.dump(save_dict, f)
        return wer
    return None


In [None]:
# Configuration variables
MODEL_NAME = 'microsoft/Phi-4-multimodal-instruct'
OUTPUT_DIR = './output/'
USE_FLASH_ATTENTION = False
BATCH_SIZE_PER_GPU = 2
NUM_TRAIN_EPOCHS = 1
LEARNING_RATE = 4.0e-5
WEIGHT_DECAY = 0.01

# Initialize Accelerator for potential multi-GPU training.
accelerator = Accelerator()

# Load processor and model
with accelerator.local_main_process_first():
    processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
    model = create_model(MODEL_NAME, use_flash_attention=USE_FLASH_ATTENTION)

# Set LoRA adapter (if available) for speech tasks.
model.set_lora_adapter('speech')

# Load and split the dataset.
ds_full = load_dataset("AdrienB134/Emilia-dataset-french-split", split="fr")
split_ds = ds_full.train_test_split(test_size=0.1, seed=42)
train_ds = split_ds["train"]
val_ds = split_ds["test"]

# Create dataset objects.
train_dataset = ESBEarnings22Dataset(processor, train_ds, training=True)
val_dataset = ESBEarnings22Dataset(processor, val_ds, training=False)

num_gpus = accelerator.num_processes
print(f"Training on {num_gpus} GPUs")

# Compute gradient accumulation steps (for multi-GPU training).
gradient_accumulation_steps = (BATCH_SIZE_PER_GPU * num_gpus) // BATCH_SIZE_PER_GPU

# Set mixed precision flags.
fp16 = not USE_FLASH_ATTENTION
bf16 = USE_FLASH_ATTENTION

# Define training arguments.
training_args = TrainingArguments(
    num_train_epochs=NUM_TRAIN_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE_PER_GPU,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={'use_reentrant': False},
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim='adamw_torch',
    adam_beta1=0.9,
    adam_beta2=0.95,
    adam_epsilon=1e-7,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    max_grad_norm=1.0,
    lr_scheduler_type='linear',
    warmup_steps=50,
    logging_steps=10,
    output_dir=OUTPUT_DIR,
    save_strategy='no',
    save_total_limit=10,
    save_only_model=True,
    bf16=bf16,
    fp16=fp16,
    remove_unused_columns=False,
    report_to='none',
    dataloader_num_workers=4,
    ddp_find_unused_parameters=True,
)

In [None]:
# Create output directory if it doesn't exist.
Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)

# Evaluate the model before fine-tuning.
print("Evaluating before fine-tuning...")
wer_before = evaluate(
    model,
    processor,
    val_dataset,
    save_path=Path(training_args.output_dir) / 'eval_before.json',
    eval_batch_size=BATCH_SIZE_PER_GPU,
)
print(f"WER before fine-tuning: {wer_before}")

# Setup the Trainer for fine-tuning.
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=esb_collate_fn,
    train_dataset=train_dataset,
)

# Train the model.
trainer.train()
trainer.save_model()
processor.save_pretrained(training_args.output_dir)
accelerator.wait_for_everyone()

In [None]:
# Free up memory before re-loading the model.
del model, trainer
torch.cuda.empty_cache()

# Reload the fine-tuned model.
model = AutoModelForCausalLM.from_pretrained(
    training_args.output_dir,
    torch_dtype=torch.bfloat16 if USE_FLASH_ATTENTION else torch.float32,
    trust_remote_code=True,
    _attn_implementation='flash_attention_2' if USE_FLASH_ATTENTION else 'sdpa',
).to('cuda')

# Evaluate the model after fine-tuning.
print("Evaluating after fine-tuning...")
wer_after = evaluate(
    model,
    processor,
    val_dataset,
    save_path=Path(training_args.output_dir) / 'eval_after.json',
    eval_batch_size=BATCH_SIZE_PER_GPU,
)
print(f"WER after fine-tuning: {wer_after}")

In [None]:
# Push to the hub

model.push_to_hub("your-username/your-model-name", commit_message="Uploading fine-tuned model")
processor.push_to_hub("your-username/your-model-name", commit_message="Uploading processor")


WER before finetuning: 0.3346904527379519

WER after finetuning: 0.2521907263507351