-
Notifications
You must be signed in to change notification settings - Fork 540
Description
As a model is not threadsafe, I'm trying to clone an existing model to use each clones in separate threads. I'm using save_weights + load_weights to recreate the clone (note that if you have another way to do it without using files, I'm in)
tested on
tensorflow.net 0.40.1
tensorflow.keras 0.5.1
scisharp.tensorflow.redist-windows-gpu 2.5.0
and
tensorflow.net 0.60.5
tensorflow.keras 0.6.5
scisharp.tensorflow.redist-windows-gpu 2.6.0
I reduced the problem to this unit test:
`
public Functional BuildModel()
{
tf.Context.reset_context();
var inputs = keras.Input(shape: 2);
// 1st dense layer
var DenseLayer = keras.layers.Dense(1, activation: keras.activations.Sigmoid);
var outputs = DenseLayer.Apply(inputs);
// build keras model
Functional model = keras.Model(inputs, outputs, name: Guid.NewGuid().ToString());
// show model summary
model.summary();
// compile keras model into tensorflow's static graph
model.compile(loss: keras.losses.MeanSquaredError(name: Guid.NewGuid().ToString()),
optimizer: keras.optimizers.Adam(name: Guid.NewGuid().ToString()),
metrics: new[] { "accuracy" });
return model;
}
[Fact]
public void Test3Multithreading()
{
//Arrange
string savefile = "mymodel3.h5";
var model = BuildModel();
model.save_weights(savefile);
var TensorflowMutex = new object();
//Sanity check without multithreading
for (int i = 0; i < 2; i++)
{
Functional clone;
lock (TensorflowMutex)
{
clone = BuildModel();
}
clone.load_weights(savefile);
//Predict something
clone.predict(np.array(new float[,] { { 0, 0 } }));
}
//act
ParallelOptions parallelOptions = new ParallelOptions();
parallelOptions.MaxDegreeOfParallelism = 2;
Parallel.For(0, 2, parallelOptions, i =>
{
Functional clone;
lock (TensorflowMutex)
{
clone = BuildModel();
}
clone.load_weights(savefile);
//Predict something
clone.predict(np.array(new float[,] { { 0, 0 } }));
});
}
`
Depending on some race-conditions I have various exceptions:
"Attempting to capture an EagerTensor without building a function." on the clone.predict
"SafeHandle cannot be null. (Parameter 'pHandle')" during the building of model var outputs = DenseLayer.Apply(inputs);
How can I properly do multithreading with tensorflow.net ?
TestProject1.zip