Skip to content

Commit 59672a4

Browse files
committed
1: change learning rate and lr_tensor.
2: override _prepare() for AdamOptimizer. 3: fix key name if _get_non_slot_variable.
1 parent 9a71f75 commit 59672a4

File tree

4 files changed

+22
-14
lines changed

4 files changed

+22
-14
lines changed

src/TensorFlowNET.Core/Train/AdamOptimizer.cs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ private Operation _apply_sparse_shared(Tensor grad, RefVariable var, Tensor indi
4646
var lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power));
4747
var m = get_slot(var, "m");
4848
var m_scaled_g_values = grad * (1 - beta1_t);
49-
var m_t = state_ops.assign(m, m * beta1_t, use_locking: _use_locking);
49+
var mul = m * beta1_t;
50+
var m_t = state_ops.assign(m, mul, use_locking: _use_locking);
5051
with(ops.control_dependencies(new[] { m_t }), delegate
5152
{
5253
m_t = scatter_add(m, indices, m_scaled_g_values);
@@ -88,9 +89,15 @@ protected override void _create_slots(RefVariable[] var_list)
8889

8990
public override void _prepare()
9091
{
91-
//copied from GradientDescentOptimizer
92-
LearningRate = _call_if_callable(LearningRate);
93-
LearningRateTensor = ops.convert_to_tensor(LearningRate, name: "learning_rate");
92+
var lr = _call_if_callable(_lr);
93+
var beta1 = _call_if_callable(_beta1);
94+
var beta2 = _call_if_callable(_beta2);
95+
var epsilon = _call_if_callable(_epsilon);
96+
97+
_lr_t = ops.convert_to_tensor(lr, name: "learning_rate");
98+
_beta1_t = ops.convert_to_tensor(beta1, name: "beta1");
99+
_beta2_t = ops.convert_to_tensor(beta2, name: "beta2");
100+
_epsilon_t = ops.convert_to_tensor(epsilon, name: "epsilon");
94101
}
95102
}
96103
}

src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,13 @@ public class GradientDescentOptimizer : Optimizer
2626
public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent")
2727
: base(learning_rate, use_locking, name)
2828
{
29-
LearningRate = learning_rate;
30-
LearningRateTensor = null;
29+
_lr = learning_rate;
3130
}
3231

3332
public override void _prepare()
3433
{
35-
LearningRate = _call_if_callable(LearningRate);
36-
LearningRateTensor = ops.convert_to_tensor(LearningRate, name: "learning_rate");
34+
var lr = _call_if_callable(_lr);
35+
_lr_t = ops.convert_to_tensor(lr, name: "learning_rate");
3736
}
3837
}
3938
}

src/TensorFlowNET.Core/Train/Optimizer.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@ public abstract class Optimizer : Trackable
2323

2424
string _name;
2525
public string Name => _name;
26-
public float LearningRate { get; set; }
27-
public Tensor LearningRateTensor { get; set; }
26+
protected float _lr;
27+
public float LearningRate => _lr;
28+
protected Tensor _lr_t;
29+
public Tensor LearningRateTensor => _lr_t;
2830
public bool _use_locking;
2931
public Dictionary<string, Dictionary<string, RefVariable>> _slots;
3032
public Dictionary<string, RefVariable> _non_slot_dict;
@@ -38,7 +40,7 @@ public Optimizer(float learning_rate, bool use_locking, string name = null)
3840

3941
_name = name;
4042
_use_locking = use_locking;
41-
LearningRate = learning_rate;
43+
_lr = learning_rate;
4244
// Dictionary of slots.
4345
_slots = new Dictionary<string, Dictionary<string, RefVariable>>();
4446
_non_slot_dict = new Dictionary<string, RefVariable>();
@@ -302,7 +304,7 @@ private string _var_key(RefVariable var)
302304

303305
protected RefVariable _get_non_slot_variable(string name, Graph graph = null)
304306
{
305-
var key = $"{graph.graph_key}.{name}";
307+
var key = $"{name}.{graph.graph_key}";
306308
var non_slot = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null;
307309

308310
return non_slot;

src/TensorFlowNET.Core/Variables/state_ops.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ public static Tensor assign(Tensor @ref, object value,
3636
validate_shape: validate_shape,
3737
use_locking: use_locking,
3838
name: name);
39-
else
40-
throw new NotImplementedException("state_ops.assign");
39+
throw new NotImplementedException("state_ops.assign");
40+
//return @ref.assign(value, name: name);
4141
}
4242

4343
public static Tensor assign_sub(RefVariable @ref,

0 commit comments

Comments
 (0)