In [1]:
from huggingface_hub import login, HfFolder
import os
huggingface_token = os.getenv('HUGGINGFACE_TOKEN')

login(token=huggingface_token, add_to_git_credential=True)

Token is valid (permission: read).
Your token has been saved in your configured git credential helpers (manager).
Your token has been saved to C:\Users\Zhenya\.cache\huggingface\token
Login successful


In [2]:
import csv

import os
import sys
import time
import warnings
from dataclasses import dataclass, field
from datetime import timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import datasets
import evaluate
import numpy as np
import torch
import transformers
from accelerate import Accelerator, InitProcessGroupKwargs

from datasets import (
    DatasetDict,
    IterableDatasetDict,
    load_dataset,
)
from huggingface_hub import HfFolder, create_repo, get_full_repo_name, snapshot_download, upload_folder
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
    HfArgumentParser,
    Seq2SeqTrainingArguments,
    WhisperConfig,
    WhisperFeatureExtractor,
    WhisperForConditionalGeneration,
    WhisperProcessor,
    WhisperTokenizerFast,
)
from transformers.models.whisper.english_normalizer import BasicTextNormalizer, EnglishTextNormalizer
from transformers.utils import check_min_version
from transformers.utils.versions import require_version


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.34.0.dev0")

require_version("datasets>=2.14.6", "To fix: `pip install --upgrade datasets`")



In [3]:
DEBUG_MODE = True
raw_datasets = DatasetDict()
# sampling_rate = 16_000
# 3. Load dataset
all_eval_datasets_list = []

spits_to_load = ['train', 'validation','test']

# 3. Load dataset
raw_datasets = DatasetDict()
token = HfFolder().get_token()

for split in spits_to_load:
    print(split)
    raw_datasets[split] = datasets.load_from_disk(
            f'datasets/mozila_uk/{split}',
        )

    if DEBUG_MODE:
        raw_datasets[split] = raw_datasets[split].select(range(100))
    

train
validation
test


In [4]:
from run_pseudo_labelling import ModelArguments, DataTrainingArguments, shift_tokens_right, DataCollatorSpeechSeq2SeqWithPadding, log_metric, log_pred 

In [6]:
%load_ext autoreload
%autoreload 1

%aimport src.utils
%aimport src.load_model

from src.utils import prepare_accelerator
from src.load_model import load_config_feature_ext_tokenizer, load_processor, load_whisper_model

In [7]:
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
model_args, data_args, training_args = parser.parse_json_file(json_file="pipe_configs/pseudo_v0.json")

In [8]:
accelerator, model_dtype, logger = prepare_accelerator(input_dtype=model_args.dtype, project_name=data_args.wandb_project, training_args=training_args)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mzekamrozek[0m. Use [1m`wandb login --relogin`[0m to force relogin




In [9]:
config, feature_extractor, tokenizer = load_config_feature_ext_tokenizer(
    model_name_or_path="openai/whisper-medium", model_args=model_args)

processor = load_processor(processor_path="openai/whisper-medium", model_args=model_args)

model = load_whisper_model(model_args.model_name_or_path, model_args, dtype=model_dtype)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [10]:
config

WhisperConfig {
  "_name_or_path": "openai/whisper-medium",
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "apply_spec_augment": false,
  "architectures": [
    "WhisperForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "begin_suppress_tokens": [
    220,
    50257
  ],
  "bos_token_id": 50257,
  "classifier_proj_size": 256,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 24,
  "decoder_start_token_id": 50258,
  "dropout": 0.0,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 24,
  "eos_token_id": 50257,
  "forced_decoder_ids": [
    [
      1,
      50259
    ],
    [
      2,
      50359
    ],
    [
      3,
      50363
    ]
  ],
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "mask_feature_length": 10,
  "mask_feature_min_masks": 0,
  "mask_feature_prob": 0.0,
  "mask_time_length": 10,
  "mask_time_min_masks": 2,
 

In [10]:
model.eval()

if model.config.decoder_start_token_id is None:
    raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

return_timestamps = data_args.return_timestamps
if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
    # We need to set the language and task ids for multilingual checkpoints
    tokenizer.set_prefix_tokens(
        language=data_args.language, task=data_args.task, predict_timestamps=return_timestamps
    )
elif data_args.language is not None:
    raise ValueError(
        "Setting language token for an English-only checkpoint is not permitted. The language argument should "
        "only be set for multilingual checkpoints."
    )

In [11]:
def prepare_normilazer(language, tokenizer=None):

    if language is not None:
        normalizer = BasicTextNormalizer()
    else:
        normalizer = EnglishTextNormalizer(tokenizer.english_spelling_normalizer)
    return normalizer

normalizer = prepare_normilazer(language=data_args.language, tokenizer=tokenizer)

In [12]:
raw_datasets = raw_datasets.cast_column(
    data_args.audio_column_name,
    datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate),
)

In [13]:
# import matplotlib.pyplot as plt
# plt.plot(raw_datasets['train'][0]['audio']['array'])

In [14]:
max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
max_label_length = (
    data_args.max_label_length if data_args.max_label_length is not None else model.config.max_length
)
audio_column_name = data_args.audio_column_name
sampling_rate = feature_extractor.sampling_rate

preprocessing_batch_size = data_args.preprocessing_batch_size
num_workers = data_args.preprocessing_num_workers
dataloader_num_workers = training_args.dataloader_num_workers

text_column_name = data_args.text_column_name
model_input_name = feature_extractor.model_input_names[0]
id_column_name = data_args.id_column_name
speaker_id_column_name = data_args.speaker_id_column_name

In [15]:
timestamp_position = 1
decoder_prev_token_id = tokenizer.convert_tokens_to_ids("<|startofprev|>")
decoder_eot_token_id = tokenizer.eos_token_id

In [16]:
if speaker_id_column_name is not None:
    raw_datasets = raw_datasets.sort(speaker_id_column_name)

In [17]:
def concatenate_dataset(batch):
    audio = [sample["array"] for sample in batch[audio_column_name]]
    input_lengths = [len(sample) for sample in audio]

    text = batch[text_column_name]
    speaker_id = batch[speaker_id_column_name] if speaker_id_column_name else len(text) * [None]

    concatenated_audio = []
    concatenated_text = []
    concatenated_speaker = []
    condition_on_prev = []
    audio_sample = audio[0]
    text_sample = text[0]

    for idx in range(1, len(audio)):
        prev_speaker = speaker_id[idx - 1]
        speaker = speaker_id[idx]

        if len(audio_sample) + input_lengths[idx] < max_input_length:
            if speaker == prev_speaker:
                # we have no information about whether the segments follow on sequentially
                # so we just ensure the same speaker as we concatenate across files
                audio_sample = np.append(audio_sample, audio[idx])
                # extra spaces in the text transcription don't matter, since we only use it for the WER computation
                text_sample += " " + text[idx]
            else:
                # speakers do not follow sequentially, save the audio and start looping again
                concatenated_audio.append(audio_sample)
                concatenated_text.append(text_sample)
                concatenated_speaker.append(speaker)
                condition_on_prev.append(0)
                audio_sample = audio[idx]
                text_sample = text[idx]

        else:
            # concatenated audio exceeds max length, save the audio and start looping again
            concatenated_audio.append(audio_sample)
            concatenated_text.append(text_sample)
            concatenated_speaker.append(speaker)
            condition_on_prev.append(1)
            audio_sample = audio[idx]
            text_sample = text[idx]

    batch[audio_column_name] = [{"array": array, "sampling_rate": sampling_rate} for array in concatenated_audio]
    batch[text_column_name] = concatenated_text
    batch[id_column_name] = concatenated_speaker
    batch["condition_on_prev"] = condition_on_prev

    return batch

In [18]:
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
if data_args.concatenate_audio and not data_args.streaming:
    with accelerator.main_process_first():
        raw_datasets = raw_datasets.map(
            concatenate_dataset,
            batched=True,
            batch_size=preprocessing_batch_size,
            num_proc=1,
            remove_columns=set(raw_datasets_features)
            - {audio_column_name, text_column_name, id_column_name, "condition_on_prev"},
            desc="Concatenating dataset...",
        )



In [19]:
dataset_name ="mozila-uk"

raw_datasets = raw_datasets.cast_column(
        audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
    )

In [20]:

def postprocess_ids(speaker_ids, indices):
    speaker_ids_formatted = []
    for speaker, idx in zip(speaker_ids, indices):
        formatted_idx = f"{dataset_name}-{speaker}-{idx}" if speaker is not None else f"{dataset_name}-{idx}"
        speaker_ids_formatted.append(formatted_idx)
    return {id_column_name: speaker_ids_formatted}

In [21]:
with accelerator.main_process_first():
    raw_datasets = raw_datasets.map(
        postprocess_ids,
        input_columns=[id_column_name],
        with_indices=True,
        desc="Setting sample idxs...",
        batched=True,
        batch_size=preprocessing_batch_size,
        num_proc=1,
    )

In [22]:
file_ids_dataset = IterableDatasetDict() if data_args.streaming else DatasetDict()
for split in raw_datasets:
    file_ids_dataset[split] = raw_datasets[split][id_column_name]

In [23]:
def prepare_dataset(batch):
    # process audio
    sample = batch[audio_column_name]
    inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
    # process audio length
    batch[model_input_name] = inputs.get(model_input_name)[0]

    # process targets
    input_str = batch[text_column_name]
    batch["labels"] = tokenizer(input_str, max_length=max_label_length, truncation=True).input_ids
    return batch

In [24]:
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())

with accelerator.main_process_first():
    vectorized_datasets = raw_datasets.map(
        prepare_dataset,
        remove_columns=raw_datasets_features,
        num_proc=1,
        desc="preprocess dataset",
    )

preprocess dataset:   0%|          | 0/1658 [00:00<?, ? examples/s]

In [25]:
os.makedirs(training_args.output_dir, exist_ok=True)

In [26]:
metric = evaluate.load("wer")

In [27]:
def compute_metrics(preds, labels, file_ids):
    # replace padded labels by the padding token
    for idx in range(len(labels)):
        labels[idx][labels[idx] == -100] = tokenizer.pad_token_id

    pred_str = tokenizer.batch_decode(preds, skip_special_tokens=False, decode_with_timestamps=return_timestamps)
    # we do not want to group tokens when computing the metrics
    label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # normalize everything and re-compute the WER
    norm_pred_str = [normalizer(pred) for pred in pred_str]
    norm_label_str = [normalizer(label) for label in label_str]
    # for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
    pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
    label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
    file_ids = [file_ids[i] for i in range(len(file_ids)) if len(norm_label_str[i]) > 0]
    # filtering step to only evaluate the samples that correspond to non-zero normalized references:
    norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
    norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]

    wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)

    return {"wer": wer}, pred_str, label_str, norm_pred_str, norm_label_str, file_ids

def filter_eot_tokens(preds):
    for idx in range(len(preds)):
        # remove the EOT tokens to get the 'true' token length
        token_ids = [token for token in preds[idx] if token != decoder_eot_token_id]
        token_ids = token_ids + [decoder_eot_token_id]
        preds[idx] = token_ids
    return preds

In [28]:
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,  # <|startoftranscript|>
    input_padding="longest",
    target_padding="max_length",
    max_target_length=max_label_length,
)

In [29]:
gen_kwargs = {
    "max_length": max_label_length,
    "num_beams": 1,
    "return_timestamps": data_args.return_timestamps,
    "language": data_args.language,
    "task": data_args.task,
}

In [30]:
model.generation_config.forced_decoder_ids = None
model.config.forced_decoder_ids = None

# 15. Prepare everything with accelerate
model = accelerator.prepare(model)

In [31]:
output_dir = training_args.output_dir

In [32]:
def eval_step_with_save(split="eval", decoder_eot_token_id=decoder_eot_token_id, decoder_prev_token_id=decoder_prev_token_id, timestamp_position=timestamp_position):
    # ======================== Evaluating ==============================
    eval_preds = []
    eval_labels = []
    eval_ids = []
    pred_str = []
    eval_start = time.time()

    eval_loader = DataLoader(
        vectorized_datasets[split],
        batch_size=per_device_eval_batch_size,
        collate_fn=data_collator,
        num_workers=dataloader_num_workers,
        pin_memory=True,
    )
    file_loader = DataLoader(
        file_ids_dataset[split],
        batch_size=per_device_eval_batch_size * accelerator.num_processes,
        num_workers=dataloader_num_workers,
    )

    eval_loader = accelerator.prepare(eval_loader)
    batches = tqdm(eval_loader, desc=f"Evaluating {split}...", disable=not accelerator.is_local_main_process)

    # make the split name pretty for librispeech etc
    split = split.replace(".", "-").split("/")[-1]
    output_csv = os.path.join(output_dir, f"{split}-transcription.csv")

    for step, (batch, file_ids) in enumerate(zip(batches, file_loader)):
        # Generate predictions and pad to max generated length
        generate_fn = model.module.generate if accelerator.num_processes > 1 else model.generate
        generated_ids = generate_fn(batch["input_features"].to(dtype=model_dtype), **gen_kwargs)
        generated_ids = accelerator.pad_across_processes(generated_ids, dim=1, pad_index=tokenizer.pad_token_id)
        # Gather all predictions and targets
        generated_ids, labels = accelerator.gather_for_metrics((generated_ids, batch["labels"]))
        eval_preds.extend(generated_ids.cpu().numpy())
        eval_labels.extend(labels.cpu().numpy())
        eval_ids.extend(file_ids)

        if step % training_args.logging_steps == 0 and step > 0:
            batches.write(f"Saving transcriptions for split {split} step {step}")
            accelerator.wait_for_everyone()
            pred_ids = eval_preds[-(len(eval_preds) - len(pred_str)) :]
            pred_ids = filter_eot_tokens(pred_ids)
            pred_str.extend(
                tokenizer.batch_decode(
                    pred_ids, skip_special_tokens=False, decode_with_timestamps=return_timestamps
                )
            )
            csv_data = [[eval_ids[i], pred_str[i]] for i in range(len(eval_preds))]

            with open(output_csv, "w", encoding="UTF8", newline="") as f:
                writer = csv.writer(f)
                # write multiple rows
                writer.writerow(["file_id", "whisper_transcript"])
                writer.writerows(csv_data)

            # if training_args.push_to_hub and accelerator.is_main_process:
            #     upload_folder(
            #         folder_path=output_dir,
            #         repo_id=repo_name,
            #         repo_type="dataset",
            #         commit_message=f"Saving transcriptions for split {split} step {step}.",
            #     )

    accelerator.wait_for_everyone()
    eval_time = time.time() - eval_start

    # compute WER metric for eval sets
    wer_desc = ""
    if "validation" in split or "test" in split:
        eval_preds = filter_eot_tokens(eval_preds)
        wer_metric, pred_str, label_str, norm_pred_str, norm_label_str, eval_ids = compute_metrics(
            eval_preds, eval_labels, eval_ids
        )
        wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
        # Save metrics + predictions
        log_metric(
            accelerator,
            metrics=wer_metric,
            train_time=eval_time,
            prefix=split,
        )
        log_pred(
            accelerator,
            pred_str,
            label_str,
            norm_pred_str,
            norm_label_str,
            prefix=split,
        )
    else:
        pred_ids = eval_preds[-(len(eval_preds) - len(pred_str)) :]
        pred_ids = filter_eot_tokens(pred_ids)
        pred_str.extend(
            tokenizer.batch_decode(pred_ids, skip_special_tokens=False, decode_with_timestamps=return_timestamps)
        )

    batches.write(f"Saving final transcriptions for split {split}.")
    csv_data = [[eval_ids[i], eval_preds[i]] for i in range(len(eval_preds))]
    with open(output_csv, "w", encoding="UTF8", newline="") as f:
        writer = csv.writer(f)
        # write multiple rows
        writer.writerow(["file_id", "whisper_transcript"])
        writer.writerows(csv_data)

    # Print metrics
    logger.info(wer_desc)

    if not data_args.streaming:
        raw_datasets[split] = raw_datasets[split].add_column("whisper_transcript", pred_str)
        raw_datasets[split] = raw_datasets[split].add_column("eval_preds", eval_preds)

        def add_concatenated_text(eval_preds, condition_on_prev):
            concatenated_prev = [None]
            for token_ids, condition in zip(eval_preds[:-1], condition_on_prev[1:]):
                if condition is False:
                    concatenated_prev.append(None)
                else:
                    prompt_ids = [token for token in token_ids if token != decoder_eot_token_id]
                    prompt_ids = [decoder_prev_token_id] + prompt_ids[timestamp_position:]
                    concatenated_prev.append(prompt_ids)
            return {"condition_on_prev": concatenated_prev}

        with accelerator.main_process_first():
            raw_datasets[split] = raw_datasets[split].map(
                add_concatenated_text,
                input_columns=["eval_preds", "condition_on_prev"],
                remove_columns=["eval_preds"],
                desc="Setting condition on prev...",
                batched=True,
                batch_size=preprocessing_batch_size,
                num_proc=num_workers,
            )

In [33]:
logger.info("***** Running Labelling *****")
logger.info("  Instantaneous batch size per device =" f" {training_args.per_device_eval_batch_size}")
logger.info(
    f"  Total eval batch size (w. parallel & distributed) = {training_args.per_device_eval_batch_size * accelerator.num_processes}"
)
logger.info(f"  Predict labels with timestamps = {return_timestamps}")


for split in spits_to_load:
    eval_step_with_save(split=split)
    accelerator.wait_for_everyone()

raw_datasets.save_to_disk(output_dir, num_proc=num_workers)

accelerator.end_training()

05/07/2024 19:24:41 - INFO - src.utils - ***** Running Labelling *****
05/07/2024 19:24:41 - INFO - src.utils -   Instantaneous batch size per device = 35
05/07/2024 19:24:41 - INFO - src.utils -   Total eval batch size (w. parallel & distributed) = 35
05/07/2024 19:24:41 - INFO - src.utils -   Predict labels with timestamps = True
  attn_output = torch.nn.functional.scaled_dot_product_attention(
Evaluating train...: 100%|██████████| 101/101 [15:40<00:00,  9.32s/it]


Saving final transcriptions for split train.


05/07/2024 19:40:28 - INFO - src.utils - 


Setting condition on prev... (num_proc=4):   0%|          | 0/3529 [00:00<?, ? examples/s]

Evaluating validation...: 100%|██████████| 48/48 [07:34<00:00,  9.46s/it]
05/07/2024 19:48:13 - INFO - src.utils - Eval wer: 16.204646768812918 |


Saving final transcriptions for split validation.


Setting condition on prev... (num_proc=4):   0%|          | 0/1658 [00:00<?, ? examples/s]

Evaluating test...: 100%|██████████| 50/50 [2:45:08<00:00, 198.18s/it]    
05/07/2024 22:33:37 - INFO - src.utils - Eval wer: 17.418362864255382 |


Saving final transcriptions for split test.


Saving the dataset (0/7 shards):   0%|          | 0/3529 [00:00<?, ? examples/s]

Saving the dataset (0/4 shards):   0%|          | 0/1658 [00:00<?, ? examples/s]

Saving the dataset (0/4 shards):   0%|          | 0/1734 [00:00<?, ? examples/s]

VBox(children=(Label(value='66.085 MB of 66.085 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/time,▁
test/wer,▁
validation/time,▁
validation/wer,▁

0,1
test/time,9909.19393
test/wer,17.41836
validation/time,454.54624
validation/wer,16.20465


In [34]:
output_dir

'datasets/labeled_mozila_3'

In [35]:
decoder_eot_token_id

50257