# API Development

This notebook exists to develop the API in before moving the code to well-tested, documented Python files and any new stable models to Protobuf configurations.

Currently being used to bring-up the RNNT model.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import pathlib
import typing

import torch
from google.protobuf import text_format

from myrtlespeech.model.rnnt import RNNTEncoder, RNNT
from myrtlespeech.run.callbacks.csv_logger import CSVLogger
from myrtlespeech.run.callbacks.callback import Callback, ModelCallback
from myrtlespeech.run.callbacks.clip_grad_norm import ClipGradNorm
from myrtlespeech.run.callbacks.report_mean_batch_loss import ReportMeanBatchLoss
from myrtlespeech.run.callbacks.stop_epoch_after import StopEpochAfter
from myrtlespeech.run.callbacks.mixed_precision import MixedPrecision
from myrtlespeech.post_process.utils import levenshtein
from myrtlespeech.builders.task_config import build
from myrtlespeech.run.train import fit
from myrtlespeech.protos import task_config_pb2
from myrtlespeech.run.stage import Stage

In [None]:
torch.backends.cudnn.benchmark = False

Build the RNNT model defined in the config file:

In [None]:
# parse example config file
with open("../src/myrtlespeech/configs/rnnt_en.config") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

task_config

In [None]:
# create all components for config
seq_to_seq, epochs, train_loader, eval_loader = build(task_config)


In [None]:
import math
from typing import Tuple

from myrtlespeech.builders.fully_connected import build as build_fully_connected
from myrtlespeech.builders.rnn import build as build_rnn
from myrtlespeech.protos import rnn_t_pb2
from myrtlespeech.protos import rnn_t_encoder_pb2
from myrtlespeech.data.stack import StackTime
from myrtlespeech.model.utils import Lambda
from torch import nn

"""

oneof time_reduction {
bool time_reduction_NULL = 2;
uint32 time_reduction_factor = 3;
}
  
  
  """



# rnn_t {
#     rnn_t_encoder {
#       rnn1 {
#         rnn_type: LSTM;
#         hidden_size: 640;
#         num_layers: 3;
#         bias: true;
#         bidirectional: false;
#       }
# }
    
def build_rnn_t(
    rnn_t_cfg: rnn_t_pb2.RNNT,
    input_features: int,
    vocab_size: int
) -> RNNT:
    """
    
    """
    encoder, encoder_out = build_rnnt_enc(rnn_t_cfg.rnn_t_encoder, input_features)
    
    ##decoder/prediction network
    #can get embedding dims from the rnnt
    embedding = nn.Embedding(vocab_size, rnn_t_cfg.dec_rnn.hidden_size)
    dec_rnn, prediction_out = build_rnn(
            rnn_t_cfg.dec_rnn,
            vocab_size
        )
    
    ##joint 
    fc_in_dim = encoder_out + prediction_out #features are concatenated
    
    fully_connected = build_fully_connected(
        rnn_t_cfg.fully_connected,
        input_features=fc_in_dim,
        output_features=vocab_size + 1,
    )
    
    return RNNT(encoder, embedding, dec_rnn, fully_connected)
    

def build_rnnt_enc(
    rnn_t_enc: rnn_t_encoder_pb2.RNNTEncoder, 
    input_features: int,
) -> Tuple[RNNTEncoder, int]:
    """
    
    
    """
    rnn1, rnn1_out_features = build_rnn(
            rnn_t_enc.rnn1,
            input_features
        )
    
    
    
    if rnn_t_enc.time_reduction_factor == 0: #default value (i.e. not set)
        assert rnn_t_enc.HasField('rnn2') is False
        encoder = RNNTEncoder(rnn1)
        
        encoder_out_features = rnn1_out_features
    else:
        time_reduction_factor = rnn_t_enc.time_reduction_factor
        
        assert time_reduction_factor > 1, "time_reduction_factor must be an integer > 1 but is = {time_reduction_factor}"
        
        
        reduction = rnn_t_enc.time_reduction_factor
        
        time_reducer = Lambda(StackTime(reduction))
        
        rnnt_input_features = rnn1_out_features * reduction
        
        rnn2, encoder_out_features = build_rnn(
            rnn_t_enc.rnn2,
            rnnt_input_features
        )
        encoder = RNNTEncoder(rnn1, time_reducer, reduction, rnn2)
        
    
    return encoder, encoder_out_features


stt = task_config.speech_to_text
encoder = build_rnnt_enc(stt.rnn_t.rnn_t_encoder, 16)

encoder

rnnt = build_rnn_t(stt.rnn_t, 80, 28)

rnnt

In [None]:

from myrtlespeech.model.seq_len_wrapper import SeqLenWrapper
from myrtlespeech.model.utils import Lambda
from myrtlespeech.protos import conv_layer_pb2
from myrtlespeech.protos import rnn_t_pb2

#params required for whole RNNT:
hidden_size = 640 #rnn hidden size
projection_size = 640 #LSTM projection size
num_enc_layers = 8 #Number of layers in the encoder
num_pred_layers = 2 #Number of layers in the prediction network
input_features = 80 * 4
vocab_size = 28

# #hierachy is:
# build_rnnt
#     #should check that dims match up
#     build_rnnt_enc
#     build_rnnt_pred
#     build_rnnt_joint



In [None]:
rnn = torch.nn.LSTM(2, 4)
rnn.batch_first

In [None]:
ds2 = build_ds2(
    ds2_cfg=task_config.speech_to_text.deep_speech_2,
    input_features=83,
    output_features=len(list("_abcdefghijklmnopqrstuvwxyz '")),
    input_channels=1
) 

In [None]:
ds2

Now build all the other components using the generic build function and monkey-patch in the DS2 model:

In [None]:
# parse example config file
with open("../src/myrtlespeech/configs/deep_speech_2_en.config") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

In [None]:
# create all components for config
seq_to_seq, epochs, train_loader, eval_loader = build(task_config)

In [None]:
seq_to_seq.model = ds2.cuda()

In [None]:
seq_to_seq.optim = torch.optim.Adam(
    params = ds2.parameters(),
    lr=0.001
)

Define some useful callbacks:

In [None]:
from typing import List

class WordSegmentor:
    def __init__(self, separator: str):
        self.separator = separator
        
    def __call__(self, sentence: List[str]) -> List[str]:
        new_sentence = []
        word = []
        for symb in sentence:
            if symb == self.separator:
                if word:
                    new_sentence.append("".join(word))
                    word = []
            else:
                word.append(symb)
        if word:
            new_sentence.append("".join(word))
        return new_sentence

In [None]:
ctc_greedy = CTCGreedyDecoder(blank_index=0)
ctc_beam = CTCBeamDecoder(blank_index=0, beam_width=12)

class ReportCTCDecoder(Callback):
    """TODO
    
    Args:
        ctc_decoder: decodes output to sequence of indices based on CTC
        
        alphabet: converts sequences of indices to sequences of symbols (strs)
        
        word_segmentor: groups sequences of symbols into sequences of words
    """
    def __init__(self, ctc_decoder, alphabet, word_segmentor):
        self.ctc_decoder = ctc_decoder
        self.alphabet = alphabet
        self.word_segmentor = word_segmentor
        
    def _reset(self, **kwargs):
        kwargs["reports"][self.ctc_decoder.__class__.__name__] = {
            "wer": -1.0,
            "transcripts": []
        }
        self.distances = []
        self.lengths = []
        
    def on_train_begin(self, **kwargs):
        self._reset(**kwargs)
        
    def on_epoch_begin(self, **kwargs):
        self._reset(**kwargs)
        
    def _process(self, sentence: List[int]) -> List[str]:
        symbols = self.alphabet.get_symbols(sentence)
        return self.word_segmentor(symbols)
        
    def on_batch_end(self, **kwargs):
        if self.training:
            return
        transcripts = kwargs["reports"][self.ctc_decoder.__class__.__name__]["transcripts"]
        
        targets = kwargs["last_target"][0]
        target_lens = kwargs["last_target"][1]

        acts = self.ctc_decoder(*kwargs["last_output"])
        for act, target, target_len in zip(acts, targets, target_lens):
            act = self._process(act)
            exp = self._process([int(e) for e in target[:target_len]])
            
            transcripts.append((act, exp))
            
            distance = levenshtein(act, exp)
            self.distances.append(distance)
            self.lengths.append(len(exp))
              
    def on_epoch_end(self, **kwargs):
        if self.training:
            return
        wer = float(sum(self.distances)) / sum(self.lengths) * 100
        kwargs["reports"][self.ctc_decoder.__class__.__name__]["wer"] = wer

In [None]:
class Foo(Callback):
    def on_epoch_end(self, **kwargs):
        from IPython.display import clear_output
        clear_output()
        for act, exp in kwargs["reports"]["CTCGreedyDecoder"]["transcripts"]:
            print(act, exp)
        print('\n\n\n')

In [None]:
import time

from torch.utils.tensorboard import SummaryWriter

class TensorBoardLogger(ModelCallback):
    def __init__(self, model, histograms=False):
        super().__init__(model)
        self.writer = SummaryWriter(
            log_dir=f'/tmp/writer/{time.time()}',
        )
        self.histograms = histograms
        
    def on_backward_begin(self, **kwargs):
        if not self.training:
            return
        stage = "train" if self.training else "eval"
        self.writer.add_scalar(
            f"{stage}/loss", 
            kwargs["last_loss"].item(),
            global_step=kwargs["total_train_batches"]
        )
        
    def on_step_end(self, **kwargs):
        if not self.training or not self.histograms:
            return
        for name, param in self.model.named_parameters():
            if param.grad is None:
                continue
            self.writer.add_histogram(
                name.replace(".", "/") + "/grad", 
                param.grad,
                global_step=kwargs["total_train_batches"]
            )
        
    def on_batch_end(self, **kwargs):
        if not self.training or not self.histograms:
            return
        for name, param in self.model.named_parameters():
            self.writer.add_histogram(
                name.replace(".", "/"), 
                param,
                global_step=kwargs["total_train_batches"]
            )
        
    def on_train_end(self, **kwargs):
        self.writer.close()

Compute an estimate of the mean (TODO: improve interface!):

In [None]:
standardize = seq_to_seq.pre_process_steps[2][0]

for idx, x in enumerate(train_loader.dataset):
    if idx > 10000:
        break
    seq_to_seq.pre_process(x[0][0])

standardize.training = False

Train the model using the fit function:

In [None]:
# train the model
fit(
    seq_to_seq, 
    1000,#epochs, 
    train_loader=train_loader, 
    eval_loader=eval_loader,
    callbacks=[
        ReportMeanBatchLoss(),
        ReportCTCDecoder(
            ctc_greedy, 
            seq_to_seq.alphabet,
            WordSegmentor(" "),
        ),
        TensorBoardLogger(seq_to_seq.model, histograms=False),
        MixedPrecision(seq_to_seq, opt_level="O1"),
        ClipGradNorm(seq_to_seq, max_norm=400),
        #StopEpochAfter(epoch_batches=30),
        CSVLogger("/tmp/foo_0.csv", 
            exclude=[
                "epochs", 
                #"reports/CTCGreedyDecoder/transcripts",
            ]
        )
    ],
)