In [None]:
with open('req.txt', 'w', encoding='utf-8') as f:
  f.write('''accelerate==0.15.0
datasets==2.18.0
huggingface-hub==0.22.2
jiwer==2.5.1
librosa==0.9.2
espnet==202211
espnet_model_zoo==0.1.7
natsort==8.2.0
numpy==1.23.5
omegaconf==2.2.3
pandas==1.5.2
parallel-wavegan==0.5.5
pyctcdecode==0.4.0
soundfile==0.11.0
torch==1.13.0
torchaudio==0.13.0
tqdm==4.64.1
transformers==4.24.0
wandb==0.13.4''')

In [None]:
%%capture
!pip install -r req.txt

In [None]:
import omegaconf as oc
import transformers as hft
import librosa as lb
import datasets as hfds
import jiwer
import pandas as pd
import soundfile as sf
import os
import wandb
from tqdm import tqdm
import random
import numpy as np
import torch
import typing
import dataclasses
import json
from datasets import Dataset
from pathlib import Path
from transformers import set_seed

In [None]:
def prRed(skk): print("\033[91m {}\033[00m" .format(skk))
def prGreen(skk): print("\033[92m {}\033[00m" .format(skk))
def prYellow(skk): print("\033[93m{}\033[00m" .format(skk))
def prLightPurple(skk): print("\033[94m {}\033[00m" .format(skk))
def prPurple(skk): print("\033[95m {}\033[00m" .format(skk))
def prCyan(skk): print("\033[96m {}\033[00m" .format(skk))
def prLightGray(skk): print("\033[97m {}\033[00m" .format(skk))
def prBlack(skk): print("\033[98m {}\033[00m" .format(skk))

def announce(announcement):
    pad_length  = 5

    print(f"{'-' * pad_length} {announcement} {'-' * pad_length}")

def make_config(config):

    # Overwrite config vars with anything supplied in the command line
    config = oc.OmegaConf.merge(
        oc.OmegaConf.load(config['--config']),
        oc.OmegaConf.from_cli()
    )

    flat_args_long = pd.json_normalize(oc.OmegaConf.to_container(config), sep=".").melt(var_name='argument')
    missing_args   = flat_args_long.query("value == '???'")

    assert len(missing_args) == 0, f"""

    The following required arguments are missing:

        {','.join(missing_args['argument'].to_list())}

    """

    announce("Configuring environment")

    # Set environment variables
    for key, value in config['env'].items():

        if key == 'CUDA_VISIBLE_DEVICES':
            # OmegaConf will coerce number-like values into integers
            # but CUDA_VISIBLE_DEVICES should be a (comma-seperated) string
            value = str(value)

        os.environ[key] = value

    if not 'wandb' in config.keys():

        return config, None

    else:
        run = wandb.init(allow_val_change=True, settings=wandb.Settings(code_dir="."), **config['wandb'])

        if config.get("--run_name"):
            # Interpolate 'lr={tranargs[learning_rate]}' to 'lr=0.0001', where config['tranargs']['learning_rate'] = 0.0001
            run.name = config["--run_name"].format(**config)

        # Log hyper-parameters not automatically tracked by wandb
        untracked_args = flat_args_long[ ~flat_args_long.argument.str.contains("w2v2|trainargs|wandb|--", regex=True) ]
        # Convert to flat dict, e.g. { 'data.base_path' : '/path/to/the/data' }
        untracked_args = dict([ (d['argument'], d['value']) for d in untracked_args.to_dict(orient='records') ])

        wandb.config.update(untracked_args, allow_val_change=True)

        config['trainargs']['report_to'] = "wandb"

        return config, run

def load_datasets(data_config, processor):

    announce("Loading data ...")

    def _tsv2ds(tsv_file):

        tsv_path = os.path.join(data_config['base_path'], data_config[tsv_file])

        print(f"Reading split from {tsv_path}")

        df = pd.read_csv(tsv_path, sep='\t')

        for c in ['path_col', 'text_col']:
            col_name = data_config[c]

            assert col_name in df.columns, f"\n\n\tDataset {tsv_path} is missing '{col_name}' column\n"

        # Normalize column names
        df = df.rename(columns = {
            data_config['path_col'] : 'path',
            data_config['text_col'] : 'text'
        })

        def _read_audio(path):
            full_path = os.path.join(data_config['base_path'], path)

            data, sr = lb.load(full_path, sr=None)

            assert sr == 16_000

            return data

        df['audio'] = [ _read_audio(path) for path in tqdm(df['path'].to_list(), desc="Reading audio data") ]

        if 'subset_train' in data_config and tsv_file == 'train_tsv':

            df = df.sample(frac=1, random_state=data_config['subset_train']['seed']).copy().reset_index(drop=True)
            df = df[ df['audio'].apply(lambda s: len(s)/16_000).cumsum() <= (60 * data_config['subset_train']['mins']) ].copy().reset_index(drop=True)

            prYellow(f"Subsetted training data as specified: {data_config['subset_train']['mins']} minutes, random seed {data_config['subset_train']['seed']}. Rows kept: {len(df)}")

        # see files in subset
        print("Files in training subset:")
        for f in df['path'].to_list():
            print(f)

        dataset = hfds.Dataset.from_pandas(df[['audio', 'text']])

        return dataset

    datasets = hfds.DatasetDict({
        'train' : _tsv2ds('train_tsv'),
        'eval' : _tsv2ds('eval_tsv')
    })

    def _to_inputs_and_labels(batch):
        batch["input_values"] = processor(batch["audio"], sampling_rate=16000).input_values[0]

        batch["labels"] = processor.tokenizer(batch["text"]).input_ids

        return batch

    announce("Preparing input features and labels ...")

    datasets = datasets.map(_to_inputs_and_labels, remove_columns=['audio', 'text'])

    return datasets

In [None]:
def configure_hf_w2v2_model(config):

    print(f"Loading {config['w2v2']['model']['pretrained_model_name_or_path']} model ...")

    # Set verbosity to error while loading models (skips warnings about loading a not-yet fine-tuned model)
    hft.logging.set_verbosity_error()

    # Re-use the vocab.json from the fine-tuned model instead of re-deriving it from the train/test data

    # !wget https://huggingface.co/facebook/wav2vec2-large-960h/raw/main/vocab.json

    if config['w2v2']['tok']['vocab_file'] is None:
        # Load tokenizer from model if already fine-tuned
        processor = hft.Wav2Vec2Processor.from_pretrained(config['w2v2']['model']['pretrained_model_name_or_path'])

    else:
        # Create a new processor (i.e. fine-tuning for the first time)
        processor = hft.Wav2Vec2Processor(
            tokenizer=hft.Wav2Vec2CTCTokenizer(**(config['w2v2']['tok'] or {})),
            feature_extractor=hft.Wav2Vec2FeatureExtractor(**(config['w2v2']['fext'] or {})),
            **(config['w2v2']['proc'] or {})
        )

    processor.save_pretrained(config['trainargs']['output_dir'])

    model_config = hft.AutoConfig.from_pretrained(config['w2v2']['model']['pretrained_model_name_or_path'])

    # set vocab size
    f = open(config['w2v2']['tok']['vocab_file'])
    vocab = json.load(f)
    vocab_size = max(vocab.values()) + 1
    model_config.vocab_size = vocab_size

    config['w2v2']['model']['pad_token_id'] = processor.tokenizer.pad_token_id
    config['w2v2']['model']['ctc_zero_infinity'] = True

    model_config.update(config['w2v2']['model'])

    model = hft.Wav2Vec2ForCTC.from_pretrained(
        config['w2v2']['model']['pretrained_model_name_or_path'],
        config=model_config
    )

    model.freeze_feature_encoder()

    return model, processor

@dataclasses.dataclass
class DataCollatorCTCWithPadding:

    processor: hft.Wav2Vec2Processor
    padding: typing.Union[bool, str] = True

    def __call__(self, features: typing.List[typing.Dict[str, typing.Union[typing.List[int], torch.Tensor]]]) -> typing.Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )

        labels_batch = self.processor.tokenizer.pad(
            label_features,
            padding=self.padding,
            return_tensors="pt",
        )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

class MetricsComputer:

    def __init__(self, config, processor):

        self.processor = processor
        self.report_to = config['trainargs']['report_to']

        decode_method = config['w2v2']['decode']['method']

        assert decode_method in ['greedy', 'beam_search'], f"\n\tError: Unrecognized decoding method '{decode_method}'"

        if decode_method == 'greedy':

            self.decoder = self.greedy_decoder

        elif decode_method == 'beam_search':

            from torchaudio.models.decoder import ctc_decoder
            from functools import partial

            _decoder = ctc_decoder(
                lexicon=None,
                tokens=list(processor.tokenizer.get_vocab().keys()),
                blank_token=processor.tokenizer.pad_token,
                sil_token=processor.tokenizer.word_delimiter_token,
                unk_word=processor.tokenizer.unk_token,
                **config['w2v2']['decode']['args']
            )

            self.decoder = partial(self.beam_search_decoder, decoder=_decoder)

    def __call__(self, pred):

        labels = self.get_labels(pred)
        preds  = self.decoder(pred)

        wer, cer = self.compute_metrics(labels, preds)

        return { "wer" : wer, "cer" : cer }

    def get_labels(self, pred):
        # Replace data collator padding with tokenizer's padding
        pred.label_ids[pred.label_ids == -100] = self.processor.tokenizer.pad_token_id
        # Retrieve labels as characters, e.g. 'hello', from label_ids, e.g. [5, 3, 10, 10, 2] (where 5 = 'h')
        label_str = self.processor.tokenizer.batch_decode(pred.label_ids, group_tokens=False)

        return label_str

    def beam_search_decoder(self, pred, decoder):

        pred_logits = torch.tensor(pred.predictions, dtype=torch.float32)

        from tqdm import tqdm
        from joblib import Parallel, delayed

        def logits_to_preds(logits):
            # unsqueeze to make logits to shape (B=1, T, V) expected by decode
            # instead of just (T, V), where B = batch, T = time steps, V = vocab size
            hypotheses = decoder(logits.unsqueeze(0))

            # Subset to get hypotheses for first example (of batch size 1)
            hypotheses = hypotheses[0]

            # Return top hypothesis as a string
            return self.processor.decode(hypotheses[0].tokens)

        # Decode in parallel
        pred_str = Parallel(n_jobs=-1, verbose=0, prefer="threads")(delayed(logits_to_preds)(l) for l in tqdm(pred_logits, desc="Running beam search decoding ..."))

        return pred_str

    def greedy_decoder(self, pred):

        pred_logits = pred.predictions
        pred_ids = np.argmax(pred_logits, axis=-1)
        pred_str = self.processor.batch_decode(pred_ids)

        return pred_str

    def compute_metrics(self, labels, preds):

        scoring_df = pd.DataFrame({"Reference" : labels, "Prediction"  : preds})

        if self.report_to == 'wandb':
            wandb.log({ "asr_out": wandb.Table(data=scoring_df) })

        # Print two newlines first to separate table from progress bar
        print("\n\n")
        print(scoring_df)

        wer = jiwer.wer(labels, preds)
        cer = jiwer.cer(labels, preds)

        return wer, cer

# Adapted from https://discuss.huggingface.co/t/weights-biases-supporting-wave2vec2-finetuning/4839/4
def get_flat_linear_schedule_with_warmup(optimizer, num_warmup_steps:int, num_training_steps:int, last_epoch:int =-1, lr_warmup_pc=0.1, lr_const_pc=0.4):

    def lr_lambda(current_step):
        constant_steps = int(num_training_steps * lr_const_pc)
        warmup_steps = int(num_training_steps * lr_warmup_pc)

        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        elif current_step < warmup_steps+constant_steps:
            return 1
        else:
            return max(
                0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - (warmup_steps+constant_steps)))
            )

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)

def get_flat_cheduler(name = None, optimizer = None, num_warmup_steps = None, num_training_steps = None):
    return get_flat_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)

class ReplicationTrainer(hft.Trainer):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def create_flat_scheduler(self, num_training_steps: int):
        self.lr_scheduler = get_flat_cheduler(optimizer = self.optimizer,
                                              num_training_steps=num_training_steps)

    def create_optimizer_and_scheduler(self, num_training_steps):
        self.create_optimizer()
        self.create_flat_scheduler(num_training_steps)


In [None]:
config, wandb_run = make_config({'--config': '/content/drive/MyDrive/configs/generate_config.yaml'})

announce('Configuring model and reading data')

model, processor = configure_hf_w2v2_model(config)
model = model.eval().cuda()

devset_path = os.path.join(config['data'].base_path, config['data'].transcribe_tsv)

print(f"Data to transcribe: {devset_path}")

----- Configuring environment -----


ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpburub[0m. Use [1m`wandb login --relogin`[0m to force relogin


----- Configuring model and reading data -----
Loading /content/drive/MyDrive/model_ft/checkpoint-60000 model ...
Data to transcribe: /content/drive/MyDrive/copt/train.tsv


In [None]:
dev_ds = pd.read_csv(devset_path, sep = '\t')

In [None]:
def _read_audio(path):
    full_path = os.path.join(config['data'].base_path, path)

    data, sr = sf.read(full_path)

    assert sr == 16_000

    return data

dev_ds['audio'] = [ _read_audio(path) for path in tqdm(dev_ds['path'].to_list(), desc='Reading audio data') ]

Reading audio data: 100%|██████████| 3000/3000 [20:44<00:00,  2.41it/s]


In [None]:
dev_ds = Dataset.from_pandas(dev_ds[['audio', 'path']])

In [None]:
announce('Evaluating model')

def evaluate(batch):
    inputs = processor(batch['audio'], sampling_rate=16_000, return_tensors='pt', padding=True)

    with torch.no_grad():
        logits = model(inputs.input_values.to('cuda'), attention_mask=inputs.attention_mask.to('cuda')).logits

    pred_ids = np.argmax(logits.cpu(), axis=-1)
    batch['transcription'] = processor.batch_decode(pred_ids)

    return batch

dev_ds = dev_ds.map(evaluate, batched=True, batch_size=2)
dev_ds = dev_ds.to_pandas()

----- Evaluating model -----


Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

In [None]:
announce('Transcribing')
os.makedirs('./data/transcriptions/', exist_ok=True)
dev_ds[['path', 'transcription']].to_csv('/content/drive/MyDrive/copt/file.tsv', sep = '\t')

----- Transcribing -----
