# RNN-T Overfitting
Validate that all elements of the pipeline are working by overfitting to a small number of training examples

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import pathlib
import typing

import torch
from google.protobuf import text_format

from myrtlespeech.model.rnn_t import RNNTEncoder, RNNT
from myrtlespeech.run.callbacks.csv_logger import CSVLogger
from myrtlespeech.run.callbacks.callback import Callback, ModelCallback
from myrtlespeech.run.callbacks.clip_grad_norm import ClipGradNorm
from myrtlespeech.run.callbacks.report_mean_batch_loss import ReportMeanBatchLoss
from myrtlespeech.run.callbacks.stop_epoch_after import StopEpochAfter
from myrtlespeech.run.callbacks.mixed_precision import MixedPrecision
from myrtlespeech.post_process.utils import levenshtein
from myrtlespeech.builders.task_config import build
from myrtlespeech.run.train import fit
from myrtlespeech.protos import task_config_pb2
from myrtlespeech.run.stage import Stage

In [None]:
torch.backends.cudnn.benchmark = False
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

Build the RNNT model defined in the config file:

In [None]:
# parse example config file
with open("../src/myrtlespeech/configs/rnn_t_en.config") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

task_config

In [None]:
# create all components for config
# FYI: if using train-clean-100 & dev-clean this cell takes O(60s) 
seq_to_seq, epochs, train_loader, eval_loader = build(task_config)
seq_to_seq

In [None]:
print("number of params:", sum(p.numel() for _, p in seq_to_seq.model.named_parameters()))

In [None]:
import itertools
from torch.utils.data.sampler import SequentialSampler
from torch.utils.data import DataLoader

NUMBER_SAMPLES = 2

#get dataset from trainloader:
dataset = train_loader.sampler.data_source
print(dataset)

train_loader_no_shuffle = DataLoader(dataset, batch_size=1, shuffle=False)
train_loader_overfit = list(itertools.islice(train_loader_no_shuffle, NUMBER_SAMPLES))
eval_loader_overfit = train_loader_overfit #i.e. use the same


In [None]:
#custom callback to monitor training and print results
class PrintCB(Callback):
    def __init__(self):
        super().__init__()
    
    def on_epoch_end(self, **kwargs):
        if self.training:
            return
        epoch = kwargs["epoch"]
        if epoch % 10 == 0:
            try:
                wer_reports = kwargs["reports"][seq_to_seq.post_process.__class__.__name__]
                wer = wer_reports["wer"]
                transcripts = wer_reports["transcripts"][0] #take first element
                pred, exp = transcripts
                pred = "".join(pred)
                exp = "".join(exp)
                loss = kwargs["reports"]["ReportMeanBatchLoss"]

                print("{}, pred: {}, exp: {}, loss {:.8f}, wer: {:.4f}".format(epoch, pred, exp, loss, wer, ))
            except KeyError:
                print("no wer - using new decoder?")
        

In [None]:
from myrtlespeech.post_process.rnn_t_decoders import RNNTBeamDecoder, RNNTGreedyDecoder
use_beam = False
no_max = False
if use_beam:
    decoder = RNNTBeamDecoder(blank_index=28,
                                beam_width=4,
                                 length_norm=False,
                                 max_symbols_per_step = 4,
                             model=seq_to_seq.model)
else:
    decoder = RNNTGreedyDecoder(blank_index=28,
                                 max_symbols_per_step = 4,
                               model=seq_to_seq.model) 

seq_to_seq.post_process = decoder

if no_max:
    seq_to_seq.post_process.max_symbols_per_step = 100
else:
    seq_to_seq.post_process.max_symbols_per_step = 4

In [None]:
from myrtlespeech.run.callbacks.rnn_t_training import RNNTTraining
from myrtlespeech.run.run import ReportRNNTDecoder

rnnt_decoder_cb  = ReportRNNTDecoder(seq_to_seq.post_process, seq_to_seq.alphabet)

callbacks = [RNNTTraining(),
            ReportMeanBatchLoss(), 
            rnnt_decoder_cb,
            PrintCB()] 


In [None]:
fit(
    seq_to_seq, 
    epochs=3000,
    train_loader=train_loader_overfit, 
    eval_loader=train_loader_overfit, #i.e. use the same loader
    callbacks=callbacks,
)
