In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html

In [2]:
!pip install datasets transformers==4.21.3



In [3]:
!pip install sacrebleu



In [4]:
!pip install evaluate
!pip install jiwer



In [5]:
from evaluate import load
wer = load("wer")
cer = load("cer")

In [6]:
import pandas as pd
from transformers import AutoTokenizer
from transformers import (AutoTokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer,
                          DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM)
import torch
from tqdm import tqdm
import numpy as np
from datasets import load_metric, Dataset, load_dataset
import os
from sklearn.model_selection import train_test_split

import gc
tqdm.pandas()

In [7]:
path = 'UrukHan/t5-russian-spell'
tokeniser = AutoTokenizer.from_pretrained(path, use_fast=True)
model = AutoModelForSeq2SeqLM.from_pretrained(path)

In [8]:
metric_bleu = load_metric("sacrebleu")
metric_meteor = load_metric("meteor")

  metric_bleu = load_metric("sacrebleu")
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [9]:
def postprocess_text_wer(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    return preds, labels

def postprocess_text_cer(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    return preds, labels

def compute_metrics(eval_preds):

    torch.cuda.empty_cache()

    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokeniser.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokeniser.pad_token_id)
    decoded_labels = tokeniser.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing

    decoded_preds_wer, decoded_labels_wer = postprocess_text_wer(decoded_preds, decoded_labels)
    decoded_preds_cer, decoded_labels_cer = postprocess_text_cer(decoded_preds, decoded_labels)

    wer_score = wer.compute(
        predictions=decoded_preds_wer,
        references=decoded_labels_wer
    )

    cer_score = cer.compute(
        predictions=decoded_preds_cer,
        references=decoded_labels_cer,
    )

    result = {
        "WER": round(wer_score, 4),
        "CER": round(cer_score, 4)
    }

    return result

In [40]:
train = pd.read_csv('whisper_small_ru_train.csv', on_bad_lines='skip', encoding_errors='ignore')
valid = pd.read_csv('whisper_small_ru_validation.csv', on_bad_lines='skip', encoding_errors='ignore')[0:1000]
test = pd.read_csv('vosk_small_ru_test.csv', on_bad_lines='skip', encoding_errors='ignore')

In [41]:
test = test.dropna(subset=['pred'])

In [42]:
def tokenize_col(df_t):
  df_t['tok'] = df_t.pred.progress_apply(lambda x: tokeniser.encode(x))
  df_t = df_t[df_t.tok.apply(len) <= 128]
  df_t = df_t[df_t.tok.apply(len) >= 2]
  return df_t

train = tokenize_col(train)
valid = tokenize_col(valid)
test = tokenize_col(test)

100%|██████████| 22862/22862 [00:02<00:00, 8884.39it/s]
100%|██████████| 1000/1000 [00:00<00:00, 7224.36it/s]
100%|██████████| 9621/9621 [00:01<00:00, 8868.51it/s]


In [13]:
test.iloc[112]

text                               Мне кажется, что и он счастлив.
text_clean                           мне кажется что и он счастлив
path_relative    data/ru/common_voice/wav/test/common_voice_ru_...
path             /home/jovyan/bystrova-ov/whisper/data/ru/commo...
model                                                         nemo
pred                                 мне кажется что и он счастлив
tok                                 [109, 791, 16, 5, 39, 3875, 2]
Name: 113, dtype: object

In [14]:
test.head()

Unnamed: 0,text,text_clean,path_relative,path,model,pred,tok
0,"К сожалению, эти предложения не нашли отражени...",к сожалению эти предложения не нашли отражения...,data/ru/common_voice/wav/test/common_voice_ru_...,/home/jovyan/bystrova-ov/whisper/data/ru/commo...,nemo,к сожалению эти предложения не нашли отражения...,"[24, 2468, 287, 3719, 10, 2692, 25622, 6, 1458..."
1,"Если не будет возражений, я буду считать, что ...",если не будет возражений я буду считать что ас...,data/ru/common_voice/wav/test/common_voice_ru_...,/home/jovyan/bystrova-ov/whisper/data/ru/commo...,nemo,если не будет возражений я буду считать что ас...,"[119, 10, 127, 28582, 35, 858, 2636, 16, 37, 8..."
2,Новошахтинск — милый город,новошахтинск милый город,data/ru/common_voice/wav/test/common_voice_ru_...,/home/jovyan/bystrova-ov/whisper/data/ru/commo...,nemo,новошахтинск милый город,"[4907, 20223, 98, 15070, 13874, 690, 2]"
3,"Мы особенно рады отметить, что число скрывающи...",мы особенно рады отметить что число скрывающих...,data/ru/common_voice/wav/test/common_voice_ru_...,/home/jovyan/bystrova-ov/whisper/data/ru/commo...,nemo,мы особенно рады отметить что число скрывающих...,"[93, 733, 10609, 3436, 16, 1926, 13734, 2274, ..."
4,Контроллер,контроллер,data/ru/common_voice/wav/test/common_voice_ru_...,/home/jovyan/bystrova-ov/whisper/data/ru/commo...,nemo,контролер,"[22646, 238, 2]"


In [43]:
def preprocess_datasets(examples, tokeniser, max_length):

    inputs = examples['text_clean']
    targets = examples['text']

    model_inputs = tokeniser(inputs, max_length=max_length, truncation=True, padding=True)

    with tokeniser.as_target_tokenizer():
        labels = tokeniser(targets, max_length=max_length, truncation=True, padding=True)

    model_inputs["labels"] = labels['input_ids']

    return model_inputs

In [16]:
train_dataset = Dataset.from_pandas(train)
prep_train_dataset = train_dataset.map(
    preprocess_datasets,
    batched=True,
    fn_kwargs=dict(
        tokeniser=tokeniser,
        max_length=128+5
    ),
    remove_columns=['path_relative', 'path', 'model']
)

Map:   0%|          | 0/22856 [00:00<?, ? examples/s]

In [17]:
validation_dataset = Dataset.from_pandas(valid)

prep_eval_dataset = validation_dataset.map(
    preprocess_datasets,
    batched=True,
    fn_kwargs=dict(
        tokeniser=tokeniser,
        max_length=128+5
    ),
    remove_columns=['path_relative', 'path', 'model']
)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [44]:
test_dataset = Dataset.from_pandas(test)

prep_test_dataset = test_dataset.map(
    preprocess_datasets,
    batched=True,
    fn_kwargs=dict(
        tokeniser=tokeniser,
        max_length=128+5
    ),
    remove_columns=['path_relative', 'path', 'model']
)

Map:   0%|          | 0/9621 [00:00<?, ? examples/s]

In [46]:
datacollator = DataCollatorForSeq2Seq(tokenizer=tokeniser, model=model, return_tensors="pt", padding="longest")

In [47]:
training_args = Seq2SeqTrainingArguments(
        remove_unused_columns=True,
        output_dir="mt5_cis_new_after_rls",
        overwrite_output_dir=True,
        evaluation_strategy='epoch',
        save_strategy='epoch',
        ddp_find_unused_parameters=False,
        learning_rate=1e-3,
        per_device_train_batch_size=10,#32,
        per_device_eval_batch_size=10,#29,
#         weight_decay=1e-6,
        save_total_limit=2,
        num_train_epochs=10,
        predict_with_generate=True,
        do_predict=True,
        dataloader_num_workers=12,
        report_to="tensorboard",
        dataloader_pin_memory=False,
        label_smoothing_factor=0.3,
#         resume_from_checkpoint="mt5_cis_new/"
    )

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=prep_train_dataset,
    eval_dataset=prep_eval_dataset,
    tokenizer=tokeniser,
    data_collator=datacollator,
    compute_metrics=compute_metrics,
)

PyTorch: setting up devices


In [48]:
os.environ["WANDB_DISABLED"] = "true"
os.environ["TOKENIZERS_PARALLELISM"] = "true"

In [49]:
preds = trainer.predict(prep_test_dataset)

The following columns in the test set don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: __index_level_0__, text, text_clean, tok, pred. If __index_level_0__, text, text_clean, tok, pred are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 9621
  Batch size = 10


In [50]:
kek = pd.DataFrame([tokeniser.decode(i, skip_special_tokens =True) for i in preds.predictions], columns = ['ft_txt'])

In [51]:
kek['text'] = test_dataset['text']

Whisper

In [38]:
wer.compute(predictions=kek['ft_txt'], references=kek['text'])

0.15396455052954325

In [39]:
cer.compute(predictions=kek['ft_txt'], references=kek['text'])

0.07614353446805608

Vosk

In [52]:
wer.compute(predictions=kek['ft_txt'], references=kek['text'])

0.15401534776870815

In [53]:
cer.compute(predictions=kek['ft_txt'], references=kek['text'])

0.07616758335494261

Nemo

In [28]:
wer.compute(predictions=kek['ft_txt'], references=kek['text'])

0.1540011478953217

In [29]:
cer.compute(predictions=kek['ft_txt'], references=kek['text'])

0.07614112800787037