In [1]:
import sys
sys.path.append('/mydata/hassan/TensorFlowASR')

In [5]:
config = {
    "speech_config": {
        "sample_rate": 16000,
        "frame_ms": 25,
        "stride_ms": 10,
        "num_feature_bins": 80,
        "feature_type": "log_mel_spectrogram",
        "preemphasis": 0.97,
        "normalize_signal": True,
        "normalize_feature": True,
        "normalize_per_frame": False,
    },
    "decoder_config": {
        "vocabulary": None,
        "target_vocab_size": 1000,
        "max_subword_length": 10,
        "blank_at_zero": True,
        "beam_width": 0,
        "norm_score": True,
        "corpus_files": None,
    },
    "model_config": {
        "name": "conformer",
        "encoder_subsampling": {
            "type": "conv2d",
            "filters": 144,
            "kernel_size": 3,
            "strides": 2,
        },
        "encoder_positional_encoding": "sinusoid_concat",
        "encoder_dmodel": 144,
        "encoder_num_blocks": 16,
        "encoder_head_size": 36,
        "encoder_num_heads": 4,
        "encoder_mha_type": "relmha",
        "encoder_kernel_size": 32,
        "encoder_fc_factor": 0.5,
        "encoder_dropout": 0.1,
        "prediction_embed_dim": 320,
        "prediction_embed_dropout": 0,
        "prediction_num_rnns": 1,
        "prediction_rnn_units": 320,
        "prediction_rnn_type": "lstm",
        "prediction_rnn_implementation": 2,
        "prediction_layer_norm": True,
        "prediction_projection_units": 0,
        "joint_dim": 320,
        "prejoint_linear": True,
        "joint_activation": "tanh",
        "joint_mode": "add",
    },
    "learning_config": {
        "train_dataset_config": {
            "use_tf": True,
            "augmentation_config": {
                "feature_augment": {
                    "time_masking": {
                        "num_masks": 10,
                        "mask_factor": 100,
                        "p_upperbound": 0.05,
                    },
                    "freq_masking": {"num_masks": 1, "mask_factor": 27},
                }
            },
            "data_paths": [
                "/mydata/hassan/data/LibriSpeech/train-clean-100/transcript.tsv"
            ],
            "tfrecords_dir": None,
            "shuffle": True,
            "cache": True,
            "buffer_size": 100,
            "drop_remainder": True,
            "stage": "train",
        },
        "eval_dataset_config": {
            "use_tf": True,
            "data_paths": None,
            "tfrecords_dir": None,
            "shuffle": False,
            "cache": True,
            "buffer_size": 100,
            "drop_remainder": True,
            "stage": "eval",
        },
        "test_dataset_config": {
            "use_tf": True,
            "data_paths": None,
            "tfrecords_dir": None,
            "shuffle": False,
            "cache": True,
            "buffer_size": 100,
            "drop_remainder": True,
            "stage": "test",
        },
        "optimizer_config": {
            "warmup_steps": 40000,
            "beta_1": 0.9,
            "beta_2": 0.98,
            "epsilon": 1e-09,
        },
        "running_config": {
            "batch_size": 4,
            "num_epochs": 15,
            "checkpoint": {
                "filepath": "/mydata/hassan/conformer2/checkpoints/{epoch:02d}.h5",
                "save_best_only": False,
                "save_weights_only": True,
                "save_freq": "epoch",
            },
            "states_dir": "/mydata/hassan/conformer2/states",
            "tensorboard": {
                "log_dir": "/mydata/hassan/conformer2/tensorboard",
                "histogram_freq": 1,
                "write_graph": True,
                "write_images": True,
                "update_freq": "epoch",
                "profile_batch": 2,
            },
        },
    },
}

In [6]:
metadata = {
    "train": {"max_input_length": 2974, "max_label_length": 194, "num_entries": 281241},
    "eval": {"max_input_length": 3516, "max_label_length": 186, "num_entries": 5567},
}

In [None]:
import os
import math
import argparse
from tensorflow_asr.utils import env_util

env_util.setup_environment()
import tensorflow as tf

tf.keras.backend.clear_session()
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
strategy = env_util.setup_strategy([0])

from tensorflow_asr.configs.config import Config
from tensorflow_asr.datasets import asr_dataset
from tensorflow_asr.featurizers import speech_featurizers, text_featurizers
from tensorflow_asr.models.transducer.conformer import Conformer
from tensorflow_asr.optimizers.schedules import TransformerSchedule

config = Config(config)
speech_featurizer = speech_featurizers.TFSpeechFeaturizer(config.speech_config)

text_featurizer = text_featurizers.CharFeaturizer(config.decoder_config)

train_dataset = asr_dataset.ASRSliceDataset(
    speech_featurizer=speech_featurizer,
    text_featurizer=text_featurizer,
    **vars(config.learning_config.train_dataset_config),
    indefinite=True
)
eval_dataset = asr_dataset.ASRSliceDataset(
    speech_featurizer=speech_featurizer,
    text_featurizer=text_featurizer,
    **vars(config.learning_config.eval_dataset_config),
    indefinite=True
)

train_dataset.load_metadata(metadata)
eval_dataset.load_metadata(metadata)
speech_featurizer.reset_length()
text_featurizer.reset_length()

global_batch_size = config.learning_config.running_config.batch_size
global_batch_size *= strategy.num_replicas_in_sync

train_data_loader = train_dataset.create(global_batch_size)
eval_data_loader = eval_dataset.create(global_batch_size)

with strategy.scope():
    # build model
    conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
    conformer.make(speech_featurizer.shape)
    conformer.summary(line_length=100)

    optimizer = tf.keras.optimizers.Adam(
        TransformerSchedule(
            d_model=conformer.dmodel,
            warmup_steps=config.learning_config.optimizer_config.pop("warmup_steps", 10000),
            max_lr=(0.05 / math.sqrt(conformer.dmodel))
        ),
        **config.learning_config.optimizer_config
    )

    conformer.compile(
        optimizer=optimizer,
        experimental_steps_per_execution=10,
        global_batch_size=global_batch_size,
        blank=text_featurizer.blank
    )

callbacks = [
    tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
    tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
    tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
]

conformer.fit(
    train_data_loader,
    epochs=config.learning_config.running_config.num_epochs,
    validation_data=eval_data_loader,
    callbacks=callbacks,
    steps_per_epoch=train_dataset.total_steps,
    validation_steps=eval_dataset.total_steps
)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
{' ': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, 'i': 10, 'j': 11, 'k': 12, 'l': 13, 'm': 14, 'n': 15, 'o': 16, 'p': 17, 'q': 18, 'r': 19, 's': 20, 't': 21, 'u': 22, 'v': 23, 'w': 24, 'x': 25, 'y': 26, 'z': 27, "'": 28}
INFO:tensorflow:Reading /mydata/hassan/data/LibriSpeech/train-clean-100/transcript.tsv ...
{' ': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, 'i': 10, 'j': 11, 'k': 12, 'l': 13, 'm': 14, 'n': 15, 'o': 16, 'p': 17, 'q': 18, 'r': 19, 's': 20, 't': 21, 'u': 22, 'v': 23, 'w': 24, 'x': 25, 'y': 26, 'z': 27, "'": 28} ['c', 'h', 'a', 'r', 'l', 'e', 's', ' ', 'g', 'a', 'v', 'e', ' ', 't', 'h', 'e', 'm', ' ', 'a', ' ', 'g', 'r', 'a', 'c', 'i', 'o', 'u', 's', ' ', 'a', 'n', 'd', ' ', 'a', ' ', 'c', 'o', 'm', 'p', 'l', 'i', 'a', 'n', 't', ' ', 'a', 'n', 's', 'w', 'e', 'r', ' ', 't', 'o', ' ', 'a', 'l', 'l', ' ', 't', 'h', 'e', 'i', 'r'

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



{' ': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, 'i': 10, 'j': 11, 'k': 12, 'l': 13, 'm': 14, 'n': 15, 'o': 16, 'p': 17, 'q': 18, 'r': 19, 's': 20, 't': 21, 'u': 22, 'v': 23, 'w': 24, 'x': 25, 'y': 26, 'z': 27, "'": 28} ['y', 'e', 's', ' ', 's', 'a', 'i', 'd', ' ', 'h', 'e', ' ', 'a', 'l', 'l', ' ', 't', 'h', 'a', 't', ' ', 'r', 'e', 'm', 'a', 'i', 'n', 's', ' ', 'o', 'f', ' ', 'm', 'e', ' ', 's', 'i', 'n', 'c', 'e', ' ', 'i', ' ', 'l', 'e', 'f', 't', ' ', 'l', 'o', 'n', 'd', 'o', 'n', ' ', 'w', 'h', 'a', 't', ' ', 'd', 'o', ' ', 'y', 'o', 'u', ' ', 't', 'h', 'i', 'n', 'k', ' ', 'o', 'f', ' ', 't', 'h', 'e', ' ', 'd', 'i', 'g', 'g', 'i', 'n', 'g', 's', ' ', 's', 'a', 'i', 'd', ' ', 'm', 'i', 's', 't', 'e', 'r', ' ', 'p', 'o', 's', 't', 'm', 'a', 'n']
{' ': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, 'i': 10, 'j': 11, 'k': 12, 'l': 13, 'm': 14, 'n': 15, 'o': 16, 'p': 17, 'q': 18, 'r': 19, 's': 20, 't': 21, 'u': 22, 'v': 23, 'w': 24, 'x': 25

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



{' ': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, 'i': 10, 'j': 11, 'k': 12, 'l': 13, 'm': 14, 'n': 15, 'o': 16, 'p': 17, 'q': 18, 'r': 19, 's': 20, 't': 21, 'u': 22, 'v': 23, 'w': 24, 'x': 25, 'y': 26, 'z': 27, "'": 28} ['o', 'b', 's', 'e', 'r', 'v', 'e', 'd', ' ', 'a', ' ', 'm', 'a', 'n', ' ', 'h', 'i', 's', ' ', 'f', 'r', 'i', 'e', 'n', 'd', ' ', 'r', 'e', 'c', 'e', 'n', 't', 'l', 'y', ' ', 'a', 'r', 'o', 'u', 's', 'e', 'd', ' ', 'w', 'a', 's', ' ', 's', 't', 'i', 'l', 'l', ' ', 'v', 'e', 'r', 'y', ' ', 'd', 'r', 'o', 'w', 's', 'y', ' ', 'h', 'e', ' ', 'l', 'o', 'o', 'k', 'e', 'd', ' ', 'b', 'e', 'h', 'i', 'n', 'd', ' ', 'h', 'i', 'm', ' ', 'u', 'n', 't', 'i', 'l', ' ', 'h', 'i', 's', ' ', 'm', 'i', 'n', 'd', ' ', 't', 'o', 'o', 'k', ' ', 'i', 'n', ' ', 't', 'h', 'e', ' ', 'm', 'e', 'a', 'n', 'i', 'n', 'g', ' ', 'o', 'f', ' ', 't', 'h', 'e', ' ', 'm', 'o', 'v', 'e', 'm', 'e', 'n', 't', ' ', 't', 'h', 'e', 'n', ' ', 'h', 'e', ' ', 's', 'i', 'g', 'h', 'e', 'd', 

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



{' ': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, 'i': 10, 'j': 11, 'k': 12, 'l': 13, 'm': 14, 'n': 15, 'o': 16, 'p': 17, 'q': 18, 'r': 19, 's': 20, 't': 21, 'u': 22, 'v': 23, 'w': 24, 'x': 25, 'y': 26, 'z': 27, "'": 28} ['w', 'h', 'a', 't', ' ', 'a', 'r', 'e', ' ', 'y', 'o', 'u', ' ', 'd', 'o', 'i', 'n', 'g', ' ', 'w', 'i', 't', 'h', ' ', 't', 'h', 'e', 'm', ' ', 'i', ' ', 't', 'h', 'o', 'u', 'g', 'h', 't', ' ', 't', 'h', 'e', 'y', ' ', 'b', 'e', 'l', 'o', 'n', 'g', 'e', 'd', ' ', 't', 'o', ' ', 'm', 'y', ' ', 's', 'h', 'e', 'e', 'p', ' ', 'a', 'n', 's', 'w', 'e', 'r', 'e', 'd', ' ', 'b', 'o', ' ', 'p', 'e', 'e', 'p', ' ', 's', 'o', 'r', 'r', 'o', 'w', 'f', 'u', 'l', 'l', 'y']
{' ': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, 'i': 10, 'j': 11, 'k': 12, 'l': 13, 'm': 14, 'n': 15, 'o': 16, 'p': 17, 'q': 18, 'r': 19, 's': 20, 't': 21, 'u': 22, 'v': 23, 'w': 24, 'x': 25, 'y': 26, 'z': 27, "'": 28} ['i', ' ', 'm', 'u', 's', 't', ' ', 'h', 'u', 