## Importing Libraries

In [4]:
import os
from dotenv import load_dotenv
from IPython.display import display, Markdown

# pytorch
import torch

# huggingface
from huggingface_hub import login
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, Seq2SeqTrainingArguments, Seq2SeqTrainer
import evaluate
from datasets import Audio, load_dataset, DatasetDict

# wandb
import wandb

## Login to Hugging Face

In [1]:
load_dotenv()
token = os.getenv("HUGGINGFACE_TOKEN")
login(
    token=token,  # ADD YOUR TOKEN HERE
    add_to_git_credential=True
)

In [2]:
model_name = "Waktaverse-whisper-KO-large-v3"  # ADD YOUR MODEL NAME HERE
username = "PathFinderKR"  # ADD YOUR USERNAME HERE
repo_id = f"{username}/{model_name}"  # repository id

## Login to Weights & Biases

In [3]:
api_key = os.getenv("WANDB_API_KEY")
wandb.login(
    key=api_key  # ADD YOUR API KEY HERE
)
wandb.init(project=model_name)

## Device

In [5]:
# Device setup
device = (
    "cuda:0" if torch.cuda.is_available() else # Nvidia GPU
    #"mps" if torch.backends.mps.is_available() else # Apple Silicon GPU
    "cpu"
)
print(f"Device = {device}")

In [6]:
# Flash Attention Implementation
if device == "cuda:0":
    if torch.cuda.get_device_capability()[0] >= 8: # Ampere, Ada, or Hopper GPUs
        attn_implementation = "flash_attention_2"
        torch_dtype = torch.float16
    else:
        attn_implementation = "eager"
        torch_dtype = torch.float16
else:
    attn_implementation = "eager"
    torch_dtype = torch.float32
print(f"Attention Implementation = {attn_implementation}")

## Hyperparameters

In [7]:
################################################################################
# Language
################################################################################
language = "Korean"
language_code = "ko"

################################################################################
# Generation parameters
################################################################################
max_new_tokens=100
chunk_length_s=30
batch_size=4
sampling_rate=16000

################################################################################
# Training parameters
################################################################################
output_dir="./results"
logging_dir="./logs"
save_strategy="epoch"
logging_strategy="steps" # "steps", "epoch"
if logging_strategy == "steps":
    logging_steps=10
else:
    logging_steps=None
evaluation_strategy="epoch" # "steps", "epoch"
load_best_model_at_end=True
metric_for_best_model="wer"
greater_is_better=False
save_total_limit=1
report_to="wandb"

num_train_epochs=2
per_device_train_batch_size=4
gradient_accumulation_steps=4
gradient_checkpointing=True
learning_rate=2e-5
lr_scheduler_type="cosine" # "constant", "linear", "cosine"
warmup_ratio=0.1
optim = "adamw_torch"
weight_decay=0.01

## Model

In [8]:
# Model ID for base model
model_id = "openai/whisper-large-v3"

In [9]:
# load model and processor
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id,
    device_map=device,
    #attn_implementation=attn_implementation,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True
)

In [10]:
# Display the model architecture
display(Markdown(f'```{model}```'))

## Dataset

In [11]:
# Dataset ID
dataset_id = "mozilla-foundation/common_voice_17_0"

In [12]:
# Load dataset
dataset = DatasetDict()
dataset["train"] = load_dataset(dataset_id, language_code, split="train", trust_remote_code=True)
dataset["validation"] = load_dataset(dataset_id, language_code, split="validation", trust_remote_code=True)
dataset["test"] = load_dataset(dataset_id, language_code, split="test", trust_remote_code=True)

In [13]:
# Dataset information
dataset

In [14]:
# Dataset sample
dataset["train"][0]

In [15]:
# Tokenization sample
tokenized_sentence = processor.tokenizer(dataset["train"][0]["sentence"])
decoded_sentence = processor.tokenizer.decode(tokenized_sentence["input_ids"])

print(f"Tokenized Sentence: {tokenized_sentence}")
print(f"Decoded Sentence: {decoded_sentence}")

## Preprocessing

In [16]:
# Down sample audio
def preprocess_function(examples):
    audio = examples["audio"]
    
    examples["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=sampling_rate).input_features[0]
    examples["labels"] = processor.tokenizer(examples["sentence"], return_tensors="pt").input_ids
    
    return examples

# Preprocess dataset
dataset = dataset.map(preprocess_function, remove_columns=dataset.column_names["train"], num_proc=4)

In [17]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch


## Inference

In [18]:
pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    torch_dtype=torch_dtype,
    return_timestamps=True,

    max_new_tokens=max_new_tokens,
    chunk_length_s=chunk_length_s,
    batch_size=batch_size
)

In [19]:
def speech_recognition(audio):
    text = pipe(audio)["text"]
    return text

In [20]:
result = speech_recognition(dataset["train"][0])

In [None]:
result

## Training

In [21]:
def compute_metrics(pred):
    metric = evaluate.load("wer")
    
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    
    wer = 100 * metric.compute(predictions=pred_str, references=label_str)
    
    return {"wer": wer}

In [22]:
training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    logging_dir=logging_dir,
    save_strategy=save_strategy,
    logging_strategy=logging_strategy,
    logging_steps=logging_steps,
    evaluation_strategy=evaluation_strategy,
    load_best_model_at_end=load_best_model_at_end,
    metric_for_best_model=metric_for_best_model,
    greater_is_better=greater_is_better,
    save_total_limit=save_total_limit,
    report_to=report_to,
    
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    gradient_checkpointing=gradient_checkpointing,
    learning_rate=learning_rate,
    lr_scheduler_type=lr_scheduler_type,
    warmup_ratio=warmup_ratio,
    optim=optim,
    weight_decay=weight_decay
)

In [27]:
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
   #data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [26]:
trainer.train()