Skip to content

Commit

Permalink
Automatically add KerasInterface to f.
Browse files Browse the repository at this point in the history
  • Loading branch information
AsakusaRinne authored and Oceania2018 committed Mar 4, 2023
1 parent e5837dc commit ece36e6
Show file tree
Hide file tree
Showing 12 changed files with 132 additions and 28 deletions.
2 changes: 0 additions & 2 deletions src/TensorFlowNET.Console/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ class Program
{
static void Main(string[] args)
{
tf.UseKeras<KerasInterface>();

var diag = new Diagnostician();
// diag.Diagnose(@"D:\memory.txt");

Expand Down
44 changes: 39 additions & 5 deletions src/TensorFlowNET.Core/Keras/IKerasApi.cs
Original file line number Diff line number Diff line change
@@ -1,26 +1,60 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Threading;
using Tensorflow.Framework.Models;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Models;

namespace Tensorflow.Keras
{
public interface IKerasApi
{
public ILayersApi layers { get; }
public ILossesApi losses { get; }
public IMetricsApi metrics { get; }
public IInitializersApi initializers { get; }
IInitializersApi initializers { get; }
ILayersApi layers { get; }
ILossesApi losses { get; }
IOptimizerApi optimizers { get; }
IMetricsApi metrics { get; }
IModelsApi models { get; }

/// <summary>
/// `Model` groups layers into an object with training and inference features.
/// </summary>
/// <param name="input"></param>
/// <param name="output"></param>
/// <returns></returns>
public IModel Model(Tensors inputs, Tensors outputs, string name = null);
IModel Model(Tensors inputs, Tensors outputs, string name = null);

/// <summary>
/// Instantiate a Keras tensor.
/// </summary>
/// <param name="shape"></param>
/// <param name="batch_size"></param>
/// <param name="dtype"></param>
/// <param name="name"></param>
/// <param name="sparse">
/// A boolean specifying whether the placeholder to be created is sparse.
/// </param>
/// <param name="ragged">
/// A boolean specifying whether the placeholder to be created is ragged.
/// </param>
/// <param name="tensor">
/// Optional existing tensor to wrap into the `Input` layer.
/// If set, the layer will not create a placeholder tensor.
/// </param>
/// <returns></returns>
Tensors Input(Shape shape = null,
int batch_size = -1,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
bool sparse = false,
Tensor tensor = null,
bool ragged = false,
TypeSpec type_spec = null,
Shape batch_input_shape = null,
Shape batch_shape = null);
}
}
47 changes: 47 additions & 0 deletions src/TensorFlowNET.Core/Keras/IOptimizerApi.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras
{
public interface IOptimizerApi
{
/// <summary>
/// Adam optimization is a stochastic gradient descent method that is based on
/// adaptive estimation of first-order and second-order moments.
/// </summary>
/// <param name="learning_rate"></param>
/// <param name="beta_1"></param>
/// <param name="beta_2"></param>
/// <param name="epsilon"></param>
/// <param name="amsgrad"></param>
/// <param name="name"></param>
/// <returns></returns>
IOptimizer Adam(float learning_rate = 0.001f,
float beta_1 = 0.9f,
float beta_2 = 0.999f,
float epsilon = 1e-7f,
bool amsgrad = false,
string name = "Adam");

/// <summary>
/// Construct a new RMSprop optimizer.
/// </summary>
/// <param name="learning_rate"></param>
/// <param name="rho"></param>
/// <param name="momentum"></param>
/// <param name="epsilon"></param>
/// <param name="centered"></param>
/// <param name="name"></param>
/// <returns></returns>
IOptimizer RMSprop(float learning_rate = 0.001f,
float rho = 0.9f,
float momentum = 0.0f,
float epsilon = 1e-7f,
bool centered = false,
string name = "RMSprop");

IOptimizer SGD(float learning_rate);
}
}
12 changes: 12 additions & 0 deletions src/TensorFlowNET.Core/Keras/Models/IModelsApi.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Models
{
public interface IModelsApi
{
public IModel load_model(string filepath, bool compile = true, LoadOptions? options = null);
}
}
8 changes: 0 additions & 8 deletions src/TensorFlowNET.Core/tensorflow.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,6 @@ public tensorflow()
InitGradientEnvironment();
}

public void UseKeras<T>() where T : IKerasApi, new()
{
if (keras == null)
{
keras = new T();
}
}

public string VERSION => c_api.StringPiece(c_api.TF_Version());

private void InitGradientEnvironment()
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Keras/KerasApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ namespace Tensorflow
/// </summary>
public static class KerasApi
{
public static KerasInterface keras { get; } = new KerasInterface();
public static KerasInterface keras { get; } = KerasInterface.Instance;
}
}
26 changes: 24 additions & 2 deletions src/TensorFlowNET.Keras/KerasInterface.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,28 @@ namespace Tensorflow.Keras
{
public class KerasInterface : IKerasApi
{
private static KerasInterface _instance = null;
private static readonly object _lock = new object();
private KerasInterface()
{
Tensorflow.Binding.tf.keras = this;
}

public static KerasInterface Instance
{
get
{
lock (_lock)
{
if (_instance is null)
{
_instance = new KerasInterface();
}
return _instance;
}
}
}

public KerasDataset datasets { get; } = new KerasDataset();
public IInitializersApi initializers { get; } = new InitializersApi();
public Regularizers regularizers { get; } = new Regularizers();
Expand All @@ -27,9 +49,9 @@ public class KerasInterface : IKerasApi
public Preprocessing preprocessing { get; } = new Preprocessing();
ThreadLocal<BackendImpl> _backend = new ThreadLocal<BackendImpl>(() => new BackendImpl());
public BackendImpl backend => _backend.Value;
public OptimizerApi optimizers { get; } = new OptimizerApi();
public IOptimizerApi optimizers { get; } = new OptimizerApi();
public IMetricsApi metrics { get; } = new MetricsApi();
public ModelsApi models { get; } = new ModelsApi();
public IModelsApi models { get; } = new ModelsApi();
public KerasUtils utils { get; } = new KerasUtils();

public Sequential Sequential(List<ILayer> layers = null,
Expand Down
4 changes: 2 additions & 2 deletions src/TensorFlowNET.Keras/Models/ModelsApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@

namespace Tensorflow.Keras.Models
{
public class ModelsApi
public class ModelsApi: IModelsApi
{
public Functional from_config(ModelConfig config)
=> Functional.from_config(config);

public Model load_model(string filepath, bool compile = true, LoadOptions? options = null)
public IModel load_model(string filepath, bool compile = true, LoadOptions? options = null)
{
return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model;
}
Expand Down
9 changes: 5 additions & 4 deletions src/TensorFlowNET.Keras/Optimizers/OptimizerApi.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Optimizers
{
public class OptimizerApi
public class OptimizerApi: IOptimizerApi
{
/// <summary>
/// Adam optimization is a stochastic gradient descent method that is based on
Expand All @@ -15,7 +16,7 @@ public class OptimizerApi
/// <param name="amsgrad"></param>
/// <param name="name"></param>
/// <returns></returns>
public OptimizerV2 Adam(float learning_rate = 0.001f,
public IOptimizer Adam(float learning_rate = 0.001f,
float beta_1 = 0.9f,
float beta_2 = 0.999f,
float epsilon = 1e-7f,
Expand All @@ -38,7 +39,7 @@ public class OptimizerApi
/// <param name="centered"></param>
/// <param name="name"></param>
/// <returns></returns>
public OptimizerV2 RMSprop(float learning_rate = 0.001f,
public IOptimizer RMSprop(float learning_rate = 0.001f,
float rho = 0.9f,
float momentum = 0.0f,
float epsilon = 1e-7f,
Expand All @@ -54,7 +55,7 @@ public class OptimizerApi
Name = name
});

public SGD SGD(float learning_rate)
public IOptimizer SGD(float learning_rate)
=> new SGD(learning_rate);
}
}
2 changes: 0 additions & 2 deletions test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ public class EagerModeTestBase
[TestInitialize]
public void TestInit()
{
tf.UseKeras<KerasInterface>();

if (!tf.executing_eagerly())
tf.enable_eager_execution();
tf.Context.ensure_initialized();
Expand Down
1 change: 0 additions & 1 deletion test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ public void EinsumDense()
[TestMethod, Ignore("WIP")]
public void SimpleRNN()
{
tf.UseKeras<KerasInterface>();
var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32);
/*var simple_rnn = keras.layers.SimpleRNN(4);
var output = simple_rnn.Apply(inputs);
Expand Down
3 changes: 2 additions & 1 deletion test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Diagnostics;
using static Tensorflow.KerasApi;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Models;

namespace TensorFlowNET.Keras.UnitTest
{
Expand All @@ -18,7 +19,7 @@ public void GetAndFromConfig()
var model = GetFunctionalModel();
var config = model.get_config();
Debug.Assert(config is ModelConfig);
var new_model = keras.models.from_config(config as ModelConfig);
var new_model = new ModelsApi().from_config(config as ModelConfig);
Assert.AreEqual(model.Layers.Count, new_model.Layers.Count);
}

Expand Down

0 comments on commit ece36e6

Please sign in to comment.