# API Development

This notebook contains code to run a model using the current API. It exists as a playground for developing the API.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from google.protobuf import text_format

from myrtlespeech.run.callback import Callback, ModelCallback
from myrtlespeech.post_process.ctc_greedy_decoder import CTCGreedyDecoder
from myrtlespeech.post_process.ctc_beam_decoder import CTCBeamDecoder
from myrtlespeech.builders.task_config import build
from myrtlespeech.run.train import fit
from myrtlespeech.protos import task_config_pb2
from myrtlespeech.run.stage import Stage

In [None]:
# parse example config file
with open("../src/myrtlespeech/configs/2-block-vgg_5-bidir-lstm_ctc.config") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

In [None]:
# create all components for config
model, epochs, optim, train_loader, eval_loader = build(
    task_config, 
    seq_len_support=True
)

In [None]:
class GreedyCTCWER:
    def __init__(self):
        self.decoder = CTCGreedyDecoder(blank_index=0)
        
    def __call__(self, **kwargs):
        print(f"training={kwargs['training']}")
        targets = kwargs["last_target"][0]
        target_lens = kwargs["last_target"][1]
        xs = kwargs["last_output"]
        for x, target, target_len in zip(xs, targets, target_lens):
            act = "".join(model.alphabet.get_symbols(x))
            exp = "".join(model.alphabet.get_symbols([int(e) for e in target[:target_len]]))
            print(f"{act} ||| {exp}")
            
        metrics = kwargs["metrics"][self.__class__.__name__]
        if "wer" not in metrics:
            metrics["wer"] = []
        metrics["wer"].append(act)

In [None]:
class Logger(Callback):
    def on_backward_begin(self, **kwargs):
        if "last_loss" in kwargs and self.training:
            print(f"last_loss: {kwargs['last_loss']}")
                  
    def on_batch_end(self, **kwargs):
        return {'stop_epoch': True}
                  
    def on_epoch_end(self, **kwargs):
        pass#print(kwargs["metrics"])

In [None]:
# train the model
fit(
    model, 
    epochs, 
    optim, 
    train_loader=train_loader, 
    eval_loader=eval_loader,
    callbacks=[Logger()],
    metrics=[GreedyCTCWER()]
)