Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the API of LayerNormalization #1114

Merged
merged 2 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/TensorFlowNET.Core/APIs/tf.nn.cs
Expand Up @@ -14,8 +14,10 @@
limitations under the License.
******************************************************************************/

using System.Xml.Linq;
using Tensorflow.Operations;
using Tensorflow.Operations.Activation;
//using static System.Formats.Asn1.AsnWriter;
AsakusaRinne marked this conversation as resolved.
Show resolved Hide resolved
using static Tensorflow.Binding;

namespace Tensorflow
Expand Down Expand Up @@ -125,6 +127,22 @@ public Tensor relu(Tensor features, string name = null)
is_training: is_training,
name: name,
exponential_avg_factor: exponential_avg_factor);
public Tensor batch_normalization(Tensor x,
AsakusaRinne marked this conversation as resolved.
Show resolved Hide resolved
Tensor mean,
Tensor variance,
Tensor offset,
Tensor scale,
float variance_epsilon,
string name = null)
{
var inv = math_ops.rsqrt(variance + variance_epsilon);
tf_with(ops.name_scope(name, "batchnorm", (x, mean, variance, scale, offset)), scope =>
{
if (scale != null) inv *= scale;
});
if (offset != null) return x * math_ops.cast(inv, x.dtype) + math_ops.cast(offset - mean * inv, dtype: x.dtype);
else return x * math_ops.cast(inv, x.dtype) + math_ops.cast(-mean * inv, dtype: x.dtype);
}

public Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null)
=> nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name);
Expand Down
Expand Up @@ -153,9 +153,22 @@ protected override Tensors Call(Tensors inputs, Tensors state = null, bool? trai
}
else
{
var input_dtype = inputs.dtype;
if ((input_dtype == tf.float16) && DType == tf.float32) inputs = tf.cast(inputs, tf.float32);
(Tensor mean, Tensor variance) = tf.nn.moments(inputs, axis, keep_dims: true);

}
(Tensor scale, Tensor offset) = (_broadcast(gamma), _broadcast(beta));

outputs = tf.nn.batch_normalization(
inputs,
mean,
variance,
offset: offset,
scale: scale,
variance_epsilon: epsilon);

outputs = tf.cast(outputs, input_dtype);
}
// If some components of the shape got lost due to adjustments, fix that.
outputs.shape = input_shape;

Expand Down
22 changes: 22 additions & 0 deletions test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
@@ -1,5 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
Expand Down Expand Up @@ -161,6 +163,26 @@ public void LayerNormalization()
Tensor output = layer.Apply(inputs);
Assert.AreEqual((5, 2), output.shape);
Assert.IsTrue(output[0].numpy().Equals(new[] { -0.99998f, 0.99998f }));

// test_layernorm_weights
Assert.AreEqual(len(layer.TrainableWeights), 2);
Assert.AreEqual(len(layer.Weights), 2);

var beta = layer.Weights.Where(x => x.Name.StartsWith("beta")).Single();
var gamma = layer.Weights.Where(x => x.Name.StartsWith("gamma")).Single();

// correctness_test
layer = keras.layers.LayerNormalization(axis: -1, epsilon: (float) 1e-12);
var x = np.random.normal(loc: 5.0f, scale: 10.0f, size: (1000, 2, 2, 2)).astype(tf.float32);

output = layer.Apply(x);

var y = (output - beta.numpy()) / gamma.numpy();

var y_mean = np.mean(y.numpy());
var y_std = np.sqrt(np.sum(np.power(y.numpy() - np.mean(y.numpy()), 2)) / 8000);
Assert.IsTrue(tf.greater(np.array(0.1f), tf.abs(y_std - 1.0)).ToArray<bool>()[0]);
Assert.IsTrue(tf.greater(np.array(0.1f), tf.abs(y_mean)).ToArray<bool>()[0]);
}

/// <summary>
Expand Down