In [None]:
#| default_exp lstm.shim

# singleline_models.lstm

Helpers for creating LSTM layers, with and without Layer Normalization and Recurrent Dropout.

In [None]:
#| export
from singleline_models.lstm.rnnlib import *
from singleline_models.lstm.custom_lstm import *

In [None]:
#| export
def lstm_layer(ni,
               nh,
               bidirectional=False,
               use_recurrent_dropout=False,
               r_dropout_prob=0.0,
               use_layer_norm=False,
               layer_norm_learnable=False,
               lstm_impl="builtin"):
    """
    Creates an LSTM layer, using a different underlying LSTM implementation.
    - 'builtin' uses pytorch `nn.LSTM`
    - 'rnnlib' uses LayerNormLSTM from `rnnlib`
    """
    if lstm_impl == 'builtin':
        assert not use_recurrent_dropout
        assert not use_layer_norm
        assert not layer_norm_learnable
        return nn.LSTM(ni, nh, bidirectional=bidirectional)
    elif lstm_impl == 'rnnlib':
        r_dropout = r_dropout_prob if use_recurrent_dropout else 0
        if use_layer_norm:
            return LayerNormLSTM(ni, nh, num_layers=1, bidirectional=bidirectional, layer_norm_enabled=True, r_dropout=r_dropout)
        rnn_cells = [[LSTMCell(ni, nh), LSTMCell(ni, nh)]] if bidirectional else [LSTMCell(ni, nh)]
        return LSTMFrame(rnn_cells, bidirectional=bidirectional, batch_first=False)
    else:
        raise NotImplementedError()


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()