Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change type of BuildInputShape and BatchInputShape #1033

Merged
merged 2 commits into from
Apr 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/TensorFlowNET.Core/Extensions/JObjectExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Extensions
{
public static class JObjectExtensions
{
public static T? TryGetOrReturnNull<T>(this JObject obj, string key)
{
var res = obj[key];
if(res is null)
{
return default(T);
}
else
{
return res.ToObject<T>();
}
}
}
}
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ public class TensorSpec : DenseSpec
public TensorSpec(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) :
base(shape, dtype, name)
{

}

public TensorSpec _unbatch()
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Keras/Activations/Activations.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Newtonsoft.Json;
using System.Reflection;
using System.Runtime.Versioning;
using Tensorflow.Keras.Common;
using Tensorflow.Keras.Saving.Common;

namespace Tensorflow.Keras
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.ArgsDefinition
{
Expand All @@ -18,7 +19,7 @@ public class AutoSerializeLayerArgs: LayerArgs
[JsonProperty("dtype")]
public override TF_DataType DType { get => base.DType; set => base.DType = value; }
[JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)]
public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; }
public override KerasShapesWrapper BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; }
[JsonProperty("trainable")]
public override bool Trainable { get => base.Trainable; set => base.Trainable = value; }
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Newtonsoft.Json;
using Newtonsoft.Json.Serialization;
using Tensorflow.Keras.Common;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.ArgsDefinition
{
Expand All @@ -17,6 +17,6 @@ public class InputLayerArgs : LayerArgs
[JsonProperty("dtype")]
public override TF_DataType DType { get => base.DType; set => base.DType = value; }
[JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)]
public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; }
public override KerasShapesWrapper BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; }
}
}
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class LayerArgs: IKerasConfig
/// <summary>
/// Only applicable to input layers.
/// </summary>
public virtual Shape BatchInputShape { get; set; }
public virtual KerasShapesWrapper BatchInputShape { get; set; }

public virtual int BatchSize { get; set; } = -1;

Expand Down
6 changes: 3 additions & 3 deletions src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public interface ILayer: IWithTrackable, IKerasConfigable
string Name { get; }
bool Trainable { get; }
bool Built { get; }
void build(Shape input_shape);
void build(KerasShapesWrapper input_shape);
List<ILayer> Layers { get; }
List<INode> InboundNodes { get; }
List<INode> OutboundNodes { get; }
Expand All @@ -22,8 +22,8 @@ public interface ILayer: IWithTrackable, IKerasConfigable
void set_weights(IEnumerable<NDArray> weights);
List<NDArray> get_weights();
Shape OutputShape { get; }
Shape BatchInputShape { get; }
TensorShapeConfig BuildInputShape { get; }
KerasShapesWrapper BatchInputShape { get; }
KerasShapesWrapper BuildInputShape { get; }
TF_DataType DType { get; }
int count_params();
void adapt(Tensor data, int? batch_size = null, int? steps = null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Common
namespace Tensorflow.Keras.Saving.Common
{
public class CustomizedActivationJsonConverter : JsonConverter
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Common
namespace Tensorflow.Keras.Saving.Common
{
public class CustomizedAxisJsonConverter : JsonConverter
{
Expand Down Expand Up @@ -38,7 +38,7 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
int[]? axis;
if(reader.ValueType == typeof(long))
if (reader.ValueType == typeof(long))
{
axis = new int[1];
axis[0] = (int)serializer.Deserialize(reader, typeof(int));
Expand All @@ -51,7 +51,7 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
{
throw new ValueError("Cannot deserialize 'null' to `Axis`.");
}
return new Axis((int[])(axis!));
return new Axis(axis!);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Newtonsoft.Json.Linq;
using Newtonsoft.Json;

namespace Tensorflow.Keras.Common
namespace Tensorflow.Keras.Saving.Common
{
public class CustomizedDTypeJsonConverter : JsonConverter
{
Expand All @@ -16,7 +16,7 @@ public override bool CanConvert(Type objectType)

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
var token = JToken.FromObject(dtypes.as_numpy_name((TF_DataType)value));
var token = JToken.FromObject(((TF_DataType)value).as_numpy_name());
token.WriteTo(writer);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations;

using Tensorflow.Operations.Initializers;

namespace Tensorflow.Keras.Common
namespace Tensorflow.Keras.Saving.Common
{
class InitializerInfo
{
Expand All @@ -27,7 +28,7 @@ public override bool CanConvert(Type objectType)
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
var initializer = value as IInitializer;
if(initializer is null)
if (initializer is null)
{
JToken.FromObject(null).WriteTo(writer);
return;
Expand All @@ -42,7 +43,7 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
var info = serializer.Deserialize<InitializerInfo>(reader);
if(info is null)
if (info is null)
{
return null;
}
Expand All @@ -54,8 +55,8 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
"Orthogonal" => new Orthogonal(info.config["gain"].ToObject<float>(), info.config["seed"].ToObject<int?>()),
"RandomNormal" => new RandomNormal(info.config["mean"].ToObject<float>(), info.config["stddev"].ToObject<float>(),
info.config["seed"].ToObject<int?>()),
"RandomUniform" => new RandomUniform(minval:info.config["minval"].ToObject<float>(),
maxval:info.config["maxval"].ToObject<float>(), seed: info.config["seed"].ToObject<int?>()),
"RandomUniform" => new RandomUniform(minval: info.config["minval"].ToObject<float>(),
maxval: info.config["maxval"].ToObject<float>(), seed: info.config["seed"].ToObject<int?>()),
"TruncatedNormal" => new TruncatedNormal(info.config["mean"].ToObject<float>(), info.config["stddev"].ToObject<float>(),
info.config["seed"].ToObject<int?>()),
"VarianceScaling" => new VarianceScaling(info.config["scale"].ToObject<float>(), info.config["mode"].ToObject<string>(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
using Newtonsoft.Json.Linq;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Saving.Json
{
public class CustomizedKerasShapesWrapperJsonConverter : JsonConverter
{
public override bool CanConvert(Type objectType)
{
return objectType == typeof(KerasShapesWrapper);
}

public override bool CanRead => true;

public override bool CanWrite => true;

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
if (value is null)
{
JToken.FromObject(null).WriteTo(writer);
return;
}
if (value is not KerasShapesWrapper wrapper)
{
throw new TypeError($"Expected `KerasShapesWrapper` to be serialized, bug got {value.GetType()}");
}
if (wrapper.Shapes.Length == 0)
{
JToken.FromObject(null).WriteTo(writer);
}
else if (wrapper.Shapes.Length == 1)
{
JToken.FromObject(wrapper.Shapes[0]).WriteTo(writer);
}
else
{
JToken.FromObject(wrapper.Shapes).WriteTo(writer);
}
}

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
if (reader.TokenType == JsonToken.StartArray)
{
TensorShapeConfig[] shapes = serializer.Deserialize<TensorShapeConfig[]>(reader);
if (shapes is null)
{
return null;
}
return new KerasShapesWrapper(shapes);
}
else if (reader.TokenType == JsonToken.StartObject)
{
var shape = serializer.Deserialize<TensorShapeConfig>(reader);
if (shape is null)
{
return null;
}
return new KerasShapesWrapper(shape);
}
else if (reader.TokenType == JsonToken.Null)
{
return null;
}
else
{
throw new ValueError($"Cannot deserialize the token type {reader.TokenType}");
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
using System.Text;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Common
namespace Tensorflow.Keras.Saving.Common
{
public class CustomizedNodeConfigJsonConverter : JsonConverter
{
Expand Down Expand Up @@ -46,10 +46,10 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
{
throw new ValueError("Cannot deserialize 'null' to `Shape`.");
}
if(values.Length == 1)
if (values.Length == 1)
{
var array = values[0] as JArray;
if(array is null)
if (array is null)
{
throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Common
namespace Tensorflow.Keras.Saving.Common
{
class ShapeInfoFromPython
{
public string class_name { get; set; }
public long?[] items { get; set; }
}
public class CustomizedShapeJsonConverter: JsonConverter
public class CustomizedShapeJsonConverter : JsonConverter
{
public override bool CanConvert(Type objectType)
{
Expand All @@ -25,20 +25,20 @@ public override bool CanConvert(Type objectType)

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
if(value is null)
if (value is null)
{
var token = JToken.FromObject(null);
token.WriteTo(writer);
}
else if(value is not Shape)
else if (value is not Shape)
{
throw new TypeError($"Unable to use `CustomizedShapeJsonConverter` to serialize the type {value.GetType()}.");
}
else
{
var shape = (value as Shape)!;
long?[] dims = new long?[shape.ndim];
for(int i = 0; i < dims.Length; i++)
for (int i = 0; i < dims.Length; i++)
{
if (shape.dims[i] == -1)
{
Expand All @@ -61,7 +61,7 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
long?[] dims;
try
if (reader.TokenType == JsonToken.StartObject)
{
var shape_info_from_python = serializer.Deserialize<ShapeInfoFromPython>(reader);
if (shape_info_from_python is null)
Expand All @@ -70,14 +70,22 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
}
dims = shape_info_from_python.items;
}
catch(JsonSerializationException)
else if (reader.TokenType == JsonToken.StartArray)
{
dims = serializer.Deserialize<long?[]>(reader);
}
else if (reader.TokenType == JsonToken.Null)
{
return null;
}
else
{
throw new ValueError($"Cannot deserialize the token {reader} as Shape.");
}
long[] convertedDims = new long[dims.Length];
for(int i = 0; i < dims.Length; i++)
for (int i = 0; i < dims.Length; i++)
{
convertedDims[i] = dims[i] ?? (-1);
convertedDims[i] = dims[i] ?? -1;
}
return new Shape(convertedDims);
}
Expand Down