Skip to content

Commit

Permalink
[Dy2St] Fix typo runable -> runnable (#62779)
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo committed Mar 17, 2024
1 parent fc0d76a commit 9eb80b4
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 19 deletions.
34 changes: 18 additions & 16 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def iter_elements(self):
yield from self.father.keys()


class RunableProgram:
class RunnableProgram:
"""a pir program ready for run_program_op to run. constructed by 3 parts:
- pir program (pir::Program)
- in_out_values
Expand Down Expand Up @@ -240,7 +240,7 @@ def clone(self):
cloned_program, _ = paddle.base.libpaddle.pir.clone_program(
self.program
)
return RunableProgram(
return RunnableProgram(
cloned_program,
(self.x_names, self.param_names, self.out_names),
None,
Expand Down Expand Up @@ -462,7 +462,7 @@ def __init__(

# program_id -> list(scope)
self._scope_cache = {}
self._hooker = []
self._hookers = []
self._backend = kwargs.get('backend', None)
self._grad_var_names = {}
self._debug_name = None
Expand Down Expand Up @@ -506,7 +506,7 @@ def sot_call(self, inputs):
return out_vars

@cached_property
def origin_runable_program(self):
def origin_runnable_program(self):
inputs = list(self._inputs.var_list)
outputs = list(self._outputs.var_list)
params = self._param_values
Expand All @@ -516,7 +516,7 @@ def origin_runable_program(self):
len(self._origin_main_program.global_block().ops),
"output_",
)
return RunableProgram(
return RunnableProgram(
self._origin_main_program, (inputs, params, outputs)
)

Expand All @@ -536,7 +536,7 @@ def _sync_lr_value_with_scheduler(self):
lr_var.set_value(data)

def add_hooker(self, hooker):
self._hooker.append(hooker)
self._hookers.append(hooker)

def _get_scope(self, program_id=None, use_scope_cache=False):
if not use_scope_cache:
Expand Down Expand Up @@ -571,13 +571,15 @@ def pass_fn(forward_program, backward_program):
return forward_program, backward_program

# TODO(xiongkun) who to transfer the pruning program?
infer_program = self.origin_runable_program.clone()
for hooker in self._hooker:
infer_program = self.origin_runnable_program.clone()
for hooker in self._hookers:
hooker.after_infer(infer_program)
infer_program.apply_pir_program_pass(pass_fn)
return infer_program
else:
train_program: RunableProgram = self.origin_runable_program.clone()
train_program: RunnableProgram = (
self.origin_runnable_program.clone()
)
train_program = self._append_backward_desc(train_program)
# Note: Only set grad type once after initializing train program. So we put it here.
self._set_grad_type(self._params, train_program)
Expand Down Expand Up @@ -722,11 +724,11 @@ def _insert_aggregation_ops_for_var(target_program, var):
_insert_aggregation_ops_for_var(target_program, _var)

@switch_to_static_graph
def _append_backward_desc(self, train_runnable_program: RunableProgram):
def _append_backward_desc(self, train_runnable_program: RunnableProgram):
program = train_runnable_program.program
targets = train_runnable_program.out_values
# TODO(@zhuoge): refine the interface, use runable_program to apply passes.
for hooker in self._hooker:
# TODO(@zhuoge): refine the interface, use runnable_program to apply passes.
for hooker in self._hookers:
program, targets = hooker.before_append_backward(program, targets)
inputs = train_runnable_program.x_values
params = train_runnable_program.param_values
Expand Down Expand Up @@ -793,7 +795,7 @@ def _append_backward_desc(self, train_runnable_program: RunableProgram):
forward_end_idx = idx + 1
break

for hooker in self._hooker:
for hooker in self._hookers:
(
program,
forward_end_idx,
Expand All @@ -817,7 +819,7 @@ def _append_backward_desc(self, train_runnable_program: RunableProgram):
p_grad_value = list(map(mapping_value, grad_info_map[inputs_size:]))
o_grad_value = list(map(mapping_value, forward_outputs_grads))

# insert grads name for RunableProgram (we need name for grad_inputs and grad_outputs)
# insert grads name for RunnableProgram (we need name for grad_inputs and grad_outputs)
input_grads_to_append = list(
filter(lambda x: not is_fake_value(x), o_grad_value)
)
Expand All @@ -836,7 +838,7 @@ def _append_backward_desc(self, train_runnable_program: RunableProgram):
forward_end_idx + op_between_forward_and_backward
)
# construct a runnable program.
return RunableProgram(
return RunnableProgram(
program,
(inputs, params, targets),
(x_grad_value, p_grad_value, o_grad_value),
Expand Down Expand Up @@ -1005,7 +1007,7 @@ def _remove_no_value(self, out_vars):

return out_vars

def _set_grad_type(self, params, train_program: RunableProgram):
def _set_grad_type(self, params, train_program: RunnableProgram):
# NOTE: if user set sparse gradient mode, the param's gradient
# will be SelectedRows, not LoDTensor. But tracer will just
# set param grad Tensor by forward Tensor(LoDTensor)
Expand Down
4 changes: 2 additions & 2 deletions test/legacy_test/test_network_with_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def run_net_on_place(self, place):
exe.run(startup)
for data in train_reader():
exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)
# the main program is runable, the datatype is fully supported
# the main program is runnable, the datatype is fully supported
break

def init_dtype(self):
Expand All @@ -66,7 +66,7 @@ def test_gpu(self):
self.run_net_on_place(place)


# TODO(dzhwinter): make sure the fp16 is runable
# TODO(dzhwinter): make sure the fp16 is runnable
# class TestFloat16(TestNetWithDtype):
# def init_dtype(self):
# self.dtype = "float16"
Expand Down
2 changes: 1 addition & 1 deletion test/prim/pir_prim/test_pir_prim_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def train(self):
def check_prim(self, net):
program = net.forward.program_cache.last()[-1][-1].train_program
if isinstance(
program, paddle.jit.dy2static.pir_partial_program.RunableProgram
program, paddle.jit.dy2static.pir_partial_program.RunnableProgram
):
program = program.program
block = program.global_block()
Expand Down

0 comments on commit 9eb80b4

Please sign in to comment.