In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1" # set this before importing torch
import torch
import onnx


from myrtlespeech.protos import task_config_pb2
from google.protobuf import text_format
from myrtlespeech.builders.task_config import build
from myrtlespeech.run.run import Saver
from myrtlespeech.model.fully_connected import FullyConnected
from myrtlespeech.model.rnn import RNN

from myrtlespeech.run.train import fit
from myrtlespeech.run.load import load_seq_to_seq
from pathlib import Path

import torch
import onnx

import onnxruntime as ort
import numpy as np

In [None]:
rnnt_log_dir = '/home/julian/exp/rnnt/wer_down/2L/2/'
log_dir = '/home/julian/exp/onnx/tmp/'

## Create test to check identity

In [None]:

def export_and_check(model, args, fname, input_names, output_names, example_outputs=None, 
                     dynamic_axes=None, verbose=False, opset_version=9):
    fp = Path(log_dir) / fname
    model.eval()
    
    # run model in torch to get expected outputs
    exp_outputs = model(*args)
    
    # export onnx model
    torch.onnx.export(model, args, fp, export_params=True, verbose=False,  example_outputs=None, 
                      dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names, 
                      opset_version=opset_version)
    
    
    # Load the ONNX model
    model_onnx = onnx.load(fp)
    
    # Print a human readable representation of the graph
    if verbose:
        print("Printing graph...")
        print(onnx.helper.printable_graph(model_onnx.graph))
    
    # Check that the IR is well formed
    onnx.checker.check_model(model_onnx)
    
    # onnx runtime
    ort_session = ort.InferenceSession(str(fp))
    
    # convert input args to numpy
    args = [x.numpy() if isinstance(x, torch.Tensor) else x for x in args]
    
    outputs = ort_session.run(output_names, {k: args[idx] for idx, k in enumerate(input_names)})
    
    check_outputs_as_expected(outputs, exp_outputs)
    
    print('model correct!')

def check_outputs_as_expected(outputs, exp_outputs):
    if isinstance(exp_outputs, torch.Tensor):
        assert torch.allclose(torch.tensor(outputs), exp_outputs.cpu())
    elif isinstance(exp_outputs, tuple) and isinstance(outputs, (tuple, list)):
        assert len(exp_outputs) == len(outputs), f"{len(exp_outputs)} != {len(outputs)}"
        for idx, x in enumerate(outputs):
            check_outputs_as_expected(x, exp_outputs[idx])
    else:
        raise ValueError(f'Unexpected output type(outputs)={type(outputs)} '
                         f'with type(exp_outputs)={type(exp_outputs)} ')
        

In [None]:
# create wrapper to unwrap tuple args
class CollapseTupleArgs(torch.nn.Module):
    def __init__(self, submodel):
        super().__init__()
        self.submodel = submodel
    def forward(self, *args):
        return self.submodel(args)

class FlattenTupleArgs(torch.nn.Module):
    """Flatten Tuple Args before returning."""
    def __init__(self, submodel):
        super().__init__()
        self.submodel = submodel
    def forward(self, *args):
        res = self.submodel(*args)
        ret = []
        if isinstance(res, tuple):
            for x in res:
                if isinstance(x, tuple):
                    for y in x:
                        ret.append(y)
                else:
                    ret.append(x)
        return tuple(ret)

## Linear

In [None]:
export_and_check(model=torch.nn.Linear(2, 3), 
                 args = (torch.randn(5, 2),),
                 fname = 'linear.onnx',
                 input_names = ['input'],
                 output_names = ['output'],
                 dynamic_axes = {'input': {0: 'batch'}, 'output': {0: 'batch'}},
    )

## RNN

In [None]:
# input: (seq_len, batch, input_size)
seq_len = 2
batch = 1
input_size = 2
num_layers = 1
export_and_check(model=torch.nn.RNN(input_size, 3, num_layers), 
                 args = (torch.randn(seq_len, batch, input_size),),
                 fname = 'rnn.onnx',
                 input_names = ['input'],
                 output_names = ['output_1', 'output_2'],
                 dynamic_axes = {'input': {0: 'seq_len', 1: 'batch'}, 
                                 'output_1': {0: 'seq_len', 1: 'batch'},  
                                 'output_2': {1: 'batch'}},
    )

## LSTM

In [None]:
# LSTM
seq_len = 2
batch = 1
input_size = 2
num_layers = 1
lstm = FlattenTupleArgs(torch.nn.LSTM(input_size, 3, num_layers))
export_and_check(model=lstm, 
                 args = (torch.randn(seq_len, batch, input_size),),
                 fname = 'lstm.onnx',
                 input_names = ['input'],
                 output_names = ['output_1', 'output_2', 'output_3'],
                 dynamic_axes = {'input': {0: 'seq_len', 1: 'batch'}, 
                                 'output_1': {0: 'seq_len', 1: 'batch'},  
                                 'output_2': {1: 'batch'},
                                 'output_3': {1: 'batch'}},
    )

## myrtlespeech submodules

In [None]:
# test myrtlespeech fully_connected
# args: ([batch, seq_len, in_features], (batch,))

seq_len = 2
batch = 1
in_features = 5
fc = CollapseTupleArgs(FullyConnected(in_features, out_features=3, 
                                    num_hidden_layers=2,  hidden_size = 2, hidden_activation_fn=torch.nn.ReLU()),)

inp = torch.randn(batch, seq_len, in_features)
in_lens = torch.randint(low=1, high=seq_len, size=(batch,))
args = (inp, in_lens)

export_and_check(model=fc, 
                 args = args,
                 fname = 'fc_myrtlespeech.onnx',
                 input_names = ['input', 'in_lens'],
                 output_names = ['output', 'out_lens'],
                 dynamic_axes = {'input': {0: 'batch', 1: 'seq_len'}, 
                                 'in_lens': {0: 'batch'},
                                 'output': {0: 'batch', 1: 'seq_len'}, 
                                 'out_lens': {0: 'batch'}},                        
    )

In [None]:
from enum import IntEnum
from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import Union

import torch


class RNNType(IntEnum):
    LSTM = 0
    GRU = 1
    BASIC_RNN = 2


RNNState = TypeVar("RNNState", torch.Tensor, Tuple[torch.Tensor, torch.Tensor])

RNNLengths = TypeVar("RNNLengths", bound=torch.Tensor)

RNNInput = Union[torch.Tensor, Tuple[torch.Tensor, Optional[RNNState]]]


class RNN(torch.nn.Module):
    """A recurrent neural network.

    See :py:class:`torch.nn.LSTM`, :py:class:`torch.nn.GRU` and
    :py:class:`torch.nn.RNN` for more information as these are used internally
    (see Attributes).

    This wrapper ensures the sequence length information is correctly used by
    the RNN (i.e. using :py:func:`torch.nn.utils.rnn.pad_packed_sequence` and
    :py:func:`torch.nn.utils.rnn.pad_packed_sequence`).

    Args:
        rnn_type: The type of recurrent neural network cell to use. See
            :py:class:`RNNType` for a list of the supported types.

        input_size: The number of features in the input.

        hidden_size: The number of features in the hidden state.

        num_layers: The number of recurrent layers.

        bias: If :py:data:`False`, then the layer does not use the bias weights
            ``b_ih`` and ``b_hh``.

        dropout: If non-zero, introduces a dropout layer on the
            outputs of each LSTM layer except the last layer,
            with dropout probability equal to ``dropout``.

        bidirectional: If :py:data:`True`, becomes a bidirectional LSTM.

        forget_gate_bias: If ``rnn_type == RNNType.LSTM`` and ``bias = True``
            then the sum of forget gate bias after initialisation equals this
            value if it is not :py:data:`None`. If it is :py:data:`None` then
            the default initialisation is used.

            See `Jozefowicz et al., 2015
            <http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf>`_.

        batch_first: If :py:data:`True`, then the input and output tensors are
            provided as ``[batch, seq_len, in_features]``.

    Attributes:
        rnn: A :py:class:`torch.LSTM`, :py:class:`torch.GRU`, or
            :py:class:`torch.RNN` instance.
    """

    def __init__(
        self,
        rnn_type: RNNType,
        input_size: int,
        hidden_size: int,
        num_layers: int = 1,
        bias: int = True,
        dropout: float = 0.0,
        bidirectional: bool = False,
        forget_gate_bias: Optional[float] = None,
        batch_first: bool = False,
    ):
        super().__init__()
        if rnn_type == RNNType.LSTM:
            rnn_cls = torch.nn.LSTM
        elif rnn_type == RNNType.GRU:
            rnn_cls = torch.nn.GRU
        elif rnn_type == RNNType.BASIC_RNN:
            rnn_cls = torch.nn.RNN
        else:
            raise ValueError(f"unknown rnn_type {rnn_type}")

        self.batch_first = batch_first

        self.rnn = rnn_cls(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            bias=bias,
            batch_first=self.batch_first,
            dropout=dropout,
            bidirectional=bidirectional,
        )

        if rnn_type == RNNType.LSTM and bias and forget_gate_bias is not None:
            for l in range(num_layers):
                ih = getattr(self.rnn, f"bias_ih_l{l}")
                ih.data[hidden_size : 2 * hidden_size] = forget_gate_bias
                hh = getattr(self.rnn, f"bias_hh_l{l}")
                hh.data[hidden_size : 2 * hidden_size] = 0.0

        self.use_cuda = torch.cuda.is_available()
        if self.use_cuda:
            self.rnn = self.rnn.cuda()

    def forward(
        self, x: Tuple[Union[torch.Tensor, Tuple[torch.Tensor, Optional[RNNState]]], RNNLengths]
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, Optional[RNNState]]], RNNLengths]:
        r"""Returns the result of applying the rnn to ``x[0]``.

        All inputs are moved to the GPU with :py:meth:`torch.nn.Module.cuda`
        if :py:func:`torch.cuda.is_available` was :py:data:`True` on
        initialisation.

        Args:
            x: A Tuple ``(x[0], x[1])``. ``x[0]`` can take two forms: either it
                is a tuple ``x[0] = (inp, hid)`` or it is a torch tensor
                ``x[0] = inp``. ``inp`` is the network input (a
                :py:class:`torch.Tensor`) with size ``[seq_len, batch,
                in_features]``. ``hid`` is the RNN hidden state which is
                either a length 2 Tuple of :py:class:`torch.Tensor`s or
                a single :py:class:`torch.Tensor` depending on the ``RNNType``
                (see :py:class:`torch.nn` documentation for more information).

                The return type `res[0]` will be the same as the `x[0]` type so
                you should pass `hid = None` if you would like the hidden state
                returned and it is the start-of-sequence. In this case, the
                hidden state(s) will be initialised to zero in PyTorch.

                ``x[1]`` is a :py:class:`torch.Tensor` where each entry
                represents the sequence length of the corresponding network
                *input* sequence.

        Returns:
            A Tuple ``(res[0], res[1])``. ``res[0]`` will take the same form as
            ``x[0]``: either a tuple ``res[0] = (out, hid)`` or a
            :py:class:`torch.Tensor``. ``res[0] = out``. ``out`` is the
            result after applying the RNN to ``inp``. It will have size
            ``[seq_len, batch, out_features]``. ``hid`` is the
            returned RNN hidden state which is either a length 2 Tuple of
            :py:class:`torch.Tensor`s or a single :py:class:`torch.Tensor`
            depending on the ``RNNType`` (see :py:class:`torch.nn`
            documentation for more information).

            ``res[1]`` is a :py:class:`torch.Tensor` where each entry
            represents the sequence length of the corresponding network
            *output* sequence. This will be equal to ``x[1]`` as this layer
            does not change sequence length.
        """

        if isinstance(x[0], torch.Tensor):
            inp = x[0]
            hid = None
            return_tuple = False
        elif isinstance(x[0], tuple) and len(x[0]) == 2:
            inp, hid = x[0]
            return_tuple = True
        else:
            raise ValueError(
                "`x[0]` must be of form (input, hidden) or (input)."
            )

        if self.use_cuda:
            inp = inp.cuda()
            if hid is not None:
                if isinstance(hid, tuple) and len(hid) == 2:  # LSTM
                    hid = hid[0].cuda(), hid[1].cuda()
                elif isinstance(hid, torch.Tensor):  # Vanilla RNN/GRU
                    hid = hid.cuda()
                else:
                    raise ValueError(
                        "hid must be a length 2 tuple or a torch.Tensor."
                    )

        # Record sequence length to enable DataParallel
        # https://pytorch.org/docs/stable/notes/faq.html#pack-rnn-unpack-with-data-parallelism
        total_length = inp.size(0 if not self.batch_first else 1)
        inp = torch.nn.utils.rnn.pack_padded_sequence(
            input=inp,
            lengths=x[1],
            batch_first=self.batch_first,
            enforce_sorted=False,
        )

        out, hid = self.rnn(inp, hx=hid)

        out, lengths = torch.nn.utils.rnn.pad_packed_sequence(
            sequence=out,
            batch_first=self.batch_first,
            total_length=total_length,
        )

        if return_tuple:
            return (out, hid), lengths
        return out, lengths


In [None]:
# test myrtlespeech RNN
# args: ([seq_len, batch, in_features], (batch,))

seq_len = 2
batch = 2
input_size = 2
num_layers = 1
hidden_size = 3
rnn = RNN(rnn_type=2, input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)

rnn = CollapseTupleArgs(rnn)
inp = torch.randn(seq_len, batch, input_size)
in_lens = torch.randint(low=1, high=seq_len, size=(batch,))
args = (inp, in_lens)

export_and_check(model=rnn, 
                 args = args,
                 fname = 'fc_myrtlespeech.onnx',
                 input_names = ['input', 'in_lens'],
                 output_names = ['output', 'out_lens'],
                 dynamic_axes = {'input': {1: 'batch', 0: 'seq_len'}, 
                                 'in_lens': {0: 'batch'},
                                 'output': {1: 'batch', 0: 'seq_len'}, 
                                 'out_lens': {0: 'batch'}},
                 opset_version=11,
    )

## ds1

In [None]:
from myrtlespeech.builders.speech_to_text import build as build_stt

In [None]:
# parse example config file
with open("../src/myrtlespeech/configs/deep_speech_1_en.config") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

stt = build_stt(task_config.speech_to_text)

In [None]:
ds1 = CollapseTupleArgs(stt.model)
ds1

In [None]:
# inputs: Tuple: [(batch, channels, in_features, seq_len), (batch,)]
# outputs: Tuple: [(seq_len, batch, out_feat), (batch,)]
seq_len = 2
batch = 1
in_features = 26 * 19
num_layers = 1
channels = 1


# generate args

inp = torch.randn(batch, channels, in_features, seq_len)
in_lens = torch.randint(low=1, high=seq_len, size=(batch,))
args = (inp, in_lens)
export_and_check(model=ds1, 
                 args = args,
                 fname = 'ds1_1024.onnx',
                 input_names = ['input', 'in_lens'],
                 output_names = ['output', 'out_lens'],
                 dynamic_axes = {'input': {0: 'batch', 3: 'seq_len'}, 
                                 'in_lens': {0: 'batch'},
                                 'output': {0: 'batch', 3: 'seq_len'}, 
                                 'output': {0: 'seq_len', 1: 'batch'},  
                                 'out_lens': {0: 'batch'}},
    )

## rnnt

In [None]:
# parse example config file
with open("../src/myrtlespeech/configs/rnn_t_en_2L.config") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

task_config

In [None]:
log_dir = '/home/julian/exp/rnnt/wer_down/2L/2/'


In [None]:
# create all components for config
# FYI: if using train-clean-100 & dev-clean this cell takes O(60s) 
seq_to_seq, epochs, train_loader, eval_loader = build(task_config, accumulation_steps=2)
seq_to_seq

In [None]:
load_model = True
epoch = 68
training_state = {}
if load_model:
    fp = log_dir + f'state_dict_{epoch}.pt'
    #fp = '/home/julian/exp/rnnt/wer_down/2D/1/model_saved.pt'
    training_state = load_seq_to_seq(seq_to_seq, fp)
    
seq_to_seq.model.eval()