# RNN-T Training profiling


This notebook profiles training of RNN-T.

In [None]:
log_dir = "/logdir/"

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import os
import pathlib
import typing
import cProfile 

import torch
from google.protobuf import text_format

from myrtlespeech.model.rnn_t import RNNTEncoder, RNNT
from myrtlespeech.run.run import TensorBoardLogger, Saver
from myrtlespeech.run.callbacks.csv_logger import CSVLogger
from myrtlespeech.run.callbacks.callback import Callback, ModelCallback
from myrtlespeech.run.callbacks.clip_grad_norm import ClipGradNorm
from myrtlespeech.run.callbacks.report_mean_batch_loss import ReportMeanBatchLoss
from myrtlespeech.run.callbacks.stop_epoch_after import StopEpochAfter
from myrtlespeech.run.callbacks.mixed_precision import MixedPrecision
from myrtlespeech.post_process.utils import levenshtein
from myrtlespeech.builders.task_config import build
from myrtlespeech.run.train import fit
from myrtlespeech.protos import task_config_pb2
from myrtlespeech.run.stage import Stage

from myrtlespeech.run.train import run_stage
from myrtlespeech.run.callbacks.callback import CallbackHandler

In [None]:
from myrtlespeech.run.train import run_stage
from myrtlespeech.run.callbacks.callback import CallbackHandler
from myrtlespeech.run.callbacks.rnn_t_training import RNNTTraining
from myrtlespeech.run.run import ReportRNNTDecoder

In [None]:
torch.backends.cudnn.benchmark = False # since variable size inputs

Build the RNNT model defined in the config file:

In [None]:
# parse example config file
with open("../src/myrtlespeech/configs/rnn_t_en.config") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

task_config

In [None]:
# create all components for config
seq_to_seq, epochs, train_loader, eval_loader = build(task_config)


### Maybe load model

In [None]:
load_model = False
if load_model:
    fp = "/home/julian/models/dsint_imported_w_zeros.pt"
    seq_to_seq.model.load_state_dict(torch.load(fp))

### Maybe change decoder:

In [None]:
use_beam = False
no_max = False

from myrtlespeech.post_process.rnn_t_beam_decoder import RNNTBeamDecoder
from myrtlespeech.post_process.rnn_t_greedy_decoder import RNNTGreedyDecoder

if use_beam:
    decoder = RNNTBeamDecoder(blank_index=28,
                                beam_width=4,
                                 length_norm=False,
                                 max_symbols_per_step = 4,
                             model=seq_to_seq.model)
    beam_str = "beam"
else:
    decoder = RNNTGreedyDecoder(blank_index=28,
                                 max_symbols_per_step = 4,
                               model=seq_to_seq.model) 
    beam_str = "greedy"
seq_to_seq.post_process = decoder

if no_max:
    seq_to_seq.post_process.max_symbols_per_step = 100
else:
    seq_to_seq.post_process.max_symbols_per_step = 4

### Define Callbacks

In [None]:
mixed_precision_cb = MixedPrecision(seq_to_seq) # this can only be initialized once so place it in separate cell

In [None]:

rnnt_decoder_cb  = ReportRNNTDecoder(seq_to_seq.post_process, seq_to_seq.alphabet, eval_every=4, 
                                         skip_first_epoch=True)

keys_to_log_in_csv = ["epoch", 
        f"reports/{seq_to_seq.post_process.__class__.__name__}/wer",
        "reports/ReportMeanBatchLoss"]

callbacks = [RNNTTraining(),
            rnnt_decoder_cb,
            ReportMeanBatchLoss(),
             
            #Note: the following three callbacks, if present, must appear in this order (see docstrings):
            TensorBoardLogger(log_dir, seq_to_seq.model, histograms=False), #Note: histograms=True adds large overhead
            mixed_precision_cb,
            ClipGradNorm(seq_to_seq, 200),
            
            # stop training prematurely (useful for debug). 
            # Ensure following line is commented out to perform full training
            StopEpochAfter(epoch_batches=2),
            
            # logging
            CSVLogger(log_dir + "log.csv", keys=keys_to_log_in_csv),
            
            # save model @ end of epoch:
            Saver(log_dir, seq_to_seq.model)] 



In [None]:
is_training=True
cb_handler = CallbackHandler(callbacks, is_training)
cb_handler.on_train_begin(epochs=1)
run_stage(seq_to_seq, cb_handler, train_loader, is_training=is_training)

In [None]:
cProfile.run("""run_stage(seq_to_seq, cb_handler, train_loader, is_training=is_training)""")