# API Development

This notebook contains code to run a model using the current API. It exists as a playground for developing the API.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pathlib
import typing

import torch
from google.protobuf import text_format

from myrtlespeech.model.speech_to_text import SpeechToText
from myrtlespeech.run.callbacks.csv_logger import CSVLogger
from myrtlespeech.run.callbacks.callback import Callback, ModelCallback
from myrtlespeech.run.callbacks.report_mean_batch_loss import ReportMeanBatchLoss
from myrtlespeech.run.callbacks.stop_epoch_after import StopEpochAfter
from myrtlespeech.post_process.utils import levenshtein
from myrtlespeech.post_process.ctc_greedy_decoder import CTCGreedyDecoder
from myrtlespeech.post_process.ctc_beam_decoder import CTCBeamDecoder
from myrtlespeech.builders.task_config import build
from myrtlespeech.run.train import fit
from myrtlespeech.protos import task_config_pb2
from myrtlespeech.run.stage import Stage

In [None]:
# parse example config file
with open("../src/myrtlespeech/configs/2-block-vgg_5-bidir-lstm_ctc.config") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

In [None]:
# create all components for config
model, epochs, optim, train_loader, eval_loader = build(
    task_config
)

In [None]:
ctc_greedy = CTCGreedyDecoder(blank_index=0)
ctc_beam = CTCBeamDecoder(blank_index=0, beam_width=12)

class ReportCTCDecoderWER(Callback):
    def __init__(self, ctc_decoder, separator_index=None):
        super().__init__(model)
        self.ctc_decoder = ctc_decoder
        self.separator_index = separator_index
        
    def _reset(self, **kwargs):
        kwargs["reports"][self.ctc_decoder.__class__.__name__] = {"wer": 100.0}
        self.distances = []
        self.lengths = []
        
    def on_train_begin(self, **kwargs):
        self._reset(**kwargs)
        
    def on_epoch_begin(self, **kwargs):
        self._reset(**kwargs)
        
    def on_batch_end(self, **kwargs):
        targets = kwargs["last_target"][0]
        target_lens = kwargs["last_target"][1]

        acts = self.ctc_decoder(*kwargs["last_output"])
        for act, target, target_len in zip(acts, targets, target_lens):
            exp = [int(e) for e in target[:target_len]]
            if self.separator_index is not None:
                act = word_chunk(act, self.separator_index)
                exp = word_chunk(exp, self.separator_index)
            distance = levenshtein(act, exp)
            self.distances.append(distance)
            self.lengths.append(len(exp))
              
    def on_epoch_end(self, **kwargs):
        wer = float(sum(self.distances)) / sum(self.lengths) * 100
        kwargs["reports"][self.ctc_decoder.__class__.__name__]["wer"] = wer
        
def word_chunk(seq, sep_idx): 
    import itertools
    return [list(group) for k, group in itertools.groupby(seq, key=lambda x: x == sep_idx) if not k]

In [None]:
# train the model
fit(
    model, 
    10000,#epochs, 
    optim, 
    train_loader=train_loader, 
    eval_loader=eval_loader,
    callbacks=[
        ReportMeanBatchLoss(),
        ReportCTCDecoderWER(ctc_greedy, model.alphabet.get_index(" ")),
        #ReportCTCDecoderWER(model, ctc_beam), 
        StopEpochAfter(epoch_batches=1),
        CSVLogger("/tmp/foo.csv", exclude=["epochs"])
    ],
)