Skip to content

Commit

Permalink
Merge pull request #1109 from AsakusaRinne/rnn-dev
Browse files Browse the repository at this point in the history
feat: support training of RNN.
  • Loading branch information
AsakusaRinne committed Jun 17, 2023
2 parents 1d97b71 + a0df810 commit edbf89b
Show file tree
Hide file tree
Showing 90 changed files with 8,353 additions and 1,013 deletions.
14 changes: 14 additions & 0 deletions src/TensorFlowNET.Core/APIs/c_api.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

using System;
using System.Runtime.InteropServices;
using static Tensorflow.CppShapeInferenceResult.Types;

namespace Tensorflow
{
Expand Down Expand Up @@ -50,6 +51,19 @@ public static string StringPiece(IntPtr handle)
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle);
}

public unsafe static byte[] ByteStringPiece(IntPtr handle)
{
byte* str_data = (byte*)handle.ToPointer();
List<byte> bytes = new List<byte>();
byte current = 255;
while (current != ((byte)'\0'))
{
current = *(str_data++);
bytes.Add(current);
}
return bytes.Take(bytes.Count - 1).ToArray();
}

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args);

Expand Down
10 changes: 5 additions & 5 deletions src/TensorFlowNET.Core/APIs/tf.control_flow.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ public partial class tensorflow
Tensor loop_vars,
int parallel_iterations = 10)
{
Func<Tensor[], Tensor> cond1 = x
Func<Tensors, Tensor> cond1 = x
=> cond(x[0]);

Func<Tensor[], Tensor[]> body1 = x
Func<Tensors, Tensors> body1 = x
=> new[] { body(x[0]) };

var results = control_flow_ops.while_loop(cond1,
Expand All @@ -58,9 +58,9 @@ public partial class tensorflow
return results[0];
}

public Tensor[] while_loop(Func<Tensor[], Tensor> cond,
Func<Tensor[], Tensor[]> body,
Tensor[] loop_vars,
public Tensor[] while_loop(Func<Tensors, Tensor> cond,
Func<Tensors, Tensors> body,
Tensors loop_vars,
int parallel_iterations = 10,
string name = null)
=> control_flow_ops.while_loop(cond, body, loop_vars,
Expand Down
6 changes: 3 additions & 3 deletions src/TensorFlowNET.Core/APIs/tf.tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ public Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.Dt
public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null)
=> array_ops.split(
value: value,
num_split: num_split,
num_or_size_splits: num_split,
axis: axis,
name: name);

public Tensor[] split(Tensor value, int num_split, int axis, string name = null)
=> array_ops.split(
value: value,
num_split: num_split,
axis: axis,
num_or_size_splits: num_split,
axis: ops.convert_to_tensor(axis),
name: name);

public Tensor ensure_shape(Tensor x, Shape shape, string name = null)
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Binding.Util.cs
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ public static TF_DataType GetDataType(this object data)
case Tensors tensors:
return tensors.dtype;
case IEnumerable<Tensor> tensors:
return tensors.First().dtype;
return tensors.Where(x => x is not null).First().dtype;
case RefVariable variable:
return variable.dtype;
case ResourceVariable variable:
Expand Down
7 changes: 6 additions & 1 deletion src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@ public static IEnumerable<T> SkipLast<T>(this IEnumerable<T> sequence, int count
return sequence.Take(sequence.Count() - count);
}
#endif
public static Tensors ToTensors(this IEnumerable<Tensor> tensors)
public static Tensors ToTensors(this Tensor[] tensors)
{
return new Tensors(tensors);
}

public static Tensors ToTensors(this IList<Tensor> tensors)
{
return new Tensors(tensors);
}
Expand Down
20 changes: 20 additions & 0 deletions src/TensorFlowNET.Core/Common/Types/FakeTensorByTensorArray.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// This is a temp solution, which should be removed after refactoring `Tensors`
/// </summary>
[Obsolete]
public class FakeTensorByTensorArray: Tensor
{
public TensorArray TensorArray { get; set; }

public FakeTensorByTensorArray(TensorArray array)
{
TensorArray = array;
}
}
}
127 changes: 28 additions & 99 deletions src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,136 +5,65 @@

namespace Tensorflow.Common.Types
{
public class GeneralizedTensorShape: IEnumerable<long?[]>, INestStructure<long?>, INestable<long?>
public class GeneralizedTensorShape: Nest<Shape>
{
public TensorShapeConfig[] Shapes { get; set; }
/// <summary>
/// create a single-dim generalized Tensor shape.
/// </summary>
/// <param name="dim"></param>
public GeneralizedTensorShape(int dim, int size = 1)
public GeneralizedTensorShape(Shape value, string? name = null)
{
var elem = new TensorShapeConfig() { Items = new long?[] { dim } };
Shapes = Enumerable.Repeat(elem, size).ToArray();
//Shapes = new TensorShapeConfig[size];
//Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } });
//Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } });
////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
NodeValue = value;
NestType = NestType.Node;
}

public GeneralizedTensorShape(Shape shape)
public GeneralizedTensorShape(IEnumerable<Shape> values, string? name = null)
{
Shapes = new TensorShapeConfig[] { shape };
ListValue = values.Select(s => new Nest<Shape>(s) as INestStructure<Shape>).ToList();
Name = name;
NestType = NestType.List;
}

public GeneralizedTensorShape(TensorShapeConfig shape)
public GeneralizedTensorShape(Dictionary<string, Shape> value, string? name = null)
{
Shapes = new TensorShapeConfig[] { shape };
DictValue = value.ToDictionary(x => x.Key, x => new Nest<Shape>(x.Value) as INestStructure<Shape>);
Name = name;
NestType = NestType.Dictionary;
}

public GeneralizedTensorShape(TensorShapeConfig[] shapes)
public GeneralizedTensorShape(Nest<Shape> other)
{
Shapes = shapes;
}

public GeneralizedTensorShape(IEnumerable<Shape> shape)
{
Shapes = shape.Select(x => (TensorShapeConfig)x).ToArray();
NestType = other.NestType;
NodeValue = other.NodeValue;
DictValue = other.DictValue;
ListValue = other.ListValue;
Name = other.Name;
}

public Shape ToSingleShape()
{
if (Shapes.Length != 1)
var shapes = Flatten().ToList();
if (shapes.Count != 1)
{
throw new ValueError("The generalized shape contains more than 1 dim.");
}
var shape_config = Shapes[0];
Debug.Assert(shape_config is not null);
return new Shape(shape_config.Items.Select(x => x is null ? -1 : x.Value).ToArray());
return shapes[0];
}

public long ToNumber()
{
if(Shapes.Length != 1 || Shapes[0].Items.Length != 1)
var shapes = Flatten().ToList();
if (shapes.Count != 1 || shapes[0].ndim != 1)
{
throw new ValueError("The generalized shape contains more than 1 dim.");
}
var res = Shapes[0].Items[0];
return res is null ? -1 : res.Value;
}

public Shape[] ToShapeArray()
{
return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray();
}

public IEnumerable<long?> Flatten()
{
List<long?> result = new List<long?>();
foreach(var shapeConfig in Shapes)
{
result.AddRange(shapeConfig.Items);
}
return result;
}
public INestStructure<TOut> MapStructure<TOut>(Func<long?, TOut> func)
{
List<Nest<TOut>> lists = new();
foreach(var shapeConfig in Shapes)
{
lists.Add(new Nest<TOut>(shapeConfig.Items.Select(x => new Nest<TOut>(func(x)))));
}
return new Nest<TOut>(lists);
return shapes[0].dims[0];
}

public Nest<long?> AsNest()
public INestStructure<TensorShapeConfig> ToTensorShapeConfigs()
{
Nest<long?> DealWithSingleShape(TensorShapeConfig config)
{
if (config.Items.Length == 0)
{
return Nest<long?>.Empty;
}
else if (config.Items.Length == 1)
{
return new Nest<long?>(config.Items[0]);
}
else
{
return new Nest<long?>(config.Items.Select(x => new Nest<long?>(x)));
}
}

if(Shapes.Length == 0)
{
return Nest<long?>.Empty;
}
else if(Shapes.Length == 1)
{
return DealWithSingleShape(Shapes[0]);
}
else
{
return new Nest<long?>(Shapes.Select(s => DealWithSingleShape(s)));
}
}



public static implicit operator GeneralizedTensorShape(int dims)
=> new GeneralizedTensorShape(dims);

public IEnumerator<long?[]> GetEnumerator()
{
foreach (var shape in Shapes)
{
yield return shape.Items;
}
return MapStructure(s => new TensorShapeConfig() { Items = s.dims.Select<long, long?>(x => x == -1 ? null : x).ToArray() });
}

IEnumerator IEnumerable.GetEnumerator()
public static implicit operator GeneralizedTensorShape(Shape shape)
{
return GetEnumerator();
return new GeneralizedTensorShape(shape);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@ namespace Tensorflow.Common.Types
/// </summary>
public interface INestStructure<T>: INestable<T>
{
NestType NestType { get; }

/// <summary>
/// The item count of depth 1 of the nested structure.
/// For example, [1, 2, [3, 4, 5]] has ShallowNestedCount = 3.
/// </summary>
int ShallowNestedCount { get; }
/// <summary>
/// The total item count of depth 1 of the nested structure.
/// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5.
/// </summary>
int TotalNestedCount { get; }

/// <summary>
/// Flatten the Nestable object. Node that if the object contains only one value,
/// it will be flattened to an enumerable with one element.
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Common/Types/Nest.Static.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public static class Nest
/// <param name="template"></param>
/// <param name="flatItems"></param>
/// <returns></returns>
public static Nest<T> PackSequenceAs<T>(INestable<T> template, T[] flatItems)
public static Nest<TOut> PackSequenceAs<T, TOut>(INestable<T> template, TOut[] flatItems)
{
return template.AsNest().PackSequence(flatItems);
}
Expand Down

0 comments on commit edbf89b

Please sign in to comment.