Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
tiancaishaonvjituizi committed May 4, 2022
1 parent 0403b08 commit f353873
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions python/paddle/fluid/tests/unittests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,32 +80,42 @@ def forward(self, x):
return x * self._w

with paddle.fluid.dygraph.guard():
np_neg_ones = np.ones(w_shape) * -1

model = MyLayer()
x = paddle.ones([1, 3, 4])
loss = model(x)
asgd = paddle.optimizer.ASGD(learning_rate=1., parameters=model.parameters(), t0=1)
loss.backward()

np_neg_ones = np.ones(w_shape) * -1
print(f'grad before step: {model._w.grad}')
print(f'w before step: {model._w.numpy()}')
loss = model(x)
loss.backward()
asgd.step()
print(f'w after step: {model._w.numpy()}')
print(f'grad after step: {model._w.grad}')
assert np.allclose(model._w.numpy(), np_neg_ones)
assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones)
print(f'grad: {model._w.grad}')
print(f'w before step: {model._w.numpy()}')
asgd.clear_grad()

loss = model(x)
loss.backward()
asgd.step()
print(f'w after step: {model._w.numpy()}')
assert np.allclose(model._w.numpy(), np_neg_ones * 2)
assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 2)
asgd.clear_grad()

loss = model(x)
loss.backward()
asgd.step()
assert np.allclose(model._w.numpy(), np_neg_ones * 3)
assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 3)
asgd.clear_grad()

loss = model(x)
loss.backward()
asgd.step()
assert np.allclose(model._w.numpy(), np_neg_ones * 4)
assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 3.5)
asgd.clear_grad()

loss = model(x)
loss.backward()
asgd.step()
assert np.allclose(model._w.numpy(), np_neg_ones * 5)
assert np.allclose(asgd.averaged_parameters()[0].numpy(), np_neg_ones * 4)
Expand Down

0 comments on commit f353873

Please sign in to comment.