# Setup environment

In [None]:
# !pip install transformers  datasets   torch-summary  jiwer torchaudio   wandb >ou

In [None]:
# !pip install https://github.com/kpu/kenlm/archive/master.zip >ou

In [1]:
# !pip install pyctcdecode >ou

# Load lib

In [1]:

from dataclasses import dataclass, field
from tqdm import  tqdm
# import wandb

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim,Tensor


# from torchsummary import summary
import torchaudio
import torchaudio.transforms as T

import transformers
from transformers import (
    Trainer,
    TrainingArguments,
    Wav2Vec2Processor,
    Wav2Vec2ForCTC,


)
from inference import Inference

from typing import Any, Dict, List, Optional, Union,Tuple
# import jiwer
import numpy as np
from datasets  import load_metric
import IPython.display as ipd

from dataset import VLSPDataset,DatasetValidated

torchaudio.set_audio_backend('soundfile')

  from .autonotebook import tqdm as notebook_tqdm
  torchaudio.set_audio_backend('soundfile')


In [None]:
# !wandb login --relogin a19ac0b8d3d00acd9062764a8ac2e6d7cebb9ee1
# %env WANDB_PROJECT=ASR_with_NST
# wandb.init(project='ASR_with_NST', name=f"Teacher")

In [2]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

Device: cpu


In [3]:
model_name = 'checkpoint\checkpoint_Teacher'
batch_size = 64
batch_size_text = 8

# lr = 0.0005 * batch_size ** (1 / 2)
lr = 0.000001
max_epochs = 24

processor = Wav2Vec2Processor.from_pretrained('nguyenvulebinh/wav2vec2-base-vietnamese-250h')

# labels=list(dict(sorted(processor.tokenizer.get_vocab().items(), key=lambda item: item[1])))
# decoder=build_ctcdecoder(labels=labels,
#             kenlm_model_path='/kaggle/input/pretrain/vi_lm_4grams.bin',
#             alpha = 0.5,
#             beta= 1.5,
#             unk_score_offset=-10.0,
#             lm_score_boundary=True,)
# processor_LM = Wav2Vec2ProcessorWithLM(
#             feature_extractor=processor.feature_extractor,
#             tokenizer=processor.tokenizer,
#             decoder=self.decoder)

In [None]:

train_dataset = VLSPDataset(
    processor,
)
test_dataset=DatasetValidated(  
    processor,
    path='/kaggle/input/commonvoice-vie/cv-corpus-15.0-2023-09-08',
    path_csv='/kaggle/input/datacsv/commonvoice_test.csv'
)
val_dataset=DatasetValidated(  
    processor,
)

In [None]:
test_dataset[0]

In [None]:
class DataCollatorCTCWithPadding:
    def __init__(
        self,
        processor:Wav2Vec2Processor,
        padding: Union[bool, str] = True,
        max_length: Optional[int] = None,
        max_length_labels: Optional[int] = None,
        pad_to_multiple_of: Optional[int] = None,
        pad_to_multiple_of_labels: Optional[int] = None,
    ):
        self.processor=processor
        self.padding=padding
        self.max_length=max_length
        self.max_length_labels=max_length_labels
        self.pad_to_multiple_of=pad_to_multiple_of
        self.pad_to_multiple_of_labels=pad_to_multiple_of_labels
    def __call__(self, batchs: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_values": batch["input_values"]} for batch in batchs]
        label_features = [{"input_ids": batch["labels"]} for batch in batchs]
        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",

        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
    
            )
        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        batch["labels"] = labels
        batch["input_values"]= batch["input_values"]
        return batch

In [None]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [None]:
# train_dataloader = DataLoader(
#     train_dataset,
#     batch_size=batch_size,
#     collate_fn=data_collator,
#     shuffle=True,
#     pin_memory=True,
#     num_workers=2,
#     drop_last=False,
# )
test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size_text,
    collate_fn=data_collator,
    shuffle=False,
    pin_memory=True,
    num_workers=2,
    drop_last=False,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size_text,
    collate_fn=data_collator,
    shuffle=False,
    pin_memory=True,
    num_workers=2,
    drop_last=False,
)

In [None]:
# for d in val_dataloader:
#     print(d)
#     break

In [None]:
model = Wav2Vec2ForCTC.from_pretrained(model_name)
# model = Wav2Vec2ForPreTraining.from_pretrained(model_name)
model.to(device)
model.freeze_feature_encoder()
# print(model)
# print(summary(model,input_size= (1, 500)))

In [None]:
wer_metric = load_metric("wer")
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [None]:
config = {
    "model":model_name,
    "learning_rate": lr,
    "max_epochs": max_epochs,
    "batch_size":  batch_size,
    "dataset":"VLSP"
    
}
# wandb.config = config

In [None]:
repo_name="Teacher_base"
training_args = TrainingArguments(
  output_dir=repo_name,
  group_by_length=True,
  per_device_train_batch_size=batch_size,
  per_device_eval_batch_size =8,
  evaluation_strategy="steps",
  num_train_epochs=max_epochs,
  fp16=True,
  use_cpu=False,
  gradient_checkpointing=True, 
  save_steps=500,
  eval_steps=500,
  learning_rate=1e-4,
  weight_decay=0.005,
  do_train =True,
  save_total_limit=2,
  report_to="wandb", 
  run_name="Teacher", 
  logging_steps=50,
  warmup_steps=500,
  logging_strategy='epoch'
   )

In [None]:
total_steps=int((len(train_dataset)//batch_size+1)*max_epochs)
optimizer = optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.9999))
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=0.0005, pct_start=0.3, total_steps=total_steps
)
print(total_steps)

In [None]:
trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=processor.feature_extractor,
    optimizers =[optimizer,scheduler],

  
)

In [None]:
trainer.train()
# wandb.finish()

In [6]:
# inference=Inference(model,processor)

Unigrams not provided and cannot be automatically determined from LM file (only arpa format). Decoding accuracy might be reduced.
No known unigrams provided, decoding results might be a lot worse.
  self.wer_metric = load_metric("wer")


In [7]:
# running_wers=inference.test_wer(val_dataloader,50)
# print(np.mean(running_wers))

NameError: name 'val_dataloader' is not defined