<a href="https://colab.research.google.com/github/alt-f13/Dell_QA/blob/main/gigaam_ctc_hf_finetune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title Install libs
%%capture
!pip install -q evaluate
!pip install -q jiwer
!pip install -q pytorch_lightning
!pip install -U transformers==4.49.0 accelerate==1.5.2 datasets==3.4.1


# Finetune GigaAM-v2-CTC with 🤗 HuggingFace transformers

GigaAM-v2 is an open-source models for Russian speech recognition tasks. SoTA in Russian ASR by March 2025.

Original git: https://github.com/salute-developers/GigaAM

This notebook is for finetuning CTC version of GigaAM-v2 model with `transformers` library.

Use GPU environment.

In [None]:
import os

# os.environ["HUGGINGFACE_HUB_CACHE"] = "."
os.environ["WANDB_PROJECT"] = "project"

In [None]:
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, List, Optional, Union

import datasets
import evaluate
import numpy as np
import peft
import pytorch_lightning as pl
import torch
import wandb
from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader
from transformers import (AutoFeatureExtractor, AutoModel, AutoProcessor,
                          AutoTokenizer, Trainer, TrainingArguments)
from transformers.utils import is_datasets_available

In [None]:
# wandb.login()

parameters

In [None]:
model_name = "waveletdeboshir/gigaam-ctc"
SEED = 1234

# Set max duration of audio files to 30 seconds
MAX_DURATION = 30.

In [None]:
np.random.seed(SEED)
pl.seed_everything(SEED)

# Load model, feature extractor and tokenizer

These are `transformers` wrappers for GigaAM-CTC model, tokenizer and featuren extractor from https://huggingface.co/waveletdeboshir/gigaam-ctc

In [None]:
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)

In [None]:
def msize(m):
    return sum(p.numel() for p in m.parameters())

print(f"N of parameters: {msize(model.model)}")

## Load Datasets

We will load part of Golos dataset just for example. **GigaAM was already trained on this dataset. So use some other data**

In [None]:
audio_dataset = load_dataset("bond005/sberdevices_golos_10h_crowd")
audio_dataset["train"] = audio_dataset["train"].shuffle()

In [None]:
audio_dataset["train"] = audio_dataset["train"].add_column("duration", np.array([len(x["array"]) / x["sampling_rate"] for x in audio_dataset["train"]["audio"]]))
audio_dataset["validation"] = audio_dataset["validation"].add_column("duration", np.array([len(x["array"]) / x["sampling_rate"] for x in audio_dataset["validation"]["audio"]]))

In [None]:
audio_dataset["train"] = audio_dataset["train"].filter(lambda x: x["duration"] < MAX_DURATION)
audio_dataset["validation"] = audio_dataset["validation"].filter(lambda x: x["duration"] < MAX_DURATION)

In [None]:
audio_dataset

### Prepare Data

In [None]:
def prepare_dataset(batch, feature_extractor, tokenizer, text_column="text", val=False):
    """
    Compute log-Mel features.
    Text tokenization."""
    # load and resample audio data
    audio = batch["audio"]

    # compute log-Mel input features from input audio array
    # if val:
    feats = feature_extractor(
        audio["array"], sampling_rate=audio["sampling_rate"], padding="longest"
    )
    batch["input_features"] = feats.input_features[0]
    batch["input_lengths"] = feats.input_lengths[0]
    # else:
    #     batch["input_features"] = audio["array"].copy()

    batch["labels"] = tokenizer(batch[text_column]).input_ids
    return batch

In [None]:
audio_dataset["train"] = audio_dataset["train"].map(
    partial(
        prepare_dataset,
        feature_extractor=feature_extractor,
        tokenizer=tokenizer,
        text_column="transcription",
        val=False,
    ),
    remove_columns=audio_dataset.column_names["train"],
    num_proc=1,
)

audio_dataset["validation"] = audio_dataset["validation"].map(partial(
            prepare_dataset,
            feature_extractor=feature_extractor,
            tokenizer=tokenizer,
            text_column="transcription",
            val=True
            ), remove_columns=audio_dataset.column_names["validation"],
            num_proc=1
)

In [None]:
audio_dataset

## Training and Evaluation

### Define a Data Collator

In [None]:
@dataclass
class DataCollatorCTCWithPadding:
    processor: Any
    padding: str = "longest"
    max_length: Optional[int] = 3001
    max_length_tokens: Optional[int] = 1000

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": np.asarray(feature["input_features"]).T} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, padding=self.padding, max_length=self.max_length, return_tensors="pt")
        batch["input_features"] = batch["input_features"].transpose(1, 2)

        input_lengths = [feature["input_lengths"] for feature in features]
        # batch = self.processor.feature_extractor.pad(input_features, padding=self.padding, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, padding=self.padding, max_length=self.max_length_tokens, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["input_lengths"] = torch.LongTensor(input_lengths)
        batch["labels"] = labels

        if "attention_mask" in batch:
            batch["attention_mask"] = batch["attention_mask"].to(torch.long)

        return batch


I want to use different collators for train and validation:

In [None]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding="longest")
val_data_collator = DataCollatorCTCWithPadding(processor=processor, padding="max_length")

### Evaluation Metrics

In [None]:
metric = evaluate.load("wer")

In [None]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids)
    label_str = tokenizer.batch_decode(label_ids)

    # # save references and predictions to a txt file for debugging
    # with open('refs_and_preds.txt', 'w') as f:
    #     for ref, pred in zip(label_str, pred_str):
    #         f.write(f"Ref: {ref}\n")
    #         f.write(f"Pred: {pred}\n\n")

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

### If you want to use LoRA

In [None]:
# lora_config = dict(
#     r=32, lora_alpha=64,
#     lora_dropout=0.05,
#     target_modules=[
#         "linear_k", "linear_q", # change if you want
#     ],
#     bias="none"
# )

# peft_config = peft.LoraConfig(
#     inference_mode=False,
#     **lora_config,
# )

# model = peft.get_peft_model(model, peft_config)

In [None]:
# model.print_trainable_parameters()

### Define the Training Configuration

Set trainer with different collators for train and val

In [None]:
class TrainerDifCollators(Trainer):

    def __init__(self,  val_data_collator=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.val_data_collator = val_data_collator

    def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
        """
        Returns the evaluation [`~torch.utils.data.DataLoader`].

        Subclass and override this method if you want to inject some custom behavior.

        Args:
            eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*):
                If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed.
        """
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")

        # If we have persistent workers, don't do a fork bomb especially as eval datasets
        # don't change during training
        dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
        if (
            hasattr(self, "_eval_dataloaders")
            and dataloader_key in self._eval_dataloaders
            and self.args.dataloader_persistent_workers
        ):
            return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])

        eval_dataset = (
            self.eval_dataset[eval_dataset]
            if isinstance(eval_dataset, str)
            else eval_dataset
            if eval_dataset is not None
            else self.eval_dataset
        )
        data_collator = self.val_data_collator if self.val_data_collator else self.data_collator

        if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
            eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")

        dataloader_params = {
            "batch_size": self.args.eval_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
            "persistent_workers": self.args.dataloader_persistent_workers,
        }

        if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
            dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
            dataloader_params["drop_last"] = self.args.dataloader_drop_last
            dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

        # accelerator.free_memory() will destroy the references, so
        # we need to store the non-prepared version
        eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
        if self.args.dataloader_persistent_workers:
            if hasattr(self, "_eval_dataloaders"):
                self._eval_dataloaders[dataloader_key] = eval_dataloader
            else:
                self._eval_dataloaders = {dataloader_key: eval_dataloader}

        return self.accelerator.prepare(eval_dataloader)

In [None]:
# Experiment name
ex_name = "gigaam-ctc-test"

training_args = TrainingArguments(
    output_dir=f"./finetune/{ex_name}",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-4,
    warmup_steps=100,
    max_steps=2500,
    weight_decay=1e-4,
    gradient_checkpointing=False,
    # fp16=False,
    save_only_model=True,
    dataloader_num_workers=2,
    eval_strategy="steps",
    per_device_eval_batch_size=4,
    save_steps=500,
    eval_steps=500,
    logging_steps=50,
    save_total_limit=2,
    report_to=["wandb"],
    load_best_model_at_end=True,
    remove_unused_columns=False,
    label_names=["labels"],
    metric_for_best_model="val_wer",
    greater_is_better=False,
    push_to_hub=False,
    seed=SEED,
    run_name=ex_name,
)

trainer = TrainerDifCollators(
    args=training_args,
    model=model,
    train_dataset=audio_dataset["train"],
    eval_dataset={"val": audio_dataset["validation"]},
    data_collator=data_collator,
    val_data_collator=val_data_collator,
    compute_metrics=compute_metrics,
    processing_class=processor.feature_extractor,
)

# processor.save_pretrained(training_args.output_dir)

### Training

In [None]:
trainer.train()
wandb.finish()