# RNN-T Overfitting
Validate that all elements of the pipeline are working by overfitting to a small number of training examples

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import pathlib
import typing
import time

import torch
from google.protobuf import text_format

from myrtlespeech.model.rnn_t import RNNT
from myrtlespeech.run.callbacks.csv_logger import CSVLogger
from myrtlespeech.run.callbacks.callback import Callback, ModelCallback
from myrtlespeech.run.callbacks.clip_grad_norm import ClipGradNorm
from myrtlespeech.run.callbacks.report_mean_batch_loss import ReportMeanBatchLoss
from myrtlespeech.run.callbacks.stop_epoch_after import StopEpochAfter
from myrtlespeech.run.callbacks.mixed_precision import MixedPrecision
from myrtlespeech.post_process.utils import levenshtein
from myrtlespeech.builders.task_config import build
from myrtlespeech.run.train import fit
from myrtlespeech.protos import task_config_pb2
from myrtlespeech.run.stage import Stage

In [3]:
torch.backends.cudnn.benchmark = False

In [4]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

Build the RNNT model defined in the config file:

In [5]:
# parse example config file
with open("../src/myrtlespeech/configs/rnn_t_en.config") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

task_config

speech_to_text {
  alphabet: " abcdefghijklmnopqrstuvwxyz\'_"
  pre_process_step {
    stage: TRAIN_AND_EVAL
    lmfb {
      n_mels: 80
      win_length: 400
      hop_length: 160
    }
  }
  pre_process_step {
    stage: TRAIN_AND_EVAL
    standardize {
    }
  }
  pre_process_step {
    stage: TRAIN_AND_EVAL
    left_context_frames {
      n_context: 3
      subsample: 3
    }
  }
  rnn_t {
    transcription {
      n_hidden: 1152
      rnn_layers: 2
    }
    prediction {
      n_hidden: 256
      rnn_layers: 2
    }
    joint {
      n_hidden: 512
    }
  }
  rnn_t_loss {
    blank_index: 28
    reduction: SUM
  }
  rnn_t_greedy_decoder {
    blank_index: 28
    max_symbols_per_step: 30
  }
}
train_config {
  batch_size: 8
  epochs: 40
  adam {
    learning_rate: 0.0003000000142492354
  }
  dataset {
    librispeech {
      root: "/data/"
      subset: DEV_CLEAN
      max_secs {
        value: 16.700000762939453
      }
    }
  }
  shuffle_batches_before_every_epoch: true
}
eval_c

In [6]:
# create all components for config
# FYI: if using train-clean-100 & dev-clean this cell takes O(60s) 
seq_to_seq, epochs, train_loader, eval_loader = build(task_config)
seq_to_seq

SpeechToText(
  (alphabet): Alphabet(symbols=[' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', "'", '_'])
  (model): RNNT(
    (encoder): Sequential(
      (0): Linear(in_features=320, out_features=1152, bias=True)
      (1): Hardtanh(min_val=0.0, max_val=20.0)
      (2): Dropout(p=0.25, inplace=False)
      (3): Linear(in_features=1152, out_features=1152, bias=True)
      (4): Hardtanh(min_val=0.0, max_val=20.0)
      (5): Dropout(p=0.25, inplace=False)
      (6): BNRNNSum(
        (layers): ModuleList(
          (0): RNNLayer(
            (rnn): LSTM(1152, 1152)
          )
          (1): Dropout(p=0.25, inplace=False)
          (2): RNNLayer(
            (rnn): LSTM(1152, 1152)
          )
        )
      )
      (7): Lambda(lambda_fn=Access RNN output)
      (8): Linear(in_features=1152, out_features=1152, bias=True)
      (9): Hardtanh(min_val=0.0, max_val=20.0)
      (10): Dropout(p=0.25, inplac

In [7]:
print("number of params:", sum(p.numel() for _, p in seq_to_seq.model.named_parameters()))

number of params: 26337181


In [8]:
#custom callback to monitor training and print results
class PrintCB(Callback):
    def __init__(self):
        super().__init__()
    
    def on_epoch_end(self, **kwargs):
        if self.training:
            return
        epoch = kwargs["epoch"]
        if epoch % 10 == 0:
            try:
                wer_reports = kwargs["reports"][seq_to_seq.post_process.__class__.__name__]
                wer = wer_reports["wer"]
                transcripts = wer_reports["transcripts"][0] #take first element
                pred, exp = transcripts
                pred = "".join(pred)
                exp = "".join(exp)
                loss = kwargs["reports"]["ReportMeanBatchLoss"]

                print("{}, pred: {}, exp: {}, loss {:.8f}, wer: {:.4f}".format(epoch, pred, exp, loss, wer, ))
            except KeyError:
                print("no wer - using new decoder?")
        

In [9]:
log_dir = "/home/samgd/logs/rnnt/" + str(time.time())

In [10]:
class SkipEval(Callback):
    def __init__(self, up_to=0):
        super().__init__()
        self.up_to = up_to 
        
    def on_batch_end(self, *args, **kwargs):
        if kwargs["epoch"] <= self.up_to and not self.training:
            print("stop epoch")
            return {"stop_epoch": True}

In [11]:
from myrtlespeech.run.callbacks.rnn_t_training import RNNTTraining
from myrtlespeech.run.run import ReportRNNTDecoder
from myrtlespeech.run.run import ReportCTCDecoder
from myrtlespeech.run.run import Saver
from myrtlespeech.run.run import TensorBoardLogger
from myrtlespeech.run.run import WordSegmentor


callbacks = [
    RNNTTraining(),
    ReportMeanBatchLoss(),
    ReportRNNTDecoder(
        seq_to_seq.post_process, 
        seq_to_seq.alphabet,
    ),
    TensorBoardLogger(log_dir, seq_to_seq.model, histograms=False),
    MixedPrecision(seq_to_seq, opt_level="O1"),
    CSVLogger(f"{log_dir}/log.csv", 
        exclude=[
            "epochs", 
            #"reports/CTCGreedyDecoder/transcripts",
        ]
    ),
    SkipEval(up_to=100)
    #StopEpochAfter(epoch_batches=1),
]

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [None]:
# Note the SkipEval callback skips computing the WER as it can 
# take a while whilst the predications are poor
fit(
    seq_to_seq, 
    epochs=3000,
    train_loader=train_loader, 
    eval_loader=eval_loader,
    callbacks=callbacks,
)


stop epoch
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 4096.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 2048.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 1024.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 512.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 256.0
stop epoch
stop epoch
stop epoch
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 128.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 64.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32.0
stop epoch
stop epoch
stop epoch
stop epoch
stop epoch
stop epoch
stop epoch
stop epoch
stop