In [None]:
config = {
    "speech_config": {
        "sample_rate": 16000,
        "frame_ms": 25,
        "stride_ms": 10,
        "num_feature_bins": 80,
        "feature_type": "log_mel_spectrogram",
        "preemphasis": 0.97,
        "normalize_signal": True,
        "normalize_feature": True,
        "normalize_per_frame": False,
    },
    "decoder_config": {
        "vocabulary": None,
        "target_vocab_size": 1024,
        "max_subword_length": 4,
        "blank_at_zero": True,
        "beam_width": 5,
        "norm_score": True,
    },
    "model_config": {
        "name": "streaming_transducer",
        "encoder_reductions": {0: 3, 1: 2},
        "encoder_dmodel": 320,
        "encoder_rnn_type": "lstm",
        "encoder_rnn_units": 1024,
        "encoder_nlayers": 8,
        "encoder_layer_norm": True,
        "prediction_embed_dim": 320,
        "prediction_embed_dropout": 0.0,
        "prediction_num_rnns": 2,
        "prediction_rnn_units": 1024,
        "prediction_rnn_type": "lstm",
        "prediction_projection_units": 320,
        "prediction_layer_norm": True,
        "joint_dim": 320,
        "joint_activation": "tanh",
    },
    "learning_config": {
        "train_dataset_config": {
            "use_tf": True,
            "augmentation_config": {
                "feature_augment": {
                    "time_masking": {
                        "num_masks": 10,
                        "mask_factor": 100,
                        "p_upperbound": 0.05,
                    },
                    "freq_masking": {"num_masks": 1, "mask_factor": 27},
                }
            },
            "data_paths": [
                "/mnt/h/ML/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv"
            ],
            "tfrecords_dir": None,
            "shuffle": True,
            "cache": True,
            "buffer_size": 100,
            "drop_remainder": True,
            "stage": "train",
        },
        "eval_dataset_config": {
            "use_tf": True,
            "data_paths": None,
            "tfrecords_dir": None,
            "shuffle": False,
            "cache": True,
            "buffer_size": 100,
            "drop_remainder": True,
            "stage": "eval",
        },
        "test_dataset_config": {
            "use_tf": True,
            "data_paths": [
                "/mnt/h/ML/Datasets/ASR/Raw/LibriSpeech/test-clean/transcripts.tsv"
            ],
            "tfrecords_dir": None,
            "shuffle": False,
            "cache": True,
            "buffer_size": 100,
            "drop_remainder": True,
            "stage": "test",
        },
        "optimizer_config": {"class_name": "adam", "config": {"learning_rate": 0.0001}},
        "running_config": {
            "batch_size": 2,
            "num_epochs": 20,
            "checkpoint": {
                "filepath": "/mnt/e/Models/local/rnn_transducer/checkpoints/{epoch:02d}.h5",
                "save_best_only": True,
                "save_weights_only": True,
                "save_freq": "epoch",
            },
            "states_dir": "/mnt/e/Models/local/rnn_transducer/states",
            "tensorboard": {
                "log_dir": "/mnt/e/Models/local/rnn_transducer/tensorboard",
                "histogram_freq": 1,
                "write_graph": True,
                "write_images": True,
                "update_freq": "epoch",
                "profile_batch": 2,
            },
        },
    },
}


In [None]:
metadata = {
    "train": {"max_input_length": 2974, "max_label_length": 194, "num_entries": 281241},
    "eval": {"max_input_length": 3516, "max_label_length": 186, "num_entries": 5567},
}

In [None]:
import os
import math
import argparse
from tensorflow_asr.utils import env_util

env_util.setup_environment()
import tensorflow as tf

tf.keras.backend.clear_session()
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
strategy = env_util.setup_strategy([0])

from tensorflow_asr.configs.config import Config
from tensorflow_asr.datasets import asr_dataset
from tensorflow_asr.featurizers import speech_featurizers, text_featurizers
from tensorflow_asr.models.transducer.rnn_transducer import RnnTransducer
from tensorflow_asr.optimizers.schedules import TransformerSchedule

config = Config(config)
speech_featurizer = speech_featurizers.TFSpeechFeaturizer(config.speech_config)

text_featurizer = text_featurizers.CharFeaturizer(config.decoder_config)

train_dataset = asr_dataset.ASRSliceDataset(
    speech_featurizer=speech_featurizer,
    text_featurizer=text_featurizer,
    **vars(config.learning_config.train_dataset_config),
    indefinite=True
)
eval_dataset = asr_dataset.ASRSliceDataset(
    speech_featurizer=speech_featurizer,
    text_featurizer=text_featurizer,
    **vars(config.learning_config.eval_dataset_config),
    indefinite=True
)

train_dataset.load_metadata(metadata)
eval_dataset.load_metadata(metadata)
speech_featurizer.reset_length()
text_featurizer.reset_length()

global_batch_size = config.learning_config.running_config.batch_size
global_batch_size *= strategy.num_replicas_in_sync

train_data_loader = train_dataset.create(global_batch_size)
eval_data_loader = eval_dataset.create(global_batch_size)

with strategy.scope():
    # build model
    rnnt = RnnTransducer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
    rnnt.make(speech_featurizer.shape)
    rnnt.summary(line_length=100)

    rnnt.compile(
        optimizer=config.learning_config.optimizer_config,
        experimental_steps_per_execution=10,
        global_batch_size=global_batch_size,
        blank=text_featurizer.blank
    )

callbacks = [
    tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
    tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
    tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
]

rnnt.fit(
    train_data_loader,
    epochs=config.learning_config.running_config.num_epochs,
    validation_data=eval_data_loader,
    callbacks=callbacks,
    steps_per_epoch=train_dataset.total_steps,
    validation_steps=eval_dataset.total_steps
)