Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement GRU layer #1168

Merged
merged 1 commit into from
Sep 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUArgs.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class GRUArgs : AutoSerializeLayerArgs
{
public int Units { get; set; }
public Activation Activation { get; set; }
public Activation RecurrentActivation { get; set; }
public bool UseBias { get; set; } = true;
public float Dropout { get; set; } = .0f;
public float RecurrentDropout { get; set; } = .0f;
public IInitializer KernelInitializer { get; set; }
public IInitializer RecurrentInitializer { get; set; }
public IInitializer BiasInitializer { get; set; }
public bool ReturnSequences { get;set; }
public bool ReturnState { get;set; }
public bool GoBackwards { get;set; }
public bool Stateful { get;set; }
public bool Unroll { get;set; }
public bool TimeMajor { get;set; }
public bool ResetAfter { get;set; }
public int Implementation { get; set; } = 2;

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class GRUOptionalArgs
{
public string Identifier => "GRU";

public Tensor Mask { get; set; } = null;
}
}
19 changes: 19 additions & 0 deletions src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,25 @@ public partial interface ILayersApi
float recurrent_dropout = 0f,
bool reset_after = true);

public ILayer GRU(
int units,
string activation = "tanh",
string recurrent_activation = "sigmoid",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros",
float dropout = 0f,
float recurrent_dropout = 0f,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false,
bool reset_after = true
);

/// <summary>
/// Bidirectional wrapper for RNNs.
/// </summary>
Expand Down
61 changes: 60 additions & 1 deletion src/TensorFlowNET.Keras/Layers/LayersApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ public ILayer LeakyReLU(float alpha = 0.3f)
string recurrent_activation = "sigmoid",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal", // TODO(Wanglongzhi2001),glorot_uniform has not been developed.
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros",
bool unit_forget_bias = true,
float dropout = 0f,
Expand Down Expand Up @@ -908,6 +908,65 @@ public ILayer LeakyReLU(float alpha = 0.3f)
ResetAfter = reset_after
});

/// <summary>
/// Gated Recurrent Unit - Cho et al. 2014.
/// </summary>
/// <param name="units">Positive integer, dimensionality of the output space.</param>
/// <param name="activation">Activation function to use. If you pass `None`, no activation is applied.(ie. "linear" activation: `a(x) = x`).</param>
/// <param name="recurrent_activation">Activation function to use for the recurrent step. If you pass `None`, no activation is applied. (ie. "linear" activation: `a(x) = x`).</param>
/// <param name="use_bias">Boolean, (default `True`), whether the layer uses a bias vector.</param>
/// <param name="kernel_initializer">Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. Default: `glorot_uniform`.</param>
/// <param name="recurrent_initializer">Initializer for the `recurrent_kernel` weights matrix, used for the linear transformation of the recurrent state. Default: `orthogonal`.</param>
/// <param name="bias_initializer">Initializer for the bias vector. Default: `zeros`.</param>
/// <param name="dropout">Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. Default: 0.</param>
/// <param name="recurrent_dropout">Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. Default: 0.</param>
/// <param name="implementation"></param>
/// <param name="return_sequences">Boolean. Whether to return the last output in the output sequence, or the full sequence. Default: `False`.</param>
/// <param name="return_state">Boolean. Whether to return the last state in addition to the output. Default: `False`.</param>
/// <param name="go_backwards">Boolean (default `False`). If True, process the input sequence backwards and return the reversed sequence.</param>
/// <param name="stateful">Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch.</param>
/// <param name="unroll">Boolean (default False). If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed-up a RNN,</param>
/// <param name="time_major">The shape format of the `inputs` and `outputs` tensors.</param>
/// <param name="reset_after">GRU convention (whether to apply reset gate after or before matrix multiplication). False = "before", True = "after" (default and cuDNN compatible).</param>
/// <returns></returns>
public ILayer GRU(
int units,
string activation = "tanh",
string recurrent_activation = "sigmoid",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros",
float dropout = 0f,
float recurrent_dropout = 0f,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false,
bool reset_after = true
)
=> new GRU(new GRUArgs
{
Units = units,
Activation = keras.activations.GetActivationFromName(activation),
RecurrentActivation = keras.activations.GetActivationFromName(recurrent_activation),
KernelInitializer = GetInitializerByName(kernel_initializer),
RecurrentInitializer = GetInitializerByName(recurrent_initializer),
BiasInitializer = GetInitializerByName(bias_initializer),
UseBias = use_bias,
Dropout = dropout,
RecurrentDropout = recurrent_dropout,
ReturnSequences = return_sequences,
ReturnState = return_state,
GoBackwards = go_backwards,
Stateful = stateful,
TimeMajor = time_major,
Unroll = unroll,
ResetAfter = reset_after
});

public ILayer Bidirectional(
ILayer layer,
string merge_mode = "concat",
Expand Down
168 changes: 168 additions & 0 deletions src/TensorFlowNET.Keras/Layers/Rnn/GRU.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Common.Extensions;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Saving;


namespace Tensorflow.Keras.Layers
{
public class GRU : RNN
{
GRUArgs _args;
private static GRUCell _cell;

bool _return_runtime;
public GRUCell Cell { get => _cell; }
public int units { get => _args.Units; }
public Activation activation { get => _args.Activation; }
public Activation recurrent_activation { get => _args.RecurrentActivation; }
public bool use_bias { get => _args.UseBias; }
public float dropout { get => _args.Dropout; }
public float recurrent_dropout { get => _args.RecurrentDropout; }
public IInitializer kernel_initializer { get => _args.KernelInitializer; }
public IInitializer recurrent_initializer { get => _args.RecurrentInitializer; }
public IInitializer bias_initializer { get => _args.BiasInitializer; }
public int implementation { get => _args.Implementation; }
public bool reset_after { get => _args.ResetAfter; }

public GRU(GRUArgs args) : base(CreateCell(args), PreConstruct(args))
{
_args = args;

if (_args.Implementation == 0)
{
// Use the red output to act as a warning message that can also be used under the release version
Console.ForegroundColor = ConsoleColor.Red;
Console.WriteLine("Warning: `implementation=0` has been deprecated, "+
"and now defaults to `implementation=2`."+
"Please update your layer call.");
Console.ResetColor();
}

GRUCell cell = new GRUCell(new GRUCellArgs
{
Units = _args.Units,
Activation = _args.Activation,
RecurrentActivation = _args.RecurrentActivation,
UseBias = _args.UseBias,
Dropout = _args.Dropout,
RecurrentDropout = _args.RecurrentDropout,
KernelInitializer = _args.KernelInitializer,
RecurrentInitializer = _args.RecurrentInitializer,
BiasInitializer = _args.BiasInitializer,
ResetAfter = _args.ResetAfter,
Implementation = _args.Implementation
});
_cell = cell;
}

protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
GRUOptionalArgs? gru_optional_args = optional_args as GRUOptionalArgs;
if (optional_args is not null && gru_optional_args is null)
{
throw new ArgumentException("The type of optional args should be `GRUOptionalArgs`.");
}
Tensors? mask = gru_optional_args?.Mask;

// Not support ragger input temporarily;
int row_length = 0;
bool is_ragged_input = false;

_validate_args_if_ragged(is_ragged_input, mask);

// GRU does not support constants.Ignore it during process.
(inputs, initial_state, _) = this._process_inputs(inputs, initial_state, null);

if (mask.Length > 1)
{
mask = mask[0];
}

var input_shape = inputs.shape;
var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1];


// TODO(Wanglongzhi2001), finish _could_use_gpu_kernel part
Func<Tensors, Tensors, (Tensors, Tensors)> step = (cell_inputs, cell_states) =>
{
var res = Cell.Apply(cell_inputs, cell_states, training is null ? true : training.Value);
var (output, state) = res;
return (output, state);
};

var (last_output, outputs, states) = keras.backend.rnn(
step,
inputs,
initial_state,
constants: null,
go_backwards: _args.GoBackwards,
mask: mask,
unroll: _args.Unroll,
input_length: ops.convert_to_tensor(timesteps),
time_major: _args.TimeMajor,
zero_output_for_mask: base.Args.ZeroOutputForMask,
return_all_outputs: _args.ReturnSequences
);

Tensors output;
if (_args.ReturnSequences)
{
output = outputs;
}
else
{
output = last_output;
}

if (_args.ReturnState)
{
output = new Tensors { output, states };
}
return output;
}

private static IRnnCell CreateCell(GRUArgs gruArgs)
{
return new GRUCell(new GRUCellArgs
{
Units = gruArgs.Units,
Activation = gruArgs.Activation,
RecurrentActivation = gruArgs.RecurrentActivation,
UseBias = gruArgs.UseBias,
Dropout = gruArgs.Dropout,
RecurrentDropout = gruArgs.RecurrentDropout,
KernelInitializer = gruArgs.KernelInitializer,
RecurrentInitializer = gruArgs.RecurrentInitializer,
BiasInitializer = gruArgs.BiasInitializer,
ResetAfter = gruArgs.ResetAfter,
Implementation = gruArgs.Implementation
});
}

private static RNNArgs PreConstruct(GRUArgs args)
{
return new RNNArgs
{
ReturnSequences = args.ReturnSequences,
ReturnState = args.ReturnState,
GoBackwards = args.GoBackwards,
Stateful = args.Stateful,
Unroll = args.Unroll,
TimeMajor = args.TimeMajor,
Units = args.Units,
Activation = args.Activation,
RecurrentActivation = args.RecurrentActivation,
UseBias = args.UseBias,
Dropout = args.Dropout,
RecurrentDropout = args.RecurrentDropout,
KernelInitializer = args.KernelInitializer,
RecurrentInitializer = args.RecurrentInitializer,
BiasInitializer = args.BiasInitializer
};
}
}
}
42 changes: 2 additions & 40 deletions src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ public class RNN : RnnBase
private RNNArgs _args;
private object _input_spec = null; // or NoneValue??
private object _state_spec = null;
private Tensors _states = null;
private object _constants_spec = null;
private Tensors _states = null;
private int _num_constants;
protected IVariableV1 _kernel;
protected IVariableV1 _bias;
Expand Down Expand Up @@ -469,7 +469,7 @@ public override Tensors Apply(Tensors inputs, Tensors initial_states = null, boo
return (inputs, initial_state, constants);
}

private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask)
protected void _validate_args_if_ragged(bool is_ragged_input, Tensors mask)
{
if (!is_ragged_input)
{
Expand Down Expand Up @@ -528,44 +528,6 @@ public Tensors __call__(Tensors inputs, Tensor state = null, Tensor training = n
throw new NotImplementedException();
}

// 好像不能cell不能传接口类型
//public RNN New(IRnnArgCell cell,
// bool return_sequences = false,
// bool return_state = false,
// bool go_backwards = false,
// bool stateful = false,
// bool unroll = false,
// bool time_major = false)
// => new RNN(new RNNArgs
// {
// Cell = cell,
// ReturnSequences = return_sequences,
// ReturnState = return_state,
// GoBackwards = go_backwards,
// Stateful = stateful,
// Unroll = unroll,
// TimeMajor = time_major
// });

//public RNN New(List<IRnnArgCell> cell,
// bool return_sequences = false,
// bool return_state = false,
// bool go_backwards = false,
// bool stateful = false,
// bool unroll = false,
// bool time_major = false)
// => new RNN(new RNNArgs
// {
// Cell = cell,
// ReturnSequences = return_sequences,
// ReturnState = return_state,
// GoBackwards = go_backwards,
// Stateful = stateful,
// Unroll = unroll,
// TimeMajor = time_major
// });


protected Tensors get_initial_state(Tensors inputs)
{
var input = inputs[0];
Expand Down
9 changes: 9 additions & 0 deletions test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,15 @@ public void GRUCell()

}

[TestMethod]
public void GRU()
{
var inputs = tf.ones((32, 10, 8));
var gru = tf.keras.layers.GRU(4);
var output = gru.Apply(inputs);
Assert.AreEqual((32, 4), output.shape);
}

[TestMethod]
public void Bidirectional()
{
Expand Down