In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
!pip install datasets

In [None]:
import torch

In [None]:
device = torch.device("cuda")

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Sinhala", task="transcribe")


In [None]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Sinhala", task="transcribe")


In [None]:
from datasets import load_from_disk, Dataset
save_path = '/content/drive/My Drive/asr_sinhala/ProcessedData/FMI_dataset'
old_dataset = load_from_disk(save_path)

In [None]:
from datasets import load_from_disk, Dataset
save_path = '/content/drive/My Drive/asr_sinhala/ProcessedData/trained_dataset_d'
new_dataset = load_from_disk(save_path)

In [None]:
from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:

      input_features = []
      for feature in features:
          if "input_features" not in feature:
              print("Warning: 'input_features' not found in feature:", feature)
              continue
          input_features.append({"input_features": feature["input_features"]})


      batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
      label_features = [{"input_ids": feature["labels"]} for feature in features if "labels" in feature]
      if len(label_features) == 0:
          raise ValueError("No valid 'labels' found in the features.")

      labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
      labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
      if labels.size(1) > 0 and (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
          labels = labels[:, 1:]

      batch["labels"] = labels

      return batch

In [None]:
from transformers import WhisperForConditionalGeneration
model = WhisperForConditionalGeneration.from_pretrained("RRashmini/whisper-small-sinhala-1").to(device)

In [None]:
def compute_fisher_information(model,dataloader,device):
  fisher_info = {name: torch.zeros_like(param,device=device) for name,param in model.named_parameters()}
  model.eval()
  for batch in dataloader :
    inputs= batch['input_features'].to(device)
    labels = batch['labels'].to(device)
    model.zero_grad()
    outputs = model(input_features=inputs, labels=labels)
    loss = outputs.loss
    loss.backward()
    for name, param in model.named_parameters():

      fisher_info[name] += param.grad.data.pow(2)
  for name in fisher_info:
    fisher_info[name] /= len(dataloader)
  return fisher_info

In [None]:
def ewc_loss(model,fisher_info,prev_params,lambda_ewc=10):
  loss =0
  for name, param in model.named_parameters():
    loss += (fisher_info[name]* (param-prev_params[name])**2).sum()
  return lambda_ewc*loss

In [None]:
from transformers import Seq2SeqTrainer

class EWCTrainer(Seq2SeqTrainer):
  def __init__(self, model, fisher_info, prev_params, lambda_ewc=10, *args, **kwargs):
    super().__init__(model, *args, **kwargs)
    self.fisher_info = fisher_info
    self.prev_params = prev_params
    self.lambda_ewc = lambda_ewc

  def compute_loss(self, model, inputs, return_outputs=False,*args, **kwargs):
    outputs = model(**inputs)
    loss = outputs.loss
    ewc_reg = ewc_loss(model,self.fisher_info,self.prev_params,self.lambda_ewc)
    loss += ewc_reg

    return (loss, outputs) if return_outputs else loss

In [None]:
prev_params = {name: param.clone().detach() for name, param in model.named_parameters()}

In [None]:
for param in model.parameters():
    param.requires_grad = True

In [None]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

In [None]:
from torch.utils.data import DataLoader

fisher_dataloader = DataLoader(
    old_dataset,
    batch_size=4,
    shuffle = True,
    collate_fn = data_collator
)

In [None]:
fisher_info = compute_fisher_information(model, fisher_dataloader ,device)

In [None]:
fisher_save_path = "/content/drive/MyDrive/asr_sinhala/fisher_info.pt"
def save_fisher_info(load_path):
    torch.save(fisher_info, load_path)
    print(f"✅ Fisher Information saved to {load_path}")

save_fisher_info(fisher_save_path)

In [None]:
fisher_save_path = "/content/drive/MyDrive/asr_sinhala/fisher_info.pt"
def load_fisher_info(load_path):
    fisher_info = torch.load(load_path)
    print(f"✅ Fisher Information loaded from {load_path}")
    return fisher_info

fisher_info = load_fisher_info(fisher_save_path)

In [None]:
from transformers import Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-small-specaugment-sinhala",
    logging_steps=100,
    report_to=["tensorboard"],
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    learning_rate=1e-5,
    lr_scheduler_type="linear",
    warmup_steps=50,
    num_train_epochs=1,
    gradient_checkpointing=True,
    fp16=True,
    per_device_eval_batch_size=16,
    predict_with_generate=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=True,
)

In [None]:
!pip install jiwer

In [None]:
!pip install evaluate

In [None]:
import evaluate

In [None]:
from datasets import load_from_disk, Dataset
save_path_test = '/content/drive/My Drive/asr_sinhala/test_dataset_s'
processed_test_dataset = load_from_disk(save_path_test)

In [None]:
metric = evaluate.load("wer")

In [None]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    label_ids[label_ids == -100] = tokenizer.pad_token_id
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [None]:
trainer = EWCTrainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    train_dataset=new_dataset,
    eval_dataset= processed_test_dataset,
    tokenizer=processor.feature_extractor,
    compute_metrics=compute_metrics,
    fisher_info=fisher_info,
    prev_params=prev_params,
    lambda_ewc=10,

)

In [None]:
trainer.train()

In [None]:
trainer.evaluate(eval_dataset=processed_test_dataset)

In [None]:
model.push_to_hub("RRashmini/whisper-small-sinhala-2")