Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support the multiple inputs of keras model.fit. #996

Merged
merged 5 commits into from
Mar 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/TensorFlowNET.Core/Data/DatasetV2.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ public class DatasetV2 : IDatasetV2

public TensorSpec[] structure { get; set; }

public int FirstInputTensorCount { get; set; } = 1;

public Shape[] output_shapes => structure.Select(x => x.shape).ToArray();

public TF_DataType[] output_types => structure.Select(x => x.dtype).ToArray();
Expand Down Expand Up @@ -131,6 +133,7 @@ public IDatasetV2 apply_options()

// (4) Apply stats aggregator options

dataset.FirstInputTensorCount = this.FirstInputTensorCount;
return dataset;
}

Expand All @@ -142,7 +145,7 @@ public override string ToString()
$"types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}, " +
$"len: {length}";

public IEnumerator<(Tensor, Tensor)> GetEnumerator()
public IEnumerator<(Tensors, Tensors)> GetEnumerator()
{
using var ownedIterator = new OwnedIterator(this);

Expand All @@ -158,7 +161,8 @@ public IEnumerator<(Tensor, Tensor)> GetEnumerator()
break;
}

yield return (results[0], results.Length == 1 ? null : results[1]);
yield return (new Tensors(results.Take(FirstInputTensorCount)), results.Length == FirstInputTensorCount ?
null : new Tensors(results.Skip(FirstInputTensorCount)));
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/TensorFlowNET.Core/Data/IDatasetV2.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace Tensorflow
{
public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)>
public interface IDatasetV2 : IEnumerable<(Tensors, Tensors)>
AsakusaRinne marked this conversation as resolved.
Show resolved Hide resolved
{
string[] class_names { get; set; }

Expand All @@ -18,6 +18,8 @@ public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)>

TensorSpec[] structure { get; set; }

int FirstInputTensorCount { get; set; }

/// <summary>
/// Caches the elements in this dataset.
/// </summary>
Expand Down
5 changes: 3 additions & 2 deletions src/TensorFlowNET.Core/Data/OwnedIterator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ void _create_iterator(IDatasetV2 dataset)
_dataset = dataset;
_element_spec = dataset.element_spec;
// _flat_output_types =
(_iterator_resource, _deleter) = ops.anonymous_iterator_v2(_dataset.output_types, _dataset.output_shapes);
_iterator_resource = ops.anonymous_iterator_v3(_dataset.output_types, _dataset.output_shapes);
// TODO(Rinne): deal with graph mode.
ops.make_iterator(dataset.variant_tensor, _iterator_resource);
}

Expand All @@ -48,7 +49,7 @@ public Tensor[] next()

public void Dispose()
{
tf.Runner.Execute(tf.Context, "DeleteIterator", 0, new[] { _iterator_resource, _deleter }, null);
//tf.Runner.Execute(tf.Context, "DeleteIterator", 0, new[] { _iterator_resource, _deleter }, null);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ namespace Tensorflow.Keras.ArgsDefinition
{
public class DataAdapterArgs: IKerasConfig
{
public Tensor X { get; set; }
public Tensor Y { get; set; }
public Tensors X { get; set; }
public Tensors Y { get; set; }
public IDatasetV2 Dataset { get; set; }
public int BatchSize { get; set; } = 32;
public int Steps { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ namespace Tensorflow.Keras.ArgsDefinition
{
public class DataHandlerArgs: IKerasConfig
{
public Tensor X { get; set; }
public Tensor Y { get; set; }
public Tensors X { get; set; }
public Tensors Y { get; set; }
public IDatasetV2 Dataset { get; set; }
public int BatchSize { get; set; } = 32;
public int StepsPerEpoch { get; set; } = -1;
Expand Down
11 changes: 11 additions & 0 deletions src/TensorFlowNET.Core/Keras/Engine/IModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ public interface IModel : ILayer
int workers = 1,
bool use_multiprocessing = false);

ICallback fit(IEnumerable<NDArray> x, NDArray y,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
float validation_split = 0f,
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false);

void save(string filepath,
bool overwrite = true,
bool include_optimizer = true,
Expand Down
71 changes: 70 additions & 1 deletion src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,76 @@ public void Deconstruct(out byte blue, out byte green, out byte red)
red = data[2];
}

public static implicit operator NDArray(Array array)
public static implicit operator NDArray(int[] array)
AsakusaRinne marked this conversation as resolved.
Show resolved Hide resolved
=> new NDArray(array);

public static implicit operator NDArray(byte[] array)
=> new NDArray(array);

public static implicit operator NDArray(float[] array)
=> new NDArray(array);

public static implicit operator NDArray(double[] array)
=> new NDArray(array);

public static implicit operator NDArray(long[] array)
=> new NDArray(array);

public static implicit operator NDArray(bool[] array)
=> new NDArray(array);

public static implicit operator NDArray(uint[] array)
=> new NDArray(array);

public static implicit operator NDArray(ulong[] array)
=> new NDArray(array);

public static implicit operator NDArray(int[,] array)
=> new NDArray(array);

public static implicit operator NDArray(byte[,] array)
=> new NDArray(array);

public static implicit operator NDArray(float[,] array)
=> new NDArray(array);

public static implicit operator NDArray(double[,] array)
=> new NDArray(array);

public static implicit operator NDArray(long[,] array)
=> new NDArray(array);

public static implicit operator NDArray(bool[,] array)
=> new NDArray(array);

public static implicit operator NDArray(uint[,] array)
=> new NDArray(array);

public static implicit operator NDArray(ulong[,] array)
=> new NDArray(array);

public static implicit operator NDArray(int[,,] array)
=> new NDArray(array);

public static implicit operator NDArray(byte[,,] array)
=> new NDArray(array);

public static implicit operator NDArray(float[,,] array)
=> new NDArray(array);

public static implicit operator NDArray(double[,,] array)
=> new NDArray(array);

public static implicit operator NDArray(long[,,] array)
=> new NDArray(array);

public static implicit operator NDArray(bool[,,] array)
=> new NDArray(array);

public static implicit operator NDArray(uint[,,] array)
=> new NDArray(array);

public static implicit operator NDArray(ulong[,,] array)
=> new NDArray(array);

public unsafe static implicit operator bool(NDArray nd)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ private NDArray OpenEntry(ZipArchiveEntry entry)
return array;

using var s = entry.Open();
return LoadMatrix(s);
return (NDArray)LoadMatrix(s);
}

public Array LoadMatrix(Stream stream)
Expand Down
3 changes: 3 additions & 0 deletions src/TensorFlowNET.Core/Numpy/NDArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,8 @@ public IEnumerator<NDArray> GetEnumerator()

IEnumerator IEnumerable.GetEnumerator()
=> GetEnumerator();

public static explicit operator NDArray(Array array)
=> new NDArray(array);
}
}
34 changes: 34 additions & 0 deletions src/TensorFlowNET.Core/Operations/dataset_ops.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using System;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using Tensorflow.Framework.Models;
using Tensorflow.Functions;
using Tensorflow.Operations;
using static Tensorflow.Binding;

namespace Tensorflow
Expand Down Expand Up @@ -220,6 +223,37 @@ public Tensor dummy_memory_cache(string name = "")
return (results[0], results[1]);
}

public Tensor anonymous_iterator_v3(TF_DataType[] output_types, Shape[] output_shapes, string name = null)
{
var ctx = tf.Context;
Dictionary<string, object> attrs = new();
attrs["output_types"] = output_types;
attrs["output_shapes"] = output_shapes;
if (ctx.executing_eagerly())
{
try
{
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("AnonymousIteratorV3", name)
{
attrs = attrs
});
return result[0];
}
catch (Exception)
{
return anonymous_iterator_v3_eager_fallback(output_types, output_shapes, name, ctx);
}
}
return tf.OpDefLib._apply_op_helper("AnonymousIteratorV3", name, attrs).outputs[0];
}

public Tensor anonymous_iterator_v3_eager_fallback(TF_DataType[] output_types, Shape[] output_shapes, string name, Context ctx)
{
object[] attrs = new object[] { output_types, output_shapes };
var result = execute.quick_execute("AnonymousIteratorV3", 1, new Tensor[] { }, attrs, ctx, name);
return result[0];
}

/// <summary>
/// Makes a new iterator from the given `dataset` and stores it in `iterator`.
/// </summary>
Expand Down
101 changes: 101 additions & 0 deletions src/TensorFlowNET.Core/Tensors/Tensors.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,93 @@ public void Insert(int index, Tensor tensor)
IEnumerator IEnumerable.GetEnumerator()
=> GetEnumerator();

public NDArray numpy()
{
EnsureSingleTensor(this, "nnumpy");
return this[0].numpy();
}

public T[] ToArray<T>() where T: unmanaged
{
EnsureSingleTensor(this, $"ToArray<{typeof(T)}>");
return this[0].ToArray<T>();
}

#region Explicit Conversions
public unsafe static explicit operator bool(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to bool");
return (bool)tensor[0];
}

public unsafe static explicit operator sbyte(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to sbyte");
return (sbyte)tensor[0];
}

public unsafe static explicit operator byte(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to byte");
return (byte)tensor[0];
}

public unsafe static explicit operator ushort(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to ushort");
return (ushort)tensor[0];
}

public unsafe static explicit operator short(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to short");
return (short)tensor[0];
}

public unsafe static explicit operator int(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to int");
return (int)tensor[0];
}

public unsafe static explicit operator uint(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to uint");
return (uint)tensor[0];
}

public unsafe static explicit operator long(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to long");
return (long)tensor[0];
}

public unsafe static explicit operator ulong(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to ulong");
return (ulong)tensor[0];
}

public unsafe static explicit operator float(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to byte");
return (byte)tensor[0];
}

public unsafe static explicit operator double(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to double");
return (double)tensor[0];
}

public unsafe static explicit operator string(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to string");
return (string)tensor[0];
}
#endregion

#region Implicit Conversions
public static implicit operator Tensors(Tensor tensor)
=> new Tensors(tensor);

Expand All @@ -87,12 +174,26 @@ IEnumerator IEnumerable.GetEnumerator()
public static implicit operator Tensor[](Tensors tensors)
=> tensors.items.ToArray();

#endregion

public void Deconstruct(out Tensor a, out Tensor b)
{
a = items[0];
b = items[1];
}

private static void EnsureSingleTensor(Tensors tensors, string methodnName)
{
if(tensors.Length == 0)
{
throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains no Tensor.");
}
else if(tensors.Length > 1)
{
throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains more than one Tensor.");
}
}

public override string ToString()
=> items.Count() == 1
? items.First().ToString()
Expand Down