# RNN-T Overfitting


### Log of attempts
**Numbers refer to exp_dir:**
2. Baseline - took ~ 1000 epochs to almost overfit (WER = 0.2) to x2 samples
3. Tried higher lr (0.0003 -> 0.003). This did not help (loss was not able to decrease below 0.7 after 700 epochs). Using lr = 0.0003 going forward.
4. NUM_SAMPLES = 4. After x100 epochs - WER = 40% (i.e. distinguishing between samples) but after x200, WER =85%. After 600 epochs, WER=83%. Stopped. 
5. (4. with no mixed precision. _1346, pred: twenty n, exp: twenty nine, loss 0.00036553, wer: 21.4286_
6. Reduced to MUCH smaller model encoder hidden state size = 128, prediction = 64 and lstm type = GRU. Takes much longer to get to resonable loss (hangs around loss=30 for a number of epochs whereas previously this went to < 1 in ~ 10 epochs). Stopped as it doesn't reach < 1 in 500 epochs (loss=1.037)
7. Same as 6. but with prediction hidden state = 128. After 500 epochs, loss = 0.88. after 800, loss 0.37256373, wer: 100.0000
8. Back to LSTM, otherwise same as 6 (prediction hidden state = 64). MUCH better than 6. After 500 epochs, loss = 0.17, WER = 62%. i.e. previously, WER=100%. After 800, loss: 0.03146841, wer: 58.9286, after 1000, loss: 0.015, wer: 53% . after 2631, loss 0.00012446, wer: 41.0714
9. Same as 7 (prediction hidden state = 128) but with LSTM instead of GRU. Results are weird - loss is MUCH better than 8. but WER is much worse. After 500, loss 0.03620639, wer: 98.2143. After 800, loss 0.00908716, wer: 100.0000. BUT @ around 1900 epochs - the loss increases a lot and the WER decreases: After 2010, loss 0.02600078, wer: 25.0000. After 4000, loss 0.00001341, wer: 19.6429. It reached WER=3.5 but then increased again to WER=19%.
10. Possibly the lr is too high? use lr = 0.0001, (0.0003 -> 0.0001). Everthing else as in 9. No - this was very slow: After 800, loss 2.69113094, wer: 100.0000. After 3837, loss 0.00392524, wer: 66.0714.
11. Back to original lr = 0.0003. Using smaller prediction network (1 layer of size 48). After 500 epochs, loss 0.20005125, wer: 98.2143. i.e. a bit worse than 8. After 800, loss 0.03236475, wer: 82.1429
12. Larger model encoder hidden = 400, prediction = 48 (2 layers in both). i.e. these recent expts are taking a v. long time to get to low loss. After 800, loss 0.02806847, wer: 57.1429. After 3073, loss 0.01223648, wer: 48.2143.
13. Increase number of samples = 32. It found this v. hard. (Also much slower)
14. Repeat 13 but with bidirectional lstm in prediction network. Not sure this actually makes sense for decoding (as it is taking a partial sequence and going forwards and backwards on it). 
15. Use bi-directional lstm in encoder. This is not allowed in streaming use-case. But is the problem just model capacity? Loss goes down v. slowly. After ~ 600 epochs: loss: 0.16034877, wer: 58.1690, - this is fairly good for 32 samples. 
16. Higher batch_size and bidirectiona=False. Loss goes down even more slowly (this is not surprising given that we have halved the number of params. After ~600, loss: 0.21108173, wer: 82.7536. After ~800, loss: 0.09715541, wer: 68.2767. The quality of the predictions suggests that the prediction network is still too strong.
17. Added dropout. Loss values are lower but *I think* the WER is more stable. After ~600 epochs: loss: 0.97649786, wer: 60.3360. Prediction network appears to be weak. 


NOTE! - ALL of above experiments where I refer to changing the lr are not correct. i.e. I was not actually changing the lr! So me 'noticing' an effect was just confirmation bias. 

18. Added larger prediction network hidden_size = 128. Appeared to be better after 2000, loss: 0.47531913, wer: 35.4640,
19. Added MelFB preprocessing - the loss went down much quicker in this case. After 400, loss: 0.08690740, wer: 67.0298,
20. Added LogMelFB preprocessing. After 400, loss: 0.16579449, wer: 68.2898. After 550: loss: 0.07083742, wer: 56.9760






In [None]:
log_dir = "/home/julian/exp/speech/myrtlespeech/rnnt/overfit/21/"
NUMBER_SAMPLES = 2

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import pathlib
import typing

import torch
from google.protobuf import text_format

from myrtlespeech.model.rnn_t import RNNTEncoder, RNNT
from myrtlespeech.run.run import TensorBoardLogger
from myrtlespeech.run.callbacks.csv_logger import CSVLogger
from myrtlespeech.run.callbacks.callback import Callback, ModelCallback
from myrtlespeech.run.callbacks.clip_grad_norm import ClipGradNorm
from myrtlespeech.run.callbacks.report_mean_batch_loss import ReportMeanBatchLoss
from myrtlespeech.run.callbacks.stop_epoch_after import StopEpochAfter
from myrtlespeech.run.callbacks.mixed_precision import MixedPrecision
from myrtlespeech.post_process.utils import levenshtein
from myrtlespeech.builders.task_config import build
from myrtlespeech.run.train import fit
from myrtlespeech.protos import task_config_pb2
from myrtlespeech.run.stage import Stage

from myrtlespeech.run.train import run_stage
from myrtlespeech.run.callbacks.callback import CallbackHandler


In [None]:
torch.backends.cudnn.benchmark = False
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

Build the RNNT model defined in the config file:

In [None]:
# parse example config file
with open("../src/myrtlespeech/configs/rnn_t_SMALL_en.config") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

task_config

In [None]:
# create all components for config
# FYI: if using train-clean-100 & dev-clean this cell takes O(60s) 
seq_to_seq, epochs, train_loader, eval_loader = build(task_config)


In [None]:
print(len(train_loader))
len(eval_loader)

In [None]:
save_model = False
fp_out = log_dir + "saved_state_dict.pt"
if save_model:
    torch.save(seq_to_seq.model.state_dict(), fp_out)

In [None]:
load_model = True
if load_model:
    seq_to_seq.model.load_state_dict(torch.load(fp_out))

In [None]:
seq_to_seq

In [None]:
total = 0
for p in seq_to_seq.model.parameters():
    total += p.numel()
print(total)

## Get overfit loaders

In [None]:
import itertools
from torch.utils.data.sampler import SequentialSampler
from torch.utils.data import DataLoader
from myrtlespeech.data.batch import seq_to_seq_collate_fn


#get dataset from trainloader:
dataset = train_loader.sampler.data_source
print(dataset)

train_loader_no_shuffle = DataLoader(dataset, batch_size=8, shuffle=False, collate_fn=seq_to_seq_collate_fn)
train_loader_overfit = list(itertools.islice(train_loader_no_shuffle, NUMBER_SAMPLES))
eval_loader_overfit = train_loader_overfit #i.e. use the same


## Maybe change decoder

In [None]:
from myrtlespeech.post_process.rnn_t_beam_decoder import RNNTBeamDecoder
from myrtlespeech.post_process.rnn_t_greedy_decoder import RNNTGreedyDecoder
use_beam = False
no_max = True
if use_beam:
    decoder = RNNTBeamDecoder(blank_index=28,
                                beam_width=4,
                                 length_norm=False,
                                 max_symbols_per_step = 4,
                             model=seq_to_seq.model)
else:
    decoder = RNNTGreedyDecoder(blank_index=28,
                                 max_symbols_per_step = 4,
                               model=seq_to_seq.model) 

seq_to_seq.post_process = decoder

if no_max:
    seq_to_seq.post_process.max_symbols_per_step = 100
else:
    seq_to_seq.post_process.max_symbols_per_step = 4

## Callbacks
* Use callbacks to inject features into training loop. 
* It is necessary (for now) to use the `RNNTTraining()` callback but the others are optional


In [None]:
#custom callback to monitor training and print results
class PrintCB(Callback):
    def __init__(self):
        super().__init__()
    
    def on_batch_end(self, **kwargs):
        
        if self.training:
            #print("training batch ended")
            return
        epoch = kwargs["epoch"]
        if kwargs["epoch_batches"] % 100 == 0 and kwargs["epoch_batches"] != 0:
            print(f"{kwargs['epoch_batches']} batches completed")
            try:
                wer_reports = kwargs["reports"][seq_to_seq.post_process.__class__.__name__]
                wer = wer_reports["wer"]
                if len(wer_reports["transcripts"]) > 0:
                    transcripts = wer_reports["transcripts"][0] #take first element
                    pred, exp = transcripts
                    pred = "".join(pred)
                    exp = "".join(exp)
                    loss = kwargs["reports"]["ReportMeanBatchLoss"]
                    print("batch end, pred: {}, exp: {}, wer: {:.4f}".format(pred, exp, wer, ))

            except KeyError:
                print("no wer - using new decoder?")
        
        
            
    def on_epoch_end(self, **kwargs):
        if self.training:
            return
        epoch = kwargs["epoch"]
        
        try:
            
            loss = kwargs["reports"]["ReportMeanBatchLoss"]
            
            wer_reports = kwargs["reports"][seq_to_seq.post_process.__class__.__name__]
            wer = wer_reports["wer"]
            
            out_str = "{}, loss: {:.8f}".format(epoch, loss)
            
            if len(wer_reports["transcripts"]) > 0:
                transcripts = wer_reports["transcripts"][0] #take first element
                pred, exp = transcripts
                pred = "".join(pred)
                exp = "".join(exp)
                
                out_str += ", wer: {:.4f}, pred: {}, exp: {},".format(wer, pred, exp)
            print(out_str)
        except KeyError:
            
            print("no wer - using new decoder?")        
        

In [None]:
from myrtlespeech.run.callbacks.rnn_t_training import RNNTTraining
from myrtlespeech.run.run import ReportRNNTDecoder

rnnt_decoder_cb  = ReportRNNTDecoder(seq_to_seq.post_process, seq_to_seq.alphabet, eval_every=50, 
                                         skip_first_epoch=True)


keys_to_log = ["epoch", 
        f"reports/{seq_to_seq.post_process.__class__.__name__}/wer",
        "reports/ReportMeanBatchLoss"]



callbacks = [RNNTTraining(),
            ReportMeanBatchLoss(), 
            TensorBoardLogger(log_dir, seq_to_seq.model, histograms=True),
            MixedPrecision(seq_to_seq),
            ClipGradNorm(seq_to_seq, 400),
            rnnt_decoder_cb,
             
            #stop prematurely (useful for debug). Ensure following line is commented out to perform full training
            #StopEpochAfter(epoch_batches=2),
            
            # logging
            CSVLogger(log_dir + "log.csv", keys=keys_to_log),
            
            PrintCB(),
            Saver(log_dir, seq_to_seq.model)] 


In [None]:
change_lr = False

for param_group in seq_to_seq.optim.param_groups:
    print("current lr: ", param_group['lr'])

new_lr = 0.0001
if change_lr:
    for param_group in seq_to_seq.optim.param_groups:
        param_group['lr'] = new_lr

for param_group in seq_to_seq.optim.param_groups:
    print("new lr: ", param_group['lr'])

In [None]:
fit(
    seq_to_seq, 
    epochs=40,
    train_loader=train_loader_overfit, 
    eval_loader=eval_loader_overfit,
    callbacks=callbacks,
)



## Run eval?


In [None]:
run_eval = True
eval_cbs = None
if run_eval:
    eval_cbs = [RNNTTraining(),
            ReportMeanBatchLoss(), 
            ReportRNNTDecoder(seq_to_seq.post_process, seq_to_seq.alphabet),
            CSVLogger(log_dir + "log_eval.csv", keys=keys_to_log),
            TensorBoardLogger(log_dir, seq_to_seq.model),
            PrintCB(),] 
    cb_handler = CallbackHandler(eval_cbs, False)
    cb_handler.on_train_begin(epochs=2)
    
    run_stage(seq_to_seq, cb_handler, eval_loader, is_training=False)

In [None]:
def eval(
    seq_to_seq: SeqToSeq,
    eval_loader: DataLoader,
    callbacks: Optional[Collection[Callback]] = None,
) -> None:
    is_training = False
    cb_handler = CallbackHandler(callbacks, is_training)
    cb_handler.on_train_begin(epochs=1)
    
    run_stage(seq_to_seq, cb_handler, eval_loader, is_training=is_training)

In [None]:
eval(
    seq_to_seq, 
    eval_loader=eval_loader_overfit,
    callbacks=callbacks,
)