Skip to content

Commit

Permalink
Merge pull request #1100 from Wanglongzhi2001/rnn-dev
Browse files Browse the repository at this point in the history
Add feature(not completed):add SimpleRNNCell, StackedRNNCell, RNN and test.
  • Loading branch information
AsakusaRinne committed Jun 12, 2023
2 parents 81a9d23 + db8e43b commit 1d97b71
Show file tree
Hide file tree
Showing 14 changed files with 445 additions and 119 deletions.
14 changes: 12 additions & 2 deletions src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ public class GeneralizedTensorShape: IEnumerable<long?[]>, INestStructure<long?>
/// create a single-dim generalized Tensor shape.
/// </summary>
/// <param name="dim"></param>
public GeneralizedTensorShape(int dim)
public GeneralizedTensorShape(int dim, int size = 1)
{
Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
var elem = new TensorShapeConfig() { Items = new long?[] { dim } };
Shapes = Enumerable.Repeat(elem, size).ToArray();
//Shapes = new TensorShapeConfig[size];
//Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } });
//Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } });
////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
}

public GeneralizedTensorShape(Shape shape)
Expand Down Expand Up @@ -113,6 +118,11 @@ public INestStructure<TOut> MapStructure<TOut>(Func<long?, TOut> func)
return new Nest<long?>(Shapes.Select(s => DealWithSingleShape(s)));
}
}



public static implicit operator GeneralizedTensorShape(int dims)
=> new GeneralizedTensorShape(dims);

public IEnumerator<long?[]> GetEnumerator()
{
Expand Down
3 changes: 3 additions & 0 deletions src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ public class RNNArgs : AutoSerializeLayerArgs
[JsonProperty("cell")]
// TODO: the cell should be serialized with `serialize_keras_object`.
public IRnnCell Cell { get; set; } = null;
[JsonProperty("cells")]
public IList<IRnnCell> Cells { get; set; } = null;

[JsonProperty("return_sequences")]
public bool ReturnSequences { get; set; } = false;
[JsonProperty("return_state")]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
using System.Collections.Generic;
using Tensorflow.Keras.Layers.Rnn;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
public class StackedRNNCellsArgs : LayerArgs
{
public IList<RnnCell> Cells { get; set; }
public IList<IRnnCell> Cells { get; set; }
public Dictionary<string, object> Kwargs { get; set; } = null;
}
}
34 changes: 34 additions & 0 deletions src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using Tensorflow.Framework.Models;
using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.NumPy;
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;

Expand Down Expand Up @@ -192,6 +193,19 @@ public partial interface ILayersApi
float offset = 0,
Shape input_shape = null);

public IRnnCell SimpleRNNCell(
int units,
string activation = "tanh",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros",
float dropout = 0f,
float recurrent_dropout = 0f);

public IRnnCell StackedRNNCells(
IEnumerable<IRnnCell> cells);

public ILayer SimpleRNN(int units,
string activation = "tanh",
string kernel_initializer = "glorot_uniform",
Expand All @@ -200,6 +214,26 @@ public partial interface ILayersApi
bool return_sequences = false,
bool return_state = false);

public ILayer RNN(
IRnnCell cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false
);

public ILayer RNN(
IEnumerable<IRnnCell> cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false
);

public ILayer Subtract();
}
}
14 changes: 13 additions & 1 deletion src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,19 @@ public TensorArray scatter(Tensor indices, Tensor value, string name = null)
return ta;
});*/
throw new NotImplementedException("");
//if (indices is EagerTensor)
//{
// indices = indices as EagerTensor;
// indices = indices.numpy();
//}

//foreach (var (index, val) in zip(indices.ToArray<int>(), array_ops.unstack(value)))
//{
// this.write(index, val);
//}
//return base;
//throw new NotImplementedException("");
return this;
}

public void _merge_element_shape(Shape shape)
Expand Down
5 changes: 4 additions & 1 deletion src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace Tensorflow.Operations
Expand Down Expand Up @@ -146,7 +147,9 @@ public TensorArray scatter(Tensor indices, Tensor value, string name = null)
return ta;
});*/
throw new NotImplementedException("");

//throw new NotImplementedException("");
return this;
}

public void _merge_element_shape(Shape shape)
Expand Down
27 changes: 14 additions & 13 deletions src/TensorFlowNET.Keras/BackendImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ Tensor swap_batch_timestep(Tensor input_t)
}

}

// tf.where needs its condition tensor to be the same shape as its two
// result tensors, but in our case the condition (mask) tensor is
// (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
Expand All @@ -535,7 +535,7 @@ Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1)
{
mask_t = tf.expand_dims(mask_t, -1);
}
var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().ToList().GetRange(fixed_dim, input_t.rank));
var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().Skip(fixed_dim).ToArray());
return tf.tile(mask_t, multiples);
}

Expand Down Expand Up @@ -570,9 +570,6 @@ Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1)
// individually. The result of this will be a tuple of lists, each of
// the item in tuple is list of the tensor with shape (batch, feature)




Tensors _process_single_input_t(Tensor input_t)
{
var unstaked_input_t = array_ops.unstack(input_t); // unstack for time_step dim
Expand Down Expand Up @@ -609,7 +606,7 @@ object _get_input_tensor(int time)
var mask_list = tf.unstack(mask);
if (go_backwards)
{
mask_list.Reverse();
mask_list.Reverse().ToArray();
}

for (int i = 0; i < time_steps; i++)
Expand All @@ -629,9 +626,10 @@ object _get_input_tensor(int time)
}
else
{
prev_output = successive_outputs[successive_outputs.Length - 1];
prev_output = successive_outputs.Last();
}

// output could be a tensor
output = tf.where(tiled_mask_t, output, prev_output);

var flat_states = Nest.Flatten(states).ToList();
Expand Down Expand Up @@ -661,13 +659,13 @@ object _get_input_tensor(int time)
}

}
last_output = successive_outputs[successive_outputs.Length - 1];
new_states = successive_states[successive_states.Length - 1];
last_output = successive_outputs.Last();
new_states = successive_states.Last();
outputs = tf.stack(successive_outputs);

if (zero_output_for_mask)
{
last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output));
last_output = tf.where(_expand_mask(mask_list.Last(), last_output), last_output, tf.zeros_like(last_output));
outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs));
}
else // mask is null
Expand All @@ -689,8 +687,8 @@ object _get_input_tensor(int time)
successive_states = new Tensors { newStates };
}
}
last_output = successive_outputs[successive_outputs.Length - 1];
new_states = successive_states[successive_states.Length - 1];
last_output = successive_outputs.Last();
new_states = successive_states.Last();
outputs = tf.stack(successive_outputs);
}
}
Expand All @@ -701,6 +699,8 @@ object _get_input_tensor(int time)
// Create input tensor array, if the inputs is nested tensors, then it
// will be flattened first, and tensor array will be created one per
// flattened tensor.


var input_ta = new List<TensorArray>();
for (int i = 0; i < flatted_inptus.Count; i++)
{
Expand All @@ -719,6 +719,7 @@ object _get_input_tensor(int time)
}
}


// Get the time(0) input and compute the output for that, the output will
// be used to determine the dtype of output tensor array. Don't read from
// input_ta due to TensorArray clear_after_read default to True.
Expand Down Expand Up @@ -773,7 +774,7 @@ object _get_input_tensor(int time)
return res;
};
}
// TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)?
// TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor), it could be an integer or tensor
else if (input_length is Tensor)
{
if (go_backwards)
Expand Down
77 changes: 77 additions & 0 deletions src/TensorFlowNET.Keras/Layers/LayersApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,34 @@ public ILayer LeakyReLU(float alpha = 0.3f)
Alpha = alpha
});


public IRnnCell SimpleRNNCell(
int units,
string activation = "tanh",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros",
float dropout = 0f,
float recurrent_dropout = 0f)
=> new SimpleRNNCell(new SimpleRNNCellArgs
{
Units = units,
Activation = keras.activations.GetActivationFromName(activation),
UseBias = use_bias,
KernelInitializer = GetInitializerByName(kernel_initializer),
RecurrentInitializer = GetInitializerByName(recurrent_initializer),
Dropout = dropout,
RecurrentDropout = recurrent_dropout
});

public IRnnCell StackedRNNCells(
IEnumerable<IRnnCell> cells)
=> new StackedRNNCells(new StackedRNNCellsArgs
{
Cells = cells.ToList()
});

/// <summary>
///
/// </summary>
Expand All @@ -709,6 +737,55 @@ public ILayer LeakyReLU(float alpha = 0.3f)
ReturnState = return_state
});

/// <summary>
///
/// </summary>
/// <param name="cell"></param>
/// <param name="return_sequences"></param>
/// <param name="return_state"></param>
/// <param name="go_backwards"></param>
/// <param name="stateful"></param>
/// <param name="unroll"></param>
/// <param name="time_major"></param>
/// <returns></returns>
public ILayer RNN(
IRnnCell cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false)
=> new RNN(new RNNArgs
{
Cell = cell,
ReturnSequences = return_sequences,
ReturnState = return_state,
GoBackwards = go_backwards,
Stateful = stateful,
Unroll = unroll,
TimeMajor = time_major
});

public ILayer RNN(
IEnumerable<IRnnCell> cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false)
=> new RNN(new RNNArgs
{
Cells = cell.ToList(),
ReturnSequences = return_sequences,
ReturnState = return_state,
GoBackwards = go_backwards,
Stateful = stateful,
Unroll = unroll,
TimeMajor = time_major
});

/// <summary>
/// Long Short-Term Memory layer - Hochreiter 1997.
/// </summary>
Expand Down
15 changes: 15 additions & 0 deletions src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,21 @@ public DropoutRNNCellMixin(LayerArgs args): base(args)

}

protected void _create_non_trackable_mask_cache()
{

}

public void reset_dropout_mask()
{

}

public void reset_recurrent_dropout_mask()
{

}

public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
{
if (dropout == 0f)
Expand Down
Loading

0 comments on commit 1d97b71

Please sign in to comment.