# RNN-T Example Usage


This notebook provides example usage of `myrtlespeech` for RNN-T training.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1" # set this before importing torch

In [None]:
import os
import pathlib

import torch

from myrtlespeech.builders.task_config import build
from myrtlespeech.run.train import fit
from myrtlespeech.run.eval import eval
from myrtlespeech.protos import task_config_pb2
from google.protobuf import text_format


from myrtlespeech.run.callbacks.callback import CallbackHandler
from myrtlespeech.run.run import TensorBoardLogger, Saver
from myrtlespeech.run.callbacks.csv_logger import CSVLogger
from myrtlespeech.run.callbacks.callback import Callback, ModelCallback
from myrtlespeech.run.callbacks.clip_grad_norm import ClipGradNorm
from myrtlespeech.run.callbacks.report_mean_batch_loss import ReportMeanBatchLoss
from myrtlespeech.run.callbacks.stop_epoch_after import StopEpochAfter
from myrtlespeech.run.callbacks.mixed_precision import MixedPrecision
from myrtlespeech.run.callbacks.rnn_t_training import RNNTTraining
from myrtlespeech.run.run import ReportRNNTDecoder

In [None]:
torch.backends.cudnn.benchmark = False # since variable size inputs

In [None]:
log_dir = "/home/USER/INSERT/PATH/"

Build the RNNT model defined in the config file:

In [None]:
# parse example config file
with open("../src/myrtlespeech/configs/rnn_t_en.config") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

task_config

In [None]:
# create all components for config
# FYI: if using train-clean-100 & dev-clean this cell takes O(60s) 
seq_to_seq, epochs, train_loader, eval_loader = build(task_config)
seq_to_seq

## Maybe load model?

In [None]:
load_model = False
if load_model:
    fp = "/home/user/model/fp/model.pt"
    seq_to_seq.model.load_state_dict(torch.load(fp))

## Callbacks
* Use callbacks to inject difference into training loop. 
* It is necessary (for now) to use the `RNNTTraining()` callback but the others are optional


In [None]:
mixed_precision_cb = MixedPrecision(seq_to_seq) # this can only be initialized once so place it in separate cell

In [None]:

rnnt_decoder_cb  = ReportRNNTDecoder(seq_to_seq.post_process, seq_to_seq.alphabet, eval_every=1, 
                                         skip_first_epoch=True)

keys_to_log_in_csv = ["epoch", 
        f"reports/{seq_to_seq.post_process.__class__.__name__}/wer",
        "reports/ReportMeanBatchLoss"]

callbacks = [RNNTTraining(),
            rnnt_decoder_cb,
            ReportMeanBatchLoss(),
             
            #Note: the following three callbacks, if present, must appear in this order (see docstrings):
            TensorBoardLogger(log_dir, seq_to_seq.model, histograms=False),
            mixed_precision_cb,
            ClipGradNorm(seq_to_seq, 200),
            
            # stop training prematurely (useful for debug). 
            # Ensure following line is commented out to perform full training
            # StopEpochAfter(epoch_batches=20),
            
            # logging
            CSVLogger(log_dir + "log.csv", keys=keys_to_log_in_csv),
            
            # save model @ end of epoch:
            Saver(log_dir, seq_to_seq.model)] 


In [None]:
fit(
    seq_to_seq, 
    epochs=40,
    train_loader=train_loader, 
    eval_loader=eval_loader,
    callbacks=callbacks,
)



### Maybe eval?

In [None]:
run_eval = True


eval_cbs = None
if run_eval:
    eval_cbs = [RNNTTraining(),
                ReportMeanBatchLoss(), 
                ReportRNNTDecoder(seq_to_seq.post_process, seq_to_seq.alphabet),
                CSVLogger(log_dir + f"log_eval.csv", keys=keys_to_log_in_csv)] 
    
    eval(
         seq_to_seq, 
         eval_loader=eval_loader,
         callbacks=eval_cbs,
    )
    

### Maybe save model

In [None]:
save_model = False
fp_out = log_dir + "model_saved.pt"
if save_model:
    torch.save(seq_to_seq.model.state_dict(), fp_out)