Skip to content

Commit

Permalink
[autodiff] Clear all dual fields when exiting context manager (taichi…
Browse files Browse the repository at this point in the history
  • Loading branch information
erizmr authored and Ailing Zhang committed Aug 10, 2022
1 parent c9091bf commit d63bbfc
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
10 changes: 5 additions & 5 deletions python/taichi/ad/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,11 @@ def shape_flatten(shape):
else:
assert parameters_shape_flatten == len(self.seed)

# Clear gradients
if self.clear_gradients:
# TODO: the clear gradients should be controlled to clear adjoint/dual/adjoint_visited respectively
clear_all_gradients()

# Set seed for each variable
if len(self.seed) == 1:
if len(self.param.shape) == 0:
Expand All @@ -286,11 +291,6 @@ def shape_flatten(shape):
for idx, s in enumerate(self.seed):
self.param.dual[idx] = 1.0 * s

# Clear gradients
if self.clear_gradients:
for ls in self.loss:
ls.dual.fill(0)

# Attach the context manager to the runtime
self.runtime.fwd_mode_manager = self

Expand Down
20 changes: 20 additions & 0 deletions tests/python/test_ad_basics_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,23 @@ def func():

with ti.ad.FwdMode(loss=d, param=c):
func()


@test_utils.test()
def test_clear_all_dual_field():
x = ti.field(float, shape=(), needs_dual=True)
y = ti.field(float, shape=(), needs_dual=True)
loss = ti.field(float, shape=(), needs_dual=True)

x[None] = 2.0
y[None] = 3.0

@ti.kernel
def clear_dual_test():
y[None] = x[None]**2
loss[None] += y[None]

for _ in range(5):
with ti.ad.FwdMode(loss=loss, param=x):
clear_dual_test()
assert y.dual[None] == 4.0

0 comments on commit d63bbfc

Please sign in to comment.