Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export MultiTensor Update and FuseUpdateCast to GraphConfig #9209

Merged
merged 5 commits into from
Oct 10, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions oneflow/core/job/job_conf.proto
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ message JobConfigProto {
optional bool enable_auto_mixed_precision = 602 [default = false];
optional bool enable_quantization_aware_training = 603 [default = false];
optional DataType mixed_precision_data_type = 604 [default = kFloat16]; // kFloat16 or kBFloat16
optional bool enable_multi_tensor_update = 605 [default = false];
optional bool enable_fused_model_update_cast = 606 [default = false];

optional bool enable_auto_parallel = 700 [default = false];
optional double auto_parallel_computation_cost_ratio = 701 [default = 0.05];
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/job/job_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class JobDesc final {
bool enable_reuse_mem() const { return job_conf_.enable_reuse_mem(); }
bool enable_inplace() const { return job_conf_.enable_inplace(); }
bool enable_auto_mixed_precision() const { return job_conf_.enable_auto_mixed_precision(); }
bool enable_multi_tensor_update() const { return job_conf_.enable_multi_tensor_update(); }
bool enable_fused_model_update_cast() const { return job_conf_.enable_fused_model_update_cast(); }
DataType mixed_precision_data_type() const { return job_conf_.mixed_precision_data_type(); }
bool do_parallel_cast_before_widening_type_cast() const {
return job_conf_.do_parallel_cast_before_widening_type_cast();
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/job_rewriter/fuse_model_update_cast_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ class FuseModelUpdateCastOpsPass final : public JobPass {
~FuseModelUpdateCastOpsPass() override = default;

bool IsEnabled(const JobPassCtx& ctx) const {
return ParseBooleanFromEnv("ONEFLOW_FUSE_MODEL_UPDATE_CAST", false)
return (ctx.job_desc().enable_fused_model_update_cast()
|| ParseBooleanFromEnv("ONEFLOW_FUSE_MODEL_UPDATE_CAST", false))
&& ctx.job_desc().enable_auto_mixed_precision();
}
Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/job_rewriter/multi_tensor_model_update.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ class MultiTensorModelUpdatePass final : public JobPass {
~MultiTensorModelUpdatePass() override = default;

bool IsEnabled(const JobPassCtx& ctx) const {
return ParseBooleanFromEnv("ONEFLOW_ENABLE_MULTI_TENSOR_MODEL_UPDATE", false);
return ctx.job_desc().enable_multi_tensor_update()
|| ParseBooleanFromEnv("ONEFLOW_ENABLE_MULTI_TENSOR_MODEL_UPDATE", false);
}
Maybe<void> Apply(const OpGraph& op_graph, JobBuilder* job_builder) const;

Expand Down
12 changes: 12 additions & 0 deletions python/oneflow/nn/graph/graph_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,18 @@ def enable_auto_parallel_sbp_collector(self, mode: bool = True):
"""
self.proto.enable_auto_parallel_sbp_collector = mode

def enable_multi_tensor_update(self, mode: bool = True):
"""
Enable Multi Tensor Update Pass, it will merge small optimizer kernels to reduce kernel launch overhead.
"""
self.proto.enable_multi_tensor_update = mode

def enable_fused_model_update_cast(self, mode: bool = True):
"""
This option only works in AMP Mode, it will fuse optimizer update and model weights cast to half precision operation.
"""
self.proto.enable_fused_model_update_cast = mode

def _generate_optimizer_and_variable_configs(
self, opt_dict: OptDict = None, variables_conf: OrderedDict = None,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ def compare_with_numpy_adam(
do_bias_correction,
amsgrad,
):
os.environ["ONEFLOW_ENABLE_MULTI_TENSOR_MODEL_UPDATE"] = "1"
os.environ["ONEFLOW_FUSE_MODEL_UPDATE_CAST"] = "1"

random_weight_seq = []
init_value_seq = []

Expand Down Expand Up @@ -103,6 +100,8 @@ def __init__(self):
self.add_optimizer(adam0)
self.config.enable_amp(True)
self.config.allow_fuse_model_update_ops(True)
self.config.enable_multi_tensor_update(True)
self.config.enable_fused_model_update_cast(True)

def build(self, mask_tensor_list):
loss = flow.sum(self.m(mask_tensor_list))
Expand Down Expand Up @@ -193,5 +192,3 @@ def test_multi_tensor_adam(test_case):

if __name__ == "__main__":
unittest.main()
os.environ["ONEFLOW_ENABLE_MULTI_TENSOR_MODEL_UPDATE"] = "0"
os.environ["ONEFLOW_FUSE_MODEL_UPDATE_CAST"] = "0"
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@
def compare_with_numpy_sgd(
test_case, device, x_shape, tensor_num, learning_rate, train_iters, weight_decay
):
os.environ["ONEFLOW_ENABLE_MULTI_TENSOR_MODEL_UPDATE"] = "1"
os.environ["ONEFLOW_FUSE_MODEL_UPDATE_CAST"] = "1"

random_weight_seq = []
init_value_seq = []

Expand Down Expand Up @@ -89,6 +86,8 @@ def __init__(self):
self.add_optimizer(sgd0)
self.config.enable_amp(True)
self.config.allow_fuse_model_update_ops(True)
self.config.enable_multi_tensor_update(True)
self.config.enable_fused_model_update_cast(True)

def build(self, mask_tensor_list):
loss = flow.sum(self.m(mask_tensor_list))
Expand Down Expand Up @@ -155,5 +154,3 @@ def test_multi_tensor_sgd(test_case):

if __name__ == "__main__":
unittest.main()
os.environ["ONEFLOW_ENABLE_MULTI_TENSOR_MODEL_UPDATE"] = "0"
os.environ["ONEFLOW_FUSE_MODEL_UPDATE_CAST"] = "0"