Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Gradient descent optimizer tests #1184

Merged
merged 12 commits into from
Nov 10, 2023

Conversation

novikov-alexander
Copy link
Contributor

No description provided.

}
}

public void assertAllCloseAccordingToType<T>(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wanglongzhi2001 Does it make sense not to duplicate all assertion code but combine all in one assembly?

@Oceania2018
Copy link
Member

 Failed TestBasic [1 s]
  Error Message:
   Test method Tensorflow.Keras.UnitTest.Optimizers.GradientDescentOptimizerTest.TestBasic threw exception: 
System.NullReferenceException: Object reference not set to an instance of an object.
  Stack Trace:
      at TensorFlowNET.UnitTest.PythonTest.evaluate[T](Tensor tensor) in D:\a\TensorFlow.NET\TensorFlow.NET\test\TensorFlowNET.UnitTest\PythonTest.cs:line 245
   at Tensorflow.Keras.UnitTest.Optimizers.GradientDescentOptimizerTest.TestBasicGeneric[T]() in D:\a\TensorFlow.NET\TensorFlow.NET\test\TensorFlowNET.UnitTest\Training\GradientDescentOptimizerTests.cs:line 40
   at Tensorflow.Keras.UnitTest.Optimizers.GradientDescentOptimizerTest.TestBasic() in D:\a\TensorFlow.NET\TensorFlow.NET\test\TensorFlowNET.UnitTest\Training\GradientDescentOptimizerTests.cs:line 62

@novikov-alexander novikov-alexander marked this pull request as ready for review November 7, 2023 19:29
@@ -196,17 +245,25 @@ public T evaluate<T>(Tensor tensor)
// return self._eval_helper(tensors)
// else:
{
var sess = tf.Session();
var sess = tf.get_default_session();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wanglongzhi2001 The same functions exist in another assembly.
Should we generalize them in one common assembly in another PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, if you would like to.

{
throw new ValueError(@"The config used to get the cached session is
different than the one that was used to create the
session. Maybe create a new session with
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wanglongzhi2001 this code is also copy pasted from my previous PR because that's how test architecture implemented right now. Should I move it to one common assembly?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think moving to a common namespace is a good idea if it could be re-used by other modules. Please open an another PR if you would like to do that.

@@ -196,17 +245,25 @@ public T evaluate<T>(Tensor tensor)
// return self._eval_helper(tensors)
// else:
{
var sess = tf.Session();
var sess = tf.get_default_session();
var ndarray = tensor.eval(sess);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wanglongzhi2001 what is the reason tensor.eval used there?
Maybe replace by sess.run(tensor) how it is made in original tests?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's to map the tensor to a node in the session (graph).

self.assertAllCloseAccordingToType(
new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 },
self.evaluate<T[]>(var1));
// TODO: self.assertEqual(0, len(optimizer.variables()));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wanglongzhi2001 Do I understand correctly that this one is not applicable for TensorFlow.NET?

var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype);
var grads0 = tf.constant(new[] { 0.1, 0.1 }, dtype: dtype);
var grads1 = tf.constant(new[] { 0.01, 0.01 }, dtype: dtype);
var optimizer = tf.train.GradientDescentOptimizer(3.0f);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wanglongzhi2001 TensorFlow can accept lambda in this constructor and there's a test testing it.
Is it intentionally that TensorFlow.NET doesn't support that interface?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, it's more like a compromise to the early fast-development. Therefore only a smallest and most-used implementation was provided. I think it's feasible to support it.

@Wanglongzhi2001
Copy link
Collaborator

@novikov-alexander Please fix the ci error.

@novikov-alexander novikov-alexander force-pushed the alnovi/optimizer_tests branch 3 times, most recently from 9c8d624 to e6c7c79 Compare November 9, 2023 22:19
@novikov-alexander
Copy link
Contributor Author

@novikov-alexander Please fix the ci error.

Not clear how they interfer. I'll try to debug.

@novikov-alexander novikov-alexander force-pushed the alnovi/optimizer_tests branch 3 times, most recently from 239e8de to 178296b Compare November 10, 2023 03:57
@novikov-alexander novikov-alexander changed the title Gradient descent optimizer tests test: Gradient descent optimizer tests Nov 10, 2023
@novikov-alexander
Copy link
Contributor Author

@Wanglongzhi2001 That's not clear.. I'm afraid that LinearRegression is flaky one. I just put my tests near and it caused some fail inside. But there are no visible reasons. Moreover it stopped to fail when I changed namespace of my tests.

@novikov-alexander
Copy link
Contributor Author

All checks have passed

@Wanglongzhi2001
Copy link
Collaborator

@Oceania2018 LGTM.

Copy link
Collaborator

@AsakusaRinne AsakusaRinne left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you a lot for this contribution! All things seems to be well done. If you'd like to change anything mentioned in the previous comments, please open a new PR. :)

@@ -196,17 +245,25 @@ public T evaluate<T>(Tensor tensor)
// return self._eval_helper(tensors)
// else:
{
var sess = tf.Session();
var sess = tf.get_default_session();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, if you would like to.

@@ -196,17 +245,25 @@ public T evaluate<T>(Tensor tensor)
// return self._eval_helper(tensors)
// else:
{
var sess = tf.Session();
var sess = tf.get_default_session();
var ndarray = tensor.eval(sess);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's to map the tensor to a node in the session (graph).

{
throw new ValueError(@"The config used to get the cached session is
different than the one that was used to create the
session. Maybe create a new session with
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think moving to a common namespace is a good idea if it could be re-used by other modules. Please open an another PR if you would like to do that.

var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype);
var grads0 = tf.constant(new[] { 0.1, 0.1 }, dtype: dtype);
var grads1 = tf.constant(new[] { 0.01, 0.01 }, dtype: dtype);
var optimizer = tf.train.GradientDescentOptimizer(3.0f);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, it's more like a compromise to the early fast-development. Therefore only a smallest and most-used implementation was provided. I think it's feasible to support it.

@AsakusaRinne AsakusaRinne merged commit d020897 into SciSharp:master Nov 10, 2023
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants