Skip to content

Commit

Permalink
MLCE-604 Add Unidirectional Sequence Lstm support to TFLite
Browse files Browse the repository at this point in the history
 * Added Unidirectional Sequence Lstm support to TFLite Parser
 * Added support for float operations with int8 weights to TFLite Parser
    * Added to Conv2d, Conv3D, DepthwiseConv2D, FullyConnected,
      TransposeConv and UnidirectionalSequenceLstm
 * Renamed subgraphIndex to subgraph to fix name-shadowing warning.

Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I818976ab88abc05dcb4bad246fb4108e6e879283
  • Loading branch information
MikeJKelly committed Apr 22, 2022
1 parent 4dae579 commit 5880b91
Show file tree
Hide file tree
Showing 5 changed files with 580 additions and 57 deletions.
38 changes: 31 additions & 7 deletions include/armnn/Descriptors.hpp
Expand Up @@ -1086,17 +1086,29 @@ struct LstmDescriptor : BaseDescriptor
, m_ProjectionEnabled(false)
, m_LayerNormEnabled(false)
, m_TimeMajor(false)
, m_InputIntermediateScale(0.0)
, m_ForgetIntermediateScale(0.0)
, m_CellIntermediateScale(0.0)
, m_OutputIntermediateScale(0.0)
, m_HiddenStateZeroPoint(0)
, m_HiddenStateScale(0.0)
{}

bool operator ==(const LstmDescriptor& rhs) const
{
return m_ActivationFunc == rhs.m_ActivationFunc &&
m_ClippingThresCell == rhs.m_ClippingThresCell &&
m_ClippingThresProj == rhs.m_ClippingThresProj &&
m_CifgEnabled == rhs.m_CifgEnabled &&
m_PeepholeEnabled == rhs.m_PeepholeEnabled &&
m_LayerNormEnabled == rhs.m_LayerNormEnabled &&
m_TimeMajor == rhs.m_TimeMajor;
return m_ActivationFunc == rhs.m_ActivationFunc &&
m_ClippingThresCell == rhs.m_ClippingThresCell &&
m_ClippingThresProj == rhs.m_ClippingThresProj &&
m_CifgEnabled == rhs.m_CifgEnabled &&
m_PeepholeEnabled == rhs.m_PeepholeEnabled &&
m_LayerNormEnabled == rhs.m_LayerNormEnabled &&
m_TimeMajor == rhs.m_TimeMajor &&
m_InputIntermediateScale == rhs.m_InputIntermediateScale &&
m_ForgetIntermediateScale == rhs.m_ForgetIntermediateScale &&
m_CellIntermediateScale == rhs.m_CellIntermediateScale &&
m_OutputIntermediateScale == rhs.m_OutputIntermediateScale &&
m_HiddenStateZeroPoint == rhs.m_HiddenStateZeroPoint &&
m_HiddenStateScale == rhs.m_HiddenStateScale;
}

/// @brief The activation function to use.
Expand All @@ -1116,6 +1128,18 @@ struct LstmDescriptor : BaseDescriptor
bool m_LayerNormEnabled;
/// Enable/disable time major
bool m_TimeMajor;
/// Input intermediate quantization scale
float m_InputIntermediateScale;
/// Forget intermediate quantization scale
float m_ForgetIntermediateScale;
/// Cell intermediate quantization scale
float m_CellIntermediateScale;
/// Output intermediate quantization scale
float m_OutputIntermediateScale;
/// Hidden State zero point
int32_t m_HiddenStateZeroPoint;
/// Hidden State quantization scale
float m_HiddenStateScale;
};

using UnidirectionalSequenceLstmDescriptor = LstmDescriptor;
Expand Down

0 comments on commit 5880b91

Please sign in to comment.