# Transducer Example Usage


This notebook provides example usage of `myrtlespeech` for Transducer training.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0" # set this before importing torch

In [None]:
import os
import pathlib

import torch

from myrtlespeech.builders.task_config import build
from myrtlespeech.run.train import fit
from myrtlespeech.protos import task_config_pb2
from google.protobuf import text_format


from myrtlespeech.run.callbacks.callback import CallbackHandler
from myrtlespeech.run.run import TensorBoardLogger, Saver
from myrtlespeech.run.callbacks.csv_logger import CSVLogger
from myrtlespeech.run.callbacks.callback import Callback, ModelCallback
from myrtlespeech.run.callbacks.clip_grad_norm import ClipGradNorm
from myrtlespeech.run.callbacks.report_mean_batch_loss import ReportMeanBatchLoss
from myrtlespeech.run.callbacks.stop_epoch_after import StopEpochAfter
from myrtlespeech.run.callbacks.mixed_precision import MixedPrecision
from myrtlespeech.run.run import ClearMemory
from myrtlespeech.run.run import ReportTransducerDecoder


In [None]:
torch.backends.cudnn.benchmark = False # since variable size inputs

In [None]:
log_dir = "/home/julian/experiments/tmp/"
ACC_STEPS = 4

Build the RNNT model defined in the config file:

In [None]:
# parse example config file
with open("../src/myrtlespeech/configs/rnn_t_en_SMALL.config") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

task_config

In [None]:
from myrtlespeech.builders.speech_to_text import build as build_stt

In [None]:
# create all components for config
# FYI: if using train-clean-100 & dev-clean this cell takes O(60s) 
seq_to_seq, epochs, train_loader, eval_loader = build(task_config, accumulation_steps=ACC_STEPS)
seq_to_seq

In [None]:
print(len(train_loader))
print(81 * ACC_STEPS)

## Maybe load model?

In [None]:
load_model = False
if load_model:
    fp = "/home/user/model/fp/model.pt"
    seq_to_seq.model.load_state_dict(torch.load(fp))

## Callbacks
* Use callbacks to inject difference into training loop. 

In [None]:
#custom callback to monitor training and print results
class PrintCB(Callback):
    def __init__(self):
        super().__init__()
    
    def on_batch_end(self, **kwargs):
        
        if self.training and kwargs["epoch_minibatches"] % 100 == 0:
            print(kwargs["epoch_minibatches"], kwargs["last_loss"].item())
            
            return
        epoch = kwargs["epoch"]
        if kwargs["epoch_minibatches"] % 100 == 0 and kwargs["epoch_minibatches"] != 0:
            print(f"{kwargs['epoch_minibatches']} batches completed")
            try:
                wer_reports = kwargs["reports"][seq_to_seq.post_process.__class__.__name__]
                wer = wer_reports["WER"]
                cer = wer_reports["CER"]
                if len(wer_reports["transcripts"]) > 0:
                    transcripts = wer_reports["transcripts"][0] #take first element
                    pred, exp = transcripts
                    pred = "".join(pred)
                    exp = "".join(exp)
                    loss = kwargs["reports"]["ReportMeanBatchLoss"]
                    print("batch end, pred: {}, exp: {}, wer: {:.4f}, cer: {:.4f}".format(pred, exp, wer, cer))

            except KeyError:
                print("no wer - using new decoder?")
        
        
            
    def on_epoch_end(self, **kwargs):
        if self.training:
            return
        epoch = kwargs["epoch"]
        
        try:
            
            loss = kwargs["reports"]["ReportMeanBatchLoss"]
            
            wer_reports = kwargs["reports"][seq_to_seq.post_process.__class__.__name__]
            wer = wer_reports["WER"]
            cer = wer_reports["CER"]
            out_str = "{}, loss: {:.8f}".format(epoch, loss)
            
            if len(wer_reports["transcripts"]) > 0:
                transcripts = wer_reports["transcripts"][0] #take first element
                pred, exp = transcripts
                pred = "".join(pred)
                exp = "".join(exp)
                
                out_str += ", wer: {:.4f}, cer: {:.4f}, pred: {}, exp: {},".format(wer, cer, pred, exp)
            print(out_str)
        except KeyError:
            
            print("no wer - using new decoder?")       

In [None]:
#mixed_precision_cb = MixedPrecision(seq_to_seq) # this can only be initialized once so place it in separate cell

In [None]:

rnnt_decoder_cb  = ReportTransducerDecoder(seq_to_seq.post_process, seq_to_seq.alphabet, eval_every=20, 
                                         skip_first_epoch=True)

keys_to_log_in_csv = [ "epoch", 
                      f"reports/{seq_to_seq.post_process.__class__.__name__}/WER",
                      f"reports/{seq_to_seq.post_process.__class__.__name__}/CER",
                       "reports/ReportMeanBatchLoss"]

callbacks = [#rnnt_decoder_cb,
            ReportMeanBatchLoss(),
             
            #Note: the following three callbacks, if present, must appear in this order (see docstrings):
            TensorBoardLogger(log_dir, seq_to_seq.model, histograms=False),
            #mixed_precision_cb,
            ClipGradNorm(seq_to_seq, 200),
            
            # stop training prematurely (useful for debug). 
            # Ensure following line is commented out to perform full training
            StopEpochAfter(epoch_minibatches=10 * ACC_STEPS),
            
            # logging
            #CSVLogger(log_dir + "log.csv", keys=keys_to_log_in_csv),
            
            # save model @ end of epoch:
            Saver(log_dir, seq_to_seq.model),
            
            # The following callback explicitly deletes quanities in callback handler state_dict 
            # This is useful during transducer training as there are substantial memory pressures
            ClearMemory(),
            PrintCB()
            ] 


In [None]:
(2.24 + 2.09 + 1.98 + 1.49) / 4


## To test that it is working
Check that `grads` is increasing for all steps in accumulation:
``grads = zeros_like(p)
for _ in accumulation:
    loss.backwards()
    grads += p.grad.cpu()
    print("post sum:", abs(grads).mean().item())
    if accumulating:
        optim.step()
        grads = zeros_like(p)
        optim.zero_grad()``
    

In [None]:
fit(
    seq_to_seq, 
    epochs=10,
    train_loader=train_loader, 
    eval_loader=None,
    callbacks=callbacks,
)



In [None]:
for n, v in seq_to_seq.named_parameters():
    print(n, v.shape, v.grad.shape)

In [None]:
seq_to_seq = build_stt(task_config.speech_to_text)
seq_to_seq

### Maybe eval?

In [None]:
lin = torch.nn.Linear(1, 2)
lin.half()

In [None]:
run_eval = True


if run_eval:
    eval_cbs = [ReportMeanBatchLoss(), 
                ReportTransducerDecoder(seq_to_seq.post_process, seq_to_seq.alphabet),
                CSVLogger(log_dir + f"log_eval.csv", keys=keys_to_log_in_csv)] 
    
    fit(
         seq_to_seq, 
         eval_loader=eval_loader,
         callbacks=eval_cbs,
    )
    

In [None]:
import math

from typing import List
from typing import Optional
from typing import Tuple

import torch
from torch.nn import Parameter


def rnn(rnn, input_size, hidden_size, num_layers, norm=None,
        forget_gate_bias=1.0, dropout=0.0, **kwargs):
    """TODO"""
    if rnn != "lstm":
        raise ValueError(f"Unknown rnn={rnn}")
    if norm not in [None, "batch_norm", "layer_norm"]:
        raise ValueError(f"unknown norm={norm}")

    if rnn == "lstm":
        if norm is None:
            return LstmDrop(
                input_size=input_size,
                hidden_size=hidden_size,
                num_layers=num_layers,
                dropout=dropout,
                forget_gate_bias=forget_gate_bias,
                **kwargs
            )

        if norm == "batch_norm":
            return BNRNNSum(
                input_size=input_size,
                hidden_size=hidden_size,
                rnn_layers=num_layers,
                batch_norm=True,
                dropout=dropout,
                forget_gate_bias=forget_gate_bias,
                **kwargs
            )

        if norm == "layer_norm":
            return torch.jit.script(lnlstm(
                input_size=input_size,
                hidden_size=hidden_size,
                num_layers=num_layers,
                dropout=dropout,
                forget_gate_bias=forget_gate_bias,
                **kwargs
            ))


class OverLastDim(torch.nn.Module):
    """Collapses a tensor to 2D, applies a module, and (re-)expands the tensor.
    An n-dimensional tensor of shape (s_1, s_2, ..., s_n) is first collapsed to
    a tensor with shape (s_1*s_2*...*s_n-1, s_n). The module is called with
    this as input producing (s_1*s_2*...*s_n-1, s_n') --- note that the final
    dimension can change. This is expanded to (s_1, s_2, ..., s_n-1, s_n') and
    returned.
    Args:
        module (torch.nn.Module): Module to apply. Must accept a 2D tensor as
            input and produce a 2D tensor as output, optionally changing the
            size of the last dimension.
    """

    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, x):
        *dims, input_size = x.size()

        reduced_dims = 1
        for dim in dims:
            reduced_dims *= dim

        x = x.view(reduced_dims, -1)
        x = self.module(x)
        x = x.view(*dims, -1)
        return x


class LstmDrop(torch.nn.Module):

    def __init__(self, input_size, hidden_size, num_layers, dropout, forget_gate_bias,
             **kwargs):
        """Returns an LSTM with forget gate bias init to `forget_gate_bias`.
        Args:
            input_size: See `torch.nn.LSTM`.
            hidden_size: See `torch.nn.LSTM`.
            num_layers: See `torch.nn.LSTM`.
            dropout: See `torch.nn.LSTM`.
            forget_gate_bias: For each layer and each direction, the total value of
                to initialise the forget gate bias to.
        Returns:
            A `torch.nn.LSTM`.
        """
        super(LstmDrop, self).__init__()

        self.lstm = torch.nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout,
        )
        if forget_gate_bias is not None:
            for name, v in self.lstm.named_parameters():
                if "bias_ih" in name:
                    bias = getattr(self.lstm, name)
                    bias.data[hidden_size:2*hidden_size].fill_(forget_gate_bias)
                if "bias_hh" in name:
                    bias = getattr(self.lstm, name)
                    bias.data[hidden_size:2*hidden_size].fill_(0)

        self.dropout = torch.nn.Dropout(dropout) if dropout else None

    def forward(self, x, h=None):

        x, h = self.lstm(x, h)

        if self.dropout:
            x = self.dropout(x)

        return x, h



class RNNLayer(torch.nn.Module):
    """A single RNNLayer with optional batch norm."""
    def __init__(self, input_size, hidden_size, rnn_type=torch.nn.LSTM,
                 batch_norm=True, forget_gate_bias=1.0):
        super().__init__()

        if batch_norm:
            self.bn = OverLastDim(torch.nn.BatchNorm1d(input_size))

        if isinstance(rnn_type, torch.nn.LSTM) and not batch_norm:
            # batch_norm will apply bias, no need to add a second to LSTM
            self.rnn = lstm(input_size=input_size,
                            hidden_size=hidden_size,
                            forget_gate_bias=forget_gate_bias)
        else:
            self.rnn = rnn_type(input_size=input_size,
                                hidden_size=hidden_size,
                                bias=not batch_norm)

    def forward(self, x, hx=None):
        if hasattr(self, 'bn'):
            x = x.contiguous()
            x = self.bn(x)
        x, h = self.rnn(x, hx=hx)
        return x, h

    def _flatten_parameters(self):
        self.rnn.flatten_parameters()


class BNRNNSum(torch.nn.Module):
    """RNN wrapper with optional batch norm.
    Instantiates an RNN. If it is an LSTM it initialises the forget gate
    bias =`lstm_gate_bias`. Optionally applies a batch normalisation layer to
    the input with the statistics computed over all time steps.  If dropout > 0
    then it is applied to all layer outputs except the last.
    """
    def __init__(self, input_size, hidden_size, rnn_type=torch.nn.LSTM,
                 rnn_layers=1, batch_norm=True, dropout=0.0,
                 forget_gate_bias=1.0, norm_first_rnn=False, **kwargs):
        super().__init__()
        self.rnn_layers = rnn_layers

        self.layers = torch.nn.ModuleList()
        for i in range(rnn_layers):
            final_layer = (rnn_layers - 1) == i

            self.layers.append(
                RNNLayer(
                    input_size,
                    hidden_size,
                    rnn_type=rnn_type,
                    batch_norm=batch_norm and (norm_first_rnn or i > 0),
                    forget_gate_bias=forget_gate_bias,
                )
            )

            if dropout > 0.0 and not final_layer:
                self.layers.append(torch.nn.Dropout(dropout))

            input_size = hidden_size

    def forward(self, x, hx=None):
        hx = self._parse_hidden_state(hx)

        hs = []
        cs = []
        rnn_idx = 0
        for layer in self.layers:
            if isinstance(layer, torch.nn.Dropout):
                x = layer(x)
            else:
                x, h_out = layer(x, hx=hx[rnn_idx])
                hs.append(h_out[0])
                cs.append(h_out[1])
                rnn_idx += 1
                del h_out

        h_0 = torch.stack(hs, dim=0)
        c_0 = torch.stack(cs, dim=0)
        return x, (h_0, c_0)

    def _parse_hidden_state(self, hx):
        """
        Dealing w. hidden state:
        Typically in pytorch: (h_0, c_0)
            h_0 = ``[num_layers * num_directions, batch, hidden_size]``
            c_0 = ``[num_layers * num_directions, batch, hidden_size]``
        """
        if hx is None:
            return [None] * self.rnn_layers
        else:
            h_0, c_0 = hx
            assert h_0.shape[0] == self.rnn_layers
            return [(h_0[i], c_0[i]) for i in range(h_0.shape[0])]

    def _flatten_parameters(self):
        for layer in self.layers:
            if isinstance(layer, (torch.nn.LSTM, torch.nn.GRU, torch.nn.RNN)):
                layer._flatten_parameters()


class StackTime(torch.nn.Module):
    def __init__(self, factor):
        super().__init__()
        self.factor = int(factor)

    def forward(self, x):
        # T, B, U
        x, x_lens = x
        seq = [x]
        for i in range(1, self.factor):
            tmp = torch.zeros_like(x)
            tmp[:-i, :, :] = x[i:, :, :]
            seq.append(tmp)
        x_lens = torch.ceil(x_lens.float() / self.factor).int()
        return torch.cat(seq, dim=2)[::self.factor, :, :], x_lens


def lnlstm(input_size, hidden_size, num_layers, dropout, forget_gate_bias,
           **kwargs):
    """Returns a ScriptModule that mimics a PyTorch native LSTM."""
    # The following are not implemented.
    assert dropout == 0.0

    return StackedLSTM(
        num_layers,
        LSTMLayer,
        first_layer_args=[
            LayerNormLSTMCell,
            input_size,
            hidden_size,
            forget_gate_bias,
        ],
        other_layer_args=[
            LayerNormLSTMCell,
            hidden_size,
            hidden_size,
            forget_gate_bias,
        ]
    )


class LSTMLayer(torch.nn.Module):
    def __init__(self, cell, *cell_args):
        super(LSTMLayer, self).__init__()
        self.cell = cell(*cell_args)

    def forward(
        self,
        input: torch.Tensor,
        state: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        inputs = input.unbind(0)
        outputs = []
        for i in range(len(inputs)):
            out, state = self.cell(inputs[i], state)
            outputs += [out]
        return torch.stack(outputs), state


class LayerNormLSTMCell(torch.nn.Module):
    def __init__(self, input_size, hidden_size, forget_gate_bias):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
        self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))

        # layernorms provide learnable biases
        self.layernorm_i = torch.nn.LayerNorm(4 * hidden_size)
        self.layernorm_h = torch.nn.LayerNorm(4 * hidden_size)
        self.layernorm_c = torch.nn.LayerNorm(hidden_size)

        self.reset_parameters()

        self.layernorm_i.bias.data[hidden_size:2*hidden_size].fill_(0.0)
        self.layernorm_h.bias.data[hidden_size:2*hidden_size].fill_(
            forget_gate_bias
        )

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            torch.nn.init.uniform_(weight, -stdv, stdv)

    def forward(
        self,
        input: torch.Tensor,
        state: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        hx, cx = state
        igates = self.layernorm_i(torch.mm(input, self.weight_ih.t()))
        hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t()))
        gates = igates + hgates
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

        ingate = torch.sigmoid(ingate)
        forgetgate = torch.sigmoid(forgetgate)
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)

        cy = self.layernorm_c((forgetgate * cx) + (ingate * cellgate))
        hy = outgate * torch.tanh(cy)

        return hy, (hy, cy)


def init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args):
    layers = [layer(*first_layer_args)] + [layer(*other_layer_args)
                                           for _ in range(num_layers - 1)]
    return torch.nn.ModuleList(layers)


class StackedLSTM(torch.nn.Module):
    def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
        super(StackedLSTM, self).__init__()
        self.layers: Final[torch.nn.ModuleList] = init_stacked_lstm(
            num_layers, layer, first_layer_args, other_layer_args
        )

    def forward(
        self,
        input: torch.Tensor,
        states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]=None
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        if states is None:
            states: List[Tuple[torch.Tensor, torch.Tensor]] = []
            batch = input.size(1)
            for layer in self.layers:
                states.append(
                    (torch.zeros(
                        batch,
                        layer.cell.hidden_size,
                        dtype=input.dtype,
                        device=input.device
                     ),
                     torch.zeros(
                         batch,
                         layer.cell.hidden_size,
                         dtype=input.dtype,
                         device=input.device
                     )
                    )
                )

        output_states: List[Tuple[Tensor, Tensor]] = []
        output = input
        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
        i = 0
        for rnn_layer in self.layers:
            state = states[i]
            output, out_state = rnn_layer(output, state)
            output_states += [out_state]
            i += 1
        return output, output_states



### Maybe save model

In [None]:
save_model = False
fp_out = log_dir + "model_saved.pt"
if save_model:
    torch.save(seq_to_seq.model.state_dict(), fp_out)