# API Development

This notebook contains code to run a model using the current API. It exists as a playground for developing the API.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pathlib
import typing

import torch
from google.protobuf import text_format

from myrtlespeech.model.speech_to_text import SpeechToText
from myrtlespeech.run.callbacks.csv_logger import CSVLogger
from myrtlespeech.run.callbacks.callback import Callback, ModelCallback
from myrtlespeech.run.callbacks.report_mean_batch_loss import ReportMeanBatchLoss
from myrtlespeech.run.callbacks.stop_epoch_after import StopEpochAfter
from myrtlespeech.run.callbacks.mixed_precision import MixedPrecision
from myrtlespeech.post_process.utils import levenshtein
from myrtlespeech.post_process.ctc_greedy_decoder import CTCGreedyDecoder
from myrtlespeech.post_process.ctc_beam_decoder import CTCBeamDecoder
from myrtlespeech.builders.task_config import build
from myrtlespeech.run.train import fit
from myrtlespeech.protos import task_config_pb2
from myrtlespeech.run.stage import Stage

In [None]:
torch.backends.cudnn.benchmark = True

In [None]:
# parse example config file
with open("../src/myrtlespeech/configs/2-block-vgg_5-bidir-lstm_ctc.config") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

In [None]:
# create all components for config
seq_to_seq, epochs, train_loader, eval_loader = build(task_config)

In [None]:
from typing import List

class WordSegmentor:
    def __init__(self, separator: str):
        self.separator = separator
        
    def __call__(self, sentence: List[str]) -> List[str]:
        new_sentence = []
        word = []
        for symb in sentence:
            if symb == self.separator:
                if word:
                    new_sentence.append("".join(word))
                    word = []
            else:
                word.append(symb)
        if word:
            new_sentence.append("".join(word))
        return new_sentence

In [None]:
ctc_greedy = CTCGreedyDecoder(blank_index=0)
ctc_beam = CTCBeamDecoder(blank_index=0, beam_width=12)

class ReportCTCDecoder(Callback):
    """TODO
    
    Args:
        ctc_decoder: decodes output to sequence of indices based on CTC
        
        alphabet: converts sequences of indices to sequences of symbols (strs)
        
        word_segmentor: groups sequences of symbols into sequences of words
    """
    def __init__(self, ctc_decoder, alphabet, word_segmentor):
        self.ctc_decoder = ctc_decoder
        self.alphabet = alphabet
        self.word_segmentor = word_segmentor
        
    def _reset(self, **kwargs):
        kwargs["reports"][self.ctc_decoder.__class__.__name__] = {
            "wer": -1.0,
            "transcripts": []
        }
        self.distances = []
        self.lengths = []
        
    def on_train_begin(self, **kwargs):
        self._reset(**kwargs)
        
    def on_epoch_begin(self, **kwargs):
        self._reset(**kwargs)
        
    def _process(self, sentence: List[int]) -> List[str]:
        symbols = self.alphabet.get_symbols(sentence)
        return self.word_segmentor(symbols)
        
    def on_batch_end(self, **kwargs):
        if self.training:
            return
        transcripts = kwargs["reports"][self.ctc_decoder.__class__.__name__]["transcripts"]
        
        targets = kwargs["last_target"][0]
        target_lens = kwargs["last_target"][1]

        acts = self.ctc_decoder(*kwargs["last_output"])
        for act, target, target_len in zip(acts, targets, target_lens):
            act = self._process(act)
            exp = self._process([int(e) for e in target[:target_len]])
            
            transcripts.append((act, exp))
            
            distance = levenshtein(act, exp)
            self.distances.append(distance)
            self.lengths.append(len(exp))
              
    def on_epoch_end(self, **kwargs):
        if self.training:
            return
        wer = float(sum(self.distances)) / sum(self.lengths) * 100
        kwargs["reports"][self.ctc_decoder.__class__.__name__]["wer"] = wer

In [None]:
# train the model
fit(
    seq_to_seq, 
    10000,#epochs, 
    train_loader=train_loader, 
    eval_loader=eval_loader,
    callbacks=[
        ReportMeanBatchLoss(),
        ReportCTCDecoder(
            ctc_greedy, 
            seq_to_seq.alphabet,
            WordSegmentor(" "),
        ),
        MixedPrecision(seq_to_seq, opt_level="O1"),
        #StopEpochAfter(epoch_batches=1),
        CSVLogger("/tmp/foo.csv", 
            exclude=[
                "epochs", 
                #"reports/CTCGreedyDecoder/transcripts",
            ]
        )
    ],
)