Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support training of RNN. #1109

Merged
merged 8 commits into from
Jun 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/TensorFlowNET.Core/APIs/c_api.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

using System;
using System.Runtime.InteropServices;
using static Tensorflow.CppShapeInferenceResult.Types;

namespace Tensorflow
{
Expand Down Expand Up @@ -50,6 +51,19 @@ public static string StringPiece(IntPtr handle)
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle);
}

public unsafe static byte[] ByteStringPiece(IntPtr handle)
{
byte* str_data = (byte*)handle.ToPointer();
List<byte> bytes = new List<byte>();
byte current = 255;
while (current != ((byte)'\0'))
{
current = *(str_data++);
bytes.Add(current);
}
return bytes.Take(bytes.Count - 1).ToArray();
}

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args);

Expand Down
10 changes: 5 additions & 5 deletions src/TensorFlowNET.Core/APIs/tf.control_flow.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ public partial class tensorflow
Tensor loop_vars,
int parallel_iterations = 10)
{
Func<Tensor[], Tensor> cond1 = x
Func<Tensors, Tensor> cond1 = x
=> cond(x[0]);

Func<Tensor[], Tensor[]> body1 = x
Func<Tensors, Tensors> body1 = x
=> new[] { body(x[0]) };

var results = control_flow_ops.while_loop(cond1,
Expand All @@ -58,9 +58,9 @@ public partial class tensorflow
return results[0];
}

public Tensor[] while_loop(Func<Tensor[], Tensor> cond,
Func<Tensor[], Tensor[]> body,
Tensor[] loop_vars,
public Tensor[] while_loop(Func<Tensors, Tensor> cond,
Func<Tensors, Tensors> body,
Tensors loop_vars,
int parallel_iterations = 10,
string name = null)
=> control_flow_ops.while_loop(cond, body, loop_vars,
Expand Down
6 changes: 3 additions & 3 deletions src/TensorFlowNET.Core/APIs/tf.tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ public Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.Dt
public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null)
=> array_ops.split(
value: value,
num_split: num_split,
num_or_size_splits: num_split,
axis: axis,
name: name);

public Tensor[] split(Tensor value, int num_split, int axis, string name = null)
=> array_ops.split(
value: value,
num_split: num_split,
axis: axis,
num_or_size_splits: num_split,
axis: ops.convert_to_tensor(axis),
name: name);

public Tensor ensure_shape(Tensor x, Shape shape, string name = null)
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Binding.Util.cs
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ public static TF_DataType GetDataType(this object data)
case Tensors tensors:
return tensors.dtype;
case IEnumerable<Tensor> tensors:
return tensors.First().dtype;
return tensors.Where(x => x is not null).First().dtype;
case RefVariable variable:
return variable.dtype;
case ResourceVariable variable:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@ public static IEnumerable<T> SkipLast<T>(this IEnumerable<T> sequence, int count
return sequence.Take(sequence.Count() - count);
}
#endif
public static Tensors ToTensors(this IEnumerable<Tensor> tensors)
public static Tensors ToTensors(this Tensor[] tensors)
{
return new Tensors(tensors);
}

public static Tensors ToTensors(this IList<Tensor> tensors)
{
return new Tensors(tensors);
}
Expand Down
20 changes: 20 additions & 0 deletions src/TensorFlowNET.Core/Common/Types/FakeTensorByTensorArray.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// This is a temp solution, which should be removed after refactoring `Tensors`
/// </summary>
[Obsolete]
public class FakeTensorByTensorArray: Tensor
{
public TensorArray TensorArray { get; set; }

public FakeTensorByTensorArray(TensorArray array)
{
TensorArray = array;
}
}
}
127 changes: 28 additions & 99 deletions src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,136 +5,65 @@

namespace Tensorflow.Common.Types
{
public class GeneralizedTensorShape: IEnumerable<long?[]>, INestStructure<long?>, INestable<long?>
public class GeneralizedTensorShape: Nest<Shape>
{
public TensorShapeConfig[] Shapes { get; set; }
/// <summary>
/// create a single-dim generalized Tensor shape.
/// </summary>
/// <param name="dim"></param>
public GeneralizedTensorShape(int dim, int size = 1)
public GeneralizedTensorShape(Shape value, string? name = null)
{
var elem = new TensorShapeConfig() { Items = new long?[] { dim } };
Shapes = Enumerable.Repeat(elem, size).ToArray();
//Shapes = new TensorShapeConfig[size];
//Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } });
//Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } });
////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
NodeValue = value;
NestType = NestType.Node;
}

public GeneralizedTensorShape(Shape shape)
public GeneralizedTensorShape(IEnumerable<Shape> values, string? name = null)
{
Shapes = new TensorShapeConfig[] { shape };
ListValue = values.Select(s => new Nest<Shape>(s) as INestStructure<Shape>).ToList();
Name = name;
NestType = NestType.List;
}

public GeneralizedTensorShape(TensorShapeConfig shape)
public GeneralizedTensorShape(Dictionary<string, Shape> value, string? name = null)
{
Shapes = new TensorShapeConfig[] { shape };
DictValue = value.ToDictionary(x => x.Key, x => new Nest<Shape>(x.Value) as INestStructure<Shape>);
Name = name;
NestType = NestType.Dictionary;
}

public GeneralizedTensorShape(TensorShapeConfig[] shapes)
public GeneralizedTensorShape(Nest<Shape> other)
{
Shapes = shapes;
}

public GeneralizedTensorShape(IEnumerable<Shape> shape)
{
Shapes = shape.Select(x => (TensorShapeConfig)x).ToArray();
NestType = other.NestType;
NodeValue = other.NodeValue;
DictValue = other.DictValue;
ListValue = other.ListValue;
Name = other.Name;
}

public Shape ToSingleShape()
{
if (Shapes.Length != 1)
var shapes = Flatten().ToList();
if (shapes.Count != 1)
{
throw new ValueError("The generalized shape contains more than 1 dim.");
}
var shape_config = Shapes[0];
Debug.Assert(shape_config is not null);
return new Shape(shape_config.Items.Select(x => x is null ? -1 : x.Value).ToArray());
return shapes[0];
}

public long ToNumber()
{
if(Shapes.Length != 1 || Shapes[0].Items.Length != 1)
var shapes = Flatten().ToList();
if (shapes.Count != 1 || shapes[0].ndim != 1)
{
throw new ValueError("The generalized shape contains more than 1 dim.");
}
var res = Shapes[0].Items[0];
return res is null ? -1 : res.Value;
}

public Shape[] ToShapeArray()
{
return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray();
}

public IEnumerable<long?> Flatten()
{
List<long?> result = new List<long?>();
foreach(var shapeConfig in Shapes)
{
result.AddRange(shapeConfig.Items);
}
return result;
}
public INestStructure<TOut> MapStructure<TOut>(Func<long?, TOut> func)
{
List<Nest<TOut>> lists = new();
foreach(var shapeConfig in Shapes)
{
lists.Add(new Nest<TOut>(shapeConfig.Items.Select(x => new Nest<TOut>(func(x)))));
}
return new Nest<TOut>(lists);
return shapes[0].dims[0];
}

public Nest<long?> AsNest()
public INestStructure<TensorShapeConfig> ToTensorShapeConfigs()
{
Nest<long?> DealWithSingleShape(TensorShapeConfig config)
{
if (config.Items.Length == 0)
{
return Nest<long?>.Empty;
}
else if (config.Items.Length == 1)
{
return new Nest<long?>(config.Items[0]);
}
else
{
return new Nest<long?>(config.Items.Select(x => new Nest<long?>(x)));
}
}

if(Shapes.Length == 0)
{
return Nest<long?>.Empty;
}
else if(Shapes.Length == 1)
{
return DealWithSingleShape(Shapes[0]);
}
else
{
return new Nest<long?>(Shapes.Select(s => DealWithSingleShape(s)));
}
}



public static implicit operator GeneralizedTensorShape(int dims)
=> new GeneralizedTensorShape(dims);

public IEnumerator<long?[]> GetEnumerator()
{
foreach (var shape in Shapes)
{
yield return shape.Items;
}
return MapStructure(s => new TensorShapeConfig() { Items = s.dims.Select<long, long?>(x => x == -1 ? null : x).ToArray() });
}

IEnumerator IEnumerable.GetEnumerator()
public static implicit operator GeneralizedTensorShape(Shape shape)
{
return GetEnumerator();
return new GeneralizedTensorShape(shape);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@ namespace Tensorflow.Common.Types
/// </summary>
public interface INestStructure<T>: INestable<T>
{
NestType NestType { get; }

/// <summary>
/// The item count of depth 1 of the nested structure.
/// For example, [1, 2, [3, 4, 5]] has ShallowNestedCount = 3.
/// </summary>
int ShallowNestedCount { get; }
/// <summary>
/// The total item count of depth 1 of the nested structure.
/// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5.
/// </summary>
int TotalNestedCount { get; }

/// <summary>
/// Flatten the Nestable object. Node that if the object contains only one value,
/// it will be flattened to an enumerable with one element.
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Common/Types/Nest.Static.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public static class Nest
/// <param name="template"></param>
/// <param name="flatItems"></param>
/// <returns></returns>
public static Nest<T> PackSequenceAs<T>(INestable<T> template, T[] flatItems)
public static Nest<TOut> PackSequenceAs<T, TOut>(INestable<T> template, TOut[] flatItems)
{
return template.AsNest().PackSequence(flatItems);
}
Expand Down