Skip to content

Commit

Permalink
Upgrade paddle API used in mixed precision training (PaddlePaddle#227)
Browse files Browse the repository at this point in the history
  • Loading branch information
willthefrog committed Feb 13, 2020
1 parent 59b7049 commit 2835d5a
Showing 1 changed file with 32 additions and 38 deletions.
70 changes: 32 additions & 38 deletions ppdet/experimental/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,30 +129,27 @@ def __init__(self, init_loss_scale=2**15, increment_every=2000, factor=2.):
def increment(self):
enough_steps = layers.less_than(self.increment_every,
self.good_steps + 1)
with layers.Switch() as switch:
with switch.case(enough_steps):
new_scale = self.scale * self.factor
scale_valid = layers.isfinite(new_scale)
with layers.Switch() as switch2:
with switch2.case(scale_valid):
layers.assign(new_scale, self.scale)
layers.assign(
layers.zeros_like(self.good_steps), self.good_steps)
with switch2.default():
layers.increment(self.good_steps)
with switch.default():
layers.increment(self.good_steps)

def increment_step():
layers.increment(self.good_steps)

def maybe_update():
new_scale = self.scale * self.factor
scale_valid = layers.isfinite(new_scale)

def update_scale_and_step():
layers.assign(new_scale, self.scale)
layers.assign(
layers.zeros_like(self.good_steps), self.good_steps)

layers.cond(scale_valid, update_scale_and_step)

layers.cond(enough_steps, maybe_update, increment_step)

def decrement(self):
new_scale = self.scale / self.factor
one = layers.fill_constant(shape=[1], dtype='float32', value=1.0)
less_than_one = layers.less_than(new_scale, one)
with layers.Switch() as switch:
with switch.case(less_than_one):
layers.assign(one, self.scale)
with switch.default():
layers.assign(new_scale, self.scale)

layers.assign(layers.elementwise_max(new_scale, one), self.scale)
layers.assign(layers.zeros_like(self.good_steps), self.good_steps)


Expand Down Expand Up @@ -275,12 +272,13 @@ def scale_gradient(block, context):
fwd_var = block._var_recursive(context[name])
if not isinstance(fwd_var, Parameter):
continue # TODO verify all use cases
clip_op_desc = block.desc.append_op()
clip_op_desc.set_type("elementwise_div")
clip_op_desc.set_input("X", [name])
clip_op_desc.set_input("Y", [scale.name])
clip_op_desc.set_output("Out", [name])
clip_op_desc._set_attr(op_role_attr_name, bwd_role)
scale_op_desc = block.desc.append_op()
scale_op_desc.set_type("elementwise_div")
scale_op_desc.set_input("X", [name])
scale_op_desc.set_input("Y", [scale.name])
scale_op_desc.set_output("Out", [name])
scale_op_desc._set_attr("axis", -1)
scale_op_desc._set_attr(op_role_attr_name, bwd_role)


def update_loss_scale(grads):
Expand All @@ -289,12 +287,8 @@ def update_loss_scale(grads):
return
per_grad_check = layers.stack([layers.reduce_sum(g) for g in grads])
grad_valid = layers.isfinite(per_grad_check)

with layers.Switch() as switch:
with switch.case(grad_valid):
state.increment()
with switch.default():
state.decrement()
layers.cond(grad_valid, lambda: state.increment(),
lambda: state.decrement())
return grad_valid


Expand All @@ -309,15 +303,15 @@ def backward(self, loss, **kwargs):
else:
kwargs['callbacks'] = callbacks
param_grads = self._backward(loss, **kwargs)

def zero_grad():
for _, g in param_grads:
layers.assign(layers.zeros_like(g), g)

if state is not None:
grad_valid = update_loss_scale(v for k, v in param_grads)
if state.dynamic_scaling:
with layers.Switch() as switch:
with switch.case(grad_valid):
pass
with switch.default():
for _, g in param_grads:
layers.assign(layers.zeros_like(g), g)
layers.cond(grad_valid, None, zero_grad)

return param_grads

Expand Down

0 comments on commit 2835d5a

Please sign in to comment.