# RNN-T Example Usage


In [None]:
log_dir = "/home/julian/exp/speech/myrtlespeech/rnnt/test/1/"

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import os
import pathlib
import typing
import cProfile 

import torch
from google.protobuf import text_format

from myrtlespeech.model.rnn_t import RNNTEncoder, RNNT
from myrtlespeech.run.run import TensorBoardLogger, Saver
from myrtlespeech.run.callbacks.csv_logger import CSVLogger
from myrtlespeech.run.callbacks.callback import Callback, ModelCallback
from myrtlespeech.run.callbacks.clip_grad_norm import ClipGradNorm
from myrtlespeech.run.callbacks.report_mean_batch_loss import ReportMeanBatchLoss
from myrtlespeech.run.callbacks.stop_epoch_after import StopEpochAfter
from myrtlespeech.run.callbacks.mixed_precision import MixedPrecision
from myrtlespeech.post_process.utils import levenshtein
from myrtlespeech.builders.task_config import build
from myrtlespeech.run.train import fit
from myrtlespeech.protos import task_config_pb2
from myrtlespeech.run.stage import Stage

from myrtlespeech.run.train import run_stage
from myrtlespeech.run.callbacks.callback import CallbackHandler

In [None]:
from myrtlespeech.run.train import run_stage
from myrtlespeech.run.callbacks.callback import CallbackHandler
from myrtlespeech.run.callbacks.rnn_t_training import RNNTTraining
from myrtlespeech.run.run import ReportRNNTDecoder

In [None]:
torch.backends.cudnn.benchmark = False # since variable size inputs


Build the RNNT model defined in the config file:

In [None]:
# parse example config file
with open("../src/myrtlespeech/configs/rnn_t_en_ds_int.config") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

task_config

In [None]:
# create all components for config
# FYI: if using train-clean-100 & dev-clean this cell takes O(60s) 
seq_to_seq, epochs, train_loader, eval_loader = build(task_config)


In [None]:
seq_to_seq

### Maybe save model

In [None]:
save_model = False
fp_out = log_dir + "saved_state_dict.pt"
if save_model:
    torch.save(seq_to_seq.model.state_dict(), fp_out)

## maybe load model

In [None]:
load_model = False
if load_model:
    fp = "/home/julian/models/dsint_imported_w_zeros.pt"
    seq_to_seq.model.load_state_dict(torch.load(fp))

### Maybe change decoder:

In [None]:
use_beam = False
no_max = False

from myrtlespeech.post_process.rnn_t_beam_decoder import RNNTBeamDecoder
from myrtlespeech.post_process.rnn_t_greedy_decoder import RNNTGreedyDecoder

if use_beam:
    decoder = RNNTBeamDecoder(blank_index=28,
                                beam_width=4,
                                 length_norm=False,
                                 max_symbols_per_step = 4,
                             model=seq_to_seq.model)
    beam_str = "beam"
else:
    decoder = RNNTGreedyDecoder(blank_index=28,
                                 max_symbols_per_step = 4,
                               model=seq_to_seq.model) 
    beam_str = "greedy"
seq_to_seq.post_process = decoder

if no_max:
    seq_to_seq.post_process.max_symbols_per_step = 100
else:
    seq_to_seq.post_process.max_symbols_per_step = 4

## Callbacks
* Use callbacks to inject features into training loop. 
* It is necessary (for now) to use the `RNNTTraining()` callback but the others are optional


In [None]:
#custom callback to monitor training and print results
class PrintCB(Callback):
    def __init__(self):
        super().__init__()
    
    def on_batch_end(self, **kwargs):
        
        if self.training:
            print(kwargs["epoch_batches"], "loss", kwargs["reports"]["ReportMeanBatchLoss"], kwargs["last_loss"].item())
            
            return
        epoch = kwargs["epoch"]
        if kwargs["epoch_batches"] % 100 == 0 and kwargs["epoch_batches"] != 0:
            print(f"{kwargs['epoch_batches']} batches completed")
            try:
                wer_reports = kwargs["reports"][seq_to_seq.post_process.__class__.__name__]
                wer = wer_reports["wer"]
                if len(wer_reports["transcripts"]) > 0:
                    transcripts = wer_reports["transcripts"][0] #take first element
                    pred, exp = transcripts
                    pred = "".join(pred)
                    exp = "".join(exp)
                    loss = kwargs["reports"]["ReportMeanBatchLoss"]
                    print("batch end, pred: {}, exp: {}, wer: {:.4f}".format(pred, exp, wer, ))

            except KeyError:
                print("no wer - using new decoder?")
        
        
            
    def on_epoch_end(self, **kwargs):
        if self.training:
            return
        epoch = kwargs["epoch"]
        
        try:
            
            loss = kwargs["reports"]["ReportMeanBatchLoss"]
            
            wer_reports = kwargs["reports"][seq_to_seq.post_process.__class__.__name__]
            wer = wer_reports["wer"]
            
            out_str = "{}, loss: {:.8f}".format(epoch, loss)
            
            if len(wer_reports["transcripts"]) > 0:
                transcripts = wer_reports["transcripts"][0] #take first element
                pred, exp = transcripts
                pred = "".join(pred)
                exp = "".join(exp)
                
                out_str += ", wer: {:.4f}, pred: {}, exp: {},".format(wer, pred, exp)
            print(out_str)
        except KeyError:
            
            print("no wer - using new decoder?")        
        

### Maybe Change lr

In [None]:
change_lr = False
new_lr = 0.0003



for param_group in seq_to_seq.optim.param_groups:
    print("current lr: ", param_group['lr'])

if change_lr:
    for param_group in seq_to_seq.optim.param_groups:
        param_group['lr'] = new_lr

for param_group in seq_to_seq.optim.param_groups:
    print("new lr: ", param_group['lr'])

In [None]:
optim_layer_wise = True

from torch.optim import Adam

if optim_layer_wise:
    optim = Adam(
    [
        {"params": seq_to_seq.model.encode.parameters(), "lr": 0.0003},
        {"params": seq_to_seq.model.predict_net.parameters()},
        {"params": seq_to_seq.model.joint_net.parameters(),},
    ],
    lr=0.0003,
    )
    seq_to_seq.optim = optim

In [None]:
from myrtlespeech.run.callbacks.rnn_t_training import RNNTTraining
from myrtlespeech.run.run import ReportRNNTDecoder

rnnt_decoder_cb  = ReportRNNTDecoder(seq_to_seq.post_process, seq_to_seq.alphabet, eval_every=1, 
                                         skip_first_epoch=True)


keys_to_log = ["epoch", 
        f"reports/{seq_to_seq.post_process.__class__.__name__}/wer",
        "reports/ReportMeanBatchLoss"]



callbacks = [RNNTTraining(),
            ReportMeanBatchLoss(), 
            #MixedPrecision(seq_to_seq),
            ClipGradNorm(seq_to_seq, 200),
            rnnt_decoder_cb,
             
            #stop prematurely (useful for debug). Ensure following line is commented out to perform full training
            StopEpochAfter(epoch_batches=1000),
            
            # logging
            CSVLogger(log_dir + "log.csv", keys=keys_to_log),
            TensorBoardLogger(log_dir, seq_to_seq.model, histograms=False),
            PrintCB(),
            
            Saver(log_dir, seq_to_seq.model)] 


In [None]:
is_training=True
cb_handler = CallbackHandler(callbacks, is_training)
cb_handler.on_train_begin(epochs=1)
run_stage(seq_to_seq, cb_handler, train_loader, is_training=is_training)

In [None]:
cProfile.run("""run_stage(seq_to_seq, cb_handler, train_loader, is_training=is_training)""")

In [None]:
fit(
    seq_to_seq, 
    epochs=40,
    train_loader=train_loader, 
    eval_loader=eval_loader,
    callbacks=callbacks,
)



In [None]:
import torch
import gc
for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
            print(type(obj), obj.size())
    except: pass

### Maybe eval

In [None]:
run_eval = True


eval_cbs = None
if run_eval:
    eval_cbs = [RNNTTraining(),
            ReportMeanBatchLoss(), 
            ReportRNNTDecoder(seq_to_seq.post_process, seq_to_seq.alphabet),
            CSVLogger(log_dir + f"log_eval{beam_str}.csv", keys=keys_to_log),
            PrintCB(),] 
    cb_handler = CallbackHandler(eval_cbs, False)
    cb_handler.on_train_begin(epochs=1)
    
    run_stage(seq_to_seq, cb_handler, eval_loader, is_training=False)
    