# API Development

This notebook contains code to run a model using the current API. It exists as a playground for developing the API.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from google.protobuf import text_format

from myrtlespeech.run.callback import Callback, ModelCallback
from myrtlespeech.post_process.ctc_greedy_decoder import CTCGreedyDecoder
from myrtlespeech.post_process.ctc_beam_decoder import CTCBeamDecoder
from myrtlespeech.builders.task_config import build
from myrtlespeech.run.train import fit
from myrtlespeech.protos import task_config_pb2

In [3]:
# parse example config file
with open("../src/myrtlespeech/configs/2-block-vgg_5-bidir-lstm_ctc.config") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

In [4]:
# create all components for config
model, epochs, optim, train_loader, eval_loader = build(
    task_config, 
    seq_len_support=True
)

In [5]:
class Decoder(Callback):
    """Decode model output using Greedy and Beam approaches."""
    def __init__(self):
        super().__init__()
        self.beam = CTCBeamDecoder(
            blank_index=0,
            beam_width=16
        )
        self.greedy = CTCGreedyDecoder(blank_index=0)
        
    def on_loss_begin(self, **kwargs):
        out = kwargs['last_output']
        if not self.training and kwargs['iteration'] % 10 == 0:
            for symbol_indices in self.beam(torch.nn.functional.softmax(out['inputs'], dim=-1), out['input_lengths']):
                print("".join(model.alphabet.get_symbols(symbol_indices)))
                
            for symbol_indices in self.greedy(out['inputs'], out['input_lengths']):
                print("".join(model.alphabet.get_symbols(symbol_indices)))
                
decoder_callback = Decoder()

In [6]:
class TestCallback(ModelCallback):
    """Convert model output to form acceptable to model.loss(**args)."""
    def on_loss_begin(self, **kwargs):
        out = kwargs['last_output']
        last_output = {
            'inputs': out[0],
            'input_lengths': out[1]
        }
        return {'last_output': last_output}

    def on_batch_begin(self, **kwargs):
        print(f"training: {self.training},  iteration: {kwargs['iteration']}")
        if "last_loss" in kwargs:
              print(f"last_loss: {kwargs['last_loss']}")

    def on_batch_end(self, **kwargs):
        if not self.training:
            return {'stop_epoch': True}
    
callback = TestCallback(model)

In [None]:
# train the model
fit(
    model, 
    epochs, 
    optim, 
    train_loader=train_loader, 
    eval_loader=eval_loader,
    callbacks=[callback, decoder_callback]
)