Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tidy up op_conf.proto #3932

Merged
merged 19 commits into from Dec 24, 2020
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions cmake/cfg.cmake
Expand Up @@ -57,6 +57,9 @@ function(GENERATE_CFG_AND_PYBIND11_CPP SRCS HDRS PYBIND_SRCS ROOT_DIR)
oneflow/core/job/parallel_signature.proto
oneflow/core/eager/eager_instruction.proto
oneflow/core/job/cluster_instruction.proto
oneflow/core/job/initializer_conf.proto
oneflow/core/job/regularizer_conf.proto
oneflow/core/job/learning_rate_schedule_conf.proto
oneflow/core/common/cfg_reflection_test.proto
oneflow/core/common/data_type.proto
oneflow/core/common/device_type.proto
Expand Down
93 changes: 93 additions & 0 deletions oneflow/core/job/initializer_conf.proto
@@ -0,0 +1,93 @@
syntax = "proto2";
package oneflow;

message ConstantInitializerConf {
optional float value = 1 [default = 0];
}

message ConstantIntInitializerConf {
optional int64 value = 1 [default = 0];
}

message RandomUniformInitializerConf {
optional float min = 1 [default = 0];
optional float max = 2 [default = 1];
}

message RandomUniformIntInitializerConf {
optional int32 min = 1 [default = 0];
optional int32 max = 2 [default = 1];
}

message RandomNormalInitializerConf {
optional float mean = 1 [default = 0];
optional float std = 2 [default = 1];
}

message TruncatedNormalInitializerConf {
optional float mean = 1 [default = 0.0];
optional float std = 2 [default = 0.05];
}

enum VarianceNorm {
kFanIn = 0;
kFanOut = 1;
kAverage = 2;
}

enum RandomDistribution {
kRandomUniform = 0;
kRandomNormal = 1;
kTruncatedNormal = 2;
}

message XavierInitializerConf {
required VarianceNorm variance_norm = 1;
required string data_format = 2;
}

message MsraInitializerConf {
required VarianceNorm variance_norm = 1;
required string data_format = 2;
}

//output[D_0 ... D_(axis - 1) i D_(axis + 1) ... D_n] = start + i * stride
message RangeInitializerConf {
optional double start = 1 [default = 0];
optional double stride = 2 [default = 1];
optional int64 axis = 3 [default = -1];
}

message IntRangeInitializerConf {
optional int64 start = 1 [default = 0];
optional int64 stride = 2 [default = 1];
optional int64 axis = 3 [default = -1];
}

message VarianceScalingInitializerConf {
required float scale = 1;
required VarianceNorm variance_norm = 2;
required RandomDistribution distribution = 3;
required string data_format = 4;
}

message InitializerConf {
oneof type {
ConstantInitializerConf constant_conf = 1;
ConstantIntInitializerConf constant_int_conf = 2;
RandomUniformInitializerConf random_uniform_conf = 3;
RandomUniformIntInitializerConf random_uniform_int_conf = 4;
RandomNormalInitializerConf random_normal_conf = 5;
TruncatedNormalInitializerConf truncated_normal_conf = 6;
XavierInitializerConf xavier_conf = 7;
MsraInitializerConf msra_conf = 8;
RangeInitializerConf range_conf = 9;
IntRangeInitializerConf int_range_conf = 10;
VarianceScalingInitializerConf variance_scaling_conf = 11;
}
}

message InitializeWithSnapshotConf {
required string path = 1;
optional string key = 2;
}
82 changes: 81 additions & 1 deletion oneflow/core/job/job_conf.proto
Expand Up @@ -4,9 +4,89 @@ package oneflow;
import "oneflow/core/common/data_type.proto";
import "oneflow/core/job/placement.proto";
import "oneflow/core/register/blob_desc.proto";
import "oneflow/core/operator/op_conf.proto";
import "oneflow/core/job/sbp_parallel.proto";
import "oneflow/core/framework/user_op_attr.proto";
import "oneflow/core/job/initializer_conf.proto";
import "oneflow/core/job/learning_rate_schedule_conf.proto";


message NaiveModelUpdateConf {
}

message MomentumModelUpdateConf {
optional float beta = 1 [default = 0.9];
}

message RMSPropModelUpdateConf {
optional float decay_rate = 1 [default = 0.99];
optional float epsilon = 2 [default = 1e-8];
optional bool centered = 3 [default = false];
}

message LARSModelUpdateConf {
optional float momentum_beta = 1 [default = 0.9];
optional float epsilon = 2 [default = 1e-9];
optional float lars_coefficient = 3 [default = 0.0001];
}

message AdamModelUpdateConf {
optional float beta1 = 1 [default = 0.9];
optional float beta2 = 2 [default = 0.999];
optional float epsilon = 3 [default = 1e-8];
optional bool do_bias_correction = 4 [default = false];
}

message LazyAdamModelUpdateConf {
optional float beta1 = 1 [default = 0.9];
optional float beta2 = 2 [default = 0.999];
optional float epsilon = 3 [default = 1e-8];
}

message LambModelUpdateConf {
required float beta1 = 1;
required float beta2 = 2;
required float epsilon = 3;
}

message ClipByGlobalNormConf {
required float clip_norm = 1;
optional float global_norm = 2;
}

message ClipConf {
oneof type {
ClipByGlobalNormConf clip_by_global_norm = 1;
}
}

message WeightDecayFilterPatternSet {
repeated string pattern = 1;
}

message WeightDecayConf {
required float weight_decay_rate = 1;
oneof weight_decay_filter_type {
WeightDecayFilterPatternSet includes = 2;
WeightDecayFilterPatternSet excludes = 3;
}
}

message NormalModelUpdateOpUserConf {
optional LearningRateDecayConf learning_rate_decay = 1;
optional WarmupConf warmup_conf = 2;
optional ClipConf clip_conf = 3;
optional WeightDecayConf weight_decay_conf = 4;
oneof normal_mdupdt {
NaiveModelUpdateConf naive_conf = 1000;
MomentumModelUpdateConf momentum_conf = 1001;
RMSPropModelUpdateConf rmsprop_conf = 1002;
LARSModelUpdateConf lars_conf = 1003;
AdamModelUpdateConf adam_conf = 1004;
LazyAdamModelUpdateConf lazy_adam_conf = 1005;
LambModelUpdateConf lamb_conf = 1006;
}
}


message DynamicLossScalePolicy {
optional float initial_loss_scale = 1 [default = 32768.0];
Expand Down
79 changes: 79 additions & 0 deletions oneflow/core/job/learning_rate_schedule_conf.proto
@@ -0,0 +1,79 @@
syntax = "proto2";
package oneflow;

message ExponentialDecayConf {
required int64 decay_batches = 1;
required double decay_rate = 2;
optional bool staircase = 3 [default = false];
}

message InverseTimeDecayConf {
required int64 decay_batches = 1;
required double decay_rate = 2;
optional bool staircase = 3 [default = false];
}

message NaturalExpDecayConf {
required int64 decay_batches = 1;
required double decay_rate = 2;
optional bool staircase = 3 [default = false];
}

message PiecewiseConstantConf {
repeated int64 boundaries = 1;
repeated double values = 2;
}

message PolynomialDecayConf {
required int64 decay_batches = 1;
optional double end_learning_rate = 2 [default = 0.0001];
optional double power = 3 [default = 1.0];
optional bool cycle = 4 [default = false];
}

message CosineDecayConf {
required int64 decay_batches = 1;
optional double alpha = 2 [default = 0.0];
}

message LinearCosineDecayConf {
required int64 decay_batches = 1;
optional double num_periods = 2 [default = 0.5];
optional double alpha = 3 [default = 0.0];
optional double beta = 4 [default = 0.001];
}

message PiecewiseScalingConf {
repeated int64 boundaries = 1;
repeated double scales = 2;
}

message LearningRateDecayConf {
oneof type {
ExponentialDecayConf exponential_conf = 2000;
InverseTimeDecayConf inverse_time_conf = 2001;
NaturalExpDecayConf natural_exp_conf = 2002;
PiecewiseConstantConf piecewise_constant_conf = 2003;
PolynomialDecayConf polynomial_conf = 2004;
CosineDecayConf cosine_conf = 2005;
LinearCosineDecayConf linear_cosine_conf = 2006;
PiecewiseScalingConf piecewise_scaling_conf = 2007;
}
}

message ConstantWarmupConf {
required int64 warmup_batches = 1;
required double multiplier = 2;
}

message LinearWarmupConf {
required int64 warmup_batches = 1;
required double start_multiplier = 2;
}

message WarmupConf {
oneof type {
ConstantWarmupConf constant_conf = 3000;
LinearWarmupConf linear_conf = 3001;
}
}
5 changes: 4 additions & 1 deletion oneflow/core/job/placement.proto
@@ -1,7 +1,6 @@
syntax = "proto2";
package oneflow;

import "oneflow/core/operator/op_conf.proto";
import "oneflow/core/register/logical_blob_id.proto";

message ParallelContext {
Expand All @@ -14,6 +13,10 @@ message ParallelConf {
required string device_tag = 2;
}

message OpNameSet {
repeated string op_name = 1;
}

message PlacementGroup {
required OpNameSet op_set = 1;
required ParallelConf parallel_conf = 2;
Expand Down
13 changes: 13 additions & 0 deletions oneflow/core/job/regularizer_conf.proto
@@ -0,0 +1,13 @@
syntax = "proto2";
package oneflow;

message L1L2RegularizerConf {
optional float l1 = 1 [default = 0.0];
optional float l2 = 2 [default = 0.0];
}

message RegularizerConf {
oneof type {
L1L2RegularizerConf l1_l2_conf = 1;
}
}