In [1]:
pip -q install evaluate==0.3.0 jiwer==2.5.1

Note: you may need to restart the kernel to use updated packages.


In [None]:
# login wandb
import wandb
wandb.login(key = '')

# Import module

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import warnings
warnings.filterwarnings('ignore')

import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration 
from transformers import TrainingArguments, Trainer

import datasets
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import numpy as np

import evaluate

2024-05-29 14:01:07.053708: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-29 14:01:07.053838: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-29 14:01:07.193844: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


# Build data

In [None]:
@dataclass
class DataCollatorWhisperCTCEncoder:
    processor: WhisperProcessor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    truncation: Optional[bool] = True
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None
    
    def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        batch_audio = []
        batch_label = []
        
        batch_size = len(features)

        for batch_idx in range(batch_size):
            batch_audio.append(features[batch_idx]['audio'])
            batch_label.append(features[batch_idx]['transcription'])

        data = list(zip(batch_audio, batch_label))
        # random.shuffle(data)
        
        batch_audio = [item[0] for item in data]
        batch_label = [item[1] for item in data]
        audio_arrays = [audio_data['array'] for audio_data in batch_audio]
        
        batch = self.processor.feature_extractor(
            audio_arrays,
            truncation=self.truncation,
            sampling_rate = 16000,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        
        batch_label_id = [self.processor.tokenizer(item, truncation=True, max_length=448)['input_ids'] for item in batch_label]
        label_features = [{"input_ids": np.asarray(item)} for item in batch_label_id]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        batch["labels"] = labels
        
        return batch


In [None]:
from datasets import load_dataset, DatasetDict

youtube = DatasetDict()

youtube["train"] = load_dataset('linhtran92/viet_youtube_asr_corpus_v2', split="train")
youtube["test"] = load_dataset('linhtran92/viet_youtube_asr_corpus_v2', split="test")

print(youtube)

In [None]:
train_dataset = youtube["train"].remove_columns(["w2v2_transcription", "WER", "sum"])
eval_dataset = youtube["test"].remove_columns(["w2v2_transcription", "WER", "sum"])

# Load model

In [None]:
# # Whisper
# model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
# processor = WhisperProcessor.from_pretrained("openai/whisper-base", language="vi", task="transcribe")
# processor.tokenizer.pad_token = processor.tokenizer.eos_token
# processor.tokenizer.max_length = 448
# processor.tokenizer.set_prefix_tokens(language="vi", task="transcribe")

In [3]:
# Phowhisper
model = WhisperForConditionalGeneration.from_pretrained("vinai/PhoWhisper-base")
processor = WhisperProcessor.from_pretrained("vinai/PhoWhisper-base", language="vi", task="transcribe")
processor.tokenizer.pad_token = processor.tokenizer.eos_token
processor.tokenizer.max_length = 448
processor.tokenizer.set_prefix_tokens(language="vi", task="transcribe")

config.json:   0%|          | 0.00/1.32k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/290M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/3.77k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/339 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/804 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/836k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.20M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/2.08k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.08k [00:00<?, ?B/s]

In [None]:
model.freeze_encoder()

In [None]:
model.config.forced_decoder_ids = processor.tokenizer.get_decoder_prompt_ids(
    language="vi", task="transcribe"
)
model.config.suppress_tokens = []
model.generation_config.forced_decoder_ids = processor.tokenizer.get_decoder_prompt_ids(
    language="vi", task="transcribe"
)
model.generation_config.suppress_tokens = []


# Training setting

In [None]:
# total_steps = (total_samples / batch_size) * num_epochs
batch_size = 64
num_epochs = 5
eval_accumulation_steps=100


training_args = TrainingArguments(
        output_dir='/kaggle/working/',
        logging_dir='/kaggle/working/',
        group_by_length=False,
        per_device_train_batch_size=batch_size,
#         per_device_eval_batch_size=batch_size,
#         evaluation_strategy="steps",
        save_strategy="steps",
        num_train_epochs=num_epochs,
#         eval_accumulation_steps=eval_accumulation_steps,
#         metric_for_best_model='wer',
        greater_is_better=False,
        fp16=True,
        gradient_checkpointing=True, 
        remove_unused_columns=False,
        dataloader_num_workers=2,
        save_steps=2000,
#         eval_steps=4000,
        logging_steps=1000,
        learning_rate=5e-4,
        # weight_decay=0.005,
        warmup_steps=2000,
        save_total_limit=2,
        ignore_data_skip=True,
        label_names=["labels"],
    ) 

In [None]:
data_collator = DataCollatorWhisperCTCEncoder(
    processor=processor, 
)


metric = evaluate.load("wer")

In [None]:
def compute_wer(eval_prediction):
    pred_ids = eval_prediction.predictions[0]
    label_ids = eval_prediction.label_ids

    pred_ids = np.argmax(pred_ids, axis=-1)

    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)
    return {'wer': wer}

# Training

In [None]:
trainer = Trainer(
        model=model,
        data_collator=data_collator,
        args=training_args,
#         compute_metrics=compute_wer,
        train_dataset=train_dataset,
#         eval_dataset=eval_dataset,
        tokenizer=processor.feature_extractor,
    )
# trainer.train(resume_from_checkpoint='path/to/checkpoint')
trainer.train()

In [None]:
# trainer.save_state()
trainer.save_model("/kaggle/working/model_final")

# Demo

In [None]:
# import librosa
# input_speech, rate = librosa.load('/kaggle/input/vin100h-test-wave/vin100h_test_wave.wav', sr=16000)

# forced_decoder_ids = processor.get_decoder_prompt_ids(language="vi", task="transcribe")
# input_features = processor(input_speech, sampling_rate=rate, return_tensors="pt").input_features

In [None]:
# model_trained = WhisperForConditionalGeneration.from_pretrained('/kaggle/working/model_final')
# predicted_ids_model_trained = model_trained.generate(input_features, forced_decoder_ids=forced_decoder_ids)
# processor.batch_decode(predicted_ids_model_trained)