Skip to content

Commit

Permalink
bug fix (#21)
Browse files Browse the repository at this point in the history
Fix ODPS table io_slicing type error.
Fix restoring checkpoint in distributed evaluation.
Fix hang when enabling amp dynamic loss scale and gradient accumulation.
  • Loading branch information
SeaOfOcean committed Sep 16, 2022
1 parent 6388b3c commit 4178818
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 23 deletions.
2 changes: 1 addition & 1 deletion epl/parallel/graph_editor.py
Expand Up @@ -117,7 +117,7 @@ def table_io_slicing(self, dataset_api_op):
"""Slicing table to balance data load among all model replicas."""
slice_id = 0
all_devices = dataset_api_op.taskgraph.virtual_device.all_devices
list.sort(all_devices)
all_devices = sorted(all_devices)
if self._graph.num_constructors > 1:
total_num_slices = len(all_devices)
for idx, dev in enumerate(all_devices):
Expand Down
27 changes: 15 additions & 12 deletions epl/parallel/hooks.py
Expand Up @@ -210,23 +210,25 @@ def apply_gradients(self, grads_and_vars, *args, **kwargs):
with ModelPhase(ModelPhase.APPLY):
if not apply_opt:
apply_fn = lambda: fn(self, grads_and_vars, *args, **kwargs)
elif zero_enabled():
apply_fn = lambda: apply_zero(self, fn, grads_and_vars,
global_step, ga_iters,
num_apply_group, name)
elif ga_enabled():
apply_fn = lambda: apply_ga(self, fn, grads_and_vars,
global_step, ga_iters,
num_apply_group, name)
elif num_apply_group > 1:
apply_fn = lambda: apply_grad_group(self, fn, grads_and_vars,
num_apply_group,
global_step, name=name)
else:
apply_fn = lambda: fn(self, grads_and_vars, *args, **kwargs)
if zero_enabled():
apply_fn = lambda: apply_zero(self, fn, grads_and_vars,
global_step, ga_iters,
num_apply_group, name)

elif num_apply_group > 1:
apply_fn = lambda: apply_grad_group(self, fn, grads_and_vars,
num_apply_group,
global_step, name=name)
else:
apply_fn = lambda: fn(self, grads_and_vars, *args, **kwargs)

if apply_opt and amp_enabled() and Env.get().config.amp.loss_scale == "dynamic":
return amp_update(grads_and_vars, apply_fn, name)
if amp_enabled() and Env.get().config.amp.loss_scale == "dynamic":
return amp_update(grads_and_vars, apply_fn, name)
return apply_fn()
return apply_gradients

Expand Down Expand Up @@ -578,7 +580,8 @@ def restore(self, sess, save_path):
# TODO(wangang.wa): This code will be removed after merging
# variables for split strategy.
if Graph.get().first_constructor_rank == Env.get().cluster.worker_index or \
any(taskgraph.strategy_context.split_strategy is not None for taskgraph in Graph.get().taskgraphs):
any(taskgraph.strategy_context.split_strategy is not None for taskgraph in Graph.get().taskgraphs) or \
not Graph.get().need_parallel:
with ModelPhase(ModelPhase.SAVE_AND_RESTORE):
ret = fn(self, sess, save_path)
return ret
Expand Down
11 changes: 9 additions & 2 deletions epl/runtime/gradient_accumulation.py
Expand Up @@ -33,6 +33,9 @@
from epl.ir.graph import Graph
from epl.runtime.optimizer_helper import filter_none_grads, \
apply_grad_group
from epl.runtime.amp.loss_scale import amp_update
from epl.runtime.amp.auto_mixed_precision import amp_enabled


def ga_iter_num():
"""Return gradient accumulation iteration number."""
Expand Down Expand Up @@ -72,8 +75,12 @@ def apply_accmulation(optimizer, apply_gradients_fn,
grads_and_vars.append((g, v))
Graph.get().add_grads_and_vars(grads_and_vars)
update_ops = []
apply_op = apply_grad_group(optimizer, apply_gradients_fn, grads_and_vars,
ngroup, global_step, "epl_apply_grad_ga")
apply_fn = lambda: apply_grad_group(optimizer, apply_gradients_fn, grads_and_vars,
ngroup, global_step, "epl_apply_grad_ga")
if amp_enabled() and Env.get().config.amp.loss_scale == "dynamic":
apply_op = amp_update(grads_and_vars, apply_fn, 'amp_update')
else:
apply_op = apply_fn()
update_ops.append(apply_op)
with ops.control_dependencies(update_ops):
clear_ops = [state_ops.assign(s, array_ops.zeros_like(s)) for s in slots]
Expand Down
2 changes: 1 addition & 1 deletion epl/utils/version.py
Expand Up @@ -14,4 +14,4 @@
# =============================================================================
"""EPL version."""

VERSION = "0.3.0"
VERSION = "0.6.0"
3 changes: 2 additions & 1 deletion tests/Makefile
Expand Up @@ -8,8 +8,9 @@ GPU_8 = 0,1,2,3,4,5,6,7

.PHONY: test
test:
PYTHONPATH=../ $(PYTHON) -m epl.utils.launcher --num_workers=2 --gpu_per_worker=1 --debug=True test_launcher.sh
CUDA_VISIBLE_DEVICES='' ./launch.sh header_test.py
PYTHONPATH=../ $(PYTHON) -m epl.utils.launcher --num_workers=2 --gpu_per_worker=1 --debug=True test_launcher.sh
PYTHONPATH=../ $(PYTHON) -m epl.utils.launcher --num_workers=2 --gpu_per_worker=1 --debug=True test_amp_parallel.sh
CUDA_VISIBLE_DEVICES=$(GPU_4) ./launch.sh strategy_new_test.py
CUDA_VISIBLE_DEVICES=$(GPU_4) ./launch.sh auto_cluster_test.py
CUDA_VISIBLE_DEVICES=$(GPU_2) ./launch.sh estimator_test.py
Expand Down
34 changes: 28 additions & 6 deletions tests/dnn_data_parallel.py
Expand Up @@ -23,11 +23,19 @@
import tensorflow as tf
import epl

tf.logging.set_verbosity(tf.logging.INFO)
flags = tf.app.flags
flags.DEFINE_integer("max_steps", 10, "max training step")
flags.DEFINE_float("learning_rate", 0.001, "learning_rate")
flags.DEFINE_integer("num_micro_batch", 1, "num_micro_batch")
flags.DEFINE_string("amp", None, "amp")
FLAGS = tf.app.flags.FLAGS

epl.init()
config_json = {}
if FLAGS.amp:
config_json["amp.level"] = "o1"
config_json["amp.loss_scale"] = float(FLAGS.amp) if FLAGS.amp != "dynamic" else "dynamic"
config_json["pipeline.num_micro_batch"] = FLAGS.num_micro_batch
epl.init(epl.Config(config_json))
epl.set_default_strategy(epl.replicate(1))

# dataset
Expand All @@ -43,12 +51,26 @@
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

global_step = tf.train.get_or_create_global_step()
optimizer = tf.train.MomentumOptimizer(learning_rate=0.001, momentum=0.9)
optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, momentum=0.9)
train_op = optimizer.minimize(loss, global_step=global_step)

hooks = [tf.train.StopAtStepHook(last_step=FLAGS.max_steps)]
max_steps = (FLAGS.max_steps+1) * FLAGS.num_micro_batch
cum_steps = 0
with tf.train.MonitoredTrainingSession(hooks=hooks) as sess:
while not sess.should_stop():
train_loss, _, step = sess.run([loss, train_op, global_step])
print("Iteration %s , Loss: %s ." % (step, train_loss))
while not sess.should_stop() or cum_steps > max_steps:
train_ops = [loss, train_op, global_step]
if FLAGS.amp:
train_ops.append(epl.Env.get().parallel_information["AMP_LOSS_SCALE"]._num_good_steps) # pylint: disable=protected-access
train_ops.append(epl.Env.get().parallel_information["AMP_LOSS_SCALE"]._current_loss_scale) # pylint: disable=protected-access
res = sess.run(train_ops)
print("Iteration %s , Loss: %s ." % (res[2], res[0]))
if FLAGS.amp:
num_good_steps = res[3]
current_loss_scale = res[4]
if FLAGS.learning_rate >= 100:
assert num_good_steps <= 1
assert res[2] <= 1
print('good_steps: {}, current_loss_scale: {}'.format(num_good_steps, current_loss_scale))
cum_steps += 1
print("Train Finished.")
4 changes: 4 additions & 0 deletions tests/test_amp_parallel.sh
@@ -0,0 +1,4 @@
python dnn_data_parallel.py \
--amp=dynamic \
--num_micro_batch=3 \
--learning_rate=1000

0 comments on commit 4178818

Please sign in to comment.