Skip to content

Commit

Permalink
Implement GELU as function op (onnx#5277)
Browse files Browse the repository at this point in the history
These changes have been made to support the GELU operator as a function
op.

Support for [GELU: Gaussian Error Linear
Unit](https://paperswithcode.com/method/gelu) activation function, which
was requested in onnx#4933.
Welcome`.

As per the discussion in onnx#4933, I have added GELU as a context-dependent
function-op, that uses the attribute `approximate` to return one of the
two possible function-body definitions.

The first function definition is the regular GELU:
`GELU(x)=x∗Φ(x) = 0.5 * x * (1 + erf(x / sqrt(2)))`

The second is the fast approximation based on `tanh`:
`GELU(x)=0.5 ∗ x ∗ (1+Tanh( sqrt(2/π) ∗ (x + 0.044715 ∗ x^3)))`

This implementation uses the [PyTorch docs for
GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html?highlight=gelu#torch.nn.GELU)
as a reference.

PS: I also refactored `onnx/defs/math/defs.cc` to bring the operator
implementation of `mish` right next to its doc string.

---------

Signed-off-by: pranshupant <pranshupant@gmail.com>
Co-authored-by: G. Ramalingam <grama@microsoft.com>
Signed-off-by: Aditya Goel <agoel4512@gmail.com>
  • Loading branch information
2 people authored and adityagoel4512 committed Jul 28, 2023
1 parent d8333a1 commit 1207f1b
Show file tree
Hide file tree
Showing 32 changed files with 354 additions and 13 deletions.
42 changes: 42 additions & 0 deletions docs/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -23919,6 +23919,48 @@ This version of the operator has been available since version 20 of the default
<dd>Constrain output types to be numerics.</dd>
</dl>

### <a name="Gelu-20"></a>**Gelu-20**</a>

Gelu takes one input data (Tensor<T>) and produces one
output data (Tensor<T>) where the gaussian error linear units function,
$y = 0.5 * x * (1 + erf(x/sqrt(2)))$ is applied to the tensor elementwise.
If the attribute "approximate" is set to "tanh", the function estimation,
$y = 0.5 * x * (1 + Tanh(sqrt(2/\pi) * (x + 0.044715 * x^3)))$ is used and applied
to the tensor elementwise.


#### Version

This version of the operator has been available since version 20 of the default ONNX operator set.

#### Attributes

<dl>
<dt><tt>approximate</tt> : string (default is none)</dt>
<dd>Gelu approximation algorithm: `"tanh"`, `"none"`(default).`"none"`: do not use approximation.`"tanh"`: use tanh approximation.</dd>
</dl>

#### Inputs

<dl>
<dt><tt>X</tt> (differentiable) : T</dt>
<dd>Input tensor</dd>
</dl>

#### Outputs

<dl>
<dt><tt>Y</tt> (differentiable) : T</dt>
<dd>Output tensor</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
</dl>

### <a name="GridSample-20"></a>**GridSample-20**</a>

Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from the `grid`.
Expand Down
96 changes: 96 additions & 0 deletions docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ For an operator input/output's differentiability, it can be differentiable,
|<a href="#Clip">Clip</a>|<a href="Changelog.md#Clip-13">13</a>, <a href="Changelog.md#Clip-12">12</a>, <a href="Changelog.md#Clip-11">11</a>, <a href="Changelog.md#Clip-6">6</a>, <a href="Changelog.md#Clip-1">1</a>|13|
|<a href="#DynamicQuantizeLinear">DynamicQuantizeLinear</a>|<a href="Changelog.md#DynamicQuantizeLinear-11">11</a>|11|
|<a href="#Elu">Elu</a>|<a href="Changelog.md#Elu-6">6</a>, <a href="Changelog.md#Elu-1">1</a>|18|
|<a href="#Gelu">Gelu</a>|<a href="Changelog.md#Gelu-20">20</a>|20|
|<a href="#GreaterOrEqual">GreaterOrEqual</a>|<a href="Changelog.md#GreaterOrEqual-16">16</a>, <a href="Changelog.md#GreaterOrEqual-12">12</a>|16|
|<a href="#GroupNormalization">GroupNormalization</a>|<a href="Changelog.md#GroupNormalization-18">18</a>|18|
|<a href="#HammingWindow">HammingWindow</a>|<a href="Changelog.md#HammingWindow-17">17</a>|17|
Expand Down Expand Up @@ -9411,6 +9412,101 @@ expect(
</details>


### <a name="Gelu"></a><a name="gelu">**Gelu**</a>

Gelu takes one input data (Tensor<T>) and produces one
output data (Tensor<T>) where the gaussian error linear units function,
$y = 0.5 * x * (1 + erf(x/sqrt(2)))$ is applied to the tensor elementwise.
If the attribute "approximate" is set to "tanh", the function estimation,
$y = 0.5 * x * (1 + Tanh(sqrt(2/\pi) * (x + 0.044715 * x^3)))$ is used and applied
to the tensor elementwise.


#### Version

This version of the operator has been available since version 20 of the default ONNX operator set.

#### Attributes

<dl>
<dt><tt>approximate</tt> : string (default is none)</dt>
<dd>Gelu approximation algorithm: `"tanh"`, `"none"`(default).`"none"`: do not use approximation.`"tanh"`: use tanh approximation.</dd>
</dl>

#### Inputs

<dl>
<dt><tt>X</tt> (differentiable) : T</dt>
<dd>Input tensor</dd>
</dl>

#### Outputs

<dl>
<dt><tt>Y</tt> (differentiable) : T</dt>
<dd>Output tensor</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
</dl>


#### Examples

<details>
<summary>gelu_default</summary>

```python
node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"])

x = np.array([-1, 0, 1]).astype(np.float32)
# expected output [-0.15865526, 0., 0.84134474]
y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1")

x = np.random.randn(3, 4, 5).astype(np.float32)
# expected output [2.99595031, 3.99987331, 4.99999857]
y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2")
```

</details>


<details>
<summary>gelu_tanh</summary>

```python
node = onnx.helper.make_node(
"Gelu", inputs=["x"], outputs=["y"], approximate="tanh"
)

x = np.array([-1, 0, 1]).astype(np.float32)
# expected output [-0.158808, 0., 0.841192]
y = (
0.5
* x
* (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1")

x = np.random.randn(3, 4, 5).astype(np.float32)
# expected output [2.9963627, 3.99993, 4.9999995]
y = (
0.5
* x
* (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2")
```

</details>


### <a name="Gemm"></a><a name="gemm">**Gemm**</a>

General Matrix multiplication:
Expand Down
50 changes: 50 additions & 0 deletions docs/TestCoverage.md
Original file line number Diff line number Diff line change
Expand Up @@ -6241,6 +6241,56 @@ expect(
</details>


### Gelu
There are 2 test cases, listed as following:
<details>
<summary>gelu_default</summary>

```python
node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"])

x = np.array([-1, 0, 1]).astype(np.float32)
# expected output [-0.15865526, 0., 0.84134474]
y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1")

x = np.random.randn(3, 4, 5).astype(np.float32)
# expected output [2.99595031, 3.99987331, 4.99999857]
y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2")
```

</details>
<details>
<summary>gelu_tanh</summary>

```python
node = onnx.helper.make_node(
"Gelu", inputs=["x"], outputs=["y"], approximate="tanh"
)

x = np.array([-1, 0, 1]).astype(np.float32)
# expected output [-0.158808, 0., 0.841192]
y = (
0.5
* x
* (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1")

x = np.random.randn(3, 4, 5).astype(np.float32)
# expected output [2.9963627, 3.99993, 4.9999995]
y = (
0.5
* x
* (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2")
```

</details>


### Gemm
There are 11 test cases, listed as following:
<details>
Expand Down
51 changes: 51 additions & 0 deletions onnx/backend/test/case/node/gelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0

import math

import numpy as np

import onnx
from onnx.backend.test.case.base import Base
from onnx.backend.test.case.node import expect


class Gelu(Base):
@staticmethod
def export_gelu_tanh() -> None:
node = onnx.helper.make_node(
"Gelu", inputs=["x"], outputs=["y"], approximate="tanh"
)

x = np.array([-1, 0, 1]).astype(np.float32)
# expected output [-0.158808, 0., 0.841192]
y = (
0.5
* x
* (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1")

x = np.random.randn(3, 4, 5).astype(np.float32)
# expected output [2.9963627, 3.99993, 4.9999995]
y = (
0.5
* x
* (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2")

@staticmethod
def export_gelu_default() -> None:
node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"])

x = np.array([-1, 0, 1]).astype(np.float32)
# expected output [-0.15865526, 0., 0.84134474]
y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1")

x = np.random.randn(3, 4, 5).astype(np.float32)
# expected output [2.99595031, 3.99987331, 4.99999857]
y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2")
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
BxJ�x��?h��>��z?�j@$ �?�.z��8s?b��hdӽ�9�>(�>�%�?^�B?�0�= B�>]ת>�=�?R�iJ�>�Z�/d#��S'?�K]?��=��C@�(��Hm;= �?�2�?��?��>���>�Ec������!��� >*z�?��?�Oƾmǚ��6��&õ�gڿ��?�x�FKྙ[��� G?4�ο��Y�L=e��> �����k��QN�>.:�=�ݚ>�b"�6���
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ByJ��?�K�>��Q?A�@t��? V$���I?�W��uB���>�d�=��?VQ?V��=q��>�~W>�Q�?�x�G>d�+�8^_�$��>�o2?�.�ޓ@�3ٽxD�<+7����?п�?24�=I�z>jL*��'A������=D�?D�?��
�|��
��B����KR�?h@�:U�Z����?T��;%��;�)����>�f�����a�J�>�s=��?>��*�S �
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
BxJ�x��?h��>��z?�j@$ �?�.z��8s?b��hdӽ�9�>(�>�%�?^�B?�0�= B�>]ת>�=�?R�iJ�>�Z�/d#��S'?�K]?��=��C@�(��Hm;= �?�2�?��?��>���>�Ec������!��� >*z�?��?�Oƾmǚ��6��&õ�gڿ��?�x�FKྙ[��� G?4�ο��Y�L=e��> �����k��QN�>.:�=�ݚ>�b"�6���
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ByJ��?�K�>��Q?A�@t��? V$���I?�W��uB���>�d�=��?VQ?V��=q��>�~W>�Q�?�x�G>d�+�8^_�$��>�o2?�.�ޓ@�3ٽxD�<+7����?п�?24�=I�z>jL*��'A������=D�?D�?��
�|��
��B����KR�?h@�:U�Z����?T��;%��;�)����>�f�����a�J�>�s=��?>��*�S �
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
BxJ�x��?h��>��z?�j@$ �?�.z��8s?b��hdӽ�9�>(�>�%�?^�B?�0�= B�>]ת>�=�?R�iJ�>�Z�/d#��S'?�K]?��=��C@�(��Hm;= �?�2�?��?��>���>�Ec������!��� >*z�?��?�Oƾmǚ��6��&õ�gڿ��?�x�FKྙ[��� G?4�ο��Y�L=e��> �����k��QN�>.:�=�ݚ>�b"�6���
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
BxJ�x��?h��>��z?�j@$ �?�.z��8s?b��hdӽ�9�>(�>�%�?^�B?�0�= B�>]ת>�=�?R�iJ�>�Z�/d#��S'?�K]?��=��C@�(��Hm;= �?�2�?��?��>���>�Ec������!��� >*z�?��?�Oƾmǚ��6��&õ�gڿ��?�x�FKྙ[��� G?4�ο��Y�L=e��> �����k��QN�>.:�=�ݚ>�b"�6���
Binary file not shown.
Loading

0 comments on commit 1207f1b

Please sign in to comment.