In [1]:
import nemo
import nemo.collections.asr as nemo_asr
import pytorch_lightning as pl
from omegaconf import DictConfig
import pathlib
import nemo.collections.asr as nemo_asr
import pytorch_lightning as pl
import os
import matplotlib.pyplot as plt
import re

[NeMo W 2021-11-06 00:13:34 optimizers:47] Apex was not found. Using the lamb optimizer will error out.
    
[NeMo W 2021-11-06 00:13:35 nmse_clustering:54] Using eigen decomposition from scipy, upgrade torch to 1.9 or higher for faster clustering
################################################################################
###          (please add 'export KALDI_ROOT=<your_path>' in your $HOME/.profile)
###          (or run as: KALDI_ROOT=<your_path> python <your_script>.py)
################################################################################

      '"sox" backend is being deprecated. '
    
[NeMo W 2021-11-06 00:13:35 experimental:28] Module <class 'nemo.collections.asr.data.audio_to_text_dali._AudioTextDALIDataset'> is experimental, not ready for production and is not fully supported. Use at your own risk.


In [2]:
datasets_dir = '../../datasets/'

In [3]:
import pandas as pd
import numpy as np
import fastwer

def calculate_score(dataset, model, model_name='-', k=None, log=True):
    print(f'Calculating score for model {model_name} on {dataset}')
    if dataset == 'LJSpeech':
        metadata = pd.read_csv(datasets_dir + 'LJSpeech-1.1/metadata_test.csv')
        if k is not None:
            metadata = metadata[:k]
        files = metadata['file_name'].apply(lambda x: f'{datasets_dir}/LJSpeech-1.1/wavs/{x}.wav').values
        texts = metadata['transcript'].values
    elif dataset == 'AN4':
        metadata = pd.read_csv(f'{datasets_dir}/an4/metadata.csv')
        files = metadata['file_name'].values
        texts = metadata['transcript'].values
        
    
    wer = []
    cer = []
    predictions = model.transcribe(paths2audio_files=files)
    r = np.random.randint(1, 10)
    print(texts[r])
    print(predictions[r].replace('⁇', ''))
    for i in range(len(predictions)):
        text = texts[i].lower()
        text = re.sub('[^a-zA-Z ]+', '', text)
        prediction = predictions[i]
        prediction = re.sub('[^a-zA-Z ]+', '', prediction)
        wer.append(fastwer.score_sent(text, prediction, char_level=False))
        cer.append(fastwer.score_sent(text, prediction, char_level=True))
    wer = np.array(wer)
    cer = np.array(cer)
    
    wer = wer[wer != float('+inf')]
    cer = cer[cer != float('+inf')]
    
    wer = np.round(np.mean(wer), 2)
    cer = np.round(np.mean(cer), 2)
    if log:
        print(f'wer:{np.round(wer, 2)}; cer:{np.round(cer, 2)}')
    
    return wer, cer

In [4]:
from itertools import zip_longest
from typing import Any, Callable, Dict, List, Optional

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_info

import copy

class PrintTableMetricsCallback(Callback):
    """Prints a table with the metrics in columns on every epoch end.
    Example::
        from pl_bolts.callbacks import PrintTableMetricsCallback
        callback = PrintTableMetricsCallback()
    Pass into trainer like so:
    .. code-block:: python
        trainer = pl.Trainer(callbacks=[callback])
        trainer.fit(...)
        # ------------------------------
        # at the end of every epoch it will print
        # ------------------------------
        # loss│train_loss│val_loss│epoch
        # ──────────────────────────────
        # 2.2541470527648926│2.2541470527648926│2.2158432006835938│0
    """

    def __init__(self) -> None:
        self.metrics: List = []

    def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        metrics_dict = copy.copy(trainer.callback_metrics)
        self.metrics.append(metrics_dict)
        rank_zero_info(dicts_to_table(self.metrics))
        
def dicts_to_table(
    dicts: List[Dict],
    keys: Optional[List[str]] = None,
    pads: Optional[List[str]] = None,
    fcodes: Optional[List[str]] = None,
    convert_headers: Optional[Dict[str, Callable]] = None,
    header_names: Optional[List[str]] = None,
    skip_none_lines: bool = False,
    replace_values: Optional[Dict[str, Any]] = None,
) -> str:
    """Generate ascii table from dictionary Taken from (https://stackoverflow.com/questions/40056747/print-a-list-
    of-dictionaries-in-table-form)
    Args:
        dicts: input dictionary list; empty lists make keys OR header_names mandatory
        keys: order list of keys to generate columns for; no key/dict-key should
            suffix with '____' else adjust code-suffix
        pads: indicate padding direction and size, eg <10 to right pad alias left-align
        fcodes: formating codes for respective column type, eg .3f
        convert_headers: apply converters(dict) on column keys k, eg timestamps
        header_names: supply for custom column headers instead of keys
        skip_none_lines: skip line if contains None
        replace_values: specify per column keys k a map from seen value to new value;
                        new value must comply with the columns fcode; CAUTION: modifies input (due speed)
    Example:
        >>> a = {'a': 1, 'b': 2}
        >>> b = {'a': 3, 'b': 4}
        >>> print(dicts_to_table([a, b]))
        a│b
        ───
        1│2
        3│4
    """
    # optional arg prelude
    if keys is None:
        if len(dicts) > 0:
            keys = dicts[0].keys()  # type: ignore[assignment]
        elif header_names is not None:
            keys = header_names
        else:
            raise ValueError("keys or header_names mandatory on empty input list")
    if pads is None:
        pads = [""] * len(keys)  # type: ignore[arg-type]
    elif len(pads) != len(keys):  # type: ignore[arg-type]
        raise ValueError(f"bad pad length {len(pads)}, expected: {len(keys)}")  # type: ignore[arg-type]
    if fcodes is None:
        fcodes = [""] * len(keys)  # type: ignore[arg-type]
    elif len(fcodes) != len(fcodes):
        raise ValueError(f"bad fcodes length {len(fcodes)}, expected: {len(keys)}")  # type: ignore[arg-type]
    if convert_headers is None:
        convert_headers = {}
    if header_names is None:
        header_names = keys
    if replace_values is None:
        replace_values = {}
    # build header
    headline = "│".join(f"{v:{pad}}" for v, pad in zip_longest(header_names, pads))  # type: ignore[arg-type]
    underline = "─" * len(headline)
    # suffix special keys to apply converters to later on
    marked_keys = [h + "____" if h in convert_headers else h for h in keys]  # type: ignore[union-attr]
    marked_values = {}
    s = "│".join(f"{{{h}:{pad}{fcode}}}" for h, pad, fcode in zip_longest(marked_keys, pads, fcodes))
    lines = [headline, underline]
    for d in dicts:
        none_keys = [k for k, v in d.items() if v is None]
        if skip_none_lines and none_keys:
            continue
        elif replace_values:
            for k in d.keys():
                if k in replace_values and d[k] in replace_values[k]:
                    d[k] = replace_values[k][d[k]]
                if d[k] is None:
                    raise ValueError(f"bad or no mapping for key '{k}' is None. Use skip or change replace mapping.")
        elif none_keys:
            raise ValueError(f"keys {none_keys} are None in {d}. Do skip or use replace mapping.")
        for h in convert_headers:
            if h in keys:  # type: ignore[operator]
                converter = convert_headers[h]
                marked_values[h + "____"] = converter(d)
        line = s.format(**d, **marked_values)
        lines.append(line)
    return "\n".join(lines)

In [5]:
try:
    from ruamel.yaml import YAML
except ModuleNotFoundError:
    from ruamel_yaml import YAML

config_path = 'stt_en_citrinet_256_gamma_0_25'
config_name = 'model_config.yaml'
yaml = YAML(typ='safe')

with open(os.path.join(config_path, config_name)) as f:
    config = yaml.load(f)

In [6]:
config['tokenizer']['dir'] = 'citrinet_tokenizer/tokenizer_spe_unigram_v1024'
config['tokenizer']['type'] = 'bpe'

config['train_ds']['manifest_filepath']="../../datasets/LJSpeech-1.1/train_manifest.json"
config['train_ds']['batch_size'] = 1
config['train_ds']['num_workers'] = 12
config['train_ds']['pin_memory'] = True

config['validation_ds']['manifest_filepath']="../../datasets/LJSpeech-1.1/test_manifest.json"
config['validation_ds']['batch_size'] = 1
config['validation_ds']['num_workers'] = 12
config['validation_ds']['pin_memory'] = True

config['spec_augment']['freq_masks'] = 0
config['spec_augment']['time_masks'] = 0
config['optim']['lr'] = 0.01
config['optim']['name'] = 'novograd'
config['optim']['betas'] = [0.8, 0.25]
config['optim']['weight_decay'] = 0.001
config['optim']['sched']['warmup_steps']=1000
config['optim']['sched']['min_lr'] = 0.00001

config['tokenizer']['model_path'] = 'stt_en_citrinet_256_gamma_0_25/3d20ebb793c84a64a20c7ad26fc64d62_tokenizer.model'
config['tokenizer']['vocab_path'] = 'stt_en_citrinet_256_gamma_0_25/df5191f216004f10a268c44e90fdb63f_vocab.txt'
config['tokenizer']['spe_tokenizer_vocab'] = 'stt_en_citrinet_256_gamma_0_25/b774eaac83804907843607272fde21a4_tokenizer.vocab'

In [10]:
callback = PrintTableMetricsCallback()

trainer = pl.Trainer(gpus=1, max_steps=10,\
                     precision=32, sync_batchnorm=False,\
                     benchmark=False
                    )

model_name = 'stt_en_citrinet_256_gamma_0_25'
freeze_encoder = False

# Load pretrained checkpoint
checkpoint = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(
    model_name
) 

# Preserve the models decoder weights
decoder_ckpt_copy = checkpoint.decoder.state_dict()

# Load finetuning model
asr_model = nemo_asr.models.EncDecCTCModelBPE(cfg=DictConfig(config), trainer=trainer)

# Load up weights or not
load_weights = False
if load_weights:
    # this allows decoder weights to be loaded if same shape as original citrinet (1024 subword encodings)
    asr_model.load_state_dict(checkpoint.state_dict(), strict=False)

    # Insert preserved model weights if shapes match
    if decoder_ckpt_copy['decoder_layers.0.weight'].shape == asr_model.decoder.decoder_layers[0].weight.shape:
        asr_model.decoder.load_state_dict(decoder_ckpt_copy)

# # release checkpoint memory
del checkpoint

# If freezing the encoder, unfreeze the batch norm and the squeeze and excite blocks
# for transfer learning
if freeze_encoder:
    asr_model.encoder.freeze()
    asr_model.encoder.apply(enable_bn_se)

# Train model
trainer.fit(asr_model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


[NeMo I 2021-11-06 00:15:06 cloud:56] Found existing object /home/boris/.cache/torch/NeMo/NeMo_1.4.0/stt_en_citrinet_256_gamma_0_25/d6eff3868f2f7a4791eb935c8366fc46/stt_en_citrinet_256_gamma_0_25.nemo.
[NeMo I 2021-11-06 00:15:06 cloud:62] Re-using file from: /home/boris/.cache/torch/NeMo/NeMo_1.4.0/stt_en_citrinet_256_gamma_0_25/d6eff3868f2f7a4791eb935c8366fc46/stt_en_citrinet_256_gamma_0_25.nemo
[NeMo I 2021-11-06 00:15:06 common:702] Instantiating model from pre-trained checkpoint
[NeMo I 2021-11-06 00:15:06 mixins:149] Tokenizer SentencePieceTokenizer initialized with 1024 tokens


[NeMo W 2021-11-06 00:15:07 modelPT:131] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: null
    sample_rate: 16000
    batch_size: 32
    trim_silence: false
    max_duration: 20.0
    shuffle: true
    is_tarred: false
    tarred_audio_filepaths: null
    use_start_end_token: false
    
[NeMo W 2021-11-06 00:15:07 modelPT:138] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). 
    Validation config : 
    manifest_filepath: null
    sample_rate: 16000
    batch_size: 32
    shuffle: false
    use_start_end_token: false
    
[NeMo W 2021-11-06 00:15:07 modelPT:144] Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple_test_data() method and provide a v

[NeMo I 2021-11-06 00:15:07 features:262] PADDING: 16
[NeMo I 2021-11-06 00:15:07 features:279] STFT using torch
[NeMo I 2021-11-06 00:15:08 save_restore_connector:143] Model EncDecCTCModelBPE was successfully restored from /home/boris/.cache/torch/NeMo/NeMo_1.4.0/stt_en_citrinet_256_gamma_0_25/d6eff3868f2f7a4791eb935c8366fc46/stt_en_citrinet_256_gamma_0_25.nemo.
[NeMo I 2021-11-06 00:15:08 mixins:149] Tokenizer SentencePieceTokenizer initialized with 1024 tokens
[NeMo I 2021-11-06 00:15:09 collections:173] Dataset loaded with 10480 files totalling 19.11 hours
[NeMo I 2021-11-06 00:15:09 collections:174] 0 files were filtered totalling 0.00 hours
[NeMo I 2021-11-06 00:15:09 collections:173] Dataset loaded with 2620 files totalling 4.81 hours
[NeMo I 2021-11-06 00:15:09 collections:174] 0 files were filtered totalling 0.00 hours


[NeMo W 2021-11-06 00:15:09 ctc_bpe_models:235] Could not load dataset as `manifest_filepath` was None. Provided config : {'manifest_filepath': None, 'sample_rate': 16000, 'batch_size': 4, 'shuffle': False, 'use_start_end_token': False}


[NeMo I 2021-11-06 00:15:09 features:262] PADDING: 16
[NeMo I 2021-11-06 00:15:09 features:279] STFT using torch


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


[NeMo I 2021-11-06 00:15:09 modelPT:544] Optimizer config = Novograd (
    Parameter Group 0
        amsgrad: False
        betas: [0.8, 0.25]
        eps: 1e-08
        grad_averaging: False
        lr: 0.01
        weight_decay: 0.001
    )
[NeMo I 2021-11-06 00:15:09 lr_scheduler:625] Scheduler "<nemo.core.optim.lr_scheduler.CosineAnnealing object at 0x7f65a64d5a50>" 
    will be used during training (effective maximum steps = 10) - 
    Parameters : 
    (warmup_steps: 1000
    warmup_ratio: null
    min_lr: 1.0e-05
    last_epoch: -1
    max_steps: 10
    )



  | Name              | Type                              | Params
------------------------------------------------------------------------
0 | preprocessor      | AudioToMelSpectrogramPreprocessor | 0     
1 | encoder           | ConvASREncoder                    | 9.1 M 
2 | decoder           | ConvASRDecoder                    | 657 K 
3 | loss              | CTCLoss                           | 0     
4 | spec_augmentation | SpectrogramAugmentation           | 0     
5 | _wer              | WERBPE                            | 0     
------------------------------------------------------------------------
9.8 M     Trainable params
0         Non-trainable params
9.8 M     Total params
39.145    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: -1it [00:00, ?it/s]

In [18]:
trainer.optimizers

[Novograd (
 Parameter Group 0
     amsgrad: False
     betas: [0.8, 0.25]
     eps: 1e-08
     grad_averaging: False
     initial_lr: 0.01
     lr: 9.99000999000999e-06
     weight_decay: 0.001
 )]