# API Development

This notebook contains code to run a model using the current API. It exists as a playground for developing the API.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import pathlib
import typing

import torch
from google.protobuf import text_format

from myrtlespeech.model.speech_to_text import SpeechToText
from myrtlespeech.run.callbacks.csv_logger import CSVLogger
from myrtlespeech.run.callbacks.callback import Callback, ModelCallback
from myrtlespeech.run.callbacks.report_mean_batch_loss import ReportMeanBatchLoss
from myrtlespeech.run.callbacks.stop_epoch_after import StopEpochAfter
from myrtlespeech.run.callbacks.mixed_precision import MixedPrecision
from myrtlespeech.post_process.utils import levenshtein
from myrtlespeech.post_process.ctc_greedy_decoder import CTCGreedyDecoder
from myrtlespeech.post_process.ctc_beam_decoder import CTCBeamDecoder
from myrtlespeech.builders.task_config import build
from myrtlespeech.run.train import fit
from myrtlespeech.protos import task_config_pb2
from myrtlespeech.run.stage import Stage

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
torch.backends.cudnn.benchmark = True

In [None]:
# parse example config file
with open("../src/myrtlespeech/configs/deep_speech_2_en.proto") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

In [None]:
from typing import Optional
from typing import Tuple

from myrtlespeech.model.encoder_decoder.encoder.encoder import conv_to_rnn_size


class DeepSpeech2(torch.nn.Module):
    """
    CNN, RNN, Lookahead, FullyConnected must accept seq_lens argument
    
        input: -> [batch, channels, features, max_in_seq_len] 
        
    cnn: -> [batch, channels, out_features, max_out_seq_len]
    
        reshape: -> [seq_len, batch, channels*out_features]
    
    rnn: -> [seq_len, batch, out_features]
    
        reshape: -> [batch, out_features, seq_len)
    
    lookahead: -> [batch, features, seq_len]
    
        reshape: -> [batch, seq_len, features]
        
    fully_connected: [batch, seq_len, out_features]
    
        reshape: -> [seq_len, batch, out_features]
    """
    def __init__(
        self, 
        cnn: Optional[torch.nn.Module],
        rnn: torch.nn.Module,
        lookahead: Optional[torch.nn.Conv1d],
        fully_connected: torch.nn.Module
    ):
        super().__init__()
        
        self.cnn = cnn
        self.rnn = rnn
        self.lookahead = lookahead
        self.fully_connected = fully_connected
        
        self.use_cuda = torch.cuda.is_available()
        
        if self.use_cuda:
            # TODO: self.cuda()?
            if self.cnn is not None:
                self.cnn = self.cnn.cuda()
            self.rnn = self.rnn.cuda()
            if self.lookahead:
                self.lookahead = self.lookahead.cuda()
            self.fully_connected = self.fully_connected.cuda()
            
    def rnn_to_lookahead_size(self, x: torch.Tensor) -> torch.Tensor:
        seq_len, batch, features = x.size()
        return x.transpose(0, 1).transpose(1, 2)
        
    def forward(
        self, 
        x: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        """
        h = x
        
        if self.use_cuda:
            h = (h[0].cuda(), h[1].cuda())

        if self.cnn is not None:
            h = self.cnn(h)
            
        h = (conv_to_rnn_size(h[0]), h[1])
        
        h = self.rnn(h)
        
        if self.lookahead is not None:
            h = (self.rnn_to_lookahead_size(h[0]), h[1])
            h = self.lookahead(h)
            h = (h[0].transpose(1, 2), h[1])
        else:
            h = (h[0].transpose(0, 1), h[1])
            
        h = self.fully_connected(h)
        
        h = (h[0].transpose(0, 1), h[1])
        
        return h        

In [None]:
def pad_same(in_dim, ks, stride, dilation=1):
    """
    Refernces:
          https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/common_shape_fns.h
          https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/common_shape_fns.cc#L21
    """
    assert stride > 0
    assert dilation >= 1
    effective_ks = (ks - 1) * dilation + 1
    out_dim = (in_dim + stride - 1) // stride
    p = max(0, (out_dim - 1) * stride + effective_ks - in_dim)

    padding_before = p // 2
    padding_after = p - padding_before
    return padding_before, padding_after

In [None]:
import math
from typing import Tuple

from myrtlespeech.builders.fully_connected import build as build_fully_connected
from myrtlespeech.builders.rnn import build as build_rnn
from myrtlespeech.model.seq_len_wrapper import SeqLenWrapper
from myrtlespeech.model.utils import Lambda
from myrtlespeech.protos import conv_layer_pb2
from myrtlespeech.protos import deep_speech_2_pb2

def build_ds2(
    ds2_cfg: deep_speech_2_pb2.DeepSpeech2, 
    input_features: int, 
    output_features: int,
    input_channels: int = 1
) -> Tuple[torch.nn.Module, int]:
    """
    
    
    """
    cnn, cnn_out_features = _build_cnn(
        ds2_cfg.conv_layer, 
        input_features, 
        input_channels
    )
    
    rnn = build_rnn(
        ds2_cfg.rnn,
        input_features=cnn_out_features
    )
    
    rnn_out_features = rnn.rnn.hidden_size 
    if rnn.rnn.bidirectional:
        rnn_out_features *= 2
    
    lookahead = _build_lookahead(
        ds2_cfg.lookahead,
        input_features=rnn_out_features
    )
    
    fully_connected = build_fully_connected(
        ds2_cfg.fully_connected,
        input_features=rnn_out_features,
        output_features=output_features
    )
    
    return DeepSpeech2(cnn, rnn, lookahead, fully_connected)



def _build_lookahead(
    lookahead_layer_config,
    input_features: int
):
    """[batch, features, seq_len]."""
    in_channels = input_features
    out_channels = input_features
    kernel_size = int(list(lookahead_layer_config.kernel_dim)[0])
    
    layers = []
    
    def foo(x):
        return torch.nn.functional.pad(
            x, (0, kernel_size - 1)
        )
        
    layers.append(Lambda(foo))
    
    layers.append(
        torch.nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            groups=input_features,
            bias=False
        )
    )
    
    return SeqLenWrapper(torch.nn.Sequential(*layers), seq_lens_fn=lambda x: x)


def _build_cnn(
    conv_layer_configs,
    input_features: int, 
    input_channels: int
):
    layers = []
    
    size_scalar = 1
    
    output_features = input_features
    
    for conv_layer in conv_layer_configs:
        n_dims = len(conv_layer.kernel_dim)
        n_stride = len(conv_layer.stride)
        
        if n_stride != 0 and n_dims != n_stride:
            raise ValueError("must be same number of kernel_dim and stride entries")
        
        if n_dims == 1:
            kernel_size = [input_features, conv_layer.kernel_dim[0]]
            stride = [0, conv_layer.stride[0]]
        elif n_dims == 2:
            kernel_size = conv_layer.kernel_dim
            stride = list(conv_layer.stride)
        else:
            raise ValueError("only Conv1d and Conv2d supported")
            
        if conv_layer.padding != conv_layer_pb2.ConvLayer.PADDING.VALID:
            raise NotImplementedError(f"padding {conv_layer.padding} not supported")
            
        def foo(kernel_size, stride):
            def _foo(x):
                pad_len = pad_same(x.size(3), kernel_size[1], stride[1])
                pad_freq = pad_same(x.size(2), kernel_size[0], stride[0])
                x = torch.nn.functional.pad(x, pad_len + pad_freq)
                return x
            return _foo
        
        layers.append(Lambda(foo(kernel_size, stride)))
        
        def bar(stride):
            return lambda x: torch.ceil(x.float() / stride).int()
            
        layers.append(
            torch.nn.Conv2d(
                in_channels=input_channels,
                out_channels=conv_layer.output_channels,
                kernel_size=kernel_size,
                stride=stride,
                bias=True
            )
        )
        input_channels = conv_layer.output_channels
        
        size_scalar *= stride[1]
        
        output_features = math.ceil(output_features / float(stride[0]))
        
    return (SeqLenWrapper(
        torch.nn.Sequential(*layers), 
        seq_lens_fn=lambda x: torch.ceil(x.float() / size_scalar).long())
        , int(input_channels*output_features))

In [None]:
task_config

In [None]:
ds2 = build_ds2(
    ds2_cfg=task_config.speech_to_text.deepspeech_2,
    input_features=83,
    output_features=len(list("_abcdefghijklmnopqrstuvwxyz '")),
    input_channels=1
) 

In [None]:
ds2

In [None]:
ds2.cnn((torch.empty([5, 1, 83, 100]).normal_().cuda(), torch.tensor([15, 20, 50, 75, 99])))[1]

In [None]:
ds2(
    (torch.empty([5, 1, 83, 100]).normal_().cuda(), torch.tensor([10, 15, 25, 77, 99]))
)[0].size()

In [None]:
# parse example config file
with open("../src/myrtlespeech/configs/deep_speech_2_en.proto") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

In [None]:
# create all components for config
seq_to_seq, epochs, train_loader, eval_loader = build(task_config)

In [None]:
len(iter(train_loader))

In [None]:
seq_to_seq.model = ds2.cuda()

In [None]:
seq_to_seq.optim = torch.optim.Adam(
    params = ds2.parameters(),
    lr=0.001,
    #momentum=0.9,
    #weight_decay=0.0005
)

In [None]:
from typing import List

class WordSegmentor:
    def __init__(self, separator: str):
        self.separator = separator
        
    def __call__(self, sentence: List[str]) -> List[str]:
        new_sentence = []
        word = []
        for symb in sentence:
            if symb == self.separator:
                if word:
                    new_sentence.append("".join(word))
                    word = []
            else:
                word.append(symb)
        if word:
            new_sentence.append("".join(word))
        return new_sentence

In [None]:
ctc_greedy = CTCGreedyDecoder(blank_index=0)
ctc_beam = CTCBeamDecoder(blank_index=0, beam_width=12)

class ReportCTCDecoder(Callback):
    """TODO
    
    Args:
        ctc_decoder: decodes output to sequence of indices based on CTC
        
        alphabet: converts sequences of indices to sequences of symbols (strs)
        
        word_segmentor: groups sequences of symbols into sequences of words
    """
    def __init__(self, ctc_decoder, alphabet, word_segmentor):
        self.ctc_decoder = ctc_decoder
        self.alphabet = alphabet
        self.word_segmentor = word_segmentor
        
    def _reset(self, **kwargs):
        kwargs["reports"][self.ctc_decoder.__class__.__name__] = {
            "wer": -1.0,
            "transcripts": []
        }
        self.distances = []
        self.lengths = []
        
    def on_train_begin(self, **kwargs):
        self._reset(**kwargs)
        
    def on_epoch_begin(self, **kwargs):
        self._reset(**kwargs)
        
    def _process(self, sentence: List[int]) -> List[str]:
        symbols = self.alphabet.get_symbols(sentence)
        return self.word_segmentor(symbols)
        
    def on_batch_end(self, **kwargs):
        if self.training:
            return
        transcripts = kwargs["reports"][self.ctc_decoder.__class__.__name__]["transcripts"]
        
        targets = kwargs["last_target"][0]
        target_lens = kwargs["last_target"][1]

        acts = self.ctc_decoder(*kwargs["last_output"])
        for act, target, target_len in zip(acts, targets, target_lens):
            act = self._process(act)
            exp = self._process([int(e) for e in target[:target_len]])
            
            transcripts.append((act, exp))
            
            distance = levenshtein(act, exp)
            self.distances.append(distance)
            self.lengths.append(len(exp))
              
    def on_epoch_end(self, **kwargs):
        if self.training:
            return
        wer = float(sum(self.distances)) / sum(self.lengths) * 100
        kwargs["reports"][self.ctc_decoder.__class__.__name__]["wer"] = wer

In [None]:
class Foo(Callback):
    def on_epoch_end(self, **kwargs):
        from IPython.display import clear_output
        clear_output()
        for act, exp in kwargs["reports"]["CTCGreedyDecoder"]["transcripts"]:
            print(act, exp)
        print('\n\n\n')

In [None]:
# standardize = seq_to_seq.pre_process_steps[2][0]
#
# for x in train_loader.dataset:
#     seq_to_seq.pre_process(x[0][0])
#
# standardize.training = False

In [None]:
import time

from torch.utils.tensorboard import SummaryWriter

class TensorBoardLogger(ModelCallback):
    def __init__(self, model, histograms=False):
        super().__init__(model)
        self.writer = SummaryWriter(
            log_dir=f'/tmp/writer/{time.time()}',
        )
        self.histograms = histograms
        
    def on_backward_begin(self, **kwargs):
        if not self.training:
            return
        stage = "train" if self.training else "eval"
        self.writer.add_scalar(
            f"{stage}/loss", 
            kwargs["last_loss"].item(),
            global_step=kwargs["total_train_batches"]
        )
        
    def on_step_end(self, **kwargs):
        if not self.training or not self.histograms:
            return
        for name, param in self.model.named_parameters():
            if param.grad is None:
                continue
            self.writer.add_histogram(
                name.replace(".", "/") + "/grad", 
                param.grad,
                global_step=kwargs["total_train_batches"]
            )
        
    def on_batch_end(self, **kwargs):
        if not self.training or not self.histograms:
            return
        for name, param in self.model.named_parameters():
            self.writer.add_histogram(
                name.replace(".", "/"), 
                param,
                global_step=kwargs["total_train_batches"]
            )
        
    def on_train_end(self, **kwargs):
        self.writer.close()

In [None]:
# train the model
fit(
    seq_to_seq, 
    1000,#epochs, 
    train_loader=train_loader, 
    eval_loader=eval_loader,
    callbacks=[
        ReportMeanBatchLoss(),
        ReportCTCDecoder(
            ctc_greedy, 
            seq_to_seq.alphabet,
            WordSegmentor(" "),
        ),
        TensorBoardLogger(seq_to_seq.model, histograms=False),
        MixedPrecision(seq_to_seq, opt_level="O1"),
        #StopEpochAfter(epoch_batches=30),
        CSVLogger("/tmp/foo.csv", 
            exclude=[
                "epochs", 
                #"reports/CTCGreedyDecoder/transcripts",
            ]
        )
    ],
)