Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix foward forward #60972

Merged
merged 1 commit into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/jit/compilation_unit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ std::shared_ptr<BaseEngine> CompilationUnit::GetEngine(
engine_map_.count(name),
1,
phi::errors::InvalidArgument(
"Funciton named %s is not existed in engine_map_.", name));
"Function named %s is not existed in engine_map_.", name));
return engine_map_.at(name);
}

Expand Down
16 changes: 8 additions & 8 deletions paddle/fluid/jit/layer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ const std::shared_ptr<jit::FunctionInfo>& Layer::FunctionInfo(
info_map_.count(name),
1,
phi::errors::InvalidArgument(
"FuncitonInfo named %s is not existed in info_map_.", name));
"FunctionInfo named %s is not existed in info_map_.", name));
return info_map_.at(name);
}

Expand All @@ -77,7 +77,7 @@ std::vector<std::string> Layer::FunctionNames() const {
return names;
}

#define PD_SPECIALZE_ATTRIBUTE_TYPE(T) \
#define PD_SPECIALIZE_ATTRIBUTE_TYPE(T) \
template <> \
T Layer::Attribute<T>(const std::string& name) const { \
if (attrs_map_->find(name) == attrs_map_->end()) { \
Expand All @@ -90,12 +90,12 @@ std::vector<std::string> Layer::FunctionNames() const {
return ret; \
}

PD_SPECIALZE_ATTRIBUTE_TYPE(int)
PD_SPECIALZE_ATTRIBUTE_TYPE(float)
PD_SPECIALZE_ATTRIBUTE_TYPE(framework::String)
PD_SPECIALZE_ATTRIBUTE_TYPE(std::vector<int>)
PD_SPECIALZE_ATTRIBUTE_TYPE(std::vector<float>)
PD_SPECIALZE_ATTRIBUTE_TYPE(std::vector<std::string>)
PD_SPECIALIZE_ATTRIBUTE_TYPE(int)
PD_SPECIALIZE_ATTRIBUTE_TYPE(float)
PD_SPECIALIZE_ATTRIBUTE_TYPE(framework::String)
PD_SPECIALIZE_ATTRIBUTE_TYPE(std::vector<int>)
PD_SPECIALIZE_ATTRIBUTE_TYPE(std::vector<float>)
PD_SPECIALIZE_ATTRIBUTE_TYPE(std::vector<std::string>)

std::shared_ptr<Layer> Layer::Clone(void* stream) {
std::shared_ptr<Layer> x =
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/generator/parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,11 +627,11 @@ def validate_backward_attrs(op, forward_attrs, backward_attrs):
def validate_backward_inputs(
op, forward_inputs, forward_outputs, backward_inputs
):
foward_input_names = [item["name"] for item in forward_inputs]
forward_input_names = [item["name"] for item in forward_inputs]
forward_output_names = [item["name"] for item in forward_outputs]
backward_input_names = [item["name"] for item in backward_inputs]

assert len(backward_input_names) <= len(foward_input_names) + 2 * len(
assert len(backward_input_names) <= len(forward_input_names) + 2 * len(
forward_output_names
), f"{op } has too many inputs."

Expand Down
16 changes: 8 additions & 8 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
extern std::set<std::string> onednn_only_op_set;
"""

GET_OP_LIST_TEMPALTE = """{}
GET_OP_LIST_TEMPLATE = """{}
"""

DECLARE_OP_TYPE_ID = """
Expand Down Expand Up @@ -338,7 +338,7 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{


def to_phi_and_fluid_op_name(op_item):
# Templat: - op : phi_name (fluid_name)
# Template: - op : phi_name (fluid_name)
names = op_item.split('(')
if len(names) == 1:
phi_fluid_name = names[0].strip()
Expand All @@ -350,7 +350,7 @@ def to_phi_and_fluid_op_name(op_item):


def to_phi_and_fluid_grad_op_name(op_item):
# Templat: sum_grad (reduce_sum_grad), sum_double_grad
# Template: sum_grad (reduce_sum_grad), sum_double_grad
rtn = []
all_names = op_item.split(', ')
for name in all_names:
Expand Down Expand Up @@ -1499,7 +1499,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
# =================================== #
# gen GetOpInfo func str #
# =================================== #
# generate get op info funciton: inputs
# generate get op info function: inputs
input_info_list = []
for idx in range(len(op_input_name_list)):
input_info_list.append(
Expand Down Expand Up @@ -1529,7 +1529,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
inputs_info_str = ", ".join(input_info_list)
else:
inputs_info_str = ""
# generate get op info funciton: outputs
# generate get op info function: outputs
outputs_info_str = ""
if len(op_output_name_list) > 0:
output_info_list = []
Expand All @@ -1543,7 +1543,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
)
)
outputs_info_str = ", ".join(output_info_list)
# generate get op info funciton: attributes
# generate get op info function: attributes
attribute_info_str = ""
if len(op_non_mutable_attribute_name_list) > 0:
attribute_info_list = []
Expand Down Expand Up @@ -1866,7 +1866,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
ops_name_with_namespace_list = []
for name in ops_name_list:
ops_name_with_namespace_list.append(op_namespaces_prev + name)
op_list_str = GET_OP_LIST_TEMPALTE.format(
op_list_str = GET_OP_LIST_TEMPLATE.format(
", ".join(ops_name_with_namespace_list)
)

Expand Down Expand Up @@ -2154,7 +2154,7 @@ def OpGenerator(

# (6) write to files for xx_vjp_op.cc.tmp
# NOTE(Aurelius84): op_gen.py is called multiply times,
# and vjp is only avaible for pd dialect.
# and vjp is only available for pd dialect.
vjp_source_file_str = "\n".join(vjp_source_file_strs)
vjp_source_file_str = VJP_CC_FILE_TEMPLATE.format(input=vjp_source_file_str)
if (
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/dialect/op_generator/op_infermeta_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def get_infermeta_inputs_str(
infermeta_inputs_str += "\n"

infermeta_inputs_str += ' VLOG(4) << "Builder construction outputs";\n'
# Prepar input type
# Prepare input type
for idx in range(len(op_input_name_list)):
if op_input_name_list[idx] not in inuse_infer_meta_args:
continue
Expand Down Expand Up @@ -367,7 +367,7 @@ def GenBuildOutputsPart2(
elif attr_dtype[0] == "pir::StrAttribute":
build_output_str += ""
else:
assert "mutable attribtue type is not right."
assert "mutable attribute type is not right."
build_output_str += "\n"

# Prepare inputs_meta_tensor & attributes for infer meta
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/op_generator/op_interface_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}";
{attribute_code}

VLOG(6) << "Vjp prepare call {op_phi_name}'s vjp inteface";
VLOG(6) << "Vjp prepare call {op_phi_name}'s vjp interface";
{call_vjp_code}

VLOG(6) << "Vjp prepare stop gradient of {op_grad_name}";
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ std::vector<std::vector<pir::OpResult>> WhileOp::Vjp(
PADDLE_ENFORCE_EQ(push_op.container().use_empty(),
true,
phi::errors::InvalidArgument(
"The last container in foward while op must used "
"The last container in forward while op must used "
"empty while construct while_grad op"));
break;
}
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/api/yaml/generator/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,12 +494,12 @@ def generate_api(
source_file.write(namespace[0])

for api in apis:
foward_api = ForwardAPI(api)
if foward_api.is_dygraph_api:
foward_api.is_dygraph_api = False
forward_api = ForwardAPI(api)
if forward_api.is_dygraph_api:
forward_api.is_dygraph_api = False

header_file.write(foward_api.gene_api_declaration())
source_file.write(foward_api.gene_api_code())
header_file.write(forward_api.gene_api_declaration())
source_file.write(forward_api.gene_api_code())

header_file.write(namespace[1])
source_file.write(namespace[1])
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/api/yaml/generator/dist_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1889,15 +1889,15 @@ def generate_api(
source_file.write(namespace[0])

for api in apis:
dist_foward_api = DistForwardAPI(api)
if dist_foward_api.is_dygraph_api:
dist_foward_api.is_dygraph_api = False
dist_forward_api = DistForwardAPI(api)
if dist_forward_api.is_dygraph_api:
dist_forward_api.is_dygraph_api = False

header_file.write(dist_foward_api.gene_api_declaration())
header_file.write(dist_forward_api.gene_api_declaration())
if is_fused_ops_yaml is True:
source_file.write(dist_foward_api.gene_api_code())
source_file.write(dist_forward_api.gene_api_code())
else:
source_file.write(dist_foward_api.gene_api_code())
source_file.write(dist_forward_api.gene_api_code())

header_file.write(namespace[1])
source_file.write(namespace[1])
Expand Down
10 changes: 6 additions & 4 deletions paddle/phi/api/yaml/generator/intermediate_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,12 @@ def generate_intermediate_api(
apis.extend(api_list)

for api in apis:
foward_api = DistForwardAPI(api) if gen_dist_branch else ForwardAPI(api)
if foward_api.is_dygraph_api:
dygraph_header_file.write(foward_api.gene_api_declaration())
dygraph_source_file.write(foward_api.gene_api_code())
forward_api = (
DistForwardAPI(api) if gen_dist_branch else ForwardAPI(api)
)
if forward_api.is_dygraph_api:
dygraph_header_file.write(forward_api.gene_api_declaration())
dygraph_source_file.write(forward_api.gene_api_code())

dygraph_header_file.write(sparse_namespace_pair[0])
dygraph_source_file.write(sparse_namespace_pair[0])
Expand Down