In [None]:
!pip install nemo_toolkit['all']
!pip install hydra-core==1.1

In [3]:
import json
import random
from typing import Any, List, Optional, Union

import torch
from omegaconf import DictConfig, open_dict
from omegaconf.listconfig import ListConfig
from pytorch_lightning.callbacks import BasePredictionWriter
from torch.utils.data import ChainDataset

import modified_audio_to_text, modified_audio_to_text_dali
from nemo.utils import logging

In [None]:
def get_bpe_dataset(
    config: dict, tokenizer: 'TokenizerSpec', augmentor: Optional['AudioAugmentor'] = None
) -> modified_audio_to_text.AudioToBPEDataset:
    """
    Instantiates a Byte Pair Encoding / Word Piece Encoding based AudioToBPEDataset.

    Args:
        config: Config of the AudioToBPEDataset.
        tokenizer: An instance of a TokenizerSpec object.
        augmentor: Optional AudioAugmentor object for augmentations on audio data.

    Returns:
        An instance of AudioToBPEDataset.
    """
    dataset = modified_audio_to_text.AudioToBPEDataset(
        manifest_filepath=config['manifest_filepath'],
        tokenizer=tokenizer,
        sample_rate=config['sample_rate'],
        int_values=config.get('int_values', False),
        augmentor=augmentor,
        max_duration=config.get('max_duration', None),
        min_duration=config.get('min_duration', None),
        max_utts=config.get('max_utts', 0),
        trim=config.get('trim_silence', False),
        use_start_end_token=config.get('use_start_end_token', True),
        return_sample_id=config.get('return_sample_id', False),
    )
    return dataset

In [None]:
def get_dali_bpe_dataset(
    config: dict,
    tokenizer,
    shuffle: bool,
    device_id: int,
    global_rank: int,
    world_size: int,
    preprocessor_cfg: Optional[DictConfig] = None,
) -> modified_audio_to_text_dali.AudioToCharDALIDataset:
    """
    Instantiates a Subword Encoding based AudioToBPEDALIDataset.

    Args:
        config: Config of the AudioToBPEDALIDataset.
        tokenizer: An implementation of NeMo TokenizerSpec.
        shuffle: Bool flag whether to shuffle the dataset.
        device_id: Index of the GPU to be used (local_rank). Only applicable when device == 'gpu'. Defaults to 0.
        global_rank: Global rank of this device.
        world_size: Global world size in the training method.
        preprocessor_cfg: Preprocessor configuration. Supports AudioToMelSpectrogramPreprocessor and AudioToMFCCPreprocessor.

    Returns:
        An instance of AudioToCharDALIDataset.
    """
    device = 'gpu' if torch.cuda.is_available() else 'cpu'
    dataset = modified_audio_to_text_dali.AudioToBPEDALIDataset(
        manifest_filepath=config['manifest_filepath'],
        tokenizer=tokenizer,
        device=device,
        batch_size=config['batch_size'],
        sample_rate=config['sample_rate'],
        audio_tar_filepaths=config.get('tarred_audio_filepaths', None),
        audio_tar_index_filepaths=config.get('tarred_audio_index_filepaths', None),
        max_duration=config.get('max_duration', None),
        min_duration=config.get('min_duration', None),
        trim=config.get('trim_silence', False),
        use_start_end_token=config.get('use_start_end_token', True),
        shuffle=shuffle,
        shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
        device_id=device_id,
        global_rank=global_rank,
        world_size=world_size,
        preprocessor_cfg=preprocessor_cfg,
        return_sample_id=config.get('return_sample_id', False),
    )
    return dataset

In [None]:
def get_tarred_dataset(
    config: dict,
    shuffle_n: int,
    global_rank: int,
    world_size: int,
    tokenizer: Optional['TokenizerSpec'] = None,
    augmentor: Optional['AudioAugmentor'] = None,
) -> Union[modified_audio_to_text.TarredAudioToBPEDataset, modified_audio_to_text.TarredAudioToCharDataset]:
    """
    Instantiates a Word Piece/BPE Encoding based TarredAudioToBPEDataset or a char based TarredAudioToCharDataset.

    Args:
        config: Config of the TarredAudioToBPEDataset or TarredAudioToCharDataset.
        shuffle_n: How many samples to look ahead and load to be shuffled.
            See WebDataset documentation for more details.
        tokenizer: An instance of a TokenizerSpec object if BPE dataset is needed.
        global_rank: Global rank of this device.
        world_size: Global world size in the training method.
            Passsing None would return a char-based dataset.
        augmentor: Optional AudioAugmentor object for augmentations on audio data.

    Returns:
        An instance of TarredAudioToBPEDataset or TarredAudioToCharDataset.
    """
    tarred_audio_filepaths = config['tarred_audio_filepaths']
    manifest_filepaths = config['manifest_filepath']
    datasets = []
    tarred_audio_filepaths = convert_to_config_list(tarred_audio_filepaths)
    manifest_filepaths = convert_to_config_list(manifest_filepaths)

    if len(manifest_filepaths) != len(tarred_audio_filepaths):
        raise ValueError(
            f"manifest_filepaths (length={len(manifest_filepaths)}) and tarred_audio_filepaths (length={len(tarred_audio_filepaths)}) need to have the same number of buckets."
        )

    if 'labels' not in config:
        logging.warning(f"dataset does not have explicitly defined labels")

    for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate(
        zip(tarred_audio_filepaths, manifest_filepaths)
    ):
        if len(tarred_audio_filepath) == 1:
            tarred_audio_filepath = tarred_audio_filepath[0]
        if tokenizer is None:
            dataset = modified_audio_to_text.TarredAudioToCharDataset(
                audio_tar_filepaths=tarred_audio_filepath,
                manifest_filepath=manifest_filepath,
                labels=config.get('labels', None),
                sample_rate=config['sample_rate'],
                int_values=config.get('int_values', False),
                augmentor=augmentor,
                shuffle_n=shuffle_n,
                max_duration=config.get('max_duration', None),
                min_duration=config.get('min_duration', None),
                max_utts=config.get('max_utts', 0),
                blank_index=config.get('blank_index', -1),
                unk_index=config.get('unk_index', -1),
                normalize=config.get('normalize_transcripts', False),
                trim=config.get('trim_silence', False),
                parser=config.get('parser', 'en'),
                shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
                global_rank=global_rank,
                world_size=world_size,
                return_sample_id=config.get('return_sample_id', False),
            )
        else:
            dataset = modified_audio_to_text.TarredAudioToBPEDataset(
                audio_tar_filepaths=tarred_audio_filepath,
                manifest_filepath=manifest_filepath,
                tokenizer=tokenizer,
                sample_rate=config['sample_rate'],
                int_values=config.get('int_values', False),
                augmentor=augmentor,
                shuffle_n=shuffle_n,
                max_duration=config.get('max_duration', None),
                min_duration=config.get('min_duration', None),
                max_utts=config.get('max_utts', 0),
                trim=config.get('trim_silence', False),
                use_start_end_token=config.get('use_start_end_token', True),
                shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
                global_rank=global_rank,
                world_size=world_size,
                return_sample_id=config.get('return_sample_id', False),
            )

        datasets.append(dataset)

    return get_chain_dataset(datasets=datasets, ds_config=config)

In [None]:
def convert_to_config_list(initial_list):
    if type(initial_list) is str:
        initial_list = initial_list.split(",")
    if initial_list is None or initial_list == []:
        raise ValueError("manifest_filepaths and tarred_audio_filepaths must not be empty.")
    if not isinstance(initial_list, ListConfig):
        initial_list = ListConfig([initial_list])

    for list_idx, list_val in enumerate(initial_list):
        if type(list_val) != type(initial_list[0]):
            raise ValueError(
                "manifest_filepaths and tarred_audio_filepaths need to be a list of lists for bucketing or just a list of strings"
            )
    if type(initial_list[0]) is not ListConfig:
        initial_list = ListConfig([initial_list])
    return initial_list


def get_chain_dataset(datasets, ds_config):
    if len(datasets) > 1:
        if ds_config.get('bucketing_batch_size', None) is not None:
            bucketing_batch_sizes = calc_bucketing_batch_sizes(ds_config, len(datasets))
            logging.info(
                f"Batch bucketing is enabled for {len(datasets)} buckets with adaptive batch sizes of {bucketing_batch_sizes}!"
            )
            for idx, dataset in enumerate(datasets):
                datasets[idx] = modified_audio_to_text.BucketingDataset(
                    dataset=dataset, bucketing_batch_size=bucketing_batch_sizes[idx]
                )
        else:
            logging.info(
                f"Batch bucketing is enabled for {len(datasets)} buckets with fixed batch size of {ds_config['batch_size']}!"
            )

    if len(datasets) == 1:
        return datasets[0]
    bucketing_strategy = ds_config.get('bucketing_strategy', 'synced_randomized')
    if bucketing_strategy == 'fixed_order':
        return ChainDataset(datasets)
    elif bucketing_strategy == 'synced_randomized':
        return modified_audio_to_text.RandomizedChainDataset(datasets=datasets, rnd_seed=0)
    elif bucketing_strategy == 'fully_randomized':
        return modified_audio_to_text.RandomizedChainDataset(datasets=datasets, rnd_seed=random.randint(0, 30000))
    else:
        raise ValueError(
            f'bucketing_strategy={bucketing_strategy} is not supported! Supported strategies are [fixed_order, fully_randomized, synced_randomized].'
        )


def calc_bucketing_batch_sizes(ds_config, datasets_len):
    bucketing_batch_size = ds_config['bucketing_batch_size']
    if ds_config['batch_size'] != 1:
        raise ValueError(
            f"batch_size should be set to one when bucketing_batch_size is set and adaptive bucketing is enabled (batch_size={ds_config['batch_size']}!"
        )
    if type(bucketing_batch_size) == int:
        bucketing_batch_sizes = []
        for idx in range(datasets_len):
            scale_factor = datasets_len - idx
            bucketing_batch_sizes.append(scale_factor * bucketing_batch_size)
    elif isinstance(bucketing_batch_size, ListConfig) or isinstance(bucketing_batch_size, list):
        bucketing_batch_sizes = bucketing_batch_size
    else:
        raise ValueError(
            f"bucketing_batch_size should be an integer or a list (bucketing_batch_size={bucketing_batch_size})!"
        )

    if len(bucketing_batch_sizes) != datasets_len:
        raise ValueError(
            f"batch_size should have the same length as the number of buckets ({len(bucketing_batch_sizes)}!={datasets_len}) "
        )
    return bucketing_batch_sizes