From 26a88e13feb6b897414d5f9633ee249235bf564c Mon Sep 17 00:00:00 2001 From: Nyakku Shigure Date: Fri, 13 Oct 2023 21:57:33 +0800 Subject: [PATCH] [SOT] merge PaddleSOT into Paddle (#57824) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PaddleSOT is a Bytecode level Implementation of Symbolic OpCode Translator For PaddlePaddle. We originally developed in [PaddleSOT](https://github.com/PaddlePaddle/PaddleSOT), and to ensure consistency in Paddle versions, we are now merging PaddleSOT into Paddle. Thanks to all the contributors of this project! See more details in https://github.com/PaddlePaddle/PaddleSOT/graphs/contributors --------- Co-authored-by: xiongkun Co-authored-by: feifei-111 <2364819892@qq.com> Co-authored-by: 0x45f <23097963+0x45f@users.noreply.github.com> Co-authored-by: gouzil <66515297+gouzil@users.noreply.github.com> Co-authored-by: 六个骨头 <46243324+zrr1999@users.noreply.github.com> Co-authored-by: Aurelius84 Co-authored-by: Wang Xin Co-authored-by: haozi <64006169+NotHaozi@users.noreply.github.com> Co-authored-by: RedContritio Co-authored-by: Sanbu <96160062+sanbuphy@users.noreply.github.com> Co-authored-by: Difer Co-authored-by: cyberslack_lee Co-authored-by: jjyaoao Co-authored-by: PuQing Co-authored-by: Ran chongzhi <57489288+ranchongzhi@users.noreply.github.com> Co-authored-by: Zhenghai Zhang <65210872+ccsuzzh@users.noreply.github.com> --- .flake8 | 3 + paddle/scripts/paddle_build.sh | 12 +- .../jit/dy2static/program_translator.py | 14 +- python/paddle/jit/sot/__init__.py | 22 + python/paddle/jit/sot/infer_meta.py | 282 +++ .../jit/sot/opcode_translator/__init__.py | 15 + .../jit/sot/opcode_translator/breakpoint.py | 179 ++ .../jit/sot/opcode_translator/custom_code.py | 23 + .../opcode_translator/executor/__init__.py | 15 + .../executor/dispatch_functions.py | 54 + .../opcode_translator/executor/dispatcher.py | 294 +++ .../executor/executor_cache.py | 230 ++ .../executor/function_graph.py | 680 ++++++ .../sot/opcode_translator/executor/guard.py | 183 ++ .../opcode_translator/executor/instr_flag.py | 36 + .../executor/mutable_data.py | 289 +++ .../executor/opcode_executor.py | 2073 +++++++++++++++++ .../executor/opcode_inline_executor.py | 330 +++ .../executor/pycode_generator.py | 1058 +++++++++ .../executor/side_effects.py | 234 ++ .../sot/opcode_translator/executor/tracker.py | 387 +++ .../executor/tracker_viewer.py | 115 + .../executor/variable_dispatch.py | 1109 +++++++++ .../executor/variable_stack.py | 216 ++ .../executor/variables/__init__.py | 63 + .../executor/variables/base.py | 618 +++++ .../executor/variables/basic.py | 888 +++++++ .../executor/variables/callable.py | 759 ++++++ .../executor/variables/container.py | 1011 ++++++++ .../executor/variables/iter.py | 203 ++ .../instruction_utils/__init__.py | 34 + .../instruction_utils/instruction_utils.py | 407 ++++ .../instruction_utils/opcode_analysis.py | 217 ++ .../instruction_utils/opcode_info.py | 58 + .../jit/sot/opcode_translator/skip_files.py | 177 ++ .../jit/sot/opcode_translator/transform.py | 112 + python/paddle/jit/sot/profiler.py | 78 + python/paddle/jit/sot/psdb.py | 68 + .../paddle/jit/sot/symbolic/compile_cache.py | 143 ++ python/paddle/jit/sot/symbolic/interpreter.py | 194 ++ .../paddle/jit/sot/symbolic/statement_ir.py | 338 +++ .../jit/sot/symbolic/symbolic_context.py | 161 ++ python/paddle/jit/sot/translate.py | 125 + python/paddle/jit/sot/utils/__init__.py | 62 + python/paddle/jit/sot/utils/code_status.py | 90 + python/paddle/jit/sot/utils/exceptions.py | 64 + python/paddle/jit/sot/utils/magic_methods.py | 130 ++ .../paddle/jit/sot/utils/paddle_api_config.py | 102 + python/paddle/jit/sot/utils/utils.py | 730 ++++++ python/paddle/tensor/creation.py | 3 + python/setup.py.in | 7 + setup.py | 7 + test/dygraph_to_static/CMakeLists.txt | 39 +- .../dygraph_to_static_util.py | 3 +- .../dygraph_to_static_utils_new.py | 320 +++ test/dygraph_to_static/test_assert.py | 12 +- test/dygraph_to_static/test_ast_util.py | 13 +- .../test_backward_without_params.py | 22 +- .../test_basic_api_transformation.py | 9 +- test/dygraph_to_static/test_bert.py | 7 +- test/dygraph_to_static/test_break_continue.py | 1 + test/dygraph_to_static/test_build_strategy.py | 1 + test/dygraph_to_static/test_cache_program.py | 3 + test/dygraph_to_static/test_cast.py | 14 +- test/dygraph_to_static/test_cinn.py | 6 +- test/dygraph_to_static/test_cinn_prim.py | 1 + .../test_cinn_prim_layer_norm.py | 4 +- .../test_closure_analysis.py | 16 +- test/dygraph_to_static/test_convert_call.py | 2 + .../test_convert_call_generator.py | 7 +- .../test_convert_operators.py | 10 +- .../test_cpu_cuda_to_tensor.py | 2 + test/dygraph_to_static/test_cycle_gan.py | 16 +- test/dygraph_to_static/test_declarative.py | 37 +- .../test_decorator_transform.py | 7 +- test/dygraph_to_static/test_deepcopy.py | 12 +- test/dygraph_to_static/test_dict.py | 8 +- test/dygraph_to_static/test_drop_path.py | 6 +- .../test_duplicate_output.py | 6 +- test/dygraph_to_static/test_error.py | 16 +- test/dygraph_to_static/test_fallback.py | 3 +- test/dygraph_to_static/test_fetch_feed.py | 6 +- test/dygraph_to_static/test_for_enumerate.py | 3 + .../dygraph_to_static/test_full_name_usage.py | 3 +- test/dygraph_to_static/test_grad.py | 1 + .../test_gradient_aggregation.py | 6 +- test/dygraph_to_static/test_grid_generator.py | 7 +- test/dygraph_to_static/test_ifelse.py | 5 + test/dygraph_to_static/test_isinstance.py | 6 +- .../test_jit_property_save.py | 3 + test/dygraph_to_static/test_jit_setitem.py | 2 + test/dygraph_to_static/test_lac.py | 3 + test/dygraph_to_static/test_lambda.py | 2 + test/dygraph_to_static/test_layer_hook.py | 6 +- test/dygraph_to_static/test_len.py | 3 + test/dygraph_to_static/test_list.py | 3 + .../test_load_transformer.py | 14 +- test/dygraph_to_static/test_logical.py | 3 + test/dygraph_to_static/test_loop.py | 5 + test/dygraph_to_static/test_mnist.py | 7 +- test/dygraph_to_static/test_mobile_net.py | 3 +- test/dygraph_to_static/test_multi_forward.py | 6 +- .../test_new_ir_selectedrows.py | 8 +- test/dygraph_to_static/test_op_attr.py | 3 +- test/dygraph_to_static/test_origin_info.py | 3 + test/dygraph_to_static/test_param_guard.py | 7 +- test/dygraph_to_static/test_params_no_grad.py | 3 + .../dygraph_to_static/test_partial_program.py | 23 +- .../test_partial_program_hook.py | 4 + test/dygraph_to_static/test_place.py | 3 + test/dygraph_to_static/test_print.py | 6 +- .../test_program_translator.py | 7 +- test/dygraph_to_static/test_ptb_lm.py | 6 +- test/dygraph_to_static/test_ptb_lm_v2.py | 2 + test/dygraph_to_static/test_pylayer.py | 4 + .../test_reinforcement_learning.py | 6 +- test/dygraph_to_static/test_resnet.py | 3 +- test/dygraph_to_static/test_resnet_amp.py | 6 +- .../test_resnet_pure_fp16.py | 6 +- test/dygraph_to_static/test_resnet_v2.py | 3 +- test/dygraph_to_static/test_return.py | 3 +- test/dygraph_to_static/test_rollback.py | 1 + .../test_save_inference_model.py | 8 +- test/dygraph_to_static/test_save_load.py | 7 +- test/dygraph_to_static/test_se_resnet.py | 3 +- test/dygraph_to_static/test_sentiment.py | 6 +- test/dygraph_to_static/test_seq2seq.py | 2 + test/dygraph_to_static/test_simnet.py | 8 +- test/dygraph_to_static/test_simnet_v2.py | 8 +- test/dygraph_to_static/test_slice.py | 7 +- test/dygraph_to_static/test_spec_names.py | 11 +- test/dygraph_to_static/test_tensor_hook.py | 2 + test/dygraph_to_static/test_tensor_methods.py | 11 +- test/dygraph_to_static/test_tensor_shape.py | 11 +- test/dygraph_to_static/test_to_tensor.py | 12 + test/dygraph_to_static/test_transformer.py | 6 +- test/dygraph_to_static/test_tsm.py | 6 +- test/dygraph_to_static/test_typehint.py | 6 +- .../dygraph_to_static/test_unuseful_inputs.py | 6 +- test/dygraph_to_static/test_utils.py | 4 + .../test_variable_trans_func.py | 3 + test/dygraph_to_static/test_word2vec.py | 6 +- test/dygraph_to_static/test_yolov3.py | 6 +- test/sot/extract_errors.py | 30 + test/sot/test_01_basic.py | 55 + test/sot/test_02_store_inplace.py | 47 + test/sot/test_03_tuple.py | 91 + test/sot/test_04_list.py | 327 +++ test/sot/test_05_dict.py | 264 +++ test/sot/test_06_call_function.py | 153 ++ test/sot/test_07_unpack.py | 70 + test/sot/test_08_rot.py | 97 + test/sot/test_09_f_string.py | 41 + test/sot/test_10_build_unpack.py | 97 + test/sot/test_11_jumps.py | 118 + test/sot/test_12_for_loop.py | 298 +++ test/sot/test_13_make_function.py | 39 + test/sot/test_14_operators.py | 387 +++ test/sot/test_15_slice.py | 137 ++ test/sot/test_16_paddle_api.py | 60 + test/sot/test_17_paddle_layer.py | 94 + test/sot/test_18_tensor_method.py | 90 + test/sot/test_19_closure.py | 260 +++ test/sot/test_20_string.py | 83 + test/sot/test_21_global.py | 175 ++ test/sot/test_analysis_inputs.py | 249 ++ test/sot/test_break_graph.py | 157 ++ test/sot/test_builtin_dispatch.py | 329 +++ test/sot/test_call_object.py | 83 + test/sot/test_case_base.py | 158 ++ test/sot/test_code_status.py | 154 ++ test/sot/test_constant_graph.py | 54 + test/sot/test_cost_model.py | 114 + test/sot/test_delete_fast.py | 38 + test/sot/test_dup_top.py | 49 + test/sot/test_enumerate.py | 116 + test/sot/test_error_handling.py | 39 + test/sot/test_exception.py | 94 + test/sot/test_execution_base.py | 62 + test/sot/test_guard_outputs.py | 78 + test/sot/test_guard_user_defined_fn.py | 88 + test/sot/test_inplace_api.py | 147 ++ test/sot/test_instruction_translator_cache.py | 165 ++ test/sot/test_map.py | 124 + test/sot/test_multiple_args.py | 35 + test/sot/test_mutable_data.py | 354 +++ test/sot/test_numpy.py | 44 + test/sot/test_numpy_var_if.py | 53 + test/sot/test_output_restoration.py | 95 + test/sot/test_range.py | 92 + test/sot/test_resnet.py | 59 + test/sot/test_resnet50_backward.py | 107 + test/sot/test_segment_linear.py | 71 + test/sot/test_side_effects.py | 333 +++ test/sot/test_simulate_initialize.py | 51 + test/sot/test_sir_rollback.py | 88 + test/sot/test_stack.py | 56 + test/sot/test_str_format.py | 37 + test/sot/test_tensor_dtype_in_guard.py | 76 + test/sot/test_tensor_slice.py | 33 + test/sot/test_trace_list_arg.py | 63 + 201 files changed, 22410 insertions(+), 196 deletions(-) create mode 100644 python/paddle/jit/sot/__init__.py create mode 100644 python/paddle/jit/sot/infer_meta.py create mode 100644 python/paddle/jit/sot/opcode_translator/__init__.py create mode 100644 python/paddle/jit/sot/opcode_translator/breakpoint.py create mode 100644 python/paddle/jit/sot/opcode_translator/custom_code.py create mode 100644 python/paddle/jit/sot/opcode_translator/executor/__init__.py create mode 100644 python/paddle/jit/sot/opcode_translator/executor/dispatch_functions.py create mode 100644 python/paddle/jit/sot/opcode_translator/executor/dispatcher.py create mode 100644 python/paddle/jit/sot/opcode_translator/executor/executor_cache.py create mode 100644 python/paddle/jit/sot/opcode_translator/executor/function_graph.py create mode 100644 python/paddle/jit/sot/opcode_translator/executor/guard.py create mode 100644 python/paddle/jit/sot/opcode_translator/executor/instr_flag.py create mode 100644 python/paddle/jit/sot/opcode_translator/executor/mutable_data.py create mode 100644 python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py create mode 100644 python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py create mode 100644 python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py create mode 100644 python/paddle/jit/sot/opcode_translator/executor/side_effects.py create mode 100644 python/paddle/jit/sot/opcode_translator/executor/tracker.py create mode 100644 python/paddle/jit/sot/opcode_translator/executor/tracker_viewer.py create mode 100644 python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py create mode 100644 python/paddle/jit/sot/opcode_translator/executor/variable_stack.py create mode 100644 python/paddle/jit/sot/opcode_translator/executor/variables/__init__.py create mode 100644 python/paddle/jit/sot/opcode_translator/executor/variables/base.py create mode 100644 python/paddle/jit/sot/opcode_translator/executor/variables/basic.py create mode 100644 python/paddle/jit/sot/opcode_translator/executor/variables/callable.py create mode 100644 python/paddle/jit/sot/opcode_translator/executor/variables/container.py create mode 100644 python/paddle/jit/sot/opcode_translator/executor/variables/iter.py create mode 100644 python/paddle/jit/sot/opcode_translator/instruction_utils/__init__.py create mode 100644 python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py create mode 100644 python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_analysis.py create mode 100644 python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_info.py create mode 100644 python/paddle/jit/sot/opcode_translator/skip_files.py create mode 100644 python/paddle/jit/sot/opcode_translator/transform.py create mode 100644 python/paddle/jit/sot/profiler.py create mode 100644 python/paddle/jit/sot/psdb.py create mode 100644 python/paddle/jit/sot/symbolic/compile_cache.py create mode 100644 python/paddle/jit/sot/symbolic/interpreter.py create mode 100644 python/paddle/jit/sot/symbolic/statement_ir.py create mode 100644 python/paddle/jit/sot/symbolic/symbolic_context.py create mode 100644 python/paddle/jit/sot/translate.py create mode 100644 python/paddle/jit/sot/utils/__init__.py create mode 100644 python/paddle/jit/sot/utils/code_status.py create mode 100644 python/paddle/jit/sot/utils/exceptions.py create mode 100644 python/paddle/jit/sot/utils/magic_methods.py create mode 100644 python/paddle/jit/sot/utils/paddle_api_config.py create mode 100644 python/paddle/jit/sot/utils/utils.py create mode 100644 test/dygraph_to_static/dygraph_to_static_utils_new.py create mode 100644 test/sot/extract_errors.py create mode 100644 test/sot/test_01_basic.py create mode 100644 test/sot/test_02_store_inplace.py create mode 100644 test/sot/test_03_tuple.py create mode 100644 test/sot/test_04_list.py create mode 100644 test/sot/test_05_dict.py create mode 100644 test/sot/test_06_call_function.py create mode 100644 test/sot/test_07_unpack.py create mode 100644 test/sot/test_08_rot.py create mode 100644 test/sot/test_09_f_string.py create mode 100644 test/sot/test_10_build_unpack.py create mode 100644 test/sot/test_11_jumps.py create mode 100644 test/sot/test_12_for_loop.py create mode 100644 test/sot/test_13_make_function.py create mode 100644 test/sot/test_14_operators.py create mode 100644 test/sot/test_15_slice.py create mode 100644 test/sot/test_16_paddle_api.py create mode 100644 test/sot/test_17_paddle_layer.py create mode 100644 test/sot/test_18_tensor_method.py create mode 100644 test/sot/test_19_closure.py create mode 100644 test/sot/test_20_string.py create mode 100644 test/sot/test_21_global.py create mode 100644 test/sot/test_analysis_inputs.py create mode 100644 test/sot/test_break_graph.py create mode 100644 test/sot/test_builtin_dispatch.py create mode 100644 test/sot/test_call_object.py create mode 100644 test/sot/test_case_base.py create mode 100644 test/sot/test_code_status.py create mode 100644 test/sot/test_constant_graph.py create mode 100644 test/sot/test_cost_model.py create mode 100644 test/sot/test_delete_fast.py create mode 100644 test/sot/test_dup_top.py create mode 100644 test/sot/test_enumerate.py create mode 100644 test/sot/test_error_handling.py create mode 100644 test/sot/test_exception.py create mode 100644 test/sot/test_execution_base.py create mode 100644 test/sot/test_guard_outputs.py create mode 100644 test/sot/test_guard_user_defined_fn.py create mode 100644 test/sot/test_inplace_api.py create mode 100644 test/sot/test_instruction_translator_cache.py create mode 100644 test/sot/test_map.py create mode 100644 test/sot/test_multiple_args.py create mode 100644 test/sot/test_mutable_data.py create mode 100644 test/sot/test_numpy.py create mode 100644 test/sot/test_numpy_var_if.py create mode 100644 test/sot/test_output_restoration.py create mode 100644 test/sot/test_range.py create mode 100644 test/sot/test_resnet.py create mode 100644 test/sot/test_resnet50_backward.py create mode 100644 test/sot/test_segment_linear.py create mode 100644 test/sot/test_side_effects.py create mode 100644 test/sot/test_simulate_initialize.py create mode 100644 test/sot/test_sir_rollback.py create mode 100644 test/sot/test_stack.py create mode 100644 test/sot/test_str_format.py create mode 100644 test/sot/test_tensor_dtype_in_guard.py create mode 100644 test/sot/test_tensor_slice.py create mode 100644 test/sot/test_trace_list_arg.py diff --git a/.flake8 b/.flake8 index d9585ef248701..91137a006d088 100644 --- a/.flake8 +++ b/.flake8 @@ -26,6 +26,9 @@ per-file-ignores = # These files need tabs for testing. test/dygraph_to_static/test_error.py:E101,W191 + # Ignore compare with True in sot unittest + test/sot/test_dup_top.py:E712 + # temp ignore base directory python/paddle/base/*: E712, diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 6c880988eceeb..a9c7b8416b2c9 100644 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -933,23 +933,21 @@ set -ex } function run_sot_test() { - PADDLE_SOT_ROOT=$1 - PY_VERSION=$2 + PY_VERSION=$1 PYTHON_WITH_SPECIFY_VERSION=python$PY_VERSION PY_VERSION_NO_DOT=`echo $PY_VERSION | sed 's/\.//g'` export STRICT_MODE=1 export COST_MODEL=False export MIN_GRAPH_SIZE=0 + export SOT_LOG_LEVEL=0 # Install PaddlePaddle $PYTHON_WITH_SPECIFY_VERSION -m pip install ${PADDLE_ROOT}/dist/paddlepaddle-0.0.0-cp${PY_VERSION_NO_DOT}-cp${PY_VERSION_NO_DOT}-linux_x86_64.whl # Install PaddleSOT - cd $PADDLE_SOT_ROOT - $PYTHON_WITH_SPECIFY_VERSION -m pip install -e . + cd $PADDLE_ROOT/test/sot/ # Run unittest - cd tests failed_tests=() for file in ./test_*.py; do @@ -4128,14 +4126,12 @@ function main() { ;; cicheck_sot) export WITH_SHARED_PHI=ON - PADDLE_SOT_ROOT=${PADDLE_ROOT}/sot - git clone https://github.com/PaddlePaddle/PaddleSOT.git ${PADDLE_SOT_ROOT} PYTHON_VERSIONS=(3.8 3.9 3.10 3.11) for PY_VERSION in ${PYTHON_VERSIONS[@]}; do ln -sf $(which python${PY_VERSION}) /usr/local/bin/python ln -sf $(which pip${PY_VERSION}) /usr/local/bin/pip run_setup ${PYTHON_ABI:-""} bdist_wheel ${parallel_number} - run_sot_test $PADDLE_SOT_ROOT $PY_VERSION + run_sot_test $PY_VERSION rm -rf ${PADDLE_ROOT}/build/CMakeCache.txt done ;; diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index 642b4c8b9529e..65da105499b20 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -692,13 +692,15 @@ class SymbolicStaticFunction(StaticFunction): def __init__(self, function, input_spec=None, **kwargs): if input_spec is not None: warnings.warn( - "\nSymbolic Trace don't support input_spec arguments. It will Will not produce any effect.\n" + "\nSymbolic Trace don't support input_spec arguments. It will not produce any effect.\n" "1. You can disable fallback mode by `paddle.jit.to_static(enable_fallback=False)` to switch to AST to static, then you can assign input spec.\n" ) super().__init__(function, input_spec, **kwargs) self.last_call_input_spec = None def _perform_call(self, *args, **kwargs): + from ..sot import symbolic_translate + args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs) ( input_args_with_spec, @@ -706,16 +708,6 @@ def _perform_call(self, *args, **kwargs): ) = self._function_spec.args_to_input_spec(args, kwargs) self.last_call_input_spec = input_args_with_spec - try: - from sot import symbolic_translate - except: - import os - - os.system( - "pip install git+https://github.com/PaddlePaddle/PaddleSOT@develop" - ) - from sot import symbolic_translate - build_strategy = self._kwargs.get("build_strategy", None) backend = self._kwargs.get("backend", None) traced_fun = symbolic_translate( diff --git a/python/paddle/jit/sot/__init__.py b/python/paddle/jit/sot/__init__.py new file mode 100644 index 0000000000000..1b45c0c55389b --- /dev/null +++ b/python/paddle/jit/sot/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import psdb # noqa: F401 +from .opcode_translator.breakpoint import ( # noqa: F401 + BM, + add_breakpoint, + add_event, +) +from .opcode_translator.skip_files import skip_function # noqa: F401 +from .translate import symbolic_translate # noqa: F401 diff --git a/python/paddle/jit/sot/infer_meta.py b/python/paddle/jit/sot/infer_meta.py new file mode 100644 index 0000000000000..8ea3ec28f19a4 --- /dev/null +++ b/python/paddle/jit/sot/infer_meta.py @@ -0,0 +1,282 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.amp.auto_cast import amp_state +from paddle.base.unique_name import UniqueNameGenerator +from paddle.base.unique_name import guard as UniqueNameGuard +from paddle.static import Program +from paddle.utils import flatten, is_sequence + +from .utils import Cache, Singleton, map_if_extend, meta_str + + +class MetaInfo: + def __init__( + self, shape, dtype, stop_gradient, name, persistable, type, place + ): + self.name = name + self.persistable = persistable + self.type = type + self.place = place + self.shape = shape + self.dtype = dtype + self.stop_gradient = stop_gradient + + @staticmethod + def from_tensor(tensor): + # We always use float32 in simulation if AMP is enabled. + dtype = tensor.dtype + current_amp_state = amp_state() + if ( + dtype == paddle.float16 + and current_amp_state is not None + and current_amp_state["dtype"] == "float16" + ): + dtype = paddle.float32 + return MetaInfo( + list(tensor.shape), + dtype, + tensor.stop_gradient, + tensor.name, + tensor.persistable, + tensor.type, + tensor.place, + ) + + def is_dynamic_shape(self): + """ + if -1 in shape, return True + else: return False + """ + return -1 in self.shape + + def to_input_spec(self): + return paddle.static.InputSpec( + self.shape, dtype=self.dtype, stop_gradient=self.stop_gradient + ) + + def guard_str(self): + return f"({self.shape}, {self.dtype}, {self.stop_gradient})" + + def __repr__(self): + return meta_str(self.shape, self.dtype, self.stop_gradient) + + def __eq__(self, meta): + return ( + self.shape == meta.shape + and self.dtype == meta.dtype + and self.stop_gradient == meta.stop_gradient + ) + + def __hash__(self): + return hash((tuple(self.shape), self.dtype, self.stop_gradient)) + + +@Singleton +class VariableCreator: + """ + We use the static graph Variable to infer the meta information of Tensor. + This singleton class is used to create Variable for infer meta. + """ + + def __init__(self): + self.var_cache = {} + self.main_program = Program() + self.startup_program = Program() + self.var_name_generator = UniqueNameGenerator("infer_meta_variable_") + + def gen_name(self, meta): + name = f"{meta.dtype}_{meta.stop_gradient}" + for l in meta.shape: + name += f"_{l}" + return name + + def create_var(self, meta): + var = self.main_program.global_block().create_var( + shape=meta.shape, + dtype=meta.dtype, + stop_gradient=meta.stop_gradient, + ) + assert not isinstance( + var, paddle.Tensor + ), "Expect a Variable, but got a Tensor." + return var + + def get_variable(self, meta): + var_feature_name = self.gen_name(meta) + if var_feature_name not in self.var_cache: + self.var_cache[var_feature_name] = self.create_var(meta) + return self.var_cache[var_feature_name] + + def infer_meta(self, func, *args, **kwargs): + with paddle.base.framework._dygraph_guard(None), UniqueNameGuard( + self.var_name_generator + ): + args, kwargs = convert_meta_to_variable( + args + ), convert_meta_to_variable(kwargs) + + with paddle.static.program_guard( + self.main_program, self.startup_program + ): + if isinstance(func, str): + # TODO(Aurelius84): Is length of args always greater than 0? + # Do we need add condition check here? + out = getattr(args[0], func)(*args[1:], **kwargs) + else: + out = func(*args, **kwargs) + + return convert_variable_to_meta_info(out) + + +def convert_meta_to_variable(args): + return map_if_extend( + args, + pred=lambda x: isinstance(x, MetaInfo), + true_fn=lambda x: VariableCreator().get_variable(x), + false_fn=lambda x: x, + ) + + +def convert_meta_to_input_spec(args): + return map_if_extend( + args, + pred=lambda x: isinstance(x, MetaInfo), + true_fn=lambda x: x.to_input_spec(), + # TODO(xiongkun): can x be tensor ? + false_fn=lambda x: paddle.static.InputSpec.from_tensor(x) + if isinstance(x, paddle.Tensor) + else x, + ) + + +def convert_variable_to_meta_info(args): + return map_if_extend( + args, + pred=lambda x: isinstance(x, paddle.static.Variable), + true_fn=lambda x: MetaInfo.from_tensor(x), + false_fn=lambda x: x, + ) + + +def infer_meta(func, *args, **kwargs): + fn = SpecialInferMeta().get_infermeta_fn(func) + if fn: + return fn(*args, **kwargs) + return VariableCreator().infer_meta(func, *args, **kwargs) + + +def infer_meta_for_layer(layer, *args, **kwargs): + assert isinstance( + layer, paddle.nn.Layer + ), f"Expect a Layer, but got {layer}." + layer = paddle.jit.to_static(layer, enable_fallback=False) + + args_, kwargs_ = convert_meta_to_input_spec((args, kwargs)) + + ( + concrete_program, + partial_program_layer, + ) = layer.forward.get_concrete_program(*args_, **kwargs_) + + out = partial_program_layer._restore_out( + paddle.utils.flatten( + convert_variable_to_meta_info(concrete_program.outputs) + ) + ) + layer.forward.rollback() + return out + + +@Singleton +class SpecialInferMeta: + """ + There are some functions that cannot be inferred directly through static graph, + and need to be implemented manually. This class is used to implement infer meta + for these functions. + """ + + def __init__(self): + pass + + def get_infermeta_fn(self, fn): + try: + funcname = fn.__name__ + return getattr(self, f"infermeta_{funcname}") + except: + pass + return None + + def infermeta_grad( + self, + outputs, + inputs, + grad_outputs=None, + retain_graph=None, + create_graph=False, + only_inputs=True, + allow_unused=False, + no_grad_vars=None, + ): + if not is_sequence(inputs): + inputs = [inputs] + return inputs + + +@Singleton +class InferMetaCache(Cache): + def key_fn( + self, func, *args, **kwargs + ): # args & kwargs have transformed to MetaInfo + try: + retval = hash( + ( + func, + tuple(flatten(args)), + tuple(kwargs.keys()), + tuple(flatten(kwargs)), + ) + ) + except Exception as e: + return None + return retval + + def value_fn(self, func, *args, **kwargs): + return infer_meta(func, *args, **kwargs) + + +@Singleton +class LayerInferMetaCache(Cache): + def key_fn(self, layer, *args, **kwargs): + params = [ + MetaInfo.from_tensor(x) + for x in layer.parameters(include_sublayers=True) + ] + try: + retval = hash( + ( + layer, + tuple(params), + tuple(flatten(args)), + tuple(kwargs.keys()), + tuple(flatten(kwargs)), + ) + ) + except Exception as e: + return None + return retval + + def value_fn(self, layer, *args, **kwargs): + return infer_meta_for_layer(layer, *args, **kwargs) diff --git a/python/paddle/jit/sot/opcode_translator/__init__.py b/python/paddle/jit/sot/opcode_translator/__init__.py new file mode 100644 index 0000000000000..bf230190e3e11 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .transform import eval_frame_callback # noqa: F401 diff --git a/python/paddle/jit/sot/opcode_translator/breakpoint.py b/python/paddle/jit/sot/opcode_translator/breakpoint.py new file mode 100644 index 0000000000000..6f3217dd8776e --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/breakpoint.py @@ -0,0 +1,179 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import traceback +from dataclasses import dataclass + +from ..opcode_translator.instruction_utils import instrs_info +from ..utils import Singleton, log +from .executor.opcode_executor import OpcodeExecutorBase + +# this file is a debug utils files for quick debug +# >>> sot.add_breakpoint(file, line) +# >>> sot.remove_breakpoint(file, line) + + +@dataclass +class Breakpoint: + file: str + line: int + co_name: str + offset: int + + def __hash__(self): + return hash((self.file, self.line, self.co_name, self.offset)) + + +@Singleton +class BreakpointManager: + def __init__(self): + self.breakpoints = set() + self.executors = OpcodeExecutorBase.call_stack + self.activate = 0 + self.record_event = [] + + def clear_event(self, event): + self.record_event.clear() + + def add_event(self, event): + """ + event in ['All' ,'FallbackError', 'BreakGraphError', 'InnerError'] + """ + self.record_event.append(event) + + def add(self, file, line, coname=None, offset=None): + log(1, f"add breakpoint at {file}:{line}\n") + self.breakpoints.add(Breakpoint(file, line, coname, offset)) + + def addn(self, *lines): + """ + called inside a executor. add a list of line number in current file. + """ + if not isinstance(lines, (list, tuple)): + lines = [lines] + for line in lines: + file = self.cur_exe._code.co_filename + self.add(file, line) + + def clear(self): + self.breakpoints.clear() + + def hit(self, file, line, co_name, offset): + if Breakpoint(file, line, None, None) in self.breakpoints: + return True + if Breakpoint(file, line, co_name, offset) in self.breakpoints: + return True + return False + + def locate(self, exe): + for i, _e in enumerate(self.executors): + if _e is exe: + self.activate = i + return + raise RuntimeError("Not found executor.") + + def up(self): + if self.activate == 0: + return + self.activate -= 1 + print("current function is: ", self.cur_exe._code.co_name) + + def down(self): + if self.activate >= len(self.executors) - 1: + return + self.activate += 1 + print("current function is: ", self.cur_exe._code.co_name) + + def opcode(self, cur_exe=None): + if cur_exe is None: + cur_exe = self.cur_exe + instr = cur_exe._instructions[cur_exe._lasti - 1] + message = f"[Translate {cur_exe}]: (line {cur_exe._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {cur_exe._stack}\n" + return message + + def bt(self): + """ + display all inline calls: backtrace. + """ + for exe in self.executors: + lines, _ = inspect.getsourcelines(exe._code) + print( + " " + + exe._code.co_filename + + f"({exe._current_line})" + + f"{exe._code.co_name}()" + ) + print(f"-> {lines[0].strip()}") + print(f"-> {self.opcode(exe)}") + pass + + def on_event(self, event): + if "All" in self.record_event or event in self.record_event: + print("event captured.") + self.activate = len(self.executors) - 1 + breakpoint() + + def _dis_source_code(self): + cur_exe = self.executors[self.activate] + lines, start_line = inspect.getsourcelines(cur_exe._code) + cur_line = cur_exe._current_line + lines[ + cur_line - start_line + 1 : cur_line - start_line + 1 + ] = " ^^^^^ HERE \n" + print("\033[31mSource Code is: \033[0m") + print("".join(lines)) + + def dis(self, range=5): + """ + display all instruction code and source code. + """ + print("displaying debug info...") + cur_exe = self.cur_exe + print(self._dis_source_code()) + + print(f"\n{cur_exe._code}") + lasti = cur_exe._lasti + lines = instrs_info(cur_exe._instructions, lasti - 1, range) + print("\n".join(lines)) + + @property + def cur_exe(self): + exe = self.executors[self.activate] + return exe + + def sir(self): + """ + display sir in a page. + """ + print("displaying sir...") + self.cur_exe.print_sir() + + def pe(self, e): + """ + print exception. + """ + lines = traceback.format_tb(e.__traceback__) + print("".join(lines)) + + +def add_breakpoint(file, line, co_name=None, offset=None): + BM.add(file, line, co_name, offset) + + +def add_event(event): + BM.add_event(event) + + +BM = BreakpointManager() diff --git a/python/paddle/jit/sot/opcode_translator/custom_code.py b/python/paddle/jit/sot/opcode_translator/custom_code.py new file mode 100644 index 0000000000000..da674fb673170 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/custom_code.py @@ -0,0 +1,23 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import types +from typing import NamedTuple + + +class CustomCode(NamedTuple): + code: types.CodeType | None + disable_eval_frame: bool diff --git a/python/paddle/jit/sot/opcode_translator/executor/__init__.py b/python/paddle/jit/sot/opcode_translator/executor/__init__.py new file mode 100644 index 0000000000000..4d9db28d22707 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import variable_dispatch # noqa: F401 diff --git a/python/paddle/jit/sot/opcode_translator/executor/dispatch_functions.py b/python/paddle/jit/sot/opcode_translator/executor/dispatch_functions.py new file mode 100644 index 0000000000000..9b00dcde0462b --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/dispatch_functions.py @@ -0,0 +1,54 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file stores the customed function that will be called by the dispatch mechanism. + +from ...utils import BreakGraphError, FallbackError + + +def raise_break_graph_fn(*args, **kwarg): + raise BreakGraphError("raise by raise_break_graph_fn.") + + +def raise_not_implement_fn(*args, **kwarg): + raise FallbackError("raise by raise_break_graph_fn.") + + +# just a function for operator.in +def operator_in(left, right): + return left in right + + +def operator_not_in(left, right): + return left not in right + + +def operator_exception_match(left, right): + pass + + +def operator_BAD(left, right): + pass + + +def operator_is_none(val): + pass + + +def operator_is_not_none(val): + pass + + +def tensor_numel(x): + pass diff --git a/python/paddle/jit/sot/opcode_translator/executor/dispatcher.py b/python/paddle/jit/sot/opcode_translator/executor/dispatcher.py new file mode 100644 index 0000000000000..315066f27e820 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/dispatcher.py @@ -0,0 +1,294 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import copy +import inspect +import operator +from functools import cached_property, reduce +from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple, TypeVar + +from ...utils import InnerError, NameGenerator, hashable + +if TYPE_CHECKING: + T = TypeVar("T") + Args = Tuple[T, ...] + Kwargs = Dict[str, T] + + +def format_type(type_: type[Any] | tuple[type[Any], ...]) -> str: + if not isinstance(type_, tuple): + type_ = (type_,) + return " | ".join([t.__name__ for t in type_]) + + +def format_param(param: Parameter) -> str: + kind = param.kind + # TODO: support VAR_KEYWORD + if kind == inspect.Parameter.VAR_POSITIONAL: + return f"*{format_type(param.type)}" + else: + return format_type(param.type) + + +def convert_annotation_to_type(type_str: str) -> tuple[type[Any], ...]: + """ + Convert type annotation to runtime value. Because we are using :pep:`563` + to use the future annotation syntax, we cannot use `get_type_hints `_ + directly. Currently, only the builtins and variables namespaces are supported. + + Returns: + tuple: The converted type. + """ + + import builtins + + from . import variables + + type_str = type_str.strip() + if type_str == "Any": + type_str = "object" + + if "|" in type_str: + return reduce( + operator.add, map(convert_annotation_to_type, type_str.split("|")) + ) + + search_namespaces = [variables, builtins] + for namespace in search_namespaces: + if hasattr(namespace, type_str): + return (getattr(namespace, type_str),) + raise InnerError(f"Cannot find type {type_str} in {search_namespaces}") + + +class Parameter: + name_gen = NameGenerator("param_") + annotation: str + name: str + + def __init__( + self, + annotation: str, + *, + kind: inspect._ParameterKind = inspect.Parameter.POSITIONAL_OR_KEYWORD, + name: str | None = None, + default: Any = inspect._empty, + ): + self.name = name if name is not None else Parameter.name_gen.next() + self.annotation = annotation + self.kind = kind + self.default = default + + def to_parameter(self) -> inspect.Parameter: + return inspect.Parameter( + self.name, + kind=self.kind, + annotation=self.annotation, + default=copy.copy(self.default), + ) + + @cached_property + def type(self) -> tuple[type[Any], ...]: + return convert_annotation_to_type(self.annotation) + + def match_arg(self, arg: Any) -> bool: + # TODO: support VAR_KEYWORD + if self.kind == inspect.Parameter.VAR_POSITIONAL: + is_tuple = isinstance(arg, tuple) + return is_tuple and all(isinstance(a, self.type) for a in arg) + else: + return isinstance(arg, self.type) + + @staticmethod + def from_str(annotation: str) -> Parameter: + return Parameter(annotation) + + @staticmethod + def from_parameter(parameter: inspect.Parameter) -> Parameter: + if parameter.annotation != parameter.empty and not isinstance( + parameter.annotation, str + ): + raise InnerError( + f"Parameter {parameter} has annotation {parameter.annotation} " + "which is not a string. Please add `from __future__ import annotations` " + "to the top of your file." + ) + annotation = ( + parameter.annotation + if parameter.annotation != parameter.empty + else "Any" + ) + + return Parameter( + annotation, + kind=parameter.kind, + name=parameter.name, + default=parameter.default, + ) + + def __repr__(self) -> str: + default_repr = f"= {self.default!r}" + return f"Parameter({', '.join([self.annotation, default_repr])})" + + +def optional(annotation: str, default: Any = None) -> Parameter: + return Parameter(annotation, default=default) + + +class Pattern: + parameters: dict[str, Parameter] + signature: inspect.Signature + + def __init__( + self, + *parameters: Parameter, + ): + self.parameters = { + parameter.name: parameter for parameter in parameters + } + self.signature = inspect.Signature( + [parameter.to_parameter() for parameter in self.parameters.values()] + ) + + def match_inputs(self, /, *args: Any, **kwargs: Any) -> bool: + """ + Match the input parameters of the function. + + Returns: + bool: Whether the input parameters match the pattern. + """ + try: + bound_args = self.signature.bind(*args, **kwargs) + except TypeError: + return False + for arg_name, arg_value in bound_args.arguments.items(): + if arg_name not in self.parameters: + continue + if not self.parameters[arg_name].match_arg(arg_value): + return False + return True + + def __repr__(self) -> str: + types_repr = ", ".join( + [format_param(param) for param in self.parameters.values()] + ) + return f"Pattern({types_repr})" + + +class Dispatcher: + """ + Used for pattern registration and distribution. + + For more design ideas, refer to the `Builtin dispatcher `_ for details. + + Examples: + + >>> def builtin_add(a: int, b: int) -> int: + ... ... + ... + >>> Dispatcher.register(builtin_add, ("int", "int"), lambda a, b: a + b) + >>> handler = Dispatcher.dispatch(builtin_add, 1, 2) + >>> handler(1, 2) + 3 + """ + + handlers: dict[ + Callable[..., Any], list[tuple[Pattern, Callable[..., Any]]] + ] = {} + graph: Any = None + + @classmethod + def register( + cls, + fn: Callable[..., Any], + parameters: tuple[str | Parameter, ...], + handler: Callable[..., Any], + ): + """ + Registering function signature. + + Args: + fn: The function to be registered. + parameters: The parameters of the function to be registered. + handler: The handler function. + """ + _parameters = tuple( + Parameter.from_str(parameter) + if isinstance(parameter, str) + else parameter + for parameter in parameters + ) + if fn not in cls.handlers: + cls.handlers[fn] = [] + cls.handlers[fn].append((Pattern(*_parameters), handler)) + + @classmethod + def register_decorator(cls, fn: Callable[..., Any]): + """ + Decorator mode of register, Used to register some complex functions. + + Args: + fn: The function to be registered. + + Examples: + >>> def builtin_add(a: int, b: int) -> int: + ... ... + ... + >>> @Dispatcher.register_decorator(builtin_add) + ... def builtin_add_dispatcher(a: int, b: int) -> int: + ... return a + b + ... + >>> handler = Dispatcher.dispatch(builtin_add, 1, 2) + >>> handler(1, 2) + 3 + """ + + def decorator(handler: Callable[..., Any]): + signature = inspect.signature(handler) + parameters = tuple( + Parameter.from_parameter(parameter) + for parameter in signature.parameters.values() + ) + cls.register(fn, parameters, handler) + + return decorator + + @classmethod + def call(cls, fn, *args, **kwargs): + func = cls.dispatch(fn, *args, **kwargs) + if func is None: + raise InnerError( + f"Cannot find handler for {fn} with args {args} and kwargs {kwargs}" + ) + return func(*args, **kwargs) + + @classmethod + def dispatch( + cls, fn: Callable[..., Any], *args: Any, **kwargs: Any + ) -> Callable[..., Any] | None: + """ + Find the matching handler from the registered functions. + + Args: + fn: The function to be dispatched. + args: The args of the function. + kwargs: The kwargs of the function. + """ + if not hashable(fn) or fn not in cls.handlers: + return None + for pattern, handler in cls.handlers[fn]: + if pattern.match_inputs(*args, **kwargs): + return handler + return None diff --git a/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py b/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py new file mode 100644 index 0000000000000..67d656f4dcd75 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py @@ -0,0 +1,230 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import traceback +import types +from typing import List, Tuple + +from ...profiler import EventGuard, event_register +from ...psdb import NO_FALLBACK_CODES +from ...utils import ( + BreakGraphError, + FallbackError, + InnerError, + Singleton, + is_strict_mode, + log, + log_do, +) +from ..custom_code import CustomCode +from .guard import Guard +from .opcode_executor import OpcodeExecutor, OpcodeExecutorBase +from .pycode_generator import PyCodeGen + +GuardedFunction = Tuple[CustomCode, Guard] +GuardedFunctions = List[GuardedFunction] + +dummy_guard: Guard = lambda frame: True +dummy_guard.expr = "lambda frame: True" +dummy_guard.lambda_expr = "lambda frame: True" + + +@Singleton +class OpcodeExecutorCache: + """ + A singleton class that implements a cache for translated instructions. + This cache is used to store previously translated instructions along with their corresponding guard functions. + + Attributes: + cache (dict): A dictionary that maps code objects to tuples of a cache getter function and a list of guarded functions. + translate_count (int): The count of how many instructions have been translated. It is used to test whether the cache hits. + """ + + MAX_CACHE_SIZE = 20 + cache: dict[types.CodeType, GuardedFunctions] + translate_count: int + + def __init__(self): + self.cache = {} + self.translate_count = 0 + + def clear(self): + """ + Clears the cache and resets the translate count. + """ + self.cache.clear() + self.translate_count = 0 + + def __call__(self, frame: types.FrameType, **kwargs) -> CustomCode: + code: types.CodeType = frame.f_code + if code not in self.cache: + log(2, f"[Cache]: Firstly call {code}\n") + new_custom_code, guard_fn = self.translate(frame, **kwargs) + self.cache[code] = [(new_custom_code, guard_fn)] + return new_custom_code + guarded_fns = self.cache[code] + return self.lookup(frame, guarded_fns, **kwargs) + + @event_register("lookup") + def lookup( + self, frame: types.FrameType, guarded_fns: GuardedFunctions, **kwargs + ) -> CustomCode: + """ + Looks up the cache for a matching code object and returns a custom code object if a matching guard function is found, otherwise None. + + Args: + frame (types.FrameType): The frame whose code object needs to be looked up in the cache. + guarded_fns (GuardedFunctions): The list of guarded functions associated with the code object. + + Returns: + CustomCode | None: The custom code object if a matching guard function is found, otherwise None. + """ + + if len(guarded_fns) >= self.MAX_CACHE_SIZE: + log(2, "[Cache]: Exceed max cache size, skip it\n") + return CustomCode(None, False) + + for custom_code, guard_fn in guarded_fns: + try: + with EventGuard("try guard"): + guard_result = guard_fn(frame) + if guard_result: + log( + 2, + f"[Cache]: Cache hit, Guard is \n{getattr(guard_fn, 'expr', 'None')}\n", + ) + return custom_code + else: + log_do( + 4, + self.analyse_guard_global_object(guard_fn), + ) + log( + 2, + f"[Cache]: Cache miss, Guard is \n{getattr(guard_fn, 'expr', 'None')}\n", + ) + log_do( + 2, + self.analyse_guard_error(guard_fn, frame), + ) + except Exception as e: + log(2, f"[Cache]: Guard function error: {e}\n") + continue + + log(2, "[Cache]: all guards missed\n") + new_custom_code, guard_fn = self.translate(frame, **kwargs) + guarded_fns.append((new_custom_code, guard_fn)) + return new_custom_code + + def translate( + self, frame: types.FrameType, **kwargs + ) -> tuple[CustomCode, Guard]: + """ + Translates the given frame's code object and returns the cache getter function and a guarded function for the translated code object. + + Args: + frame (types.FrameType): The frame whose code object needs to be translated. + + Returns: + tuple[CustomCode, Guard]: The cache getter function and a guarded function for the translated code object. + """ + code: types.CodeType = frame.f_code + self.translate_count += 1 + custom_new_code, guard_fn = start_translate(frame, **kwargs) + return custom_new_code, guard_fn + + def analyse_guard_global_object(self, guard_fn): + def inner(): + for key in guard_fn.__globals__.keys(): + if key.startswith("__object"): + print( + f"[Cache] meet global object: {key} : {guard_fn.__globals__[key]}", + ) + + return inner + + def analyse_guard_error(self, guard_fn, frame): + def inner(): + guard_expr = guard_fn.lambda_expr + lambda_head = "lambda frame: " + guard_expr = guard_expr.replace(lambda_head, "") + guards = guard_expr.split(" and ") + for guard_str in guards: + guard = eval(lambda_head + guard_str, guard_fn.__globals__) + result = False + try: + result = guard(frame) + except Exception as e: + print( + f"[Cache]: skip checking {guard_str}\n because error occured {e}" + ) + if result is False: + print(f"[Cache]: missed at {guard_str}") + return + print("[Cache]: missed guard not found.") + + return inner + + +def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction: + """ + Starts the translation process for the given frame and returns the translated code object and its guard function, or None if translation fails. + + Args: + frame: The frame to be translated. + + Returns: + GuardedFunction | None: The translated code object and its guard function, or None if translation fails. + """ + simulator = OpcodeExecutor(frame, **kwargs) + try: + new_custom_code, guard_fn = simulator.transform() + return new_custom_code, guard_fn + # TODO(zrr1999): InnerError maybe place before (FallbackError, BreakGraphError) + # TODO(0x45f): handle BreakGraphError to trigger fallback + except BreakGraphError as e: + raise RuntimeError( + f"Found BreakGraphError raised, it should not be catch at start_translate!\n{e}" + ) + except FallbackError as e: + if simulator._code in NO_FALLBACK_CODES: + raise InnerError( + f"{simulator._code.co_name} should not fallback, but got '{e}'" + ) + # if disable_eval_frame is True, it means we want fallback to speedup rather than error occured + if is_strict_mode() and e.disable_eval_frame is False: + raise + log( + 2, + f"Unsupport Frame is {frame.f_code}, error message is: \n" + + "".join(traceback.format_exception(type(e), e, e.__traceback__)), + ) + + # NOTE: If resume fn need fallback, we should replace NullVariable using NULL otherwise will fail to run + py_codegen = PyCodeGen(frame) + new_code = py_codegen.replace_null_variable() + # simulation not complete, not sure whether this code has sir, set disable_eval_frame = False + guard_fn = ( + dummy_guard if e.disable_eval_frame is False else simulator.guard_fn + ) + return ( + CustomCode(new_code, e.disable_eval_frame), + guard_fn, + ) + except Exception as e: + raise InnerError(OpcodeExecutorBase.error_message_summary(e)) from e + finally: + simulator.cleanup() diff --git a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py new file mode 100644 index 0000000000000..61f72b267b2de --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py @@ -0,0 +1,680 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is specifically used to handle the problem +# of generating a Graph from a linear function call. + +from __future__ import annotations + +import builtins +import inspect +from collections import namedtuple +from copy import deepcopy +from functools import cached_property +from typing import Any, Callable + +from ...infer_meta import InferMetaCache, LayerInferMetaCache, MetaInfo +from ...profiler import EventGuard, event_register +from ...symbolic.statement_ir import Symbol +from ...symbolic.symbolic_context import SymbolicTraceContext +from ...utils import ( + NameGenerator, + OrderedSet, + inner_error_default_handler, + is_inplace_api, + is_paddle_api, + log, + log_do, + map_if, + show_trackers, + tmp_name_guard, +) +from .guard import Guard, StringifyExpression, make_guard +from .mutable_data import MutationDel, MutationNew, MutationSet +from .pycode_generator import PyCodeGen +from .side_effects import ( + DictSideEffectRestorer, + GlobalDelSideEffectRestorer, + GlobalSetSideEffectRestorer, + ListSideEffectRestorer, + ObjDelSideEffectRestorer, + ObjSetSideEffectRestorer, + SideEffectRestorer, + SideEffects, +) +from .tracker import BuiltinTracker, DummyTracker +from .variables import ( + DictVariable, + GlobalVariable, + ListVariable, + NullVariable, + PaddleLayerVariable, + TensorVariable, + VariableBase, + VariableFactory, + find_traceable_vars, + map_variables, +) + + +def convert_to_meta(inputs: Any): + """ + Convert the input variables to meta if it is TensorVariable. + """ + + def func(x): + if isinstance(x, TensorVariable): + return x.meta + if isinstance(x, VariableBase): + return x.get_py_value() + return x + + return map_variables(func, inputs) + + +def convert_to_symbol(inputs: Any): + """ + Convert the input variables to symbol if it can be symbolic. + """ + + def func(x): + if isinstance(x, (TensorVariable, PaddleLayerVariable)): + return x.get_symbol() + if isinstance(x, VariableBase): + return x.get_py_value() + return x + + return map_variables(func, inputs) + + +class FunctionGraph: + """ + A Graph representation corresponding to each FunctionFrame + The input binding diagram containing the current call represents three parts of output settings, + This Graph can be compiled as a f_locals dependency function which produce the same outputs. + """ + + OUT_VAR_PREFIX = "___SIR_out_" + Memo = namedtuple( + "function_graph_memo", + [ + 'inner_out', + 'input_variables', + "stmt_ir", + "global_guards", + "side_effects_state", + "print_variables", + "inplace_tensors", + ], + ) + + def __init__(self, frame, **kwargs): + self.sir_ctx = SymbolicTraceContext() + self.inner_out = set() + self.input_variables = [] # Store variables required within a function + self.pycode_gen = PyCodeGen(frame, disable_eval_frame=True) + self.side_effects = SideEffects() + self._global_guarded_variables: OrderedSet[VariableBase] = OrderedSet() + self._print_variables = [] + self._inplace_tensors = OrderedSet() + self.build_strategy = kwargs.get('build_strategy', None) + self._kwargs = kwargs + + @cached_property + def _builtins(self): + builtins_ = {} + # prepare builtins + for name, value in builtins.__dict__.items(): + builtins_[name] = VariableFactory.from_value( + value, self, BuiltinTracker(name), debug_name=name + ) + return builtins_ + + def add_print_variables(self, variable): + """ + Used to support psdb_print + """ + self._print_variables.append(variable) + + def add_inplace_tensors(self, variable): + """ + Used to support psdb_print + """ + self._inplace_tensors.add(variable) + + def need_add_input(self, var): + """ + Determine if it is the input of graph. + + Args: + var: The input variable. + + """ + if var.id in self.inner_out: + return False + for v in self.input_variables: + if v.id == var.id: + return False + return True + + def save_memo(self) -> FunctionGraph.Memo: + """ + Save the state of the current FunctionGraph, for future state recovery, it is used for state recovery during inline call error reporting + + NOTE: + Why don't use __deepcopy__, because memo is not a deepcopy, i.e inner_out is only a shallow copy, SIR is a deepcopy. + """ + saved_stmt_ir = deepcopy(self.sir_ctx.TOS) + return FunctionGraph.Memo( + inner_out=set(self.inner_out), + input_variables=list(self.input_variables), + stmt_ir=saved_stmt_ir, + global_guards=OrderedSet(self._global_guarded_variables), + side_effects_state=self.side_effects.get_state(), + print_variables=list(self._print_variables), + inplace_tensors=OrderedSet(self._inplace_tensors), + ) + + def restore_memo(self, memo: FunctionGraph.Memo): + """ + Restore the state of graph to memo. + + Args: + memo: Previously recorded memo + + """ + self.inner_out = memo.inner_out + self.input_variables = memo.input_variables + self.sir_ctx.replace_TOS(memo.stmt_ir) + self._global_guarded_variables = memo.global_guards + self.side_effects.restore_state(memo.side_effects_state) + self._print_variables = memo.print_variables + self._inplace_tensors = memo.inplace_tensors + + def collect_input_variables(self, inputs: list[VariableBase]): + """ + Variables required within the method + + Args: + inputs: Required VariableBase + """ + + def collect(inp): + if isinstance(inp, VariableBase) and self.need_add_input(inp): + self.input_variables.append(inp) + + map_variables( + collect, + inputs, + ) + + @property + @event_register("guard_fn") + def guard_fn(self) -> Guard: + with tmp_name_guard(): + guards = [] + with EventGuard( + "guard_fn: find vars and make stringify guard", event_level=1 + ): + for variable in find_traceable_vars( + self.input_variables + list(self._global_guarded_variables) + ): + guards.extend(variable.make_stringify_guard()) + + guards = OrderedSet(guards) + + for guard in guards: + assert isinstance( + guard, StringifyExpression + ), "guard must be StringifyExpression." + + return make_guard(guards) + + def start_compile_with_name_store(self, ret_vars, to_store_vars): + class VariableLoader: + def __init__(self, index_for_load, pycode_gen): + self._index_for_load = index_for_load + self._pycode_gen = pycode_gen + + def load(self, var): + if isinstance(var, NullVariable): + var.reconstruct(self._pycode_gen) + return + self._pycode_gen.gen_load_fast(self._index_for_load[var.id]) + + # var_id -> local_name mapping + index_for_load = {} + to_store_vars = list( + filter(lambda x: not isinstance(x, NullVariable), to_store_vars) + ) + self.start_compile(*(ret_vars + to_store_vars)) + name_gen = NameGenerator("__start_compile_saved_") + for var in to_store_vars: + index_for_load[var.id] = name_gen.next() + + def _log_fn(): + print( + f"[StartCompile] saved var: {index_for_load[var.id]} = ", + var, + ) + + log_do(4, _log_fn) + + for var in to_store_vars[::-1]: + self.pycode_gen.gen_store_fast(index_for_load[var.id]) + return VariableLoader(index_for_load, self.pycode_gen) + + @event_register("start_compile", event_level=2) + def start_compile(self, *ret_vars: VariableBase): + """ + Generate bytecode based on the information collected by the simulation execution. + + This consists of the following steps: + - Compile the FunctionGraph into a dy2st StaticFunction and load it in the generated bytecode + - Load the group network input + - Calling the generated dy2st StaticFunction + - Restore the side effects + - Restore the output + - Return the top of the stack + """ + from ..breakpoint import BreakpointManager + + BreakpointManager().on_event("start_compile") + + ret_items = [ + ret_item + for ret_var in ret_vars + for ret_item in ret_var.flatten_items() + ] + + tensor_items = self._find_tensor_outputs(ret_items) + compiled_fn, statment_ir = self.sir_ctx.compile_fn( + [Symbol(tensor_var.var_name) for tensor_var in tensor_items], + **self._kwargs, + ) + input_names = statment_ir.inputs + compiled_fn_name = f"__compiled_fn_{statment_ir.name}" + # prepare function and inputs + self.pycode_gen.gen_load_object(compiled_fn, compiled_fn_name) + for name in input_names: + found = False + for variable in self.input_variables: + if ( + isinstance(variable, TensorVariable) + and variable.get_symbol().name == name + ): + variable.tracker.gen_instructions(self.pycode_gen) + found = True + break + assert found, f"can't find input {name} in SIR." + # Pack all args into a tuple, because we don't support *args now. + self.pycode_gen.gen_build_tuple(count=len(input_names)) + # call the compiled_fn + self.pycode_gen.gen_call_function(argc=1) + + # Store outputs to f_locals + self.pycode_gen.gen_unpack_sequence(count=len(tensor_items)) + for tensor_var in tensor_items: + self.pycode_gen.gen_store_fast(tensor_var.out_var_name) + # restore the outputs. + for ret_var in ret_vars: + ret_var.reconstruct(self.pycode_gen) + + # deal side effect + self.restore_inplace_tensor(self._inplace_tensors) + self.restore_print_stmts(self._print_variables) + self.restore_side_effects(self.side_effects.proxy_variables) + self.pycode_gen.gen_enable_eval_frame() + + tracker_output_path = show_trackers() + if tracker_output_path: + from .tracker_viewer import view_tracker + + view_tracker(list(ret_vars), tracker_output_path, format="png") + + def call_paddle_api( + self, + func: Callable[..., Any], + *args: VariableBase, + **kwargs: VariableBase, + ): + """ + Record Paddle Networking API to SIR + + Args: + func: paddle api + """ + assert is_paddle_api(func) + # not fallback api, start symbolic trace. + # TODO(xiokgun): may have python buildin object inside metas. + # TODO(xiokgun): 4 kinds of python arguments. support it !! + log(3, f"call paddle.api : {func.__name__}", "\n") + + def message_handler(*args, **kwargs): + return f"Call paddle_api error: {func.__name__}, may be not a operator api ?" + + return inner_error_default_handler(self.symbolic_call, message_handler)( + InferMetaCache(), self.sir_ctx.call_API, func, *args, **kwargs + ) + + def call_tensor_method( + self, method_name: str, *args: VariableBase, **kwargs + ): + """ + call tensor method, start symbolic trace. + + Args: + method_name: tensor method name + """ + + def message_handler(*args, **kwargs): + return f"Call tensor_method error: Tensor.{method_name}, may be not a valid operator api ?" + + return inner_error_default_handler(self.symbolic_call, message_handler)( + InferMetaCache(), + self.sir_ctx.call_METHOD, + method_name, + *args, + **kwargs, + ) + + @staticmethod + def get_opcode_executor_stack(): + # NOTE: only for debug. + # dependent on OpcodeExecutor. + from .opcode_executor import OpcodeExecutorBase + + if len(OpcodeExecutorBase.call_stack) == 0: + # In test case, we can meet this senario. + return [] + current_executor = OpcodeExecutorBase.call_stack[-1] + current_line = current_executor._current_line + filename = current_executor._code.co_filename + source_lines, start_line = inspect.getsourcelines( + current_executor._code + ) + # TODO(SigureMo): In 3.11, lineno maybe changed after multiple breakgraph, + # We need to find a way to fix this. + line_idx = min(current_line - start_line, len(source_lines) - 1) + code_line = source_lines[line_idx] + stack = [] + stack.append( + ' File "{}", line {}, in {}'.format( + filename, + current_line, + current_executor._code.co_name, + ) + ) + stack.append(f' {code_line}') + return stack + + def call_layer( + self, + layer: PaddleLayerVariable, + *args: VariableBase, + **kwargs: VariableBase, + ): + """ + call paddle layer, start symbolic trace. + + Args: + layer: paddle layer + """ + + def infer_meta_fn(layer, *metas, **kwmetas): + metas = LayerInferMetaCache()(layer.value, *metas, **kwmetas) + return metas + + def compute_fn(layer, inputs, outputs, stacks): + self.sir_ctx.call_LAYER( + layer.value, + inputs=inputs, + outputs=outputs, + stacks=stacks, + ) + + def message_handler(*args, **kwargs): + return f"Call paddle layer error: {layer}, may be not a valid paddle layer ?" + + return inner_error_default_handler(self.symbolic_call, message_handler)( + infer_meta_fn, compute_fn, layer, *args, **kwargs + ) + + def symbolic_call(self, infer_meta_fn, compute_fn, func, *args, **kwargs): + """ + Using infer_meta_fn and compute_fn convert func to symbolic function. + + Args: + infer_meta_fn: function for infer meta, (func, metas, kwmetas) -> output_metas + compute_fn : function for sir compile, (func, input_symbols, outputs_symbols) -> None + func : symbolic function + """ + self.collect_input_variables(list(args)) + self.collect_input_variables(list(kwargs.values())) + metas = convert_to_meta(args) + kwmetas = convert_to_meta(kwargs) + + out_metas = infer_meta_fn(func, *metas, **kwmetas) + inputs_symbols = ( + convert_to_symbol(args), + convert_to_symbol(kwargs), + ) + log(3, f" inputs : {inputs_symbols}", "\n") + + outputs = map_if( + out_metas, + pred=lambda x: isinstance(x, MetaInfo), + true_fn=lambda x: TensorVariable( + x, + self, + tracker=DummyTracker(list(args) + list(kwargs.values())), + ), + false_fn=lambda x: x, + ) + stmt_stacks = [] + log_do( + 3, + lambda: stmt_stacks.extend( + FunctionGraph.get_opcode_executor_stack() + ), + ) + if outputs is not None: + if is_inplace_api(func): + # if we want to use a non-inplace api (static api) to replace an inplace behavior (in simulation) + # just set it back in SIR, and return outputs to replace tensor meta (it might changes?) + # in this case, the output will not exactly be used + compute_fn( + func, + inputs_symbols, + convert_to_symbol(args[0]), + stmt_stacks, + ) + else: + compute_fn( + func, + inputs_symbols, + convert_to_symbol(outputs), + stmt_stacks, + ) # symbolic only contain symbols. + self._put_inner(outputs) + return VariableFactory.from_value( + outputs, self, DummyTracker(list(args) + list(kwargs.values())) + ) + else: + return None + + def _put_inner(self, vars: VariableBase): + """ + put inner variable to inner_out + """ + map_if( + vars, + pred=lambda x: isinstance(x, VariableBase), + true_fn=lambda x: self.inner_out.add(x.id), + false_fn=lambda x: None, + ) + + def add_global_guarded_variable(self, variable: VariableBase): + """ + Add variable to global guarded variable + """ + self._global_guarded_variables.add(variable) + + def remove_global_guarded_variable(self, variable: VariableBase): + """ + Remove variable to global guarded variable + """ + if variable in self._global_guarded_variables: + self._global_guarded_variables.remove(variable) + + def _find_tensor_outputs( + self, outputs: list[VariableBase] + ) -> OrderedSet[TensorVariable]: + """ + Return all TensorVariable. find TensorVariables participating in networking from the output Variables + + Args: + outputs: output variables + """ + output_tensors: OrderedSet[TensorVariable] = OrderedSet() + # Find Tensor Variables from outputs. + for output in outputs: + if isinstance(output.tracker, DummyTracker): + if isinstance(output, TensorVariable): + output_tensors.add(output) + else: + # Guard output that can not be traced. + self.add_global_guarded_variable(output) + # Find Tensor Variables from side effects Variables. + for side_effect_var in self.side_effects.proxy_variables: + if isinstance(side_effect_var, (ListVariable, DictVariable)): + for var in side_effect_var.flatten_items(): + if ( + isinstance(var.tracker, DummyTracker) + and isinstance(var, TensorVariable) + and side_effect_var.tracker.is_traceable() + ): + output_tensors.add(var) + else: + if isinstance(side_effect_var, GlobalVariable): + proxy_records = side_effect_var.proxy.records + elif side_effect_var.tracker.is_traceable(): + # for attr side effect + proxy_records = side_effect_var.attr_proxy.records + else: + continue + for record in proxy_records: + if isinstance(record, (MutationSet, MutationNew)): + for var in record.value.flatten_items(): + if isinstance( + var.tracker, DummyTracker + ) and isinstance(var, TensorVariable): + output_tensors.add(var) + # Find Tensor in print_stmts + for print_stmt in self._print_variables: + for var in print_stmt.flatten_items(): + if isinstance(var.tracker, DummyTracker) and isinstance( + var, TensorVariable + ): + output_tensors.add(var) + + # add inplace tensors into output tensors. + for inplace_tensor in self._inplace_tensors: + output_tensors.add(inplace_tensor) + + return output_tensors + + def restore_print_stmts(self, variables: list[VariableBase]): + for var in variables: + var.reconstruct( + self.pycode_gen, + use_tracker=False, + add_to_global_guarded_vars=False, + ) + + def restore_inplace_tensor(self, variables: list[VariableBase]): + for var in variables: + if not var.tracker.is_traceable(): + continue + var.reconstruct( + self.pycode_gen, + use_tracker=True, + add_to_global_guarded_vars=False, + ) + self.pycode_gen.gen_load_method( + "_inplace_assign" + ) # NOTE: paddle related logic. + var.reconstruct( + self.pycode_gen, + use_tracker=False, + add_to_global_guarded_vars=True, + ) + self.pycode_gen.gen_call_method(1) + self.pycode_gen.gen_pop_top() + + def restore_side_effects(self, variables: list[VariableBase]): + """ + Generate side effect recovery code for variables with side effects + + Args: + variables: Variables that may have side effects. + """ + restorers: list[SideEffectRestorer] = [] + + for var in variables: + # skip inner variables + if not var.tracker.is_traceable() and not isinstance( + var, GlobalVariable + ): + continue + if isinstance(var, DictVariable): + restorers.append(DictSideEffectRestorer(var)) + elif isinstance(var, ListVariable): + restorers.append(ListSideEffectRestorer(var)) + else: + if isinstance(var, GlobalVariable): + for record in var.proxy.records[::-1]: + if isinstance(record, (MutationSet, MutationNew)): + restorers.append( + GlobalSetSideEffectRestorer( + record.key, + record.value, + ) + ) + elif isinstance(record, MutationDel): + restorers.append( + GlobalDelSideEffectRestorer(record.key) + ) + else: + for record in var.attr_proxy.records[::-1]: + if isinstance(record, (MutationSet, MutationNew)): + restorers.append( + ObjSetSideEffectRestorer( + var, + record.key, + record.value, + ) + ) + elif isinstance(record, MutationDel): + restorers.append( + ObjDelSideEffectRestorer( + var, + record.key, + ) + ) + + for restorer in restorers: + restorer.pre_gen(self.pycode_gen) + for restorer in restorers[::-1]: + restorer.post_gen(self.pycode_gen) diff --git a/python/paddle/jit/sot/opcode_translator/executor/guard.py b/python/paddle/jit/sot/opcode_translator/executor/guard.py new file mode 100644 index 0000000000000..b839c064f407d --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/guard.py @@ -0,0 +1,183 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import types +import weakref +from typing import TYPE_CHECKING, Any, Callable, TypeVar + +from ...profiler import EventGuard +from ...utils import InnerError, current_tmp_name_records, log, log_do + +Guard = Callable[[types.FrameType], bool] + +if TYPE_CHECKING: + from .variables import VariableBase + + CheckGuardInputT = TypeVar("CheckGuardInputT", bound=VariableBase) + +# NOTE(SigureMo): [How to write Stringify Guard?] +# 1. we should capture free variables manually, the string cannot capture free +# variables automatically. +# 2. Be aware that the comparison logic before and after stringify may be different. +# 3. we should compute as much as possible at "compile time" and encode the +# computation in the Guard string, rather than passing it to runtime to minimize +# runtime overhead. + + +class StringifyExpression: + """ + Used to store string based expressions for generating Guard. + """ + + def __init__(self, str_expr, sub_exprs, free_vars): + expr = str_expr.format(*[arg.expr for arg in sub_exprs]) + self.expr = current_tmp_name_records().add_tmp_var(expr) + self.debug_expr = str_expr.format( + *[arg.debug_expr for arg in sub_exprs] + ) + self.free_vars = free_vars + + def __post_init__(self): + self.check_expr(self.expr) + + def check_expr(self, expr: str): + try: + pass + # ast.parse(expr) # TODO(xiongkun): too slow + except SyntaxError as e: + raise InnerError(f"Invalid expression: {expr}") from e + + def __hash__(self): + if self.free_vars: + return hash((self.debug_expr, id(self))) + else: + return hash(self.debug_expr) + + +def union_free_vars(*free_vars: dict[str, Any]): + return {k: v for d in free_vars for k, v in d.items()} + + +def make_guard(stringify_guards: list[StringifyExpression]) -> Guard: + """ + Make a guard from a list of StringifyExpression. + + For more design ideas, refer to the `Stringify guard `_ for details. + + Args: + stringify_guards: a list of StringifyExpression. + """ + with EventGuard("make_guard"): + num_guards = len(stringify_guards) + if not num_guards: + guard = lambda frame: True + guard.expr = "lambda frame: True" + return guard + + def analyse_expresions(stringify_exprs, tmp_names): + func_string = "def built_guard_fn(frame):\n" + lambda_string = "lambda frame: " + free_vars = {} + + for k, v in tmp_names.items(): + func_string += f" {v} = {k}\n" + + func_result = "" + for str_expr in stringify_exprs: + func_result += str_expr.expr + " and " + lambda_string += str_expr.debug_expr + " and " + free_vars = union_free_vars(free_vars, str_expr.free_vars) + + func_string += f" return {func_result[:-5]}" + + return func_string, free_vars, lambda_string[:-5] + + ( + func_string, + free_vars, + lambda_string, + ) = analyse_expresions( + stringify_guards, current_tmp_name_records().tmp_names_record + ) + + exec( + func_string, + free_vars, + ) + + guard = free_vars['built_guard_fn'] + log(3, f"[Guard]: {lambda_string}\n") + guard.lambda_expr = lambda_string + guard.expr = func_string + assert callable(guard), "guard must be callable." + + return guard + + +def support_weak_ref(obj): + if isinstance(obj, types.FunctionType): + return True + return False + + +def check_guard( + fn: Callable[[CheckGuardInputT], list[StringifyExpression]] +) -> Callable[[CheckGuardInputT], list[StringifyExpression]]: + def wrapper(self: CheckGuardInputT) -> list[StringifyExpression]: + assert ( + self.tracker.is_traceable() + ), "Cannot make guard from a non-tracable guard variable." + + def guard_log(): + frame_value_tracer = self.tracker.trace_value_from_frame() + print( + f"[Guard]: guard_fn for {self}, tracker={self.tracker.__class__.__name__}, value={frame_value_tracer.expr}" + ) + + log_do(4, guard_log) + return fn(self) + + return wrapper + + +@check_guard +def object_equal_stringify_guard(self) -> list[StringifyExpression]: + frame_value_tracer = self.tracker.trace_value_from_frame() + + obj_free_var_name = f"__{self.id}" + weak_ref_obj = self.get_py_value() + if support_weak_ref(weak_ref_obj): + weak_ref_obj = weakref.ref(self.get_py_value()) + return [ + StringifyExpression( + f"{obj_free_var_name}() is not None and {{}} == {obj_free_var_name}()", + [frame_value_tracer], + union_free_vars( + frame_value_tracer.free_vars, + {obj_free_var_name: weak_ref_obj}, + ), + ) + ] + return [ + StringifyExpression( + f"{{}} == {obj_free_var_name}", + [frame_value_tracer], + union_free_vars( + frame_value_tracer.free_vars, + {obj_free_var_name: self.get_py_value()}, + ), + ) + ] diff --git a/python/paddle/jit/sot/opcode_translator/executor/instr_flag.py b/python/paddle/jit/sot/opcode_translator/executor/instr_flag.py new file mode 100644 index 0000000000000..1dd795439d459 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/instr_flag.py @@ -0,0 +1,36 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flags for instructions + + +class FORMAT_VALUE_FLAG: + FVC_MASK = 0x3 + FVC_NONE = 0x0 + FVC_STR = 0x1 + FVC_REPR = 0x2 + FVC_ASCII = 0x3 + FVS_MASK = 0x4 + FVS_HAVE_SPEC = 0x4 + + +class MAKE_FUNCTION_FLAG: + MF_HAS_CLOSURE = 0x08 + MF_HAS_ANNOTATION = 0x04 + MF_HAS_KWDEFAULTS = 0x02 + MF_HAS_DEFAULTS = 0x01 + + +class CALL_FUNCTION_EX_FLAG: + CFE_HAS_KWARGS = 0x01 diff --git a/python/paddle/jit/sot/opcode_translator/executor/mutable_data.py b/python/paddle/jit/sot/opcode_translator/executor/mutable_data.py new file mode 100644 index 0000000000000..d6bda43d42ef4 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/mutable_data.py @@ -0,0 +1,289 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar + +if TYPE_CHECKING: + from typing_extensions import Concatenate, ParamSpec, TypeAlias + + P = ParamSpec("P") + R = TypeVar("R") + + MutableDataT = TypeVar("MutableDataT", bound="MutableData") + DataGetter: TypeAlias = Callable[[MutableDataT, Any], Any] + +InnerMutableDataT = TypeVar( + "InnerMutableDataT", bound="dict[str, Any] | list[Any]" +) + + +class Mutation: + ABBR: str + + +class MutationSet(Mutation): + """ + Setting a value. + This mutation is used for MutableDictLikeData and MutableListLikeData. + """ + + ABBR = "S" + + def __init__(self, key, value): + self.key = key + self.value = value + + def __repr__(self): + return f"MutationSet({self.key}, {self.value})" + + +class MutationDel(Mutation): + """ + Deleting a value. + This mutation is used for MutableDictLikeData and MutableListLikeData. + """ + + ABBR = "D" + + def __init__(self, key): + self.key = key + + def __repr__(self): + return f"MutationDel({self.key})" + + +class MutationNew(Mutation): + """ + Adding a new value. + This mutation is only used for MutableDictLikeData. + """ + + ABBR = "N" + + def __init__(self, key, value): + self.key = key + self.value = value + + def __repr__(self): + return f"MutationNew({self.key}, {self.value})" + + +class MutationInsert(Mutation): + """ + Inserting a value. + This mutation is only used for MutableListLikeData. + """ + + ABBR = "I" + + def __init__(self, index, value): + self.index = index + self.value = value + + def __repr__(self): + return f"MutationInsert({self.index}, {self.value})" + + +class MutationPermutate(Mutation): + """ + Permutating all the values. + This mutation is only used for MutableListLikeData. + """ + + ABBR = "P" + + def __init__(self, permutation): + self.permutation = permutation + + def __repr__(self): + return f"MutationPermutate({self.permutation})" + + +def record_mutation( + mutation_fn: Callable[Concatenate[MutableDataT, P], Mutation] +) -> Callable[Concatenate[MutableDataT, P], None]: + def wrapper(self, *args: P.args, **kwargs: P.kwargs): + mutation = mutation_fn(self, *args, **kwargs) + self.records.append(mutation) + + return wrapper + + +class MutableData(Generic[InnerMutableDataT]): + """ + An intermediate data structure between data and variable, it records all the mutations. + """ + + read_cache: InnerMutableDataT + + class Empty: + def __repr__(self): + return "Empty()" + + def __init__(self, data: Any, getter: DataGetter): + self.original_data = data + self.getter = getter + self.records: list[Mutation] = [] + + def is_empty(self, value): + return isinstance(value, MutableData.Empty) + + @property + def version(self): + return len(self.records) + + @property + def has_changed(self): + return self.version != 0 + + def rollback(self, version: int): + assert version <= self.version + self.records[:] = self.records[:version] + + def get(self, key): + raise NotImplementedError() + + def set(self, key, value): + raise NotImplementedError() + + def apply(self, mutation: Mutation, write_cache: InnerMutableDataT): + raise NotImplementedError() + + def reproduce(self, version: int | None = None) -> InnerMutableDataT: + if version is None: + version = self.version + write_cache = self.read_cache.copy() + for mutation in self.records[:version]: + self.apply(mutation, write_cache) + return write_cache + + def __repr__(self) -> str: + records_abbrs = "".join([mutation.ABBR for mutation in self.records]) + return f"{self.__class__.__name__}({records_abbrs})" + + +class MutableDictLikeData(MutableData["dict[str, Any]"]): + def __init__(self, data: Any, getter: DataGetter): + super().__init__(data, getter) + self.read_cache = {} + + def clear_read_cache(self): + self.read_cache.clear() + + def get(self, key: Any): + # TODO(SigureMo): Optimize performance of this. + write_cache = self.reproduce(self.version) + if key not in write_cache: + self.read_cache[key] = self.getter(self, key) + return self.reproduce(self.version)[key] + + def get_all(self): + original_keys = list(self.original_data.keys()) + for mutation in self.records: + if isinstance(mutation, MutationNew): + original_keys.append(mutation.key) + elif isinstance(mutation, MutationDel): + original_keys.remove(mutation.key) + return {key: self.get(key) for key in original_keys} + + @record_mutation + def set(self, key: Any, value: Any) -> Mutation: + is_new = False + if self.is_empty(self.get(key)): + is_new = True + return ( + MutationSet(key, value) if not is_new else MutationNew(key, value) + ) + + @record_mutation + def delete(self, key): + return MutationDel(key) + + def apply(self, mutation: Mutation, write_cache: dict[str, Any]): + if isinstance(mutation, MutationNew): + write_cache[mutation.key] = mutation.value + elif isinstance(mutation, MutationSet): + write_cache[mutation.key] = mutation.value + elif isinstance(mutation, MutationDel): + write_cache[mutation.key] = MutableData.Empty() + else: + raise ValueError(f"Unknown mutation type {mutation}") + + def reproduce(self, version: int | None = None): + if version is None: + version = self.version + write_cache = self.read_cache.copy() + for mutation in self.records[:version]: + self.apply(mutation, write_cache) + return write_cache + + +class MutableListLikeData(MutableData["list[Any]"]): + def __init__(self, data: Any, getter: DataGetter): + super().__init__(data, getter) + self.read_cache = [ + self.getter(self, idx) for idx in range(len(self.original_data)) + ] + + def clear_read_cache(self): + self.read_cache[:] = [] + + @property + def length(self): + return len(self.reproduce()) + + def get(self, key): + write_cache = self.reproduce(self.version) + return write_cache[key] + + def get_all(self) -> list[Any]: + items = self.reproduce(self.version) + return items + + @record_mutation + def set(self, key: int, value: Any): + return MutationSet(self._regularize_index(key), value) + + @record_mutation + def delete(self, key: int): + return MutationDel(self._regularize_index(key)) + + @record_mutation + def insert(self, index: int, value: Any): + return MutationInsert(self._regularize_index(index), value) + + @record_mutation + def permutate(self, permutation: list[int]): + return MutationPermutate(permutation) + + def _regularize_index(self, index: int): + if index < 0: + index += self.length + return index + + def apply(self, mutation: Mutation, write_cache: list[Any]): + if isinstance(mutation, MutationSet): + write_cache[mutation.key] = mutation.value + elif isinstance(mutation, MutationDel): + write_cache[:] = ( + write_cache[: mutation.key] + write_cache[mutation.key + 1 :] + ) + elif isinstance(mutation, MutationInsert): + write_cache.insert(mutation.index, mutation.value) + elif isinstance(mutation, MutationPermutate): + write_cache[:] = [write_cache[i] for i in mutation.permutation] + else: + raise ValueError(f"Unknown mutation type {mutation}") diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py new file mode 100644 index 0000000000000..240ca8f1b889e --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -0,0 +1,2073 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import dis +import functools +import inspect +import operator +import sys +import traceback +import types +from dataclasses import dataclass +from itertools import chain +from typing import Any, Callable + +import opcode + +from ...profiler import EventGuard, event_register +from ...psdb import NO_BREAKGRAPH_CODES +from ...utils import ( + BreakGraphError, + FallbackError, + InnerError, + OrderedSet, + SotUndefinedVar, + log, + log_do, + min_graph_size, +) +from ..custom_code import CustomCode +from ..instruction_utils import ( + Instruction, + Space, + analysis_inputs, + analysis_used_names_with_space, + calc_stack_effect, + get_instructions, +) +from ..instruction_utils.opcode_info import JumpDirection, PopJumpCond +from .dispatch_functions import ( + operator_BAD, + operator_exception_match, + operator_in, + operator_is_none, + operator_is_not_none, + operator_not_in, +) +from .dispatcher import Dispatcher +from .function_graph import FunctionGraph +from .instr_flag import CALL_FUNCTION_EX_FLAG as CFE +from .instr_flag import FORMAT_VALUE_FLAG as FV +from .instr_flag import MAKE_FUNCTION_FLAG as MF +from .pycode_generator import PyCodeGen +from .tracker import ( + CellTracker, + ConstTracker, + DanglingTracker, + DummyTracker, + LocalTracker, +) +from .variable_stack import VariableStack +from .variables import ( + BuiltinVariable, + CellVariable, + ConstantVariable, + ContainerVariable, + DictVariable, + GlobalVariable, + ListVariable, + MethodVariable, + NullVariable, + SequenceIterVariable, + SliceVariable, + TensorVariable, + TupleVariable, + UserDefinedFunctionVariable, + VariableBase, + VariableFactory, +) + +SUPPORT_COMPARE_OP = { + ">": operator.gt, + "<": operator.lt, + ">=": operator.ge, + "<=": operator.le, + "==": operator.eq, + "!=": operator.ne, + "is not": operator.is_not, + "is": operator.is_, + "in": operator_in, + "not in": operator_not_in, + "exception match": operator_exception_match, + "BAD": operator_BAD, +} + + +@dataclass +class Stop: + state: str + + +def tos_op_wrapper(fn: Callable): + """ + A decorator function that wraps an opcode operation and applies certain functionality to it. + + Args: + fn: The opcode operation to be wrapped. + + Returns: + The wrapped opcode operation. + """ + nargs = len(inspect.signature(fn).parameters) + + @call_break_graph_decorator(push_n=1) + def inner(self: OpcodeExecutorBase, instr: Instruction): + args = self.stack.pop_n(nargs) + res = BuiltinVariable(fn, graph=self._graph, tracker=DanglingTracker())( + *args + ) + self.stack.push(res) + + return inner + + +def tos_inplace_op_wrapper(fn: Callable): + """ + A decorator function that wraps an inplace opcode operation and applies certain functionality to it. + + Args: + fn: The inplace opcode operation to be wrapped. + + Returns: + The wrapped inplace opcode operation. + + """ + + @call_break_graph_decorator(push_n=1) + def inner(self: OpcodeExecutorBase, instr: Instruction): + """ + Inner function that represents the wrapped inplace opcode operation. + + Args: + self: The instance of the OpcodeExecutorBase class. + instr: The instruction to be executed. + + """ + args = self.stack.pop_n(2) + res = BuiltinVariable(fn, graph=self._graph, tracker=DanglingTracker())( + *args + ) + res.debug_name = args[0].debug_name + self.stack.push(res) + + return inner + + +def pop_jump_if_op_wrapper(fns: list[Callable[[Any], Any]]): + """ + A decorator function that wraps a POP_JUMP_*_IF_* opcode operation and applies certain functionality to it. + + Args: + fn: The condition function. + + Returns: + The wrapped POP_JUMP_*_IF_* opcode operation. + + """ + + @jump_break_graph_decorator + def inner(self: OpcodeExecutorBase, instr: Instruction): + """ + Inner function that represents the wrapped POP_JUMP_IF opcode operation. + + Args: + self: The instance of the OpcodeExecutorBase class. + instr: The instruction to be executed. + + """ + pred_obj = self.stack.pop() + + try: + self._graph.add_global_guarded_variable(pred_obj) + res = pred_obj + for fn in fns: + res = BuiltinVariable( + fn, graph=self._graph, tracker=DanglingTracker() + )(res) + + assert isinstance(res, ConstantVariable) + is_jump = res.get_py_value() + assert isinstance(is_jump, bool) + if is_jump: + assert instr.jump_to is not None + self.jump_to(instr.jump_to) + except BreakGraphError: + raise FallbackError( + f"Currently don't support predicate {pred_obj.__class__.__name__}" + ) + + return inner + + +def jump_break_graph_decorator(normal_jump: Callable): + """ + A decorator function that breaks off the graph when a JUMP-related instruction is encountered. + + Args: + normal_jump: The normal jump operation. + + Returns: + The wrapped jump operation. + + """ + + def inner(self: OpcodeExecutor, instr: Instruction): + result = self.stack.top + if isinstance(result, TensorVariable): + self.stack.pop() + # fallback when in OpcodeExecutor + # raise error in OpcodeInlineExecutor + log(3, "[BreakGraph] jump break graph, because if tensor\n") + self._break_graph_in_jump(result, instr) + return Stop(state="BreakGraph") + else: + return normal_jump(self, instr) + + return inner + + +def call_break_graph_decorator(push_n: int | Callable[[int | None], int]): + """ + A decorator function that breaks off the graph when a function CALL instruction is encountered. + + Args: + push_n: The number of arguments to be pushed onto the stack. + + Returns: + The decorated function. + + """ + + def decorate(call_fn: Callable): + @functools.wraps(call_fn) + def wrapper(self: OpcodeExecutor, instr: Instruction): + origin_stack = self.stack.copy() + try: + return call_fn(self, instr) + except BreakGraphError as e: + if self._code in NO_BREAKGRAPH_CODES: + raise InnerError( + f"{self._code.co_name} should not break graph, but got '{e}'" + ) + if isinstance(self, OpcodeExecutor): + log(3, f"[BreakGraph] call function Break graph: {e}\n") + self._break_graph_in_call(origin_stack, instr, push_n) + return Stop(state="BreakGraph") + else: + raise e + + return wrapper + + return decorate + + +def fallback_when_occur_error(fn: Callable): + """ + A decorator function that provides fallback behavior when an error occurs during graph processing. + + Args: + fn: The function to be wrapped. + + Returns: + The wrapped function. + + """ + + def inner(*args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception as e: + raise FallbackError( + f'[Fallback] An exception occurred when processing break graph, fallback to dygraph, error message is: \n{type(e)} : {e}\n' + ) + + return inner + + +class OpcodeExecutorBase: + """ + Base class for executing opcode instructions. + + The OpcodeExecutorBase class provides methods and functionality to execute opcode instructions. + + If you want to learn more about Python instructions, see https://docs.python.org/3/library/dis.html for details. + + Args: + code: The bytecode of the function to be executed. + graph: The function graph. + + Attributes: + call_stack (list[OpcodeExecutorBase]): A list to keep track of the call stack. + _stack (list[VariableBase]): The stack used for storing variables during execution. + _co_consts: List to store constants. + _locals (dict): Dictionary to store local variables. + _globals (dict): Dictionary to store global variables. + _builtins (dict): Dictionary to store built-in variables. + _lasti (int): Index of the last executed instruction. + _code (types.CodeType): The code object to be executed. + _instructions: Iterator of opcode instructions. + _graph (FunctionGraph): The function graph representing the code. + _current_line: The current line number of the execution. + new_code: Placeholder for new code (to be generated by PyCodeGen). + guard_fn: Placeholder for guard function. + _name (str): Name of the executor. + + """ + + call_stack: list[OpcodeExecutorBase] = [] + + @staticmethod + def validate_value(value): + assert isinstance( + value, VariableBase + ), f"value: {value}, type shoule be VariableBase(or derived), but get {type(value)}" + assert not isinstance(value.tracker, DanglingTracker) or isinstance( + value, (NullVariable, CellVariable) + ), f"dangling variable {value} should not be pushed into stack." + + def __init__(self, code: types.CodeType, graph: FunctionGraph): + OpcodeExecutorBase.call_stack.append(self) + # fake env for run, new env should be gened by PyCodeGen + self.stack = VariableStack(validate_value_func=self.validate_value) + self._co_consts = [] + self._locals = {} + self._globals: GlobalVariable = None # type: ignore + self._builtins = {} + self._cells = {} # position to put cells + self._lasti = 0 # idx of instruction list + self._code = code + self._current_line: int = -1 + self._instructions = get_instructions(self._code) + self._graph = graph + self.new_code: types.CodeType | None = None + self.guard_fn = None + self._name = "Executor" + self._call_shape: tuple[ + str, ... + ] | None = None # store kwnames for Python 3.11+ + self._prepare_virtual_env() + + self.stop_state = None + + def print_sir(self): + """ + Prints the Static Instruction Representation (SIR) in the executor. + + """ + print(self._graph.sir_ctx.TOS) + + def _prepare_virtual_env(self): + """ + Prepares the virtual environment for the executor. + + Raises: + NotImplementedError: If the method is not implemented. + + """ + raise NotImplementedError("Please implement virtual_env.") + + def _break_graph_in_jump(self, result, instr: Instruction): + """ + Breaks the graph in JUMP instructions. + + Args: + result: The execution result. + instr: The jump instruction. + + Raises: + NotImplementedError: If the method is not implemented. + + """ + raise NotImplementedError() + + def transform(self): + """ + Abstract method need to be implemented to symbolic translate each instruction. + + Raises: + NotImplementedError: If the method is not implemented. + + """ + raise NotImplementedError() + + def get_var(self, name: str): + """ + Gets the variable with the given name. + + Args: + name: The name of the variable. + + Returns: + The variable. + + Raises: + InnerError: If the variable cannot be found. + + """ + if name in self._locals.keys(): + return self._locals[name] + elif name in self._cells.keys(): # in closure + return self._cells[name].cell_content() + elif name in self._globals.keys(): + return self._globals.get(name) + elif name in self._builtins.keys(): + return self._builtins[name] + else: + raise InnerError(f'Can not get var: {name}') + + def has_var(self, name: str, space: str = "any"): + if space == "any": + return name in set( + chain( + self._locals.keys(), + self._cells.keys(), + self._globals.keys(), + self._builtins.keys(), + ) + ) + elif space == Space.locals: + return name in self._locals + elif space == Space.cells: + return name in self._cells + elif space == Space.globals: + return name in set( + chain( + self._globals.keys(), + self._builtins.keys(), + ) + ) + return False + + def pop_call_stack_until_self(self): + """ + Pops the call stack until the current executor. + + """ + assert ( + self in OpcodeExecutorBase.call_stack + ), f"{self} not in call stack" + while OpcodeExecutorBase.call_stack.pop() is not self: + pass + + @staticmethod + def error_message_summary(original_error: Exception) -> str: + """ + Creates a summary of the error message during execution. + + Args: + original_error: The original error. + + Returns: + The summary error message. + + """ + indent = 2 * " " + message_lines = ["In simulate execution:", ""] + for current_simulator in OpcodeExecutorBase.call_stack: + code = current_simulator._code + current_line = current_simulator._current_line + lines, start = inspect.getsourcelines(code) + real_name = code.co_name + message_lines.append( + f"{indent} File \"{code.co_filename}\", line {current_line}, in {real_name}" + ) + if current_line != -1: + message_lines.append( + f"{indent} {lines[current_line-start].rstrip()}" + ) + error_message = traceback.format_exception_only( + type(original_error), original_error + ) + for line in error_message: + line = line.rstrip() + message_lines.append(f"{indent} {line}") + return "\n".join(message_lines) + + def run(self): + """ + Executes the opcode. + + """ + log(3, f"start execute opcode: {self._code}\n") + self._lasti = 0 + while True: + if self._lasti >= len(self._instructions): + raise InnerError("lasti out of range, InnerError.") + cur_instr = self._instructions[self._lasti] + self._lasti += 1 + is_stop = self.step(cur_instr) + if is_stop: + self.stop_state = is_stop.state + self.pop_call_stack_until_self() + break + + def step(self, instr: Instruction): + """ + Executes a single step of the opcode. + + Args: + instr: The instruction to be executed. + + Returns: + True if execution should stop, False otherwise. + + Raises: + FallbackError: If the opcode is not supported. + + """ + if instr.starts_line is not None: + self._current_line = instr.starts_line + if not hasattr(self, instr.opname): + raise FallbackError(f"opcode: {instr.opname} is not supported.") + log_message = f"[Translate {self._name}]: (line {self._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {self.stack}\n" + log(3, log_message) + code_file = self._code.co_filename + code_line = self._current_line + code_name = self._code.co_name + code_offset = instr.offset + from ..breakpoint import BreakpointManager + + if BreakpointManager().hit( + code_file, code_line, code_name, code_offset + ): + BreakpointManager().locate(self) + print(log_message) + breakpoint() # breakpoint for debug + + with EventGuard(f"{instr.opname}", event_level=1): + return getattr(self, instr.opname)(instr) # run single step. + + def indexof(self, instr: Instruction): + """ + Gets the index of the instruction. + + Args: + instr: The instruction. + + Returns: + The index of the instruction. + + """ + return self._instructions.index(instr) + + def jump_to(self, instr: Instruction): + """ + Jumps to the given instruction. + + Args: + instr: The instruction to jump to. + + """ + self._lasti = self.indexof(instr) + + def COPY(self, instr: Instruction): + assert isinstance(instr.arg, int) + self.stack.push(self.stack.peek[instr.arg]) + + def DUP_TOP(self, instr: Instruction): + self.stack.push(self.stack.top) + + def DUP_TOP_TWO(self, instr: Instruction): + for ref in self.stack.peek[:2]: + self.stack.push(ref) + + def ROT_N(self, instr: Instruction): + assert instr.argval is not None + self._rot_top_n(instr.argval) + + def _rot_top_n(self, n: int): + # a1 a2 a3 ... an <- TOS + # the stack changes to + # an a1 a2 a3 an-1 <- TOS + assert ( + len(self.stack) >= n + ), f"There are not enough elements on the stack. {n} is needed." + top = self.stack.pop() + self.stack.insert(n - 1, top) + + def POP_TOP(self, instr: Instruction): + self.stack.pop() + + def PUSH_NULL(self, instr: Instruction): + self.stack.push(NullVariable()) + + def ROT_TWO(self, instr: Instruction): + self._rot_top_n(2) + + def ROT_THREE(self, instr: Instruction): + self._rot_top_n(3) + + def ROT_FOUR(self, instr: Instruction): + self._rot_top_n(4) + + def RESUME(self, instr: Instruction): + # RESUME is a no-op, it just for internal tracing, debugging and optimization checks. + pass + + def SWAP(self, instr: Instruction): + assert isinstance(instr.arg, int) + self.stack.top, self.stack.peek[instr.arg] = ( + self.stack.peek[instr.arg], + self.stack.top, + ) + + # unary operators + UNARY_POSITIVE = tos_op_wrapper(operator.pos) + UNARY_NEGATIVE = tos_op_wrapper(operator.neg) + UNARY_NOT = tos_op_wrapper(operator.not_) + UNARY_INVERT = tos_op_wrapper(operator.invert) + + # binary operators + BINARY_POWER = tos_op_wrapper(operator.pow) + BINARY_MULTIPLY = tos_op_wrapper(operator.mul) + BINARY_MATRIX_MULTIPLY = tos_op_wrapper(operator.matmul) + BINARY_FLOOR_DIVIDE = tos_op_wrapper(operator.floordiv) + BINARY_TRUE_DIVIDE = tos_op_wrapper(operator.truediv) + BINARY_MODULO = tos_op_wrapper(operator.mod) + BINARY_ADD = tos_op_wrapper(operator.add) + BINARY_SUBTRACT = tos_op_wrapper(operator.sub) + BINARY_LSHIFT = tos_op_wrapper(operator.lshift) + BINARY_RSHIFT = tos_op_wrapper(operator.rshift) + BINARY_AND = tos_op_wrapper(operator.and_) + BINARY_OR = tos_op_wrapper(operator.or_) + BINARY_XOR = tos_op_wrapper(operator.xor) + + def BINARY_OP(self, instr: Instruction): + opname, _ = opcode._nb_ops[instr.arg] + opname = ( + opname.replace("NB_", "BINARY_") + .replace("BINARY_INPLACE", "INPLACE") + .replace("REMAINDER", "MODULO") + ) + return getattr(self, opname)(instr) + + @call_break_graph_decorator(push_n=1) + def BINARY_SUBSCR(self, instr: Instruction): + key = self.stack.pop() + container = self.stack.pop() + assert isinstance(key, VariableBase) + # TODO(xiongkun): getitem / getattr support key and attr as variable. + if isinstance(key, TensorVariable) and isinstance( + container, TensorVariable + ): + # NOTE(xiongkun): tensor[tensor] should support. + output = self._graph.call_tensor_method( + "__getitem__", container, key + ) + self.stack.push(output) + return + + if isinstance(key, TensorVariable): + raise BreakGraphError( + f"Key is a TensorVariable in BINARY_SUBSCR, {container}[{key}]" + ) + + result = BuiltinVariable( + operator.getitem, self._graph, DanglingTracker() + )(container, key) + self.stack.push(result) + + # inplace operators + # paddle variable do not have inplace operators. For example when call `y **= x`, will call var.__pow__ + INPLACE_POWER = tos_inplace_op_wrapper(operator.ipow) + INPLACE_MULTIPLY = tos_inplace_op_wrapper(operator.imul) + INPLACE_MATRIX_MULTIPLY = tos_inplace_op_wrapper(operator.imatmul) + INPLACE_FLOOR_DIVIDE = tos_inplace_op_wrapper(operator.ifloordiv) + INPLACE_TRUE_DIVIDE = tos_inplace_op_wrapper(operator.itruediv) + INPLACE_MODULO = tos_inplace_op_wrapper(operator.imod) + INPLACE_ADD = tos_inplace_op_wrapper(operator.iadd) + INPLACE_SUBTRACT = tos_inplace_op_wrapper(operator.isub) + INPLACE_LSHIFT = tos_inplace_op_wrapper(operator.ilshift) + INPLACE_RSHIFT = tos_inplace_op_wrapper(operator.irshift) + INPLACE_AND = tos_inplace_op_wrapper(operator.iand) + INPLACE_OR = tos_inplace_op_wrapper(operator.ior) + INPLACE_XOR = tos_inplace_op_wrapper(operator.ixor) + + def NOP(self, instr: Instruction): + pass + + @call_break_graph_decorator(push_n=1) + def LOAD_ATTR(self, instr: Instruction): + attr_name = self._code.co_names[instr.arg] + attr_name_var = ConstantVariable.wrap_literal(attr_name, self._graph) + obj = self.stack.pop() + self.stack.push( + BuiltinVariable( + getattr, graph=self._graph, tracker=DanglingTracker() + )(obj, attr_name_var) + ) + + def LOAD_CONST(self, instr: Instruction): + var = self._co_consts[instr.arg] + self.stack.push(var) + + def MAKE_CELL(self, instr: Instruction): + self._locals[instr.argval] = self._cells[instr.argval] + + def LOAD_CLOSURE(self, instr: Instruction): + if sys.version_info >= (3, 11): + self.LOAD_FAST(instr) + return + namemap = self._code.co_cellvars + self._code.co_freevars + name = namemap[instr.arg] + self.stack.push(self._cells[name]) + + def LOAD_DEREF(self, instr: Instruction): + if sys.version_info >= (3, 11): + self.stack.push(self._locals[instr.argval].cell_content()) + return + namemap = self._code.co_cellvars + self._code.co_freevars + name = namemap[instr.arg] + self.stack.push(self._cells[name].cell_content()) + + def COPY_FREE_VARS(self, instr: Instruction): + for i in range(instr.arg): + freevar_name = self._code.co_freevars[i] + self._locals[freevar_name] = self._cells[freevar_name] + + def LOAD_FAST(self, instr: Instruction): + var = self._locals[instr.argval] + self.stack.push(var) + + def DELETE_FAST(self, instr: Instruction): + varname = self._code.co_varnames[instr.arg] + del self._locals[varname] + + def LOAD_GLOBAL(self, instr: Instruction): + namei: int = instr.arg + push_null = False + if sys.version_info >= (3, 11): + push_null = namei & 1 + namei >>= 1 + if push_null: + self.stack.push(NullVariable()) + name = self._code.co_names[namei] + if name in self._globals.keys(): + value = self._globals.get(name) + elif name in self._builtins.keys(): + value = self._builtins[name] + else: + raise InnerError(f"{name} not in globals and builtins") + self.stack.push(value) + + def LOAD_METHOD(self, instr: Instruction): + method_name = self._code.co_names[instr.arg] + method_name_var = ConstantVariable.wrap_literal( + method_name, self._graph + ) + obj = self.stack.pop() + + method = BuiltinVariable( + getattr, graph=self._graph, tracker=DanglingTracker() + )(obj, method_name_var) + + if isinstance(method, MethodVariable): + # bound method, push the unbound method and the self + self.stack.push(method.fn) + self.stack.push(obj) + else: + # unbound method, push the dummy and the function + self.stack.push(NullVariable()) + self.stack.push(method) + + @call_break_graph_decorator(push_n=0) + def STORE_ATTR(self, instr: Instruction): + obj = self.stack.pop() + val = self.stack.pop() + key = self._code.co_names[instr.arg] + key_var = ConstantVariable.wrap_literal(key, self._graph) + BuiltinVariable( + setattr, self._graph, DummyTracker([obj, key_var, val]) + )(obj, key_var, val) + + def DELETE_ATTR(self, instr: Instruction): + obj = self.stack.pop() + key = instr.argval + key_var = ConstantVariable.wrap_literal(key, self._graph) + BuiltinVariable(delattr, self._graph, DummyTracker([obj, key_var]))( + obj, key_var + ) + + def STORE_DEREF(self, instr: Instruction): + if sys.version_info >= (3, 11): + self._cells[instr.argval].set_value(self.stack.pop()) + self._locals[instr.argval] = self._cells[instr.argval] + return + namemap = self._code.co_cellvars + self._code.co_freevars + name = namemap[instr.arg] + self._cells[name].set_value(self.stack.pop()) + + def STORE_FAST(self, instr: Instruction): + """ + TODO: side effect may happen + """ + var = self.stack.pop() + name = self._code.co_varnames[instr.arg] + var.debug_name = name + self._locals[name] = var + + def STORE_GLOBAL(self, instr: Instruction): + var = self.stack.pop() + name = self._code.co_names[instr.arg] + var.debug_name = name + self._globals.set(name, var) + + def DELETE_GLOBAL(self, instr: Instruction): + self._globals.delete(self._code.co_names[instr.arg]) + + @call_break_graph_decorator(push_n=0) + def STORE_SUBSCR(self, instr: Instruction): + key = self.stack.pop() + container = self.stack.pop() + value = self.stack.pop() + assert isinstance(key, VariableBase) + self._graph.add_global_guarded_variable(key) + if isinstance(key, TensorVariable): + raise BreakGraphError( + f"Key is a TensorVariable in STORE_SUBSCR, {container}[{key}] = {value}" + ) + # TODO(xiongkun): support tensor[tensor] = tensor, dy2static is not the same with dygraph. + container[key.get_py_value()] = value + value.debug_name = f"{container.debug_name}[{key.debug_name}]" + + def DELETE_SUBSCR(self, instr: Instruction): + key = self.stack.pop() + container = self.stack.pop() + assert isinstance(key, VariableBase) + self._graph.add_global_guarded_variable(key) + BuiltinVariable(operator.delitem, self._graph, DanglingTracker())( + container, key + ) + + def BUILD_LIST(self, instr: Instruction): + list_size = instr.arg + assert list_size <= len( + self.stack + ), f"OpExecutor want BUILD_LIST with size {list_size}, but current stack do not have enough elems." + val_list = self.stack.pop_n(list_size) + self.stack.push( + ListVariable( + val_list, graph=self._graph, tracker=DummyTracker(val_list) + ) + ) + + def BUILD_TUPLE(self, instr: Instruction): + tuple_size = instr.arg + assert tuple_size <= len( + self.stack + ), f"OpExecutor want BUILD_TUPLE with size {tuple_size}, but current stack do not have enough elems." + val_tuple = self.stack.pop_n(tuple_size) + self.stack.push( + TupleVariable( + tuple(val_tuple), + graph=self._graph, + tracker=DummyTracker(val_tuple), + ) + ) + + def BUILD_STRING(self, instr: Instruction): + count = instr.arg + assert count <= len( + self.stack + ), f"OpExecutor want BUILD_STRING with size {count}, but current stack do not have enough elems." + str_list = self.stack.pop_n(count) + new_str = '' + for s in str_list: + assert s.get_py_type() is str + new_str += s.get_py_value() + self.stack.push( + ConstantVariable(new_str, self._graph, DummyTracker(str_list)) + ) + + @call_break_graph_decorator(push_n=1) + def BUILD_SLICE(self, instr: Instruction): + if instr.arg == 3: + step = self.stack.pop() + else: + step = ConstantVariable.wrap_literal(None, self._graph) + stop = self.stack.pop() + start = self.stack.pop() + + self.stack.push( + SliceVariable( + slice(start, stop, step), + graph=self._graph, + tracker=DummyTracker([start, stop, step]), + ) + ) + + def build_map( + self, keys: list[VariableBase], values: list[VariableBase] + ) -> VariableBase: + built_map = {} + for key, value in zip(keys, values): + assert isinstance(key, VariableBase) + # Add key to global guarded variable to avoid missing the key guard + self._graph.add_global_guarded_variable(key) + key = key.get_py_value() + built_map[key] = value + return DictVariable( + built_map, + graph=self._graph, + tracker=DummyTracker(keys + values), + ) + + def BUILD_MAP(self, instr: Instruction): + map_size = instr.arg + assert map_size * 2 <= len( + self.stack + ), f"OpExecutor want BUILD_MAP with size {map_size} * 2, but current stack do not have enough elems." + val_for_dict = self.stack.pop_n(map_size * 2) + keys = val_for_dict[::2] + values = val_for_dict[1::2] + self.stack.push(self.build_map(keys, values)) + + def BUILD_CONST_KEY_MAP(self, instr: Instruction): + map_size = instr.arg + assert map_size + 1 <= len( + self.stack + ), f"OpExecutor want BUILD_CONST_KEY_MAP with size {map_size} + 1, but current stack do not have enough elems." + keys = self.stack.pop().get_items() + assert len(keys) == map_size + values = self.stack.pop_n(map_size) + self.stack.push(self.build_map(keys, values)) + + def build_seq_unpack(self, instr: Instruction): + oparg = instr.arg + assert isinstance(oparg, int) + unpack_values = self.stack.pop_n(oparg) + + retval = [] + for item in unpack_values: + assert isinstance(item, (TupleVariable, ListVariable)) + retval.extend(item.get_wrapped_items()) + + if instr.opname in { + "BUILD_TUPLE_UNPACK_WITH_CALL", + "BUILD_TUPLE_UNPACK", + }: + retval = tuple(retval) + + self.stack.push( + VariableFactory.from_value( + retval, self._graph, DummyTracker(unpack_values) + ) + ) + + def BUILD_TUPLE_UNPACK_WITH_CALL(self, instr: Instruction): + self.build_seq_unpack(instr) + + def BUILD_TUPLE_UNPACK(self, instr: Instruction): + self.build_seq_unpack(instr) + + def BUILD_LIST_UNPACK(self, instr: Instruction): + self.build_seq_unpack(instr) + + def BUILD_MAP_UNPACK(self, instr: Instruction): + oparg = instr.arg + assert isinstance(oparg, int) + unpack_values = self.stack.pop_n(oparg) + + retval = {} + for item in unpack_values: + assert item.get_py_type() is dict + retval.update(item.get_wrapped_items()) + + self.stack.push( + VariableFactory.from_value( + retval, self._graph, DummyTracker(unpack_values) + ) + ) + + def BUILD_MAP_UNPACK_WITH_CALL(self, instr: Instruction): + oparg = instr.arg + assert isinstance(oparg, int) + unpack_values = self.stack.pop_n(oparg) + + retval = {} + for item in unpack_values: + assert item.get_py_type() is dict + wrapped_item = item.get_wrapped_items() + if wrapped_item.items() & retval.items(): + raise InnerError( + "BUILD_MAP_UNPACK_WITH_CALL found repeated key." + ) + retval.update(wrapped_item) + + self.stack.push( + VariableFactory.from_value( + retval, self._graph, DummyTracker(unpack_values) + ) + ) + + def PRECALL(self, instr: Instruction): + assert isinstance(instr.arg, int) + is_method_layout = not isinstance( + self.stack.peek[instr.arg + 2], NullVariable + ) + nargs = instr.arg + int(is_method_layout) + method = self.stack.peek[nargs + 1] + if not is_method_layout and isinstance(method, MethodVariable): + unbound_method = method.fn + self_var = method.bound_instance + self.stack.peek[nargs + 1] = self_var + self.stack.peek[nargs + 2] = unbound_method + + def KW_NAMES(self, instr: Instruction): + assert self._call_shape is None + assert isinstance(instr.arg, int) + self._call_shape = self._co_consts[instr.arg].get_py_value() + + @call_break_graph_decorator(push_n=1) + def CALL(self, instr: Instruction): + assert isinstance(instr.arg, int) + assert instr.arg + 2 <= len(self.stack) + is_method = not isinstance(self.stack.peek[instr.arg + 2], NullVariable) + total_args = instr.arg + int(is_method) + kwnames = self._call_shape if self._call_shape is not None else [] + n_kwargs = len(kwnames) + n_positional_args = total_args - n_kwargs + kwargs_list = self.stack.pop_n(n_kwargs) + kwargs = dict(zip(kwnames, kwargs_list)) + args = self.stack.pop_n(n_positional_args) + fn = self.stack.pop() + if not is_method: + # pop the NULL variable + self.stack.pop() + self.stack.push(fn(*args, **kwargs)) + self._call_shape = None + + @call_break_graph_decorator(push_n=1) + def CALL_FUNCTION(self, instr: Instruction): + assert isinstance(instr.arg, int) + n_args = instr.arg + assert isinstance(n_args, int) + args = self.stack.pop_n(n_args) + kwargs = {} + fn = self.stack.pop() + ret = fn(*args, **kwargs) + self.stack.push(ret) + + @call_break_graph_decorator(push_n=1) + def CALL_FUNCTION_KW(self, instr: Instruction): + n_args = instr.arg + assert n_args + 2 <= len(self.stack) + + kwargs_keys = self.stack.pop() + assert isinstance(kwargs_keys, TupleVariable) + assert len(kwargs_keys) > 0 + kwargs_keys = [ + x.get_py_value() if isinstance(x, VariableBase) else x + for x in kwargs_keys.get_py_value() + ] + + # split arg_list to args and kwargs + arg_list = self.stack.pop_n(n_args) + args = arg_list[: -len(kwargs_keys)] + kwargs_values = arg_list[-len(kwargs_keys) :] + kwargs = dict(zip(kwargs_keys, kwargs_values)) + + fn = self.stack.pop() + ret = fn(*args, **kwargs) + self.stack.push(ret) + + @call_break_graph_decorator(push_n=1) + def CALL_FUNCTION_EX(self, instr: Instruction): + flag = instr.arg + if flag & CFE.CFE_HAS_KWARGS: + kwargs_variable = self.stack.pop() + assert isinstance(kwargs_variable, DictVariable) + kwargs = kwargs_variable.get_wrapped_items() + else: + kwargs = {} + + args_variable = self.stack.pop() + assert isinstance(args_variable, (TupleVariable, ListVariable)) + args = args_variable.get_wrapped_items() + + fn = self.stack.pop() + if sys.version_info >= (3, 11): + null = self.stack.pop() + assert isinstance(null, NullVariable) + ret = fn(*args, **kwargs) + self.stack.push(ret) + + @call_break_graph_decorator(push_n=1) + def CALL_METHOD(self, instr: Instruction): + n_args = instr.arg + assert isinstance(n_args, int) + args = self.stack.pop_n(n_args) + self_var = self.stack.pop() + method = self.stack.pop() + if isinstance(method, NullVariable): + method = self_var + else: + args = [self_var] + args + self.stack.push(method(*args)) + + @call_break_graph_decorator( + push_n=1 + ) # call instance, in, not in may call TensorVariable.get_py_value, which raise BreakGraphError + def COMPARE_OP(self, instr: Instruction): + op = dis.cmp_op[instr.arg] + right, left = self.stack.pop(), self.stack.pop() + self.stack.push( + BuiltinVariable( + SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker() + )(left, right) + ) + + @call_break_graph_decorator(push_n=1) + def IS_OP(self, instr: Instruction): + # It will only be 0 or 1 + assert instr.arg == 0 or instr.arg == 1 + right, left = self.stack.pop(), self.stack.pop() + op = "is" if instr.arg == 0 else "is not" + self.stack.push( + BuiltinVariable( + SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker() + )(left, right) + ) + + def MAKE_FUNCTION(self, instr: Instruction): + if sys.version_info < (3, 11): + fn_name = self.stack.pop() + codeobj = self.stack.pop() + if sys.version_info >= (3, 11): + # MAKE_FUNCTION behavior actually changed in 3.11, see + # https://github.com/python/cpython/pull/93189/ + assert hasattr(codeobj.value, "co_qualname") + fn_name = ConstantVariable( + codeobj.value.co_qualname, self._graph, DummyTracker([codeobj]) + ) + + global_dict = self._globals.get_value() + + related_list = [fn_name, codeobj] + + flag = instr.arg + if flag & MF.MF_HAS_CLOSURE: + # closure should be a tuple of Variables + closure_variable = self.stack.pop() + assert isinstance(closure_variable, TupleVariable) + closure = [] + for item in closure_variable.get_wrapped_items(): + closure.append(types.CellType()) + closure[-1].cell_contents = item + closure = tuple(closure) + else: + closure = () + + if flag & MF.MF_HAS_ANNOTATION: + # can not set annotation in python env, skip it + related_list.append(self.stack.pop()) + + if flag & MF.MF_HAS_KWDEFAULTS: + raise FallbackError( + "Found need func_kwdefaults when MAKE_FUNCTION." + ) + + if flag & MF.MF_HAS_DEFAULTS: + ''' + default_args should have tracker too, like: + + def f(x): + def g(z=x): + pass + ''' + default_args_variable = self.stack.pop() + assert isinstance(default_args_variable, TupleVariable) + related_list.append(default_args_variable) + default_args = tuple(default_args_variable.get_wrapped_items()) + else: + default_args = () + + new_fn = types.FunctionType( + codeobj.get_py_value(), + global_dict, + fn_name.get_py_value(), + default_args, + closure, + ) + self.stack.push( + UserDefinedFunctionVariable( + new_fn, self._graph, DummyTracker(related_list) + ) + ) + + def GET_ITER(self, instr: Instruction): + source_obj = self.stack.pop() + iter_variable = BuiltinVariable(iter, self._graph, DanglingTracker())( + source_obj + ) + self.stack.push(iter_variable) + + def JUMP_ABSOLUTE(self, instr: Instruction): + assert instr.jump_to is not None + self.jump_to(instr.jump_to) + + def JUMP_FORWARD(self, instr: Instruction): + self.JUMP_ABSOLUTE(instr) + + def JUMP_BACKWARD(self, instr: Instruction): + # TODO: check interrupt + self.JUMP_ABSOLUTE(instr) + + def JUMP_BACKWARD_NO_INTERRUPT(self, instr: Instruction): + self.JUMP_ABSOLUTE(instr) + + @call_break_graph_decorator(push_n=1) + def CONTAINS_OP(self, instr: Instruction): + # It will only be 0 or 1 + assert instr.arg == 0 or instr.arg == 1 + right, left = self.stack.pop(), self.stack.pop() + op = "in" if instr.arg == 0 else "not in" + self.stack.push( + BuiltinVariable( + SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker() + )(left, right) + ) + + @jump_break_graph_decorator + def JUMP_IF_FALSE_OR_POP(self, instr: Instruction): + pred_obj = self.stack.top + if isinstance(pred_obj, (ConstantVariable, ContainerVariable)): + self._graph.add_global_guarded_variable(pred_obj) + is_jump = not bool(pred_obj) + if is_jump: + assert instr.jump_to is not None + self.jump_to(instr.jump_to) + else: + self.stack.pop() + return + raise FallbackError( + "Currently don't support predicate a non-const / non-tensor obj." + ) + + @jump_break_graph_decorator + def JUMP_IF_TRUE_OR_POP(self, instr: Instruction): + pred_obj = self.stack.top + if isinstance(pred_obj, (ConstantVariable, ContainerVariable)): + self._graph.add_global_guarded_variable(pred_obj) + is_jump = bool(pred_obj) + if is_jump: + assert instr.jump_to is not None + self.jump_to(instr.jump_to) + else: + self.stack.pop() + return + raise FallbackError( + "Currently don't support predicate a non-const / non-tensor obj." + ) + + POP_JUMP_IF_FALSE = pop_jump_if_op_wrapper([bool, operator.not_]) + POP_JUMP_FORWARD_IF_FALSE = POP_JUMP_IF_FALSE + POP_JUMP_BACKWARD_IF_FALSE = POP_JUMP_IF_FALSE + + POP_JUMP_IF_TRUE = pop_jump_if_op_wrapper([bool]) + POP_JUMP_FORWARD_IF_TRUE = POP_JUMP_IF_TRUE + POP_JUMP_BACKWARD_IF_TRUE = POP_JUMP_IF_TRUE + + POP_JUMP_FORWARD_IF_NONE = pop_jump_if_op_wrapper([operator_is_none]) + POP_JUMP_BACKWARD_IF_NONE = POP_JUMP_FORWARD_IF_NONE + + POP_JUMP_FORWARD_IF_NOT_NONE = pop_jump_if_op_wrapper( + [operator_is_not_none] + ) + POP_JUMP_BACKWARD_IF_NOT_NONE = POP_JUMP_FORWARD_IF_NOT_NONE + + @call_break_graph_decorator(push_n=lambda arg: arg) + def UNPACK_SEQUENCE(self, instr: Instruction): + sequence = self.stack.pop() + seq_iter = BuiltinVariable(iter, self._graph, DanglingTracker())( + sequence + ) + unpacked = [] + for _ in range(instr.arg): + unpacked.append(seq_iter.next()) + for item in reversed(unpacked): + self.stack.push(item) + + def UNPACK_EX(self, instr: Instruction): + getitem = BuiltinVariable( + operator.getitem, self._graph, DanglingTracker() + ) + assert instr.arg is not None + sequence = self.stack.pop() + if not isinstance( + sequence, (ListVariable, TupleVariable, TensorVariable) + ): + raise FallbackError(f"Unpack {sequence} is not implemented.") + + if instr.argval >= 256: + # NOTE: If the number of unpacked variables exceeds 256, python will report an error like: + # SyntaxError: too many expressions in star-unpacking assignmen, + # so if the number of unpacked variables exceeds 256, it will be treated as the following case. + # a, b, *c, d = e + front_nums = instr.arg & 0xFF + back_nums = instr.arg >> 8 + assert ( + len(sequence) >= front_nums + back_nums + ), f"Want unpack {sequence} to {front_nums + back_nums}, but {len(sequence)} is smaller than {front_nums + back_nums}." + + for i in range( + len(sequence) - 1, len(sequence) - back_nums - 1, -1 + ): + self.stack.push(getitem(sequence, i)) + + slice_var = SliceVariable( + slice(front_nums, len(sequence) - back_nums - 1), + self._graph, + DummyTracker([sequence]), + ) + else: + # a, b, c, *d = e + assert ( + len(sequence) >= instr.arg + ), f"Want unpack {sequence} to {instr.arg}, but {len(sequence)} is smaller than {instr.arg}." + + slice_obj = slice(instr.arg, None) + slice_var = SliceVariable( + slice_obj, self._graph, ConstTracker(slice_obj) + ) + front_nums = instr.arg + self.stack.push(getitem(sequence, slice_var)) + for i in range(front_nums - 1, -1, -1): + self.stack.push(getitem(sequence, i)) + + def FORMAT_VALUE(self, instr: Instruction): + flag = instr.arg + assert flag is not None + which_conversion = flag & FV.FVC_MASK + have_fmt_spec = bool((flag & FV.FVS_MASK) == FV.FVS_HAVE_SPEC) + + fmt_spec = self.stack.pop().get_py_value() if have_fmt_spec else "" + value = self.stack.pop() + + if which_conversion == FV.FVC_NONE: + convert_fn = None + elif which_conversion == FV.FVC_STR: + convert_fn = "__str__" + elif which_conversion == FV.FVC_REPR: + convert_fn = "__repr__" + elif which_conversion == FV.FVC_ASCII: + convert_fn = "__ascii__" + else: + raise InnerError( + f"Unexpected conversion flag {flag} for FORMAT_VALUE" + ) + + # different type will lead to different Tracker, so call self.stack.push in different branch + if isinstance(value, ConstantVariable): + result = value.get_py_value() + if convert_fn is not None: + result = getattr(result, convert_fn)(result) + + if not isinstance(result, str) or fmt_spec != "": + result = format(result, fmt_spec) + + self.stack.push( + ConstantVariable(result, self._graph, DummyTracker([value])) + ) + else: + raise FallbackError(f"Do not support format {type(value)} now") + + # NOTE: This operation will generate SideEffects, and the mechanism has not been completed yet + def DICT_UPDATE(self, instr: Instruction): + dict_value = self.stack.pop() + assert isinstance(instr.arg, int) + BuiltinVariable(dict.update, self._graph, tracker=DanglingTracker())( + self.stack.peek[instr.arg], dict_value + ) + + def DICT_MERGE(self, instr: Instruction): + dict_value = self.stack.pop() + assert isinstance(instr.arg, int) + for key in dict_value.get_wrapped_items().keys(): + result = ( + self.stack.peek[instr.arg].get_wrapped_items().get(key, None) + ) + if result is not None: + raise InnerError( + f"got multiple values for keyword argument '{key}'" + ) + BuiltinVariable(dict.update, self._graph, tracker=DanglingTracker())( + self.stack.peek[instr.arg], dict_value + ) + + def LIST_APPEND(self, instr: Instruction): + list_value = self.stack.pop() + assert isinstance(instr.arg, int) + BuiltinVariable(list.append, self._graph, tracker=DanglingTracker())( + self.stack.peek[instr.arg], list_value + ) + + def MAP_ADD(self, instr: Instruction): + key, value = self.stack.pop_n(2) + assert isinstance(instr.arg, int) + BuiltinVariable(operator.setitem, self._graph, DanglingTracker())( + self.stack.peek[instr.arg], key, value + ) + + def LIST_EXTEND(self, instr: Instruction): + list_value = self.stack.pop() + assert isinstance(instr.arg, int) + BuiltinVariable(list.extend, self._graph, tracker=DanglingTracker())( + self.stack.peek[instr.arg], list_value + ) + + def LIST_TO_TUPLE(self, instr: Instruction): + list_value = self.stack.pop() + self.stack.push( + TupleVariable( + list_value.get_wrapped_items(), + self._graph, + DummyTracker([list_value]), + ) + ) + + +class OpcodeExecutor(OpcodeExecutorBase): + """ + A class that represents an executor for opcode operations. + + Args: + frame: The frame object. + + """ + + def __init__(self, frame: types.FrameType, **kwargs): + graph = FunctionGraph(frame, **kwargs) + self._frame = frame + self._name = "Executor" + self.call_stack[:] = [] + super().__init__(frame.f_code, graph) + Dispatcher.graph = graph + + def cleanup(self): + self._graph.pycode_gen = None + Dispatcher.graph = None + + @event_register("OpcodeExecutor: _prepare_virtual_env", event_level=2) + def _prepare_virtual_env(self): + """ + Prepare the virtual environment for execution by adding variables from locals, globals, builtins, and constants. + + """ + log( + 3, + f"[Executor] code options: co_cellvars={self._frame.f_code.co_cellvars}\n", + ) + free_or_cell_vars = ( + self._frame.f_code.co_cellvars + self._frame.f_code.co_freevars + ) + for name, value in self._frame.f_locals.items(): + tracker = ( + CellTracker(name) + if name in free_or_cell_vars + else LocalTracker(name) + ) + self._locals[name] = VariableFactory.from_value( + value, self._graph, tracker, debug_name=name + ) + + for name in free_or_cell_vars: + # create a cell for each variable. + self._cells[name] = CellVariable() # put in cells. + if name in self._locals: + self._cells[name].set_value(self._locals[name]) + + self._globals = GlobalVariable( + self._frame.f_globals, + self._graph, + DanglingTracker(), + ) + + self._builtins = self._graph._builtins + + for value in self._code.co_consts: + self._co_consts.append( + VariableFactory.from_value( + value, self._graph, ConstTracker(value) + ) + ) + + def _create_resume_fn(self, index, stack_size=0): + """ + Create a resume function and its inputs at the specified index. + + Args: + index: The index at which the resume function is created. + stack_size: The size of the stack. + + Returns: + The resume function and its inputs. + + """ + pycode_gen = PyCodeGen(self._frame) + fn, inputs = pycode_gen.gen_resume_fn_at(index, stack_size) + return fn, inputs + + @fallback_when_occur_error + def _break_graph_in_jump(self, result: VariableBase, instr: Instruction): + """ + Break the graph at a JUMP instruction. + + Args: + result: The result variable of the jump instruction. + instr: The jump instruction. + + """ + self._graph.add_global_guarded_variable(result) + stack_size = len(self.stack) + if_fn, if_inputs = self._create_resume_fn( + self.indexof(instr) + 1, stack_size + ) + else_fn, else_inputs = self._create_resume_fn( + self.indexof(instr.jump_to), stack_size + ) + + # gen call static fn opcode + inputs_name = if_inputs | else_inputs + inputs_var = [ + self.get_var(name) + for name in inputs_name + if self.get_var(name) is not result + ] + ret_vars = [ + result, + ] + inputs_var + # Collect all the to store variables. + store_vars = [] + for stack_arg in self.stack: + store_vars.append(stack_arg) + for name in inputs_name: + store_vars.append(self.get_var(name)) + + var_loader = self._graph.start_compile_with_name_store( + ret_vars, store_vars + ) + # only pop the input of if/else resume fn, and keep the bool tensor result on the stack + for _ in inputs_var: + self._graph.pycode_gen.gen_pop_top() + + # gen call if/else resume fn opcode + if if_fn is not None: + self._graph.pycode_gen.gen_load_object( + if_fn, if_fn.__code__.co_name + ) + insert_index = len(self._graph.pycode_gen._instructions) - 1 + for stack_arg in self.stack: + var_loader.load(stack_arg) + for name in if_inputs: + var_loader.load(self.get_var(name)) + self._graph.pycode_gen.gen_call_function( + argc=if_fn.__code__.co_argcount, + ) + self._graph.pycode_gen.gen_return() + else: + insert_index = len(self._graph.pycode_gen._instructions) - 1 + self._graph.pycode_gen.gen_return() + + if else_fn is not None: + self._graph.pycode_gen.gen_load_object( + else_fn, else_fn.__code__.co_name + ) + jump_to = self._graph.pycode_gen._instructions[-1] + for stack_arg in self.stack: + var_loader.load(stack_arg) + for name in else_inputs: + var_loader.load(self.get_var(name)) + self._graph.pycode_gen.gen_call_function( + argc=else_fn.__code__.co_argcount, + ) + self._graph.pycode_gen.gen_return() + else: + self._graph.pycode_gen.gen_return() + jump_to = self._graph.pycode_gen._instructions[-1] + + # gen jump opcode + self._graph.pycode_gen._insert_instr( + insert_index, instr.opname, jump_to=jump_to + ) + + self.new_code = self._graph.pycode_gen.gen_pycode() + self.guard_fn = self._graph.guard_fn + + @fallback_when_occur_error + def _break_graph_in_call( + self, + origin_stack: VariableStack, + instr: Instruction, + push_n: int | Callable[[int | None], int], + ): + """ + Break the graph at a CALL instruction. + + Args: + origin_stack: The original stack. + instr: The call instruction. + push_n: The number of elements to be pushed onto the stack. + + """ + push_n = push_n(instr.arg) if callable(push_n) else push_n + index = self.indexof(instr) + self.stack = origin_stack + + # gen call static fn opcode + ret_vars = [ + arg + for arg in self.stack + if isinstance(arg, (TensorVariable, ContainerVariable)) + ] + resume_input_name = analysis_inputs(self._instructions, index + 1) + ret_vars = ret_vars + [ + self.get_var(name) + for name in resume_input_name + if self.get_var(name) not in ret_vars + ] + + # Collect all the to store variables. + store_vars = [] + for stack_arg in self.stack: + store_vars.append(stack_arg) + for name in resume_input_name: + store_vars.append(self.get_var(name)) + var_loader = self._graph.start_compile_with_name_store( + ret_vars, store_vars + ) + + for _ in ret_vars: + self._graph.pycode_gen.gen_pop_top() + + # gen graph break call fn opcode + stack_effect = calc_stack_effect(instr) + pop_n = push_n - stack_effect + + for i, stack_arg in enumerate(self.stack): + # Avoid passing NULL as a parameter to the resume function + if ( + isinstance(stack_arg, NullVariable) + and i < len(self.stack) - pop_n + ): + self._graph.pycode_gen.gen_load_object( + NullVariable(), f'null_var_{i}', push_null=False + ) + else: + var_loader.load(stack_arg) + + # gen call resume fn opcode + # NOTE(SigureMo): In Python 3.11,we need generate KW_NAMES if the call shape is not None. + self._graph.pycode_gen.gen_kw_names(self._call_shape) + self._graph.pycode_gen.add_pure_instructions([instr]) + self.stack.pop_n(pop_n) + stack_size = len(self.stack) + push_n + + resume_fn, _ = self._create_resume_fn(index + 1, stack_size) + if resume_fn: + self._graph.pycode_gen.gen_load_object( + resume_fn, resume_fn.__code__.co_name + ) + # NOTE(zrr1999): We need to shift the resume_fn under its arguments. + # In Python 3.11+, NULL + resume_fn should be shifted together. + shift_n = 2 if sys.version_info >= (3, 11) else 1 + self._graph.pycode_gen.gen_shift_n(shift_n, stack_size + shift_n) + for name in resume_input_name: + var_loader.load(self.get_var(name)) + self._graph.pycode_gen.gen_call_function( + argc=resume_fn.__code__.co_argcount, + ) + + # gen RETURN_VALUE + self._graph.pycode_gen.gen_return() + + self.new_code = self._graph.pycode_gen.gen_pycode() + self.guard_fn = self._graph.guard_fn + + def transform(self): + self.run() + if self.new_code is None: + raise InnerError("OpExecutor return a empty new_code.") + # stopped by RETURN_VALUE and has sir len is enough => disable_eval_frame + simulate_complete = bool(self.stop_state == "Return") + if simulate_complete: + if self._graph.sir_ctx.TOS.graph_size() < min_graph_size(): + raise FallbackError( + "Fallback after simulate for reasons.", + disable_eval_frame=True, + ) + else: + # if simulate stop with graph successfully, the all codes will be + # surrounded by the eval_frame triggers which exist in self.new_code + # we need not set disable_eval_frame=False here (for it already is) + return ( + CustomCode(self.new_code, True), + self.guard_fn, + ) + else: + # if return because breakgraph, need open eval_frame + return ( + CustomCode(self.new_code, False), + self.guard_fn, + ) + + def _gen_loop_body_between( + self, inputs: list, for_iter_idx: int, start: int, end: int + ) -> types.FunctionType: + """ + Generates the loop body between the specified indices in the instruction list. + + Args: + inputs: function inputs infos + for_iter_idx (int): For find the for_iter opcode + start (int): The start index of the loop body. + end (int): The end index of the loop body. + + Returns: + tuple: The generated loop body function object and its inputs. + + """ + pycode_gen = PyCodeGen(self._frame) + origin_instrs = get_instructions(pycode_gen._origin_code) + + for_iter = origin_instrs[for_iter_idx] + + # for balance the stack (the loop body will pop iter first before break or return) + # this None is used for replace the iterator obj in stack top + pycode_gen.gen_load_const(None) + + # extend loop body main logic + pycode_gen.extend_instrs(origin_instrs[start:end]) + + # break should jump to this nop + nop_for_break = pycode_gen._add_instr("NOP") + + # need do additional operates when break + pycode_gen.gen_load_const(False) + pycode_gen.gen_store_fast(inputs[-1]) + pycode_gen.gen_load_const(None) # keep stack balance + + # continue should jump to this nop + nop_for_continue = pycode_gen._add_instr("NOP") + pycode_gen.gen_pop_top() + + # relocate jump + out_loop = for_iter.jump_to + for instr in pycode_gen._instructions: + if instr.jump_to == for_iter: + instr.jump_to = nop_for_continue + if instr.jump_to == out_loop: + instr.jump_to = nop_for_break + + # outputs is the same as inputs + pycode_gen.gen_outputs_and_return(inputs) + return pycode_gen.create_fn_with_inputs(inputs) + + @fallback_when_occur_error + def _break_graph_in_for_loop( + self, iterator: VariableBase, for_iter: Instruction + ): + ''' + for_iter: the FOR_ITER opcode + + need find out opcodes which unpack value from FOR_ITER, by analysing stack + + case 1: + for i in iter: + + FOR_ITER + STORE_FAST i + + case 2: + for i,j in iter: + + FOR_ITER + UNPACK_SEQUENCE 2 + STORE_FAST i + STORE_FAST j + + TODO: check var is in globals or builtins, only locals considered now + ''' + # 0. prepare sub functions + # 0.1 find the range of loop body + assert for_iter.jump_to is not None + loop_body_start_idx = self.indexof(for_iter) + 1 + loop_body_end_idx = self.indexof(for_iter.jump_to) + curent_stack = 1 + + while True: + if loop_body_start_idx >= len(self._instructions): + raise InnerError("Can not balance stack in loop body.") + cur_instr = self._instructions[loop_body_start_idx] + # do not consider jump instr + stack_effect = calc_stack_effect(cur_instr, jump=False) + curent_stack += stack_effect + loop_body_start_idx += 1 + if curent_stack == 0: + break + + # 0.2 create loop body function + all_used_vars = analysis_used_names_with_space( + self._instructions, loop_body_start_idx, loop_body_end_idx + ) + loop_body_inputs = [ + k + for k, v in all_used_vars.items() + if v in (Space.locals, Space.cells) + ] + ["_break_flag"] + + loop_body_fn = self._gen_loop_body_between( + loop_body_inputs, + self.indexof(for_iter), + loop_body_start_idx, + loop_body_end_idx, + ) + + log(3, "[Resumed Function]: break graph in loop create loop body as\n") + log_do(3, lambda: dis.dis(loop_body_fn)) + + # 0.3 create after loop part function + after_loop_fn, fn_inputs = self._create_resume_fn( + loop_body_end_idx, len(self.stack) + ) + + total_inputs = OrderedSet(list(fn_inputs) + list(loop_body_inputs[:-1])) + + # 1. part before for-loop, start compile + ret_names = [ + name + for name in total_inputs + if name in chain(self._locals, self._cells) + ] + ret_vars = [self.get_var(name) for name in ret_names] + store_vars = [ret_vars[idx] for idx in range(len(ret_names))] + store_vars.extend(iter(self.stack)) + store_vars.append(iterator.get_hold()) + var_loader = self._graph.start_compile_with_name_store( + ret_vars, store_vars + ) + + for _ in ret_vars: + self._graph.pycode_gen.gen_pop_top() + + # 2. restore vars + for idx in range(len(ret_names)): + var_loader.load(ret_vars[idx]) + self._graph.pycode_gen.gen_store(ret_names[idx], self._code) + + # 3. setup vars which is created in loop + undefined_names = set() + for name in loop_body_inputs[:-1]: + if not self.has_var(name, all_used_vars[name]): + undefined_names.add(name) + self._graph.pycode_gen.gen_load_const(SotUndefinedVar()) + self._graph.pycode_gen.gen_store(name, self._code) + + # close eval_frame + # TODO: need support effective strategies + # self._graph.pycode_gen.gen_disable_eval_frame() + + # 4.1 load iterator + iterator.reconstruct(self._graph.pycode_gen) + + # 4.2 gen FOR_ITER and unpack data + self._graph.pycode_gen.extend_instrs( + self._instructions[self.indexof(for_iter) : loop_body_start_idx] + ) + + # 5. call loop body + # 5.1 load loop body + self._graph.pycode_gen.gen_load_object( + loop_body_fn, loop_body_fn.__code__.co_name + ) + + # 5.2 load loop body inputs + for name in loop_body_inputs[:-1]: + self._graph.pycode_gen.gen_load(name) + + # 5.3 load break flag + self._graph.pycode_gen.gen_load_const(True) + + # 5.4 call loop body + self._graph.pycode_gen.gen_call_function( + argc=loop_body_fn.__code__.co_argcount + ) + + # 5.5 unpack and store retval, keep break_flag in stack + self._graph.pycode_gen.gen_unpack_sequence(len(loop_body_inputs)) + + for name in loop_body_inputs[:-1]: + self._graph.pycode_gen.gen_store(name, self._code) + + # 6. add jump if break + jump_if_break = self._graph.pycode_gen.gen_pop_jump( + direction=JumpDirection.FORWARD, suffix=PopJumpCond.FALSE + ) + + # 7. jump back to FOR_ITER + self._graph.pycode_gen.gen_jump( + for_iter, direction=JumpDirection.BACKWARD + ) + nop = self._graph.pycode_gen._add_instr("NOP") + for_iter.jump_to = nop + jump_if_break.jump_to = nop + + # open eval_frame + # TODO: need support effective strategies + # self._graph.pycode_gen.gen_enable_eval_frame() + + # 8. call after_loop_fn + self._graph.pycode_gen.gen_load_object( + after_loop_fn, after_loop_fn.__code__.co_name + ) + + for stack_arg in self.stack: + var_loader.load(stack_arg) + for name in fn_inputs: + if not self.has_var(name) and name not in undefined_names: + undefined_names.add(name) + self._graph.pycode_gen.gen_load_const(SotUndefinedVar()) + self._graph.pycode_gen.gen_store(name, self._code) + self._graph.pycode_gen.gen_load(name) + + self._graph.pycode_gen.gen_call_function( + argc=after_loop_fn.__code__.co_argcount + ) + + self._graph.pycode_gen.gen_return() + self.new_code = self._graph.pycode_gen.gen_pycode() + self.guard_fn = self._graph.guard_fn + + def _inline_call_for_loop( + self, iterator: VariableBase, for_iter: Instruction + ): + assert for_iter.jump_to is not None + pycode_gen = PyCodeGen(self._frame) + origin_instrs = get_instructions(pycode_gen._origin_code) + + start_idx = self.indexof(for_iter) + end_idx = self.indexof(for_iter.jump_to) + + all_used_vars = analysis_used_names_with_space( + origin_instrs, start_idx, end_idx + ) + + inputs = [ + k + for k, v in all_used_vars.items() + if v in (Space.locals, Space.cells) + ] + [iterator.id] + + # 1. load iter + pycode_gen.gen_load_fast(iterator.id) + + # 2. copy main logic + pycode_gen.extend_instrs(origin_instrs[start_idx:end_idx]) + + # 3. add break, continue marker and relocate jump + for_iter_instr = origin_instrs[start_idx] + assert for_iter_instr.jump_to is not None + out_loop_instr = for_iter_instr.jump_to + + pycode_gen.gen_jump(out_loop_instr, direction=JumpDirection.FORWARD) + nop_for_continue = pycode_gen._add_instr("NOP") + + jump = pycode_gen.gen_jump( + for_iter_instr, direction=JumpDirection.BACKWARD + ) + + nop_for_break = pycode_gen._add_instr("NOP") + + for instr in pycode_gen._instructions: + if instr.jump_to == for_iter_instr: + instr.jump_to = nop_for_continue + + if ( + instr.jump_to in origin_instrs + and origin_instrs.index(instr.jump_to) >= end_idx + ): + instr.jump_to = nop_for_break + + jump.jump_to = for_iter_instr + pycode_gen.gen_outputs_and_return(inputs) + inline_call_fn = pycode_gen.create_fn_with_inputs(inputs) + + log( + 3, + f"[Resumed Function]: Inline call for loop function {inline_call_fn.__code__.co_name}\n", + ) + log_do(3, lambda: dis.dis(inline_call_fn)) + + # TODO: update globals builtins + fn = UserDefinedFunctionVariable( + inline_call_fn, + self._graph, + DanglingTracker(), + ) + + input_vars = [ + self.get_var(name) + if self.has_var(name, all_used_vars[name]) + else SotUndefinedVar() + for name in inputs[:-1] + ] + [iterator] + ret = fn(*input_vars) + # slice_variable is [:-1] + slice_const = slice(None, -1, None) + slice_variable = SliceVariable( + slice_const, self._graph, ConstTracker(slice_const) + ) + for name, val in zip(inputs[:-1], ret[slice_variable]): + self._locals[name] = val + + def FOR_ITER(self, instr): + iterator = self.stack.pop() + backup_iter_idx = None + + start = self.indexof(instr) + end = self.indexof(instr.jump_to) + for i in range(start, end): + if self._instructions[i].opname == "RETURN_VALUE": + raise FallbackError("Found RETURN_VALUE in for loop body.") + + self._graph.add_global_guarded_variable(iterator) + + try: + if not isinstance(iterator, SequenceIterVariable): + raise BreakGraphError() + + backup_iter_idx = iterator.idx + + self._inline_call_for_loop(iterator, instr) + self._lasti = self.indexof(instr.jump_to) + except BreakGraphError as e: + log(3, f"{e}") + if backup_iter_idx: + iterator.idx = backup_iter_idx + self._graph.remove_global_guarded_variable(iterator) + self._break_graph_in_for_loop(iterator, instr) + return Stop(state="BreakGraph") + + def RETURN_VALUE(self, instr: Instruction): + assert ( + len(self.stack) == 1 + ), f"Stack must have one element, but get {len(self.stack)} elements." + ret_val = self.stack.pop() + self._graph.start_compile(ret_val) + self._graph.pycode_gen.gen_return() + self.new_code = self._graph.pycode_gen.gen_pycode() + self.guard_fn = self._graph.guard_fn + return Stop(state="Return") diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py new file mode 100644 index 0000000000000..c24e94b07ffb2 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py @@ -0,0 +1,330 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import contextlib +import inspect +import re +from typing import TYPE_CHECKING + +from ...profiler import event_register +from ...utils import BreakGraphError, log +from ..instruction_utils import Instruction +from .guard import StringifyExpression, union_free_vars +from .opcode_executor import OpcodeExecutorBase, Stop +from .tracker import ConstTracker, DanglingTracker, DummyTracker, Tracker +from .variables import ( + CellVariable, + FunctionGlobalVariable, + IterVariable, + SequenceIterVariable, + VariableBase, +) + +if TYPE_CHECKING: + from .pycode_generator import PyCodeGen + from .variables import FunctionVariable + + +class FunctionGlobalTracker(Tracker): + """ + A tracker class that represents a function global variable. + + Args: + fn: FunctionVariable object. + name: The name of the global variable. + + """ + + def __init__(self, fn: FunctionVariable, name: str): + super().__init__([fn]) + self.fn = fn + self.name = name + + def gen_instructions(self, codegen: PyCodeGen): + """ + Generate bytecode instructions in order to put the variables at the top of the stack. + + Args: + codegen: The PyCodeGen object used to generate bytecode. + + """ + self.fn.tracker.gen_instructions(codegen) + codegen.gen_load_attr("__globals__") + codegen.gen_load_const(self.name) + codegen.gen_subscribe() + + def trace_value_from_frame(self) -> StringifyExpression: + """ + Trace the value of the function global variable from the frame. + + Returns: + StringifyExpression: The traced value of the function global variable. + + """ + fn_tracer = self.fn.tracker.trace_value_from_frame() + return StringifyExpression( + f"{{}}.__globals__['{self.name}']", + [fn_tracer], + union_free_vars(fn_tracer.free_vars), + ) + + def __repr__(self) -> str: + return f"FunctionGlobalTracker(fn={self.fn}, name={self.name})" + + +class FunctionClosureTracker(Tracker): + """ + A tracker class that represents a function closure variable. + + Args: + fn: The FunctionVariable object. + idx: The index of the closure variable. + + """ + + def __init__(self, fn: FunctionVariable, idx: int): + super().__init__([fn]) + self.fn = fn + self.idx = idx + + def gen_instructions(self, codegen: PyCodeGen): + """ + Generate bytecode instructions to trace the value of the function closure variable. + + Args: + codegen: The PyCodeGen object used to generate bytecode. + + """ + self.fn.tracker.gen_instructions(codegen) + codegen.gen_load_attr("__closure__") + codegen.gen_load_const(self.idx) + codegen.gen_subscribe() + codegen.gen_load_attr("cell_contents") + + def trace_value_from_frame(self): + """ + Trace the value of the function closure variable from the frame. + + Returns: + The traced value of the function closure variable. + + """ + fn_tracer = self.fn.tracker.trace_value_from_frame() + return StringifyExpression( + f"{{}}.__closure__[{self.idx}].cell_contents", + [fn_tracer], + union_free_vars(fn_tracer.free_vars), + ) + + def __repr__(self) -> str: + return f"FunctionClosureTracker(fn={self.fn}, idx={self.idx})" + + +@contextlib.contextmanager +def signature_clear_guard(fn, name): + if not hasattr(fn, name): + yield + else: + saved_attr = getattr(fn, name) + delattr(fn, name) + yield + setattr(fn, name, saved_attr) + + +class OpcodeInlineExecutor(OpcodeExecutorBase): + """ + A class that represents an executor for inlined opcode operations. + + Args: + fn_variable: The function variable. + + """ + + def __init__( + self, + fn_variable: FunctionVariable, + *args, + **kwargs, + ): + self._fn_var = fn_variable + self.return_value: VariableBase | None = None + self._fn_value = fn_variable.value + super().__init__(fn_variable.get_code(), fn_variable.graph) + self._name = "Inline" + self._prepare_locals(*args, **kwargs) + self._prepare_closure() + + def _handle_comps(self): + is_comp = any( + x in self._fn_value.__name__ + for x in ['', '', ''] + ) + if not is_comp: + return + pattern = r'implicit\d+' + for name in list(self._locals.keys()): + if re.match(pattern, name): + self._locals[name.replace('implicit', '.')] = self._locals[name] + + def _prepare_locals(self, *args, **kwargs): + """ + Prepare local variables for execution by adding them to the locals dictionary. + + """ + from .variables import VariableBase, VariableFactory + + # temparay clear the fn.__signature__ to avoid signature check error + with signature_clear_guard( + self._fn_value, "__signature__" + ), signature_clear_guard(self._fn_value, "__wrapped__"): + sig = inspect.signature(self._fn_value) + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + for name, value in bound_args.arguments.items(): + assert name in sig.parameters + # Convert varargs and kwargs to Variable + if sig.parameters[name].kind == inspect.Parameter.VAR_POSITIONAL: + tracker = DummyTracker(value) + elif sig.parameters[name].kind == inspect.Parameter.VAR_KEYWORD: + tracker = DummyTracker(list(value.values())) + # Convert default args to Variable + elif not isinstance(value, VariableBase): + tracker = ConstTracker(value) + else: + tracker = value.tracker + value = VariableFactory.from_value(value, self._graph, tracker) + self._locals[name] = value + + self._handle_comps() + + log( + 5, f"[INLINE CALL] {self._code.co_name} with locals: ", self._locals + ) + + def _prepare_closure(self): + """ + Prepare closure variables for execution by adding them to the closure list. + + """ + from .variables import VariableFactory + + closure = self._fn_var.get_py_value().__closure__ + for name in self._code.co_cellvars + self._code.co_freevars: + # create a cell for each variable. + self._cells[name] = CellVariable() # put in cells. + if name in self._locals: + self._cells[name].set_value(self._locals[name]) + + if closure is None: + return + assert len(closure) == len(self._code.co_freevars) + for idx, (name, cell) in enumerate( + zip(self._code.co_freevars, closure) + ): + value = cell.cell_contents + value = VariableFactory.from_value( + value, self._graph, FunctionClosureTracker(self._fn_var, idx) + ) + # wrapped by a CellVariable + if not isinstance(value, CellVariable): + value = CellVariable(value) + self._cells[name] = value + + @event_register("OpcodeInlineExecutor: _prepare_virtual_env", event_level=2) + def _prepare_virtual_env(self): + """ + Prepare the virtual environment for execution by adding variables from globals, builtins, and constants. + + """ + from .variables import VariableFactory + + self._globals = FunctionGlobalVariable( + self._fn_var, + self._fn_value.__globals__, + self._graph, + DanglingTracker(), + ) + + self._builtins = self._graph._builtins + + # prepare consts + for value in self._code.co_consts: + self._co_consts.append( + VariableFactory.from_value( + value, self._graph, ConstTracker(value) + ) + ) + + def inline_call(self) -> VariableBase: + """ + Execute the inline call of the function. + """ + self.run() + assert self.return_value is not None + return self.return_value + + def RETURN_VALUE(self, instr: Instruction): + assert ( + len(self.stack) == 1 + ), f"Stack must have one element, but get {len(self.stack)} elements." + self.return_value = self.stack.pop() + return Stop(state="Return") + + def _break_graph_in_jump(self, result, instr: Instruction): + """ + Helper method to raise a BreakGraphError when breaking the graph in a jump operation. + + Args: + result: The result of the operation. + instr (Instruction): The jump instruction. + """ + raise BreakGraphError( + "OpcodeInlineExecutor want call _break_graph_in_jump." + ) + + def _create_resume_fn(self, index: int, stack_size: int = 0): + """ + Helper method to create a resume function for the executor. + + Args: + index (int): The index of the instruction to resume execution from. + stack_size (int, optional): The size of the stack. Defaults to 0. + """ + raise BreakGraphError("_create_resume_fn.") + + def FOR_ITER(self, instr: Instruction): + iterator = self.stack.top + assert isinstance(iterator, IterVariable) + + self._graph.add_global_guarded_variable(iterator) + + # simplely get next + if isinstance( + iterator, + SequenceIterVariable, + ): + try: + self.stack.push(iterator.next()) + except StopIteration: + self.stack.pop() + assert isinstance(instr.jump_to, Instruction) + self._lasti = self.indexof(instr.jump_to) + + else: + self._graph.remove_global_guarded_variable(iterator) + raise BreakGraphError( + f"Found {iterator.__class__.__name__} as iterator." + ) diff --git a/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py new file mode 100644 index 0000000000000..d8ddb23d15fc1 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py @@ -0,0 +1,1058 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This class is used for abstract code generation: +# We only need to care about what type of bytecode our code needs to generate, +# without worrying about the subscripts of bytecode instructions in the code option. + +from __future__ import annotations + +import random +import sys +import types +from typing import TYPE_CHECKING + +import opcode + +import paddle + +from ...utils import ( + FallbackError, + InnerError, + OrderedSet, + ResumeFnNameFactory, + is_clean_code, + list_contain_by_id, + list_find_index_by_id, + no_eval_frame, +) +from ..instruction_utils import ( + analysis_inputs, + calc_stack_effect, + gen_instr, + get_instructions, + instrs_info, + modify_instrs, + modify_vars, +) +from ..instruction_utils.opcode_info import ( + PYOPCODE_CACHE_SIZE, + UNCONDITIONAL_JUMP, + JumpDirection, + PopJumpCond, +) +from .instr_flag import CALL_FUNCTION_EX_FLAG + +CODE_NAME_RNG = random.Random(2023) + +if TYPE_CHECKING: + from typing import Any + + from ..instruction_utils import Instruction + + +def get_pycode_attributes() -> list[str]: + """ + Returns a list of attribute names for PyCodeObject. + NOTE(SigureMo): The order should consistent with signature specified in code_doc + 3.8: https://github.com/python/cpython/blob/3.8/Objects/codeobject.c#L416-L421 + 3.10: https://github.com/python/cpython/blob/3.10/Objects/codeobject.c#L523-L543 + 3.11: https://github.com/python/cpython/blob/3.11/Objects/codeobject.c#L1494-L1516 + + Returns: + list[str]: The attribute names for PyCodeObject. + """ + pycode_attributes = [ + "co_argcount", + "co_posonlyargcount", + "co_kwonlyargcount", + "co_nlocals", + "co_stacksize", + "co_flags", + "co_code", + "co_consts", + "co_names", + "co_varnames", + "co_filename", + "co_name", + ] + if sys.version_info >= (3, 11): + pycode_attributes.append("co_qualname") + pycode_attributes.append("co_firstlineno") + if sys.version_info >= (3, 10): + pycode_attributes.append("co_linetable") + else: + pycode_attributes.append("co_lnotab") + if sys.version_info >= (3, 11): + pycode_attributes.append("co_exceptiontable") + pycode_attributes += [ + "co_freevars", + "co_cellvars", + ] + return pycode_attributes + + +PYCODE_ATTRIBUTES = get_pycode_attributes() + + +def gen_code_options(code: types.CodeType) -> dict[str, Any]: + """ + Generates a dictionary of code options for the given code object. + + Args: + code (types.CodeType): The code object. + + Returns: + dict[str, any]: The code options. + """ + code_options = {} + for k in PYCODE_ATTRIBUTES: + val = getattr(code, k) + if isinstance(val, tuple): + val = list(val) + code_options[k] = val + + return code_options + + +def gen_new_opcode( + instrs: list[Instruction], code_options: dict[str, Any], keys: list[str] +) -> types.CodeType: + """ + Generates a new code object with the given instructions, code options, and keys. + + Args: + instrs (list[Instruction]): The instructions for the new code object. + code_options (dict[str, any]): The code options for the new code object. + keys (list[str]): The keys to specify the order of code options. + + Returns: + types.CodeType: The new code object. + """ + bytecode, linetable = assemble(instrs, code_options["co_firstlineno"]) + if sys.version_info >= (3, 10): + # Python deprecated co_lnotab in 3.10, use co_linetable instead + # https://peps.python.org/pep-0626/ + code_options["co_linetable"] = linetable + else: + code_options["co_lnotab"] = linetable + code_options["co_code"] = bytecode + code_options["co_nlocals"] = len(code_options["co_varnames"]) + code_options["co_stacksize"] = stacksize(instrs) + if sys.version_info >= (3, 11): + # TODO: generate 3.11 exception table + code_options["co_exceptiontable"] = bytes([]) + for key, val in code_options.items(): + if isinstance(val, list): + code_options[key] = tuple(val) + # code_options is a dict, use keys to makesure the input order + return types.CodeType(*[code_options[k] for k in keys]) + + +def assemble( + instructions: list[Instruction], firstlineno: int +) -> tuple[bytes, bytes]: + """ + Assembles a list of instructions into bytecode and lnotab. + + Args: + instructions (list[Instruction]): The list of instructions to assemble. + firstlineno (int): The starting line number. + + Returns: + tuple[bytes, bytes]: The assembled bytecode and lnotab. + """ + code = [] + linetable = [] + + calc_linetable, update_cursor = create_linetable_calculator(firstlineno) + + for instr in instructions: + # set linetable, Python 3.11 need to set linetable for each instruction + if instr.starts_line is not None or sys.version_info >= (3, 11): + linetable.extend(calc_linetable(instr.starts_line, len(code))) + update_cursor(instr.starts_line, len(code)) + + # get bytecode + arg = instr.arg or 0 + code.extend((instr.opcode, arg & 0xFF)) + # fill CACHE + for _ in range(get_instruction_size(instr) // 2 - 1): + code.extend((0, 0)) + + if sys.version_info >= (3, 11): + # End hook for Python 3.11 + linetable.extend(calc_linetable(None, len(code))) + elif sys.version_info >= (3, 10): + # End hook for Python 3.10 + linetable.extend(calc_linetable(0, len(code))) + + return bytes(code), bytes(linetable) + + +def to_byte(num): + """ + Converts a negative number to an unsigned byte. + + Args: + num (int): The number to convert. + + Returns: + int: The converted unsigned byte. + """ + if num < 0: + num += 256 + return num + + +def get_instruction_size(instr: Instruction) -> int: + cache_size = 0 + if sys.version_info >= (3, 11): + cache_size = PYOPCODE_CACHE_SIZE.get(instr.opname, 0) + return 2 * (cache_size + 1) + + +def create_linetable_calculator(firstlineno: int): + """ + Creates a line table calculator function. + + Args: + firstlineno (int): The starting line number. + + Returns: + Callable: The line table calculator function. + """ + cur_lineno = firstlineno + cur_bytecode = 0 + line_offset = 0 # For Python 3.10 + + def update_cursor(starts_line: int | None, code_length: int): + nonlocal cur_lineno, cur_bytecode + cur_bytecode = code_length + if starts_line is not None: + cur_lineno = starts_line + + def calc_lnotab(starts_line: int, code_length: int): + """ + Calculates the lnotab for Python 3.8 and 3.9. + https://github.com/python/cpython/blob/3.9/Objects/lnotab_notes.txt + + Args: + starts_line (int): The line number where the instruction starts. + code_length (int): The length of the code. + + Returns: + list[int]: The lnotab. + """ + nonlocal cur_lineno, cur_bytecode + line_offset = starts_line - cur_lineno + byte_offset = code_length - cur_bytecode + result = [] + + while line_offset or byte_offset: + line_offset_step = min(max(line_offset, -128), 127) + byte_offset_step = min(max(byte_offset, 0), 255) + result.extend((byte_offset_step, to_byte(line_offset_step))) + line_offset -= line_offset_step + byte_offset -= byte_offset_step + return result + + def calc_linetable_py310(starts_line: int, code_length: int): + """ + Calculates the linetable for Python 3.10. + https://github.com/python/cpython/blob/3.10/Objects/lnotab_notes.txt + + Args: + starts_line (int): The line number where the instruction starts. + code_length (int): The length of the code. + + Returns: + list[int]: The linetable. + """ + nonlocal cur_lineno, cur_bytecode, line_offset + byte_offset = code_length - cur_bytecode + result = [] + while line_offset or byte_offset: + line_offset_step = min(max(line_offset, -127), 127) + byte_offset_step = min(max(byte_offset, 0), 254) + result.extend((byte_offset_step, to_byte(line_offset_step))) + line_offset -= line_offset_step + byte_offset -= byte_offset_step + line_offset = starts_line - cur_lineno + return result + + def _encode_varint(num: int): + """ + Encode unsigned integer into variable-length format. + """ + continue_flag = 0b01 << 6 + stop_flag = 0b00 << 6 + while num >= 0x40: + yield (num & 0x3F) | continue_flag + num >>= 6 + yield num | stop_flag + + def _encode_svarint(num: int): + """ + Encode signed integer into variable-length format. + """ + unsigned_value = (((-num) << 1) | 1) if num < 0 else (num << 1) + yield from _encode_varint(unsigned_value) + + def _encode_bytecode_to_entries_py311(line_offset: int, byte_offset: int): + if not byte_offset: + return [] + if 0 < byte_offset <= 8: + entry_head = 0b1_1101_000 | (byte_offset - 1) + return [entry_head, *list(_encode_svarint(line_offset))] + return [ + *_encode_bytecode_to_entries_py311(line_offset, 8), + *_encode_bytecode_to_entries_py311(line_offset, byte_offset - 8), + ] + + def calc_linetable_py311(starts_line: int | None, code_length: int): + """ + Calculates the linetable for Python 3.11. + https://github.com/python/cpython/blob/3.11/Objects/locations.md + + Args: + starts_line (int): The line number where the instruction starts. + code_length (int): The length of the code. + + Returns: + list[int]: The linetable. + """ + nonlocal cur_lineno, cur_bytecode + line_offset = starts_line - cur_lineno if starts_line is not None else 0 + byte_offset = (code_length - cur_bytecode) // 2 + return _encode_bytecode_to_entries_py311(line_offset, byte_offset) + + if sys.version_info >= (3, 11): + return calc_linetable_py311, update_cursor + elif sys.version_info >= (3, 10): + return calc_linetable_py310, update_cursor + else: + return calc_lnotab, update_cursor + + +def compile_exception_table(): + """Compile the exception table, it is used for Python 3.11+. + See https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt + """ + # TODO + ... + + +def stacksize(instructions: list[Instruction]) -> float: + """ + Calculates the maximum stack size before each opcode is called. + + Args: + instructions (list[Instruction]): The list of instructions. + + Returns: + int: The maximum stack size. + """ + max_stack = [float("-inf")] * len(instructions) + + max_stack[0] = 0 + + queue = [] + queue.append(0) + + def update_stacksize(lasti: int, nexti: int, stack_effect: int): + """ + Updates the maximum stack size. + + Args: + lasti (int): The index of the last instruction. + nexti (int): The index of the next instruction. + stack_effect (int): The effect on the stack size. + + Returns: + None + """ + old_max = max_stack[nexti] + max_stack[nexti] = max( + max_stack[nexti], max_stack[lasti] + stack_effect + ) + if old_max != max_stack[nexti]: + if nexti not in queue: # may be slow, we can use a flag. + queue.append(nexti) + + while len(queue) > 0: + idx = queue[0] + del queue[0] + instr = instructions[idx] + opname = instr.opname + if ( + idx + 1 < len(instructions) + and instr.opname not in UNCONDITIONAL_JUMP + ): + stack_effect = calc_stack_effect(instr, jump=False) + update_stacksize(idx, idx + 1, stack_effect) + + if instr.opcode in opcode.hasjabs or instr.opcode in opcode.hasjrel: + stack_effect = calc_stack_effect(instr, jump=True) + target_idx = instructions.index(instr.jump_to) + update_stacksize(idx, target_idx, stack_effect) + + # assert min(min_stack) >= 0 # min_stack may be a negative number when try: except is got. + return max(max_stack) + + +class PyCodeGen: + """Helper to create new code object""" + + def __init__( + self, frame: types.FrameType, disable_eval_frame: bool = False + ): + """ + Initializes a PyCodeGen object. + + Args: + frame: The frame to be translated. + disable_eval_frame (bool): Whether to disable the evaluation frame. Defaults to False. + """ + self._frame = frame + self._origin_code = frame.f_code + self._code_options = gen_code_options(self._origin_code) + self.update_code_name("", is_resumed_fn=False) + self._f_globals = frame.f_globals + self._instructions = [] + self.disable_eval_frame = disable_eval_frame + if self.disable_eval_frame: + self.gen_disable_eval_frame() + + def insert_prefix_instructions(self): + """ + Insert prefix instructions to the instruction list. + In Python 3.11+, we need to insert MAKE_CELL and COPY_FREE_VARS before the + first instruction. + The implementation is based on cpython implementation: + https://github.com/python/cpython/blob/f45ef5edabb1cc0748f3326e7114b8aaa0424392/Python/compile.c#L8177 + """ + prefixes = [] + if sys.version_info >= (3, 11): + if self._code_options["co_cellvars"]: + # Insert MAKE_CELL + name_map = list( + OrderedSet(self._code_options["co_varnames"]) + | OrderedSet(self._code_options["co_cellvars"]) + ) + + for i in self._code_options["co_cellvars"]: + idx: int = name_map.index(i) + prefixes.append(gen_instr("MAKE_CELL", arg=idx, argval=i)) + + if self._code_options["co_freevars"]: + n_freevars = len(self._code_options["co_freevars"]) + # Insert COPY_FREE_VARS + prefixes.append( + gen_instr( + "COPY_FREE_VARS", arg=n_freevars, argval=n_freevars + ) + ) + + # Insert RESUME + prefixes.append(gen_instr("RESUME", arg=0, argval=0)) + self._instructions[:] = prefixes + self._instructions + + def update_code_name(self, fn_name, is_resumed_fn): + if is_resumed_fn: + self._code_options[ + 'co_name' + ] = f"${fn_name}@{self._code_options['co_name'][1:]}" + else: + if self._code_options['co_name'].startswith("$"): + self._code_options[ + 'co_name' + ] = f"#{self._code_options['co_name']}" + elif not self._code_options['co_name'].startswith("#"): + random_number = int(CODE_NAME_RNG.random() * 100000000) + self._code_options[ + 'co_name' + ] = f"#{self._code_options['co_name']}_{hex(random_number & 0xFFFFF)[2:]:0>5}" + + def gen_pycode(self) -> types.CodeType: + """ + Generates a new pycode that is runnable. + + Returns: + CodeType: The generated code object. + """ + self.insert_prefix_instructions() + modify_instrs(self._instructions) + modify_vars(self._instructions, self._code_options) + new_code = gen_new_opcode( + self._instructions, self._code_options, PYCODE_ATTRIBUTES + ) + return new_code + + def gen_resume_fn_at( + self, index: int, stack_size: int = 0 + ) -> tuple[None | types.FunctionType, OrderedSet[str]]: + """ + Generates a resume function at the specified index in the instruction list. + + Args: + index (int): The index in the instruction list to generate the resume function. + stack_size (int): The size of the stack. Defaults to 0. + + Returns: + tuple: The resume function object and the inputs to the function. + + """ + self._instructions = get_instructions(self._origin_code) + # TODO(dev): could give an example code here? + if self._instructions[index].opname == 'RETURN_VALUE': + return None, OrderedSet() + inputs = analysis_inputs(self._instructions, index) + fn_name = ResumeFnNameFactory().next() + stack_arg_str = fn_name + '_stack_{}' + self._instructions = ( + [ + gen_instr('LOAD_FAST', argval=stack_arg_str.format(i)) + for i in range(stack_size) + ] + + [gen_instr('JUMP_FORWARD', jump_to=self._instructions[index])] + + self._instructions + ) + + self._code_options['co_argcount'] = len(inputs) + stack_size + # inputs should be at the front of the co_varnames + self._code_options['co_varnames'] = list( + [stack_arg_str.format(i) for i in range(stack_size)] + + list(inputs) + + [ + var_name + for var_name in self._origin_code.co_varnames + if var_name not in inputs + ] + ) + + self.update_code_name(fn_name, is_resumed_fn=True) + + new_code = self.gen_pycode() + if len(new_code.co_freevars) + len(new_code.co_cellvars) > 0: + raise FallbackError("Break graph in closure is not support.") + fn = types.FunctionType(new_code, self._f_globals, new_code.co_name) + + return fn, inputs + + def gen_disable_eval_frame(self): + """ + Generates instructions to disable the evaluation frame. + """ + if is_clean_code(): + return + self.gen_load_object( + paddle.framework.core.set_eval_frame, "paddle_set_eval_frame_fn" + ) + self.gen_load_const(None) + self.gen_call_function(1) + self.gen_store_fast("___old_eval_frame") + + def gen_enable_eval_frame(self): + """ + Generates instructions to enable the evaluation frame. + """ + if is_clean_code(): + return + self.gen_load_object( + paddle.framework.core.set_eval_frame, "paddle_set_eval_frame_fn" + ) + self.gen_load_fast("___old_eval_frame") + self.gen_call_function(1) + self.gen_pop_top() + + def gen_outputs_and_return(self, outputs): + for name in outputs: + self.gen_load(name) + self.gen_build_tuple(len(outputs)) + self.gen_return() + + def create_fn_with_inputs(self, inputs: list) -> types.FunctionType: + """ + Creates a function with specific input and output variables. + + Args: + inputs (list): The input variables. + + Returns: + function: The created function object. + """ + self._code_options['co_argcount'] = len(inputs) + self._code_options['co_varnames'] = list( + list(inputs) + + [ + var_name + for var_name in self._origin_code.co_varnames + if var_name not in inputs + ] + ) + fn_name = ResumeFnNameFactory().next() + self.update_code_name(fn_name, is_resumed_fn=True) + new_code = self.gen_pycode() + if len(new_code.co_freevars) + len(new_code.co_cellvars) > 0: + raise FallbackError("Break graph in closure is not support.") + fn = types.FunctionType(new_code, self._f_globals, new_code.co_name) + return fn + + def gen_load_const(self, value: Any): + """ + Generates instructions to load a constant value. + """ + # Python `list.index` will find an item equal to query, i.e. `query == item` + # returns a value of True. Since `1 == True`, this will result in an incorrect + # index. To avoid this problem, we use id for comparison. + if not list_contain_by_id(self._code_options["co_consts"], value): + self._code_options["co_consts"].append(value) + idx = list_find_index_by_id(self._code_options["co_consts"], value) + self._add_instr("LOAD_CONST", arg=idx, argval=value) + + def gen_print_log(self, message): + """print a log""" + import paddle + + self.gen_load_object( + paddle.framework.core.set_eval_frame, "dbg_set_eval_frame" + ) + self.gen_load_const(None) + self.gen_call_function(1) + self.gen_store_fast("old_eval_frame") + self.gen_load_global("print", push_null=True) + self.gen_load_const(message) + self.gen_call_function(1) + self.gen_pop_top() + self.gen_load_object( + paddle.framework.core.set_eval_frame, "dbg_set_eval_frame" + ) + self.gen_load_fast("old_eval_frame") + self.gen_call_function(1) + self.gen_pop_top() + + def gen_dbg_function(self, dbg_fun): + """debug bytecode helper function. + Usage like: + def dbg_func(): + import inspect + import dis + print("dbg here.") + print(locals()) + frame = inspect.currentframe().f_back + code = (inspect.currentframe().f_back.f_code) + breakpoint() + print(inspect.currentframe().f_back.f_locals['y']) + + self.pycode_gen.gen_dbg_function(dbg_func) + """ + import paddle + + self.gen_load_object( + paddle.framework.core.set_eval_frame, "dbg_set_eval_frame" + ) + self.gen_load_const(None) + self.gen_call_function(1) + self.gen_store_fast("old_eval_frame") + self.gen_load_object(dbg_fun, "dbg1") + self.gen_call_function(0) + self.gen_pop_top() + self.gen_load_object( + paddle.framework.core.set_eval_frame, "dbg_set_eval_frame" + ) + self.gen_load_fast("old_eval_frame") + self.gen_call_function(1) + self.gen_pop_top() + + @property + def cell_free_storage(self): + return ( + self._code_options["co_cellvars"] + + self._code_options["co_freevars"] + ) + + def gen_load(self, name): + if name in self.cell_free_storage: + self.gen_load_deref(name) + elif name in self._code_options["co_varnames"]: + self.gen_load_fast(name) + elif name in self._code_options["co_names"]: + self.gen_load_global(name, push_null=False) + else: + raise InnerError( + f"Want gen_load, but {name} can not found in code object." + ) + + def gen_store(self, name, code): + """ + Generate the bytecode for storing a variable identified by 'name' + in the corresponding symbol table and generate the appropriate + store code based on the symbol table analysis. + + Args: + name (str): The name of the variable. + """ + if name in (code.co_freevars + code.co_cellvars): + self.gen_store_deref(name) + elif name in code.co_varnames: + self.gen_store_fast(name) + elif name in code.co_names: + self.gen_store_global(name) + else: + raise InnerError( + f"Want gen_store, but {name} can not found in code object." + ) + + def gen_load_global(self, name, push_null=False): + """ + Generate the bytecode for loading a global variable. + + Args: + name (str): The name of the global variable. + """ + if name not in self._code_options["co_names"]: + self._code_options["co_names"].append(name) + idx = self._code_options["co_names"].index(name) + if sys.version_info >= (3, 11): + idx <<= 1 + if push_null: + idx |= 1 + self._add_instr("LOAD_GLOBAL", arg=idx, argval=name) + + def gen_load_object(self, obj, obj_name: str, push_null: bool = True): + """ + Generate the bytecode for loading an object. + + Args: + obj (Any): The object to load. + obj_name (str): The name of the object. + """ + + if obj_name not in self._f_globals: + self._f_globals[obj_name] = obj + self.gen_load_global(obj_name, push_null=push_null) + + def gen_load_fast(self, name): + """ + Generate the bytecode for loading a local variable. + + Args: + name (str): The name of the local variable. + """ + if name not in self._code_options["co_varnames"]: + self._code_options["co_varnames"].append(name) + idx = self._code_options["co_varnames"].index(name) + self._add_instr("LOAD_FAST", arg=idx, argval=name) + + def gen_load_deref(self, name): + if name not in self.cell_free_storage: + self._code_options["co_freevars"].append(name) + if sys.version_info >= (3, 11): + # Because the co_varnames maybe changed after other codegen + # operations, we need re-calculate the index in modify_vars + idx = ( + self._code_options["co_varnames"] + + self._code_options["co_freevars"] + ).index(name) + else: + idx = self.cell_free_storage.index(name) + self._add_instr("LOAD_DEREF", arg=idx, argval=name) + + def gen_load_attr(self, name: str): + if name not in self._code_options["co_names"]: + self._code_options["co_names"].append(name) + idx = self._code_options["co_names"].index(name) + self._add_instr("LOAD_ATTR", arg=idx, argval=name) + + def gen_store_attr(self, name: str): + if name not in self._code_options["co_names"]: + self._code_options["co_names"].append(name) + idx = self._code_options["co_names"].index(name) + self._add_instr("STORE_ATTR", arg=idx, argval=name) + + def gen_delete_attr(self, name: str): + if name not in self._code_options["co_names"]: + self._code_options["co_names"].append(name) + idx = self._code_options["co_names"].index(name) + self._add_instr("DELETE_ATTR", arg=idx, argval=name) + + def gen_load_method(self, name: str): + if name not in self._code_options["co_names"]: + self._code_options["co_names"].append(name) + idx = self._code_options["co_names"].index(name) + self._add_instr("LOAD_METHOD", arg=idx, argval=name) + + def gen_delete_global(self, name: str): + if name not in self._code_options["co_names"]: + self._code_options["co_names"].append(name) + idx = self._code_options["co_names"].index(name) + self._add_instr("DELETE_GLOBAL", arg=idx, argval=name) + + def gen_import_name(self, name: str): + if name not in self._code_options["co_names"]: + self._code_options["co_names"].append(name) + idx = self._code_options["co_names"].index(name) + self._add_instr("IMPORT_NAME", arg=idx, argval=name) + + def gen_push_null(self): + if sys.version_info >= (3, 11): + self._add_instr("PUSH_NULL") + else: + # There is no PUSH_NULL bytecode before python3.11, so we push + # a NULL element to the stack through the following bytecode + self.gen_load_const(0) + self.gen_load_const(None) + self.gen_import_name('sys') + self.gen_store_fast('sys') + self.gen_load_fast('sys') + self.gen_load_method('getsizeof') + self.gen_pop_top() + + def gen_store_fast(self, name): + if name not in self._code_options["co_varnames"]: + self._code_options["co_varnames"].append(name) + idx = self._code_options["co_varnames"].index(name) + self._add_instr("STORE_FAST", arg=idx, argval=name) + + def gen_store_global(self, name): + if name not in self._code_options["co_names"]: + self._code_options["co_names"].append(name) + idx = self._code_options["co_names"].index(name) + self._add_instr("STORE_GLOBAL", arg=idx, argval=name) + + def gen_store_deref(self, name): + if name not in self.cell_free_storage: + self._code_options["co_freevars"].append(name) + if sys.version_info >= (3, 11): + # Because the co_varnames maybe changed after other codegen + # operations, we need re-calculate the index in modify_vars + idx = ( + self._code_options["co_varnames"] + + self._code_options["co_freevars"] + ).index(name) + else: + idx = self.cell_free_storage.index(name) + self._add_instr("STORE_DEREF", arg=idx, argval=name) + + def gen_store_subscr(self): + self._add_instr("STORE_SUBSCR") + + def gen_subscribe(self): + self._add_instr("BINARY_SUBSCR") + + def gen_build_tuple(self, count): + self._add_instr("BUILD_TUPLE", arg=count, argval=count) + + def gen_build_list(self, count): + self._add_instr("BUILD_LIST", arg=count, argval=count) + + def gen_build_map(self, count): + self._add_instr("BUILD_MAP", arg=count, argval=count) + + def gen_build_slice(self, argc): + self._add_instr("BUILD_SLICE", arg=argc, argval=argc) + + def gen_unpack_sequence(self, count): + self._add_instr("UNPACK_SEQUENCE", arg=count, argval=count) + + def gen_call_function(self, argc=0): + if sys.version_info >= (3, 11): + self._add_instr("PRECALL", arg=argc, argval=argc) + self._add_instr("CALL", arg=argc, argval=argc) + else: + self._add_instr("CALL_FUNCTION", arg=argc, argval=argc) + + def gen_call_function_ex(self, has_kwargs): + flag = 0 + if has_kwargs: + flag |= CALL_FUNCTION_EX_FLAG.CFE_HAS_KWARGS + self._add_instr("CALL_FUNCTION_EX", arg=flag, argval=flag) + + def gen_call_method(self, argc=0): + if sys.version_info >= (3, 11): + self._add_instr("PRECALL", arg=argc, argval=argc) + self._add_instr("CALL", arg=argc, argval=argc) + else: + self._add_instr("CALL_METHOD", arg=argc, argval=argc) + + def gen_kw_names(self, kw_names: tuple[str, ...] | None): + if kw_names is None: + return + if sys.version_info < (3, 11): + raise InnerError("gen_kw_names is not supported before python3.11") + if kw_names not in self._code_options["co_consts"]: + self._code_options["co_consts"].append(kw_names) + idx = self._code_options["co_consts"].index(kw_names) + self._add_instr("KW_NAMES", arg=idx, argval=kw_names) + + def gen_pop_top(self): + self._add_instr("POP_TOP") + + def gen_rot_n(self, n): + if n <= 1: + return + if sys.version_info >= (3, 11): + for i in range(n, 1, -1): + self._add_instr("SWAP", arg=i) + elif sys.version_info >= (3, 10): + self._add_instr("ROT_N", arg=n) + else: + if n <= 4: + self._add_instr("ROT_" + ["TWO", "THREE", "FOUR"][n - 2]) + else: + + def rot_n_fn(n): + vars = [f"var{i}" for i in range(n)] + rotated = reversed(vars[-1:] + vars[:-1]) + fn = eval(f"lambda {','.join(vars)}: ({','.join(rotated)})") + fn = no_eval_frame(fn) + fn.__name__ = f"rot_{n}_fn" + return fn + + self.gen_build_tuple(n) + self.gen_load_const(rot_n_fn(n)) + self.gen_rot_n(2) + self._add_instr("CALL_FUNCTION_EX", arg=0) + self.gen_unpack_sequence(n) + + def gen_shift_n(self, s: int, n: int): + """ + Generate the bytecode for shifting the stack. + + Args: + s (int): Steps to shift. + n (int): The number of elements to shift. + """ + if s == 0 or n <= 1: + return + + # NOTE(zrr1999): right shift s steps is equal to left shift n-s steps + if abs(s) > n // 2: + new_s = s - n if s > 0 else s + n + self.gen_shift_n(new_s, n) + return + if s > 0: + # NOTE: s=1, n=3 [1,2,3,4,5] -> [1,2,5,3,4] + # s=2, n=3 [1,2,3,4,5] -> [1,2,4,5,3] + if s == 1: + self.gen_rot_n(n) + else: + self.gen_rot_n(n) + self.gen_shift_n(s - 1, n) + + else: # s < 0 + if sys.version_info >= (3, 11): + # NOTE: s=-1, n=3 [1,2,3,4,5] -> [1,2,4,5,3] + if s == -1: + for i in range(2, n + 1): + self._add_instr("SWAP", arg=i) + else: + self.gen_shift_n(-1, n) + self.gen_shift_n(s + 1, n) + else: + raise NotImplementedError( + "shift_n is not supported before python3.11" + ) + + def gen_swap(self, n): + if sys.version_info >= (3, 11): + self._add_instr("SWAP", arg=n) + else: + raise NotImplementedError("swap is not supported before python3.11") + + def gen_jump( + self, + jump_to: Instruction | None = None, + *, + direction: JumpDirection = JumpDirection.FORWARD, + ) -> Instruction: + if sys.version_info >= (3, 11): + return self._add_instr(f"JUMP_{direction.value}", jump_to=jump_to) + else: + return self._add_instr("JUMP_ABSOLUTE", jump_to=jump_to) + + def gen_pop_jump( + self, + jump_to: Instruction | None = None, + *, + direction: JumpDirection = JumpDirection.FORWARD, + suffix: PopJumpCond = PopJumpCond.NONE, + ) -> Instruction: + if sys.version_info >= (3, 11): + return self._add_instr( + f"POP_JUMP_{direction.value}_IF_{suffix.value}", jump_to=jump_to + ) + else: + return self._add_instr( + f"POP_JUMP_IF_{suffix.value}", jump_to=jump_to + ) + + def gen_return(self): + self._add_instr("RETURN_VALUE") + + def gen_get_iter(self): + self._add_instr("GET_ITER") + + def add_pure_instructions(self, instructions): + """ + add instructions and do nothing. + """ + self._instructions.extend(instructions) + + def _add_instr(self, *args, **kwargs): + instr = gen_instr(*args, **kwargs) + self._instructions.append(instr) + return instr + + def _insert_instr(self, index, *args, **kwargs): + instr = gen_instr(*args, **kwargs) + self._instructions.insert(index, instr) + + def pprint(self): + print('\n'.join(instrs_info(self._instructions))) + + def extend_instrs(self, instrs): + self._instructions.extend(instrs) + + def pop_instr(self): + self._instructions.pop() + + def replace_null_variable(self): + """ + Replace all NullVariables in the bytecode. + + Returns: + Optional[Tuple[Any, Callable]]: The new code object and its guard function, or None if no dummy variables are found. + """ + from .variables.basic import NullVariable + + instructions = get_instructions(self._origin_code) + has_null_variable = False + for instr in instructions: + if ( + instr.opname == 'LOAD_FAST' + and instr.argval in self._frame.f_locals.keys() + and isinstance(self._frame.f_locals[instr.argval], NullVariable) + ): + has_null_variable = True + self._frame.f_locals[instr.argval].reconstruct(self) + else: + self.add_pure_instructions([instr]) + + if has_null_variable: + new_code = self.gen_pycode() + return new_code + else: + return None diff --git a/python/paddle/jit/sot/opcode_translator/executor/side_effects.py b/python/paddle/jit/sot/opcode_translator/executor/side_effects.py new file mode 100644 index 0000000000000..f9f8fc20141a1 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/side_effects.py @@ -0,0 +1,234 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar + +from .mutable_data import MutableData +from .variables import VariableBase + +if TYPE_CHECKING: + from .mutable_data import DataGetter + from .pycode_generator import PyCodeGen + + MutableDataT = TypeVar("MutableDataT", bound=MutableData) + + +class SideEffectsState(NamedTuple): + data_id_to_proxy: dict[int, MutableData] + proxy_variables: list[VariableBase] + mutable_variables: list[VariableBase] + proxy_versions: list[int] + mutable_attrs: list[dict[str, Any]] + + +class SideEffects: + def __init__(self): + self.data_id_to_proxy: dict[int, MutableData] = {} + self.proxy_variables: list[VariableBase] = [] + self.mutable_variables: list[VariableBase] = [] + + def record_proxy_variable(self, variable: VariableBase): + if variable not in self.proxy_variables: + self.proxy_variables.append(variable) + + def record_mutable_variable(self, variable: VariableBase): + if variable not in self.mutable_variables: + self.mutable_variables.append(variable) + + def get_proxy( + self, + proxy_type: type[MutableDataT], + data: Any, + getter: DataGetter, + ) -> MutableDataT: + data_id = id(data) + if data_id not in self.data_id_to_proxy: + self.data_id_to_proxy[data_id] = proxy_type(data, getter) + return self.data_id_to_proxy[data_id] # type: ignore + + def get_state(self): + return SideEffectsState( + self.data_id_to_proxy.copy(), + self.proxy_variables.copy(), + self.mutable_variables.copy(), + [proxy.version for proxy in self.data_id_to_proxy.values()], + [ + {attr: getattr(var, attr)} + for var in self.mutable_variables + for attr in var.mutable_attrs + ], + ) + + def restore_state(self, state: SideEffectsState): + self.data_id_to_proxy = state.data_id_to_proxy + self.proxy_variables = state.proxy_variables + self.mutable_variables = state.mutable_variables + # NOTE(SigureMo): We can use the `strict=True` option in zip after + # Python 3.10. + assert len(self.data_id_to_proxy.values()) == len( + state.proxy_versions + ), "proxy_versions length not match" + assert len(self.mutable_variables) == len( + state.mutable_attrs + ), "mutable_attrs length not match" + + for proxy, version in zip( + self.data_id_to_proxy.values(), state.proxy_versions + ): + proxy.rollback(version) + + for (variable, attr), attr_dict in zip( + ( + (var, attr) + for var in self.mutable_variables + for attr in var.mutable_attrs + ), + (attr_dict for attr_dict in state.mutable_attrs), + ): + setattr(variable, attr, attr_dict[attr]) + + +class SideEffectRestorer: + def pre_gen(self, codegen: PyCodeGen): + raise NotImplementedError() + + def post_gen(self, codegen: PyCodeGen): + raise NotImplementedError() + + +class DictSideEffectRestorer(SideEffectRestorer): + """ + old_dict.clear() + old_dict.update(new_dict) + """ + + def __init__(self, var: VariableBase): + super().__init__() + self.var = var + + def pre_gen(self, codegen: PyCodeGen): + # Reference to the original dict. + # load old_dict.update and new_dict to stack. + self.var.reconstruct(codegen) + codegen.gen_load_method("update") + # Generate dict by each key-value pair. + self.var.reconstruct(codegen, use_tracker=False) + # load old_dict.clear to stack. + self.var.reconstruct(codegen) + codegen.gen_load_method("clear") + + def post_gen(self, codegen: PyCodeGen): + # Call methods to apply side effects. + codegen.gen_call_method(0) # call clear + codegen.gen_pop_top() + codegen.gen_call_method(1) # call update + codegen.gen_pop_top() + + +class ListSideEffectRestorer(SideEffectRestorer): + """ + old_list[:] = new_list + """ + + def __init__(self, var: VariableBase): + super().__init__() + self.var = var + + def pre_gen(self, codegen: PyCodeGen): + # Reference to the original list. + # load new_list to stack. + self.var.reconstruct(codegen, use_tracker=False) + # load old_list[:] to stack. + self.var.reconstruct(codegen) + codegen.gen_load_const(None) + codegen.gen_load_const(None) + codegen.gen_build_slice(2) + + def post_gen(self, codegen: PyCodeGen): + # Call STROE_SUBSCR to apply side effects. + codegen.gen_store_subscr() + + +class GlobalSetSideEffectRestorer(SideEffectRestorer): + """ + global_var = new_value + """ + + def __init__(self, name: str, var: VariableBase): + super().__init__() + self.name = name + self.var = var + + def pre_gen(self, codegen: PyCodeGen): + self.var.reconstruct(codegen) + + def post_gen(self, codegen: PyCodeGen): + codegen.gen_store_global(self.name) + + +class GlobalDelSideEffectRestorer(SideEffectRestorer): + """ + del global_var + """ + + def __init__(self, name: str): + super().__init__() + self.name = name + + def pre_gen(self, codegen: PyCodeGen): + # do nothing + ... + + def post_gen(self, codegen: PyCodeGen): + codegen.gen_delete_global(self.name) + + +class ObjSetSideEffectRestorer(SideEffectRestorer): + """ + obj.attr = new_value + """ + + def __init__(self, obj: VariableBase, name: str, var: VariableBase): + super().__init__() + self.obj = obj + self.name = name + self.var = var + + def pre_gen(self, codegen: PyCodeGen): + # value + self.var.reconstruct(codegen) + # obj + self.obj.reconstruct(codegen) + + def post_gen(self, codegen: PyCodeGen): + codegen.gen_store_attr(self.name) + + +class ObjDelSideEffectRestorer(SideEffectRestorer): + """ + del obj.attr + """ + + def __init__(self, obj: VariableBase, name: str): + super().__init__() + self.obj = obj + self.name = name + + def pre_gen(self, codegen: PyCodeGen): + self.obj.reconstruct(codegen) + + def post_gen(self, codegen: PyCodeGen): + codegen.gen_delete_attr(self.name) diff --git a/python/paddle/jit/sot/opcode_translator/executor/tracker.py b/python/paddle/jit/sot/opcode_translator/executor/tracker.py new file mode 100644 index 0000000000000..c085e14b5b382 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/tracker.py @@ -0,0 +1,387 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import builtins +import sys +from typing import TYPE_CHECKING + +from ...utils import InnerError, NameGenerator +from .guard import StringifyExpression, union_free_vars + +if TYPE_CHECKING: + from typing import Sequence + + from .pycode_generator import PyCodeGen + from .variables import VariableBase + + +class Tracker: + """ + Tracker is a base class responsible for tracking variables or objects in Python code. + It is used to identify how a variable is derived from the initial state of the frame. + + Args: + inputs: The list of variables to be tracked. + + Note: + It serves as an abstract class and should not be instantiated directly. + """ + + inputs: Sequence[VariableBase] + name_generator = NameGenerator("tracker_") + + def __init__(self, inputs: Sequence[VariableBase], changed: bool = False): + self.inputs = inputs + self.changed = changed + self.id = Tracker.name_generator.next() + + def gen_instructions(self, codegen: PyCodeGen) -> None: + """ + Generate instructions based on the tracked variables. + + Args: + codegen (PyCodeGen): An instance of PyCodeGen to generate instructions. + """ + raise NotImplementedError() + + # TODO(xiongkun): trace_value_from_frame is not a good name, it should be more related to guard but not tracable. + def trace_value_from_frame(self) -> StringifyExpression: + """ + Trace the value of the tracked variables from the frame. It used for generating the guard. + + Returns: + The value of the tracked variables. + """ + raise NotImplementedError() + + def is_traceable(self) -> bool: + """ + Determine if all the tracked variables can be traced from the frame. + + Returns: + bool, True if all tracked variables are traceable, False otherwise. + """ + if self.changed: + return False + for input in self.inputs: + if not input.tracker.is_traceable(): + return False + return True + + def need_guard(self) -> bool: + return self.is_traceable() + + +class DummyTracker(Tracker): + """ + DummyTracker is a subclass of Tracker that specifically tracks variables cannot be reproduced from the frame. + It is mostly generated by complex operations (instructions). + + Args: + inputs (list[VariableBase]): The input variables associated with the generated variables. + """ + + def __init__(self, inputs: Sequence[VariableBase]): + super().__init__(inputs) + + def gen_instructions(self, codegen: PyCodeGen): + raise InnerError("DummyTracker has no instructions") + + def trace_value_from_frame(self): + raise InnerError("DummyTracker can't trace value from frame") + + def is_traceable(self): + return False + + def __repr__(self) -> str: + return f"DummyTracker(num_inputs={len(self.inputs)})" + + def need_guard(self) -> bool: + return False + + +class DanglingTracker(Tracker): + """ + DanglingTracker is a subclass of Tracker that specifically tracks variables that are not in the frame. + Variables whose tracker is DanglingTracker should not be placed on the stack, except for NullVariable. + DanglingTracker is often used in conjunction with BuiltinVariable to reuse the dispatch mechanism. + + Examples: + >>> import operator + >>> from sot.opcode_translator.executor.variables import BuiltinVariable, ConstantVariable + >>> a = ConstantVariable.wrap_literal(1, None) + >>> b = ConstantVariable.wrap_literal(2, None) + >>> c = BuiltinVariable(operator.add, None, DanglingTracker())(a, b) + >>> c.value + 3 + """ + + def __init__(self): + super().__init__([]) + + def gen_instructions(self, codegen: PyCodeGen): + raise InnerError("DanglingTracker has no instructions") + + def trace_value_from_frame(self): + raise InnerError("DanglingTracker can't trace value from frame") + + def is_traceable(self): + return False + + def __repr__(self) -> str: + return "DanglingTracker()" + + +class LocalTracker(Tracker): + """ + LocalTracker is a subclass of Tracker that specifically tracks variables from f_locals of frame. + + Args: + name (str): The name of the variable in f_locals to be tracked. + """ + + def __init__(self, name: str): + super().__init__([]) + self.name = name + + def gen_instructions(self, codegen: PyCodeGen) -> None: + codegen.gen_load_fast(self.name) + + def trace_value_from_frame(self) -> StringifyExpression: + return StringifyExpression(f"frame.f_locals['{self.name}']", [], {}) + + def __repr__(self) -> str: + return f"LocalTracker(name={self.name})" + + +class CellTracker(LocalTracker): + def gen_instructions(self, codegen: PyCodeGen): + codegen.gen_load_deref(self.name) + + def trace_value_from_frame(self): + return StringifyExpression(f"frame.f_locals['{self.name}']", [], {}) + + def __repr__(self) -> str: + return f"CellTracker(name={self.name})" + + +class GlobalTracker(Tracker): + """ + GlobalTracker is a subclass of Tracker that specifically tracks variables from f_globals of frame. + + Args: + name (str): The name of the variable in f_globals to be tracked. + """ + + def __init__(self, name: str): + super().__init__([]) + self.name = name + + def gen_instructions(self, codegen: PyCodeGen) -> None: + codegen.gen_load_global(self.name, push_null=False) + + def trace_value_from_frame(self) -> StringifyExpression: + return StringifyExpression(f"frame.f_globals['{self.name}']", [], {}) + + def __repr__(self) -> str: + return f"GlobalTracker(name={self.name})" + + +class BuiltinTracker(Tracker): + """ + BuiltinTracker is a subclass of Tracker that specifically tracks variables from f_builtins of frame. + + Args: + name (str): The name of the variable in f_builtins to be tracked. + """ + + def __init__(self, name: str): + super().__init__([]) + self.name = name + + def gen_instructions(self, codegen: PyCodeGen) -> None: + codegen.gen_load_global(self.name, push_null=False) + + def trace_value_from_frame(self) -> StringifyExpression: + return StringifyExpression( + f"builtins.__dict__['{self.name}']", [], {"builtins": builtins} + ) + + def __repr__(self) -> str: + return f"BuiltinTracker(name={self.name})" + + +class ConstTracker(Tracker): + """ + ConstTracker is a subclass of Tracker that specifically tracks a constant value. + + Args: + value (Any): The value of the constant. + """ + + def __init__(self, value): + super().__init__([]) + self.value = value + + def gen_instructions(self, codegen: PyCodeGen): + codegen.gen_load_const(self.value) + + def trace_value_from_frame(self): + return StringifyExpression(f"{self.value!r}", [], {}) + + def __repr__(self) -> str: + return f"ConstTracker(value={self.value})" + + def need_guard(self) -> bool: + return False + + +class GetAttrTracker(Tracker): + """ + GetAttrTracker is a subclass of Tracker that specifically tracks the attribute access of an variable. + + Args: + obj (VariableBase): The object whose attribute is to be tracked. + attr (str): The attribute to be tracked. + """ + + def __init__(self, obj: VariableBase, attr: str, changed: bool = False): + super().__init__([obj], changed) + self.obj = obj + self.attr = attr + + def gen_instructions(self, codegen: PyCodeGen): + self.obj.tracker.gen_instructions(codegen) + codegen.gen_load_attr(self.attr) + + def trace_value_from_frame(self): + obj_tracer = self.obj.tracker.trace_value_from_frame() + if self.attr.isidentifier(): + expr = f"{{}}.{self.attr}" + else: + expr = f"getattr({{}}, '{self.attr}')" + return StringifyExpression( + expr, + [obj_tracer], + union_free_vars(obj_tracer.free_vars), + ) + + def __repr__(self) -> str: + return f"GetAttrTracker(attr={self.attr})" + + def need_guard(self) -> bool: + return self.is_traceable() and self.obj.tracker.need_guard() + + +class GetItemTracker(Tracker): + """ + GetItemTracker is a subclass of Tracker that specifically tracks item access of a container variable. + + It generates instructions and traces the item value from the frame. + + Args: + container_var (VariableBase): The container object whose item is to be tracked. + key: The key/index of the item to be tracked. + """ + + def __init__(self, container_var: VariableBase, key: object, changed=False): + super().__init__([container_var], changed) + self.container = container_var + self.key = key + + def gen_instructions(self, codegen: PyCodeGen): + self.container.tracker.gen_instructions(codegen) + if isinstance(self.key, slice): + codegen.gen_load_const(self.key.start) + codegen.gen_load_const(self.key.stop) + codegen.gen_load_const(self.key.step) + codegen.gen_build_slice(3) + else: + codegen.gen_load_const(self.key) + codegen.gen_subscribe() + + def trace_value_from_frame(self): + container_tracer = self.container.tracker.trace_value_from_frame() + return StringifyExpression( + f"{{}}[{self.key!r}]", + [container_tracer], + union_free_vars(container_tracer.free_vars), + ) + + def __repr__(self) -> str: + return f"GetItemTracker(key={self.key!r})" + + def need_guard(self) -> bool: + return self.is_traceable() and self.container.tracker.need_guard() + + +class GetIterTracker(Tracker): + """ + GetIterTracker is a subclass of Tracker that specifically tracks iteration of a variable. + + It generates instructions and traces the iterator from the frame. + + Args: + iter_source (VariableBase): The source variable to be iterated. + """ + + def __init__(self, iter_source: VariableBase): + super().__init__([iter_source]) + self.iter_source = iter_source + + def gen_instructions(self, codegen: PyCodeGen): + self.iter_source.tracker.gen_instructions(codegen) + codegen._add_instr("GET_ITER") + + def trace_value_from_frame(self): + iter_source_tracer = self.iter_source.tracker.trace_value_from_frame() + return StringifyExpression( + "iter({})", + [iter_source_tracer], + union_free_vars(iter_source_tracer.free_vars), + ) + + def __repr__(self) -> str: + return "GetIterTracker()" + + +class CreateLayerTracker(Tracker): + def __init__(self, layer_class, args, kwargs): + super().__init__([layer_class] + list(args) + list(kwargs.values())) + self.layer_class = layer_class + self.args = args + self.kwargs = kwargs + + def gen_instructions(self, codegen: PyCodeGen): + if sys.version_info >= (3, 11): + codegen.gen_push_null() + + self.layer_class.reconstruct(codegen) + for variable in self.args: + variable.reconstruct(codegen) + + if len(self.kwargs) == 0: + codegen.gen_call_function(argc=len(self.args)) + else: + codegen.gen_build_tuple(len(self.args)) + for k, v in self.kwargs.items(): + codegen.gen_load_const(k) + v.reconstruct(codegen) + codegen.gen_build_map(len(self.kwargs)) + codegen.gen_call_function_ex(has_kwargs=True) + + def __repr__(self) -> str: + return f"CreateLayerTracker(Layer={self.layer_class}, args={self.args}, kwargs={self.kwargs})" diff --git a/python/paddle/jit/sot/opcode_translator/executor/tracker_viewer.py b/python/paddle/jit/sot/opcode_translator/executor/tracker_viewer.py new file mode 100644 index 0000000000000..f132c34abcac1 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/tracker_viewer.py @@ -0,0 +1,115 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import queue +from typing import TYPE_CHECKING + +from .tracker import DummyTracker +from .variables import VariableBase + +SIR_GRAPH_CLUSTER_NAME = "cluster_sir_part" + +if TYPE_CHECKING: + import graphviz + + +def try_import_graphviz(): + try: + import graphviz + + return graphviz + except ImportError: + return None + + +def draw_variable(graph: graphviz.Digraph, var: VariableBase): + """ + Draw and colour a node in the graph. + + Args: + graph (graphviz.Digraph): The graph to draw the variable. + var (VariableBase): The variable to draw. + + Returns: + None + """ + # Draw Variable + graph.attr('node', shape='oval', style="filled", fillcolor='aliceblue') + graph.attr('edge', style='solid') + graph.node(var.id, str(var)) + + # Draw Tracker + tracker = var.tracker + graph.attr('node', shape='rect', style='filled', fillcolor='beige') + if isinstance(tracker, DummyTracker): + graph.attr('edge', style='dashed') + graph.attr('node', shape='rect', style='filled', fillcolor='goldenrod') + graph.node(tracker.id, str(tracker)) + + # Draw edge (Tracker -> Variable) + graph.edge(tracker.id, var.id) + + # Draw edge (Tracker inputs -> Tracker) + graph.attr('node', shape='oval', style="filled", fillcolor='cadetblue') + graph.attr('edge', style='solid') + for input in tracker.inputs: + graph.edge(input.id, tracker.id) + + +def view_tracker( + root_variables: list[VariableBase], filename: str, format: str +): + """ + Generates a graph visualization starting from the given root variables and save it to the given file. + + Args: + root_variables (list[VariableBase]): The root variables to start the visualization from. + filename (str): The name of the file used to save the results of the visualisation. + format (str): The format (e.g., `pdf`, `png` and 'svg' etc.) of the file to save the visualization to. + + Returns: + None + """ + # TODO(SigureMo): + # 1. Colorize the trackers + # 2. Highlight the user specific node, to speedup debug process + graphviz = try_import_graphviz() + if graphviz is None: + print("Cannot import graphviz, please install it first.") + return + + graph = graphviz.Digraph("graph", filename=filename, format=format) + visited = set() + var_queue = queue.Queue() + for var in root_variables: + var_queue.put(var) + + while not var_queue.empty(): + var = var_queue.get() + if var.id in visited: + continue + visited.add(var.id) + if isinstance(var.tracker, DummyTracker): + with graph.subgraph(name=SIR_GRAPH_CLUSTER_NAME) as sir_part: + sir_part.attr(color='green') + draw_variable(sir_part, var) + else: + draw_variable(graph, var) + for input in var.tracker.inputs: + if input not in var_queue.queue: + var_queue.put(input) + + graph.render(view=False) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py b/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py new file mode 100644 index 0000000000000..9eb10fb81bcd5 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py @@ -0,0 +1,1109 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +import operator +from functools import partial, reduce +from typing import TYPE_CHECKING + +import paddle + +from ...utils import BreakGraphError, FallbackError +from ...utils.magic_methods import ( + BINARY_OPS, + UNARY_OPS, + magic_method_builtin_dispatch, +) +from .dispatch_functions import ( + operator_in, + operator_is_none, + operator_is_not_none, + operator_not_in, + raise_break_graph_fn, + tensor_numel, +) +from .dispatcher import Dispatcher, optional +from .tracker import ConstTracker, DanglingTracker, DummyTracker +from .variables import ( + BuiltinVariable, + ConstantVariable, + ContainerVariable, + DictVariable, + EnumerateVariable, + ListVariable, + MapVariable, + NumpyVariable, + RangeVariable, + SliceVariable, + TupleVariable, + VariableBase, + VariableFactory, +) + +if TYPE_CHECKING: + from .variables import DataVariable, TensorVariable + + +def add_guard(var: VariableBase): + var.graph.add_global_guarded_variable(var) + return var + + +def raise_err_handle(error): + def inner(*args, **kwargs): + raise error + + return inner + + +# slice +Dispatcher.register( + slice, + ("VariableBase",), + lambda stop: SliceVariable( + slice(stop), + graph=stop.graph, + tracker=DummyTracker([stop]), + ), +) + +Dispatcher.register( + slice, + ("VariableBase", "VariableBase"), + lambda start, stop: SliceVariable( + slice(start, stop), + graph=stop.graph, + tracker=DummyTracker([start, stop]), + ), +) + +Dispatcher.register( + slice, + ("VariableBase", "VariableBase", "VariableBase"), + lambda start, stop, step: SliceVariable( + slice(start, stop, step), + graph=stop.graph, + tracker=DummyTracker([start, stop, step]), + ), +) + + +# iter +Dispatcher.register( + iter, + ("VariableBase",), + lambda variable: variable.get_iter(), +) + + +# in +Dispatcher.register( + operator_in, + ("VariableBase", "IterVariable"), + raise_err_handle(BreakGraphError("Codes like: `variable in iterator`.")), +) + +Dispatcher.register( + operator_in, + ("TensorVariable", "VariableBase"), + lambda left, right: ConstantVariable( + left.id + in [ + x.id + for x in right.get_py_value(allow_tensor=True) + if hasattr(x, "id") + ], + left.graph, + tracker=DummyTracker([left, right]), + ), +) + +Dispatcher.register( + operator_in, + ("VariableBase", "VariableBase"), + lambda left, right: ConstantVariable( + left.get_py_value(allow_tensor=True) + in right.get_py_value(allow_tensor=True), + left.graph, + tracker=DummyTracker([left, right]), + ), +) + +Dispatcher.register( + operator_not_in, + ("VariableBase", "IterVariable"), + raise_err_handle( + BreakGraphError("Codes like: `variable not in iterator`.") + ), +) + +Dispatcher.register( + operator_not_in, + ("TensorVariable", "VariableBase"), + lambda left, right: ConstantVariable( + left.id + not in [ + x.id + for x in right.get_py_value(allow_tensor=True) + if hasattr(x, "id") + ], + left.graph, + tracker=DummyTracker([left, right]), + ), +) + +Dispatcher.register( + operator_not_in, + ("VariableBase", "VariableBase"), + lambda left, right: ConstantVariable( + left.get_py_value(allow_tensor=True) + not in right.get_py_value(allow_tensor=True), + left.graph, + tracker=DummyTracker([left, right]), + ), +) + + +# dict +Dispatcher.register( + dict, + (), + lambda: DictVariable( + {}, + graph=Dispatcher.graph, + tracker=DummyTracker([]), + ), +) + +Dispatcher.register( + dict, + ("DictVariable",), + lambda var: var.copy(), +) + + +@Dispatcher.register_decorator(dict) +def dispatch_dict(var: ListVariable | TupleVariable): + res_dict = {} + length_var = BuiltinVariable(len, var.graph, DanglingTracker())(var) + getitem = BuiltinVariable(operator.getitem, var.graph, DanglingTracker()) + for index in range(length_var.get_py_value()): + index_value = getitem(var, index) + # check + assert isinstance(index_value, (ListVariable, TupleVariable)) + assert len(index_value) == 2 + # recombination + key = getitem(index_value, 0) + value = getitem(index_value, 1) + value.graph.add_global_guarded_variable(key) + res_dict.update({key.get_py_value(): value}) + return DictVariable(res_dict, var.graph, DummyTracker([var])) + + +@Dispatcher.register_decorator(dict.fromkeys) +def dispatch_dict_fromkeys(seq: ListVariable | TupleVariable, default: VariableBase = None): # type: ignore + if default is None: + default = ConstantVariable.wrap_literal(None, seq.graph) + res_dict = {} + getitem = BuiltinVariable(operator.getitem, seq.graph, DanglingTracker()) + for index in range(len(seq)): + index_value = getitem(seq, index) + seq.graph.add_global_guarded_variable(index_value) + res_dict.update({index_value.get_py_value(): default}) + return DictVariable(res_dict, seq.graph, DummyTracker([seq])) + + +Dispatcher.register( + dict.get, + ("DictVariable", "ConstantVariable", optional("VariableBase")), + lambda var, key, default=None: var.get(key.get_py_value(), default), +) +Dispatcher.register( + dict.keys, + ("DictVariable",), + lambda var: var.keys(), +) + +Dispatcher.register( + operator.not_, + ("VariableBase",), + lambda x: ConstantVariable( + not x.get_py_value(allow_tensor=False), x.graph, DummyTracker([x]) + ), +) + +Dispatcher.register( + dict.values, + ("DictVariable",), + lambda var: var.values(), +) +Dispatcher.register( + dict.items, + ("DictVariable",), + lambda var: var.items(), +) +Dispatcher.register( + dict.setdefault, + ("DictVariable", "ConstantVariable", optional("VariableBase")), + lambda var, key, default=None: var.setdefault(key.get_py_value(), default), +) +Dispatcher.register( + dict.update, + ("DictVariable", "DictVariable"), + lambda var, other: var.update(other), +) +Dispatcher.register( + dict.copy, + ("DictVariable",), + lambda var: var.copy(), +) +Dispatcher.register( + dict.clear, + ("DictVariable",), + lambda var: var.clear(), +) +Dispatcher.register( + dict.pop, + ("DictVariable", "ConstantVariable"), + lambda var, key: var.pop(key.get_py_value()), +) +Dispatcher.register( + dict.pop, + ("DictVariable", "ConstantVariable", "VariableBase"), + lambda var, key, default: var.pop(key.get_py_value(), default), +) +Dispatcher.register( + dict.popitem, + ("DictVariable",), + lambda var: var.popitem(), +) + +# tuple +Dispatcher.register( + tuple, + ("ContainerVariable",), + lambda var: TupleVariable( + tuple(var.get_wrapped_items()), + graph=var.graph, + tracker=DummyTracker([var]), + ), +) +Dispatcher.register( + tuple, + ("SequenceIterVariable",), + lambda var: TupleVariable( + tuple(var.to_list()), + graph=var.graph, + tracker=DummyTracker([var]), + ), +) +Dispatcher.register( + tuple.count, + ("TupleVariable", "VariableBase"), + lambda var, value: var.count(value), +) +Dispatcher.register( + tuple.index, + ("TupleVariable", "VariableBase"), + lambda var, value: var.index(value), +) + +# list +Dispatcher.register( + list, + (), + lambda: ListVariable( + [], + graph=Dispatcher.graph, + tracker=DummyTracker([]), + ), +) + +Dispatcher.register( + list, + ("ContainerVariable",), + lambda var: ListVariable( + list(var.get_wrapped_items()), + graph=var.graph, + tracker=DummyTracker([var]), + ), +) + +Dispatcher.register( + list, + ("IterVariable",), + lambda var: ListVariable( + var.to_list(), + graph=var.graph, + tracker=DummyTracker([var]), + ), +) +Dispatcher.register( + list.extend, + ("ListVariable", "ListVariable | TupleVariable"), + lambda var, other: var.extend(other), +) +Dispatcher.register( + list.append, + ("ListVariable", "VariableBase"), + lambda var, other: var.append(other), +) +Dispatcher.register( + list.insert, + ("ListVariable", "ConstantVariable", "VariableBase"), + lambda var, index, obj: var.insert(index.get_py_value(), obj), +) +Dispatcher.register( + list.remove, + ("ListVariable", "VariableBase"), + lambda var, other: var.remove(other), +) +Dispatcher.register( + list.pop, + ("ListVariable", optional("ConstantVariable")), + lambda var, index=None: var.pop(index), +) +Dispatcher.register( + list.clear, + ("ListVariable",), + lambda var: var.clear(), +) +Dispatcher.register( + list.sort, + ("ListVariable",), + lambda var: var.sort(), +) +Dispatcher.register( + list.reverse, + ("ListVariable",), + lambda var: var.reverse(), +) +Dispatcher.register( + list.copy, + ("ListVariable",), + lambda var: var.copy(), +) +Dispatcher.register( + list.count, + ("ListVariable", "VariableBase"), + lambda var, obj: var.count(obj), +) +Dispatcher.register( + list.index, + ("ListVariable", "VariableBase"), + lambda var, obj: var.index(obj), +) +Dispatcher.register( + operator.add, + ("ListVariable", "ListVariable"), + lambda var, other: var.concat(other), +) +Dispatcher.register( + operator.add, + ("TupleVariable", "TupleVariable"), + lambda var, other: var.concat(other), +) +Dispatcher.register( + operator.mul, + ("ListVariable | TupleVariable", "ConstantVariable"), + lambda var, other: var.repeat(other), +) + +# getattr +Dispatcher.register( + getattr, + ("VariableBase", "ConstantVariable", optional("VariableBase")), + lambda var, name, default=None: var.getattr( + add_guard(name).get_py_value(), default + ), +) + +# hasattr +Dispatcher.register( + hasattr, + ("VariableBase", "ConstantVariable"), + lambda var, name: var.hasattr(add_guard(name).get_py_value()), +) + +Dispatcher.register( + delattr, + ("VariableBase", "VariableBase"), + lambda var, name: var.delattr(add_guard(name).get_py_value()), +) + +Dispatcher.register( + setattr, + ("VariableBase", "VariableBase", "VariableBase"), + lambda var, name, value: var.setattr(add_guard(name).get_py_value(), value), +) + +# len +Dispatcher.register( + len, + ("ContainerVariable | ContainerLayerVariable",), + lambda var: var.len(), +) + +# range +# stop +Dispatcher.register( + range, + ("ConstantVariable",), + lambda stop: RangeVariable( + range(stop.get_py_value()), + graph=stop.graph, + tracker=DummyTracker([stop]), + ), +) + +# start, stop +Dispatcher.register( + range, + ("ConstantVariable", "ConstantVariable"), + lambda start, stop: RangeVariable( + range(start.get_py_value(), stop.get_py_value()), + graph=stop.graph, + tracker=DummyTracker([start, stop]), + ), +) +# start, stop, step +Dispatcher.register( + range, + ("ConstantVariable", "ConstantVariable", "ConstantVariable"), + lambda start, stop, step: RangeVariable( + range(start.get_py_value(), stop.get_py_value(), step.get_py_value()), + graph=stop.graph, + tracker=DummyTracker([start, stop, step]), + ), +) +# TODO(zmh): Modify +# enumerate +Dispatcher.register( + enumerate, + ("VariableBase",), + lambda var: EnumerateVariable.from_iterator( + var, graph=var.graph, tracker=DummyTracker([var]) + ), +) + + +# map +Dispatcher.register( + map, + ( + "CallableVariable", + "VariableBase", + ), + lambda fn, var: MapVariable.from_iterator( + fn, var, graph=var.graph, tracker=DummyTracker([var]) + ), +) + + +# reversed +@Dispatcher.register_decorator(reversed) +def dispatch_reversed(var: ContainerVariable): + from .tracker import DanglingTracker + from .variables import BuiltinVariable, SequenceIterVariable + + length_var = BuiltinVariable(len, var.graph, DanglingTracker())(var) + assert isinstance(length_var, ConstantVariable) + getitem = BuiltinVariable(operator.getitem, var.graph, DanglingTracker()) + out = reversed([getitem(var, i) for i in range(length_var.get_py_value())]) + out_var = ListVariable( + list(out), graph=var.graph, tracker=DummyTracker([var]) + ) + return SequenceIterVariable( + out_var, + graph=var.graph, + tracker=DummyTracker([var]), + ) + + +# isinstance +Dispatcher.register( + isinstance, + ("TensorVariable", "VariableBase"), + lambda left, right: ConstantVariable( + isinstance( + paddle.to_tensor(0), + right.get_py_value(allow_tensor=True), + ), + left.graph, + DummyTracker([left, right]), + ), +) + +Dispatcher.register( + isinstance, + ("VariableBase", "VariableBase"), + lambda left, right: ConstantVariable( + isinstance( + left.get_py_value(allow_tensor=True), + right.get_py_value(allow_tensor=True), + ), + left.graph, + DummyTracker([left, right]), + ), +) + +# bool +Dispatcher.register( + bool, + ("ContainerVariable",), + lambda var: var.bool(), +) +Dispatcher.register( + operator.truth, + ("ConstantVariable",), + lambda var: var.bool(), +) + +# str +Dispatcher.register( + str, + ("ConstantVariable",), + lambda var: var.str(), +) + + +@Dispatcher.register_decorator(str.format) +def str_format(var: ConstantVariable, *args: ConstantVariable): + return var.format(*args) + + +Dispatcher.register( + str.lower, + ("ConstantVariable",), + lambda var: var.lower(), +) + + +@Dispatcher.register_decorator(str.startswith) +def str_startswith(var: ConstantVariable, substr: ConstantVariable, beg: ConstantVariable = None, end: ConstantVariable = None): # type: ignore + value = var.get_py_value() + if end is None: + end = ConstantVariable(len(value), var.graph, DanglingTracker()) + if beg is None: + beg = ConstantVariable(0, var.graph, DanglingTracker()) + + res = value.startswith( + substr.get_py_value(), beg.get_py_value(), end.get_py_value() + ) + return ConstantVariable( + res, var.graph, DummyTracker([var, substr, beg, end]) + ) + + +@Dispatcher.register_decorator(str.endswith) +def str_endswith(var: ConstantVariable, substr: ConstantVariable, beg: ConstantVariable = None, end: ConstantVariable = None): # type: ignore + value = var.get_py_value() + if end is None: + end = ConstantVariable(len(value), var.graph, DanglingTracker()) + if beg is None: + beg = ConstantVariable(0, var.graph, DanglingTracker()) + + res = value.endswith( + substr.get_py_value(), beg.get_py_value(), end.get_py_value() + ) + return ConstantVariable( + res, var.graph, DummyTracker([var, substr, beg, end]) + ) + + +# getitem +# TODO: Should pass its Variable into the getitem and perform operations such as getting value in the getitem. like this:https://github.com/PaddlePaddle/PaddleSOT/pull/198#discussion_r1241110949 +Dispatcher.register( + operator.getitem, + ( + "TensorVariable", + "Any", + ), + lambda var, key: var.getitem( + VariableFactory.from_value( + key, graph=var.graph, tracker=ConstTracker(key) + ) + ), +) + +Dispatcher.register( + operator.getitem, + ( + "VariableBase", + "int | str", + ), + lambda var, key: var.getitem( + VariableFactory.from_value( + key, graph=var.graph, tracker=ConstTracker(key) + ) + ), +) + +Dispatcher.register( + operator.getitem, + ( + "VariableBase", + "ConstantVariable | SliceVariable", + ), + lambda var, key: var.getitem(key), +) + +# setitem +Dispatcher.register( + operator.setitem, + ( + "VariableBase", + "int | str | ConstantVariable | TensorVariable", + "int | str | ConstantVariable | TensorVariable", + ), + lambda var, key, value: var.setitem(key.get_py_value(), value), +) + +# delitem +Dispatcher.register( + operator.delitem, + ( + "VariableBase", + "int | str | TensorVariable", + ), + lambda var, key: var.delitem(key), +) +Dispatcher.register( + operator.delitem, + ( + "VariableBase", + "ConstantVariable", + ), + lambda var, key: var.delitem(key.get_py_value()), +) + + +# TensorVariable +Dispatcher.register( + paddle.is_tensor, + ("TensorVariable",), + lambda var: var.is_tensor(), +) +Dispatcher.register( + paddle.is_complex, + ("TensorVariable",), + lambda var: var.is_complex(), +) +Dispatcher.register( + paddle.is_integer, + ("TensorVariable",), + lambda var: var.is_integer(), +) +Dispatcher.register( + paddle.is_floating_point, + ("TensorVariable",), + lambda var: var.is_floating_point(), +) +Dispatcher.register( + paddle.rank, + ("TensorVariable",), + lambda var: var.ndim, +) + +Dispatcher.register( + operator.is_, + ("TensorVariable", "TensorVariable"), + lambda var, other: ConstantVariable( + var.get_symbol() == other.get_symbol(), + var.graph, + tracker=DummyTracker([var, other]), + ), +) + +Dispatcher.register( + operator.is_, + ("TensorVariable", "VariableBase"), + lambda var, other: ConstantVariable( + False, + var.graph, + tracker=DummyTracker([var, other]), + ), +) + +Dispatcher.register( + operator.is_, + ("VariableBase", "TensorVariable"), + lambda var, other: ConstantVariable( + False, + var.graph, + tracker=DummyTracker([var, other]), + ), +) + +# VariableBase +Dispatcher.register( + operator.is_, + ("VariableBase", "VariableBase"), + lambda var, other: ConstantVariable( + var.get_py_value() is other.get_py_value(), + var.graph, + tracker=DummyTracker([var, other]), + ), +) + + +@Dispatcher.register_decorator(operator.is_not) +def is_not_func(var: VariableBase, other: VariableBase): + handler = Dispatcher.dispatch(operator.is_, var, other) + if handler is None: + raise FallbackError( + f"Not found implementation operator.is for {var} and {other}." + ) + return handler(var, other).bool_not() + + +# is None +Dispatcher.register( + operator_is_none, + ("VariableBase",), + lambda var: BuiltinVariable(operator.is_, var.graph, DanglingTracker())( + var, ConstantVariable.wrap_literal(None, var.graph) + ), +) + +# is not None +Dispatcher.register( + operator_is_not_none, + ("VariableBase",), + lambda var: BuiltinVariable(operator.is_not, var.graph, DanglingTracker())( + var, ConstantVariable.wrap_literal(None, var.graph) + ), +) + + +# NOTE(SigureMo): Don't directly capture free var inside for-loop, use partial instead. +# ```python +# lambdas = [] +# for i in range(10): +# lambdas.append(lambda: i) +# for fn in lambdas: +# print(fn()) # result is 9, 9, 9, 9, 9, 9, 9, 9, 9, 9 +# ``` +# Rewrite by partial: +# ```python +# lambdas = [] +# for i in range(10): +# lambdas.append(partial(lambda i: i, i)) +# for fn in lambdas: +# print(fn()) # result is 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 +# ``` + +# Constant +for unary_fn in UNARY_OPS: + for magic_method in magic_method_builtin_dispatch(unary_fn): + Dispatcher.register( + unary_fn, + ("ConstantVariable",), + partial( + lambda fn, var: VariableFactory.from_value( + fn(var.get_py_value()), + var.graph, + tracker=DummyTracker([var]), + ), + unary_fn, + ), + ) +for binary_fn in BINARY_OPS: + for magic_method in magic_method_builtin_dispatch(binary_fn): + Dispatcher.register( + binary_fn, + ("ConstantVariable", "ConstantVariable"), + partial( + lambda fn, var, other: VariableFactory.from_value( + fn(var.get_py_value(), other.get_py_value()), + var.graph, + tracker=DummyTracker([var, other]), + ), + binary_fn, + ), + ) +# Tensor +fallback_tensor_unary_method = { + int, + bool, + operator.truth, +} + +Dispatcher.register(tensor_numel, ("TensorVariable",), lambda x: x.numel()) + +for unary_fn in UNARY_OPS: + if unary_fn in fallback_tensor_unary_method: + Dispatcher.register( + unary_fn, + ("TensorVariable",), + raise_break_graph_fn, + ) + continue + + if unary_fn is len: + Dispatcher.register( + unary_fn, + ("TensorVariable",), + lambda x: x.len(), + ) + continue + + for magic_method in magic_method_builtin_dispatch(unary_fn): + Dispatcher.register( + unary_fn, + ("TensorVariable",), + partial( + lambda magic_name, var: var.graph.call_tensor_method( + magic_name, var + ), + magic_method.name, + ), + ) +for binary_fn in BINARY_OPS: + for magic_method in magic_method_builtin_dispatch(binary_fn): + # skip all inplace magic method name, we will dispatch it to non-inplace + # magic methods + if magic_method.is_inplace: + continue + + if not magic_method.is_reverse: + Dispatcher.register( + binary_fn, + ( + "TensorVariable", + "TensorVariable | ConstantVariable | NumpyVariable", + ), + partial( + lambda magic_name, var, other: var.graph.call_tensor_method( + magic_name, var, other + ), + magic_method.name, + ), + ) + else: + # skip __mod__ for str and TensorVariable + if magic_method.name == "__rmod__": + + @Dispatcher.register_decorator(operator.mod) + def tensor_mod_dispatcher( + var: ConstantVariable, other: TensorVariable + ): + if var.get_py_type() is str: + raise BreakGraphError( + "(ConstantVariable % TensorVariable) raise a callback. " + ) + raise FallbackError("Tensor doesn't support __rmod__") + + else: + Dispatcher.register( + binary_fn, + ( + "ConstantVariable | NumpyVariable", + "TensorVariable", + ), + partial( + lambda reverse_magic_name, var, other: other.graph.call_tensor_method( + reverse_magic_name, other, var + ), + magic_method.name, + ), + ) + +# Register dispatch for NumpyVariable: fallback ! +for unary_fn in UNARY_OPS: + if unary_fn in [bool]: + continue + for magic_method in magic_method_builtin_dispatch(unary_fn): + + @Dispatcher.register_decorator(unary_fn) + def numpy_unary_dispatcher(var: NumpyVariable): + raise FallbackError('Numpy operator need fallback to dygraph') + + +Dispatcher.register( + operator.eq, + ("NumpyVariable", "ConstantVariable | NumpyVariable"), + lambda left, right: constant_numpy_equal(right, left), +) + + +for binary_fn in BINARY_OPS: + for magic_method in magic_method_builtin_dispatch(binary_fn): + + @Dispatcher.register_decorator(binary_fn) + def numpy_binary_dispatcher(var: NumpyVariable, other: NumpyVariable): + raise FallbackError('Numpy operator need fallback to dygraph') + + +# Register dispatch for DataVariable: directy call and return a wrapped variable. +def data_variable_binary_dispatcher(var, other, operator): + return VariableFactory.from_value( + operator(var.get_py_value(), other.get_py_value()), + var.graph, + DummyTracker([var, other]), + ) + + +for binary_fn in BINARY_OPS: + for magic_method in magic_method_builtin_dispatch(binary_fn): + Dispatcher.register( + binary_fn, + ("DataVariable", "Any"), + partial(data_variable_binary_dispatcher, operator=binary_fn), + ) + Dispatcher.register( + binary_fn, + ("Any", "DataVariable"), + partial(data_variable_binary_dispatcher, operator=binary_fn), + ) + +for unary_fn in UNARY_OPS: + for magic_method in magic_method_builtin_dispatch(unary_fn): + + def data_variable_unary_dispatcher(var: DataVariable, fn): + return VariableFactory.from_value( + fn(var.get_py_value()), + var.graph, + DummyTracker([var]), + ) + + Dispatcher.register( + unary_fn, + ("DataVariable",), + partial(data_variable_unary_dispatcher, fn=unary_fn), + ) + + +Dispatcher.register( + math.ceil, + ("ConstantVariable",), + lambda var: ConstantVariable( + math.ceil(var.get_py_value()), + var.graph, + tracker=DummyTracker([var]), + ), +) + +Dispatcher.register( + math.floor, + ("ConstantVariable",), + lambda var: ConstantVariable( + math.floor(var.get_py_value()), + var.graph, + tracker=DummyTracker([var]), + ), +) + +Dispatcher.register( + ord, + ("ConstantVariable",), + lambda var: var.ord(), +) + +Dispatcher.register( + chr, + ("ConstantVariable",), + lambda var: var.chr(), +) + + +# pow +# base ** exp % mod +@Dispatcher.register_decorator(pow) +def dispatch_pow(base: VariableBase, exp: VariableBase, mod: VariableBase = None): # type: ignore + graph = base.graph + result = BuiltinVariable(operator.pow, graph, DanglingTracker())(base, exp) + if exp is not None: + result = BuiltinVariable(operator.mod, graph, DanglingTracker())( + result, mod + ) + return result + + +Dispatcher.register( + math.pow, + ("ConstantVariable", "ConstantVariable"), + lambda var1, var2: ConstantVariable( + math.pow(var1.get_py_value(), var2.get_py_value()), + var1.graph, + tracker=DummyTracker([var1, var2]), + ), +) + + +@Dispatcher.register_decorator(sum) +def dispatch_sum(var: ContainerVariable | TensorVariable, start: VariableBase = None): # type: ignore + if start is None: + start = ConstantVariable.wrap_literal(0, var.graph) + elements = [ + var.getitem(ConstantVariable.wrap_literal(i, var.graph)) + for i in range(len(var)) + ] + result = reduce( + BuiltinVariable(operator.add, var.graph, DanglingTracker()), + elements, + start, + ) + return result + + +Dispatcher.register( + max, + ("ListVariable",), + lambda var: var.max(), +) + +Dispatcher.register( + min, + ("ListVariable",), + lambda var: var.min(), +) + +Dispatcher.register( + math.sqrt, + ("ConstantVariable",), + lambda var: ConstantVariable( + math.sqrt(var.get_py_value()), + var.graph, + tracker=DummyTracker([var]), + ), +) + + +def constant_numpy_equal(left, right): + numpy_ans = left.get_py_value() == right.get_py_value() + return NumpyVariable( + numpy_ans, + left.graph, + tracker=DummyTracker([left, right]), + ) + + +Dispatcher.register( + operator.eq, + ("ConstantVariable", "NumpyVariable"), + lambda left, right: constant_numpy_equal(left, right), +) + +Dispatcher.register( + bool, + ("NumpyVariable",), + lambda x: ConstantVariable( + bool(x.get_py_value()), + x.graph, + tracker=DummyTracker([x]), + ), +) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variable_stack.py b/python/paddle/jit/sot/opcode_translator/executor/variable_stack.py new file mode 100644 index 0000000000000..e7389de5b8805 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/variable_stack.py @@ -0,0 +1,216 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, overload + +if TYPE_CHECKING: + ValidateValueFunc = Callable[[Any], None] + + +StackDataT = TypeVar("StackDataT") + + +class VariableStack(Generic[StackDataT]): + """ + A stack class for storing variables. + + Examples: + >>> var1, var2, var3, var4 = range(1, 5) + >>> stack = VariableStack() + >>> stack.push(var1) + >>> stack.push(var3) + >>> stack.insert(1, var2) + >>> stack + [1, 2, 3] + >>> stack.pop() + 3 + >>> stack.pop_n(2) + [1, 2] + >>> stack.push(var1) + >>> stack.push(var2) + >>> stack.push(var3) + >>> stack + [1, 2, 3] + >>> stack.top + 3 + >>> stack.peek[1] + 3 + >>> stack.peek[:1] + [3] + >>> stack.peek[:2] + [2, 3] + >>> stack.peek[1] = var4 + >>> stack + [1, 2, 4] + + """ + + class VariablePeeker: + @overload + def __getitem__(self, index: int) -> StackDataT: + ... + + @overload + def __getitem__(self, index: slice) -> list[StackDataT]: + ... + + @overload + def __call__(self, index: int = 1) -> StackDataT: + ... + + @overload + def __call__(self, index: slice) -> list[StackDataT]: + ... + + def __init__( + self, data: list[StackDataT], validate_value_func: ValidateValueFunc + ): + self._data = data + self.validate_value_func = validate_value_func + + def __getitem__( + self, index: int | slice + ) -> StackDataT | list[StackDataT]: + if isinstance(index, int): + assert 0 < index <= len(self._data) + return self._data[-index] + if isinstance(index, slice): + assert ( + index.start is None and index.step is None + ), "slice which has start or step not supported" + assert 0 < index.stop <= len(self._data) + return self._data[-index.stop :] + raise NotImplementedError(f"index type {type(index)} not supported") + + def __setitem__(self, index: int, value: Any): + assert isinstance( + index, int + ), f"index type {type(index)} not supported" + assert ( + 0 < index <= len(self._data) + ), f"index should be in [1, {len(self._data)}], but get {index}" + self.validate_value_func(value) + self._data[-index] = value + + def __call__( + self, index: int | slice = 1 + ) -> StackDataT | list[StackDataT]: + return self[index] + + def __init__( + self, + data: list[StackDataT] | None = None, + *, + validate_value_func: ValidateValueFunc | None = None, + ): + if data is None: + data = [] + else: + data = data.copy() + self.validate_value_func = ( + (lambda _: None) + if validate_value_func is None + else validate_value_func + ) + self._data = data + self._peeker = VariableStack.VariablePeeker( + self._data, self.validate_value_func + ) + + def copy(self): + return VariableStack( + self._data, validate_value_func=self.validate_value_func + ) + + def push(self, val: StackDataT): + """ + Pushes a variable onto the stack. + + Args: + val: The variable to be pushed. + + """ + self.validate_value_func(val) + self._data.append(val) + + def insert(self, index: int, val: StackDataT): + """ + Inserts a variable onto the stack. + + Args: + index: The index at which the variable is to be inserted, the top of the stack is at index 0. + val: The variable to be inserted. + + """ + assert ( + 0 <= index <= len(self) + ), f"index should be in [0, {len(self)}], but get {index}" + self.validate_value_func(val) + self._data.insert(len(self) - index, val) + + def pop(self) -> StackDataT: + """ + Pops the top value from the stack. + + Returns: + The popped value. + + """ + assert len(self) > 0, "stack is empty" + return self._data.pop() + + def pop_n(self, n: int) -> list[StackDataT]: + """ + Pops the top n values from the stack. + + Args: + n: The number of values to pop. + + Returns: + A list of the popped values. + + """ + assert ( + len(self) >= n >= 0 + ), f"n should be in [0, {len(self)}], but get {n}" + if n == 0: + return [] + retval = self._data[-n:] + self._data[-n:] = [] + return retval + + @property + def peek(self) -> VariablePeeker: + return self._peeker + + @property + def top(self) -> StackDataT: + assert len(self) > 0, "stack is empty" + return self.peek[1] + + @top.setter + def top(self, value): + assert len(self) > 0, "stack is empty" + self.peek[1] = value + + def __iter__(self): + return iter(self._data) + + def __len__(self) -> int: + return len(self._data) + + def __repr__(self) -> str: + return str(self._data) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/__init__.py b/python/paddle/jit/sot/opcode_translator/executor/variables/__init__.py new file mode 100644 index 0000000000000..9611734ffffcd --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/__init__.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import ( # noqa: F401 + ConstTypes, + VariableBase, + VariableFactory, + find_traceable_vars, + map_variables, +) +from .basic import ( # noqa: F401 + CellVariable, + ConstantVariable, + DataVariable, + DygraphTracerVariable, + FunctionGlobalVariable, + GlobalVariable, + ModuleVariable, + NullVariable, + NumpyVariable, + ObjectVariable, + SliceVariable, + TensorVariable, +) +from .callable import ( # noqa: F401 + BuiltinVariable, + CallableVariable, + ClassVariable, + ContainerLayerVariable, + FunctionVariable, + LayerVariable, + MethodVariable, + PaddleApiVariable, + PaddleLayerVariable, + UserDefinedFunctionVariable, + UserDefinedGeneratorVariable, + UserDefinedLayerVariable, +) +from .container import ( # noqa: F401 + ContainerVariable, + DictVariable, + ListVariable, + RangeVariable, + TupleVariable, +) +from .iter import ( # noqa: F401 + EnumerateVariable, + IterVariable, + MapVariable, + SequenceIterVariable, + UserDefinedIterVariable, +) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/base.py b/python/paddle/jit/sot/opcode_translator/executor/variables/base.py new file mode 100644 index 0000000000000..17cb99aeef516 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/base.py @@ -0,0 +1,618 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import operator +from functools import cached_property +from queue import Queue +from typing import TYPE_CHECKING, Any, Callable, Optional + +import paddle + +from ....profiler import event_register +from ....utils import NameGenerator, get_unbound_method, log +from ....utils.exceptions import FallbackError, HasNoAttributeError +from ..dispatcher import Dispatcher +from ..guard import StringifyExpression, check_guard, union_free_vars +from ..mutable_data import MutableDictLikeData +from ..pycode_generator import PyCodeGen +from ..tracker import ( + DummyTracker, + GetAttrTracker, + GetItemTracker, + GetIterTracker, + Tracker, +) + +if TYPE_CHECKING: + from ..function_graph import FunctionGraph + + # Each variable object should implement a method called `from_value`, + # which should adhere to the FromValueFunc signature. + FromValueFunc = Callable[ + [Any, FunctionGraph, Tracker], Optional["VariableBase"] + ] + + +ConstTypes = (int, float, str, bool, type(None)) + + +@event_register("find_traceable_vars") +def find_traceable_vars( + root_vars: list[VariableBase], +) -> list[VariableBase]: + """ + This function is used to find all traceable variables in the given list of variables. + + Args: + root_vars (list[VariableBase]): A list of root variables from which the ordering starts. + + Returns: + list[VariableBase]: A list of variables that are traceable. + """ + results: list[VariableBase] = [] + visited: set[VariableBase] = set() + queue: Queue[VariableBase] = Queue() + + for root in root_vars: + queue.put(root) + + while not queue.empty(): + var = queue.get() + if var in visited: + continue + + visited.add(var) + if var.tracker.need_guard(): + results.append(var) + continue + + # Pruning traceable variable, if the variable is traceable, we don't need to + # trace its inputs. + inputs = var.get_inputs() + + for var in inputs: + if var not in visited and var not in queue.queue: + queue.put(var) + + return results + + +def map_variables(map_func, variables: list[VariableBase]): + """ + This function maps the given map_func to the given list of variables in a recursive manner. + Args: + map_func (Callable[[VariableBase], Any]): The function to be mapped to each variable. + variables (list[VariableBase]): A list of variables to which the map_func is to be applied. + + Returns: + tuple: The result of applying the map_func to the variables. + """ + + def _map_variable(variable: VariableBase | object): + from .basic import SliceVariable + from .container import ContainerVariable + + if isinstance(variable, ContainerVariable): + return paddle.utils.map_structure( + _map_variable, variable.get_wrapped_items() + ) + + if isinstance(variable, SliceVariable): + return slice( + map_func(variable.getattr("start")), + map_func(variable.getattr("stop")), + map_func(variable.getattr("step")), + ) + + return map_func(variable) + + return paddle.utils.map_structure(_map_variable, variables) + + +class VariableFactory: + """ + A factory class for creating variables from arbitrary values. + + This class provides a set of registration and factory methods for creating variables + of different types based on the type of the input value. + + """ + + registered_funcs: dict[str, list[str]] = {"default": []} + mapping_str_func: dict[str, FromValueFunc] = {} + + @staticmethod + def default_from_value(value, graph, tracker): + """ + A default factory function that creates an ObjectVariable from the given value. + + Args: + value: The input value. + graph: The FunctionGraph object that this variable is associated with. + tracker: The Tracker object that tracks the information of this variable. + + Returns: + ObjectVariable: A new ObjectVariable representing the input value. + """ + from .basic import ObjectVariable + + return ObjectVariable(value, graph, tracker) + + @staticmethod + def register_from_value(*, successor: str | None = None): + """ + A decorator function that registers a function for creating a Variable from a value. + + Args: + successor (str | None, optional): The name of the successor function that will be called after this function when creating a Variable. If None, the function is added to a default list of functions. + + Returns: + The _register_from_value decorator function, which takes the function to be registered as an argument. + """ + registered_funcs = VariableFactory.registered_funcs + mapping_str_func = VariableFactory.mapping_str_func + + def _register_from_value(func: FromValueFunc): + """ + Function to register a function for creating a Variable from a value + """ + # Get the name of the function + name = func.__qualname__.split(".")[0] + # Map the name of the function to the function + mapping_str_func[name] = func + if successor is None: + registered_funcs["default"].append( + name + ) # If successor is None, add the function to the "default" list + elif successor not in registered_funcs.keys(): + registered_funcs[successor] = [ + name + ] # If the successor is not in the registered_funcs dictionary, set the value to a list containing only name + else: + registered_funcs[successor].append( + name + ) # If the successor is in the registered_funcs dictionary, append name to the existing list of functions for that successor + + log( + 4, VariableFactory.registered_funcs + ) # Print the registered_funcs dictionary if the logging level is at least 4 + return _register_from_value + + @staticmethod + def from_value( + value: Any, + graph: FunctionGraph, + tracker: Tracker, + *, + debug_name: str | None = None, + ) -> VariableBase: + """ + Create a new variable object from the given value. + + This method searches through the registered from_value functions to find one + that can create a variable object from the given value. If no matching function + is found, the default_from_value function is used. + + Args: + value (Any): The input value. + graph (FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker (Tracker): The Tracker object that tracks the information of this variable. + debug_name (str | None): An optional debug name for the variable. + + Returns: + VariableBase: A new variable object representing the input value. + """ + registered_funcs = VariableFactory.registered_funcs + + def _find_var(key: str = "default") -> VariableBase | None: + for name in registered_funcs[key]: + if name in registered_funcs.keys(): + # If the function name is a key in the registered_funcs dictionary, recursively find a Variable using that function + var = _find_var(name) + if var is not None: + return var + # Get the function corresponding to the name from the mapping_str_func dictionary + func = VariableFactory.mapping_str_func[name] + var = func( + value, graph, tracker + ) # Call the function to create a Variable from the value + if var is not None: + return var + + var = _find_var() + if var is None: + var = VariableFactory.default_from_value( + value, graph, tracker + ) # If a Variable could not be found using the registered functions, use the default function to create a new Variable + var.debug_name = debug_name + return var + + +class VariableBase: + """ + VariableBase is a basic concept and each symbols in VM stack is regarded as + an Variable Object in symblic tracing process. + + There are two key data structures during Python runtime: + PyFrameObject, which provides the instance for function logical lock usage, + and PyCodeObject, which provides the bytecode for the corresponding function. + With these data, the Python virtual machine executes the bytecode sequentially on a stack to complete function logic. + + Args: + tracker(Tracker): The Tracker object that tracks the information of this variable. + + Note: + We should push an object of a subclass of VariableBase instead of an object of VariableBase onto the VM stack. + It serves as an abstract class and should not be instantiated directly. + """ + + tracker: Tracker # An attribute to store the Tracker object associated with the variable + value: Any + name_generator = NameGenerator( + "object_" + ) # A class-level attribute to generate names for new variables + mutable_attrs = [] + + def __init__(self, graph: FunctionGraph, tracker: Tracker): + self.graph = graph + self.tracker = tracker + self.id = VariableBase.name_generator.next() + self._debug_name: str | None = None + + @property + def main_info(self) -> dict[str, Any]: + """ + Property method to return a dictionary of main information about the variable + + Returns: + main_info: Main information of the variable. + """ + return {} + + @property + def debug_info(self) -> dict[str, Any]: + """ + Property method to return a dictionary of debug information about the variable + """ + return { + "debug_name": self.debug_name, + "id": self.id, + } + + @property + def debug_name(self) -> str: + """ + Generate a debug_name for each variable. + + Returns: + _debug_name: the name of variable. + """ + if self._debug_name is not None: + # Return the self._debug_name cache if it is not None. + return self._debug_name + inputs = self.tracker.inputs + if isinstance(self.tracker, GetItemTracker): + self._debug_name = ( + f"{self.tracker.container.debug_name}[{self.tracker.key}]" + ) + elif isinstance(self.tracker, GetAttrTracker): + self._debug_name = ( + f"{self.tracker.obj.debug_name}.{self.tracker.attr}" + ) + elif len(inputs) == 0: + self._debug_name = "tmp_var" + else: # len(inputs) >= 0 + for input in inputs: + assert input is not None + self._debug_name = "tmp_var_" + "_".join( + input.debug_name for input in inputs + ) + return self._debug_name + + @debug_name.setter + def debug_name(self, name): + self._debug_name = name + + def __hash__(self): + return hash(self.id) + + @check_guard + def make_stringify_guard(self) -> list[StringifyExpression]: + """ + Create a StringifyExpression object that represents a guard expression for this variable. + + Returns: + StringifyExpression: An object that contains the guard expression and the free variables used in the expression. + """ + + # Get a ValueTracer object from the Tracker object associated with the variable + frame_value_tracer = self.tracker.trace_value_from_frame() + + return [ + StringifyExpression( + f"id(type({{}})) == {id(self.get_py_type())}", + [frame_value_tracer], + union_free_vars(frame_value_tracer.free_vars), + ), + StringifyExpression( + f"{{}} == {self.get_py_value()!r}", + [frame_value_tracer], + union_free_vars(frame_value_tracer.free_vars), + ), + ] + + def get_py_value(self, allow_tensor=False) -> Any: + """ + Abstract method to get the value of the variable + """ + raise NotImplementedError() + + def get_py_type(self): + """ + Method to get the type of the variable's value + """ + return type(self.get_py_value()) + + def is_none(self) -> bool: + """ + Method to check if the variable's value is None + """ + return self.get_py_value() is None + + def reconstruct( + self, + codegen: PyCodeGen, + *, + use_tracker: bool = True, + add_to_global_guarded_vars: bool = True, + ): + if self.tracker.is_traceable() and use_tracker: + self.tracker.gen_instructions(codegen) + else: + if add_to_global_guarded_vars: + self.graph.add_global_guarded_variable(self) + self._reconstruct(codegen) + + def _reconstruct(self, codegen: PyCodeGen): + """ + Abstract method to construct an opcode and append it into codegen.instructions + """ + raise FallbackError( + f'{self.__class__.__name__} does not implement "_reconstruct" method' + ) + + def flatten_items(self) -> list[VariableBase]: + """ + Recursively flatten the items in this container variable to a list of Variable objects. + + Returns: + list[VariableBase]: Flattened items of a container variable. + """ + from .container import ContainerVariable + + if not isinstance(self, ContainerVariable): + return [self] + flattened_items = [] + for item in self.get_items(): + flattened_items.extend(item.flatten_items()) + return flattened_items + + def get_inputs(self) -> list[VariableBase]: + """ + This method is used to get the inputs for the current variable. + + Returns: + list[VariableBase]: Inputs for the current variable. + """ + return self.tracker.inputs + + def get_traceable_inputs(self) -> list[VariableBase]: + """ + This method is used to get the traceable inputs for the current variable. + + Returns: + list[VariableBase]: Traceable inputs for the current variable. + """ + return list( + filter(lambda x: x.tracker.is_traceable(), self.tracker.inputs) + ) + + def call_function(self, /, *args, **kwargs): + pass + + @cached_property + def attr_proxy(self): + return self.graph.side_effects.get_proxy( + MutableDictLikeData, self.get_py_value(), self.attr_proxy_getter + ) + + def attr_proxy_getter(self, proxy: MutableDictLikeData, name: str): + if not hasattr(proxy.original_data, name): # can't true. + return MutableDictLikeData.Empty() + + attr = getattr(proxy.original_data, name) + if inspect.ismethod(attr) or ( + hasattr(attr, "__self__") + and inspect.ismethoddescriptor( + getattr(attr.__self__.__class__, name, None) + ) + ): + from .callable import MethodVariable + + fn = None + if inspect.ismethoddescriptor( + getattr(attr.__self__.__class__, name, None) + ): + class_var = VariableFactory.from_value( + self.get_py_type(), + self.graph, + GetAttrTracker(self, "__class__"), + ) + fn = VariableFactory.from_value( + getattr(attr.__self__.__class__, name), + self.graph, + GetAttrTracker(class_var, name), + ) + return MethodVariable.wrap_method( + value=attr, + instance=self, + fn=fn, + graph=self.graph, + tracker=GetAttrTracker(self, name), + method_name=name, + ) + + return VariableFactory.from_value( + attr, self.graph, tracker=GetAttrTracker(self, name) + ) + + def hasattr(self, name: str): + from .basic import ConstantVariable + + try: + self.getattr(name) + return ConstantVariable( + True, graph=self.graph, tracker=DummyTracker([self]) + ) + except HasNoAttributeError: + # NOTE(SigureMo): Only the HasNoAttributeError is raised, we can + # ensure that the attribute does not exist. Otherwise, we should + # raise the error. + return ConstantVariable( + False, graph=self.graph, tracker=DummyTracker([self]) + ) + + def getattr(self, name: str, default=None): + result = self.attr_proxy.get(name) + if isinstance(result, MutableDictLikeData.Empty): + if default is not None: + assert isinstance(default, VariableBase) + return default + raise HasNoAttributeError( + f"{self.__class__.__name__} {self} has no attribute {name}" + ) + return result + + def setattr(self, key: str, value): + from .basic import ConstantVariable + + self.attr_proxy.set(key, value) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def delattr(self, key: str): + from .basic import ConstantVariable + + self.attr_proxy.delete(key) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def __setitem__(self, key, value): + return self.setitem(key, value) + + def setitem(self, key, value): + raise FallbackError(f"{self} is not support setitem.") + + def __repr__(self): + info = {**self.main_info, **self.debug_info} + info_str = ", ".join([f"{value}" for value in info.values()]) + return f"{self.__class__.__name__}({info_str})" + + def __str__(self): + return self.__repr__() + + def __getitem__(self, idx): + return Dispatcher.call(operator.getitem, self, idx) + + def getitem(self, item): + class_var = VariableFactory.from_value( + self.get_py_value().__class__, + self.graph, + GetAttrTracker(self, '__class__'), + ) + fn_var = VariableFactory.from_value( + get_unbound_method(self.get_py_value(), '__getitem__'), + self.graph, + GetAttrTracker(class_var, '__getitem__'), + ) + self.graph.add_global_guarded_variable(item) + item = item.get_py_value() + output = fn_var(self, item) + return output + + def __call__(self, /, *args, **kwargs): + """ + Call the object represented by this variable with the given arguments. + + Args: + *args: Positional arguments to pass to the object's __call__ method. + **kwargs: Keyword arguments to pass to the object's __call__ method. + + Returns: + VariableBase: A new variable representing the result of calling the object's __call__ method. + """ + from .callable import BuiltinVariable, UserDefinedFunctionVariable + + class_var = VariableFactory.from_value( + self.get_py_value().__class__, + self.graph, + GetAttrTracker(self, '__class__'), + ) + assert class_var is not None + # if __call__ is a method, we should add self to arguments. + if inspect.ismethod(self.get_py_value().__call__): + args = (self,) + args + unbound_method = get_unbound_method(self.get_py_value(), '__call__') + if hasattr(unbound_method, "__code__"): + fn_var = UserDefinedFunctionVariable( + unbound_method, + self.graph, + GetAttrTracker(class_var, '__call__'), + ) + else: + fn_var = BuiltinVariable( + self.value, + self.graph, + GetAttrTracker(class_var, '__call__'), + ) + output = fn_var(*args, **kwargs) + return output + + def get_iter(self): + from .iter import UserDefinedIterVariable + + return UserDefinedIterVariable(self, self.graph, GetIterTracker(self)) + + @VariableFactory.register_from_value() + def from_value( + value: Any, + graph: FunctionGraph | None, + tracker: Tracker, + ) -> VariableBase | None: + """ + Create a new variable from a given value, or return None if the value cannot be converted to a variable. + Args: + value (Any): The value to create a variable from. + graph (FunctionGraph | None): The graph in which the variable will be used. + tracker (Tracker): The variable tracker to put the new variable in if created. + + Returns: + VariableBase | None: A new variable if one can be created from the given value, or None if the value cannot be converted to a variable. + """ + if isinstance(value, VariableBase): + return value + return None diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py new file mode 100644 index 0000000000000..ba0a7f51c91a0 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py @@ -0,0 +1,888 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import operator +import types +from functools import cached_property, reduce +from typing import TYPE_CHECKING, Any + +import numpy as np + +import paddle + +from ....infer_meta import MetaInfo +from ....symbolic.statement_ir import Symbol +from ....utils import ( + BreakGraphError, + FallbackError, + NameGenerator, + paddle_tensor_methods, +) +from ....utils.exceptions import HasNoAttributeError, InnerError +from ..dispatch_functions import tensor_numel +from ..guard import ( + StringifyExpression, + check_guard, + object_equal_stringify_guard, + union_free_vars, +) +from ..mutable_data import MutableDictLikeData +from ..pycode_generator import PyCodeGen +from ..tracker import ( + ConstTracker, + DanglingTracker, + DummyTracker, + GetAttrTracker, + GetIterTracker, + GlobalTracker, + Tracker, +) +from .base import ConstTypes, VariableBase, VariableFactory + +if TYPE_CHECKING: + from ..function_graph import FunctionGraph + from .callable import FunctionVariable + + +FP_DTYPE_ABBRS = { + paddle.bfloat16: 'bfloat16', + paddle.float64: 'float64', + paddle.float32: 'float32', + paddle.float16: 'float16', +} + +CP_DTYPE_ABBRS = { + paddle.complex64: 'complex64', + paddle.complex128: 'complex128', +} + +INT_DTYPE_ABBRS = { + paddle.int8: 'int8', + paddle.int16: 'int16', + paddle.int32: 'int32', + paddle.int64: 'int64', + paddle.uint8: 'uint8', +} + +DTYPE_ABBRS = { + **FP_DTYPE_ABBRS, + **CP_DTYPE_ABBRS, + **INT_DTYPE_ABBRS, + paddle.bool: 'bool', +} + + +class ConstantVariable(VariableBase): + """ + ConstantVariable is a subclass of VariableBase used to wrap a Variable of the const type. + + Args: + value(Any): The value to be wrapped. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, + value: Any, + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(graph, tracker) + self.value = value + + def get_py_value(self, allow_tensor=False): + return self.value + + @property + def debug_name(self) -> str: + return f"{self.value}" + + @debug_name.setter + def debug_name(self, name): + pass + + def _reconstruct(self, codegen: PyCodeGen): + codegen.gen_load_const(self.value) + + @property + def main_info(self) -> dict[str, Any]: + return {"value": self.value} + + def __bool__(self) -> bool: + return bool(self.value) + + def bool(self): + return ConstantVariable(bool(self), self.graph, DummyTracker([self])) + + def bool_not(self): + assert isinstance( + self.get_py_value(), bool + ), "Bool_not can only be applied to a bool variable." + return ConstantVariable( + not bool(self.get_py_value()), self.graph, DummyTracker([self]) + ) + + def str(self): + return ConstantVariable( + str(self.value), self.graph, DummyTracker([self]) + ) + + def format(self, *args): + return ConstantVariable( + str(self.value).format(*[str(a.value) for a in args]), + self.graph, + DummyTracker([self, *args]), + ) + + def lower(self): + return ConstantVariable( + str(self.value).lower(), + self.graph, + DummyTracker([self]), + ) + + def ord(self): + return ConstantVariable( + ord(self.value), + self.graph, + DummyTracker([self]), + ) + + def chr(self): + return ConstantVariable( + chr(self.value), + self.graph, + DummyTracker([self]), + ) + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if type(value) in ConstTypes: + return ConstantVariable(value, graph, tracker) + return None + + @staticmethod + def wrap_literal(value: Any, graph: FunctionGraph) -> ConstantVariable: + """ + Wrap a literal value in a ConstantVariable. + + Args: + value(Any): The literal value to be wrapped. + + Returns: + ConstantVariable: A new ConstantVariable object that wraps the given value. + """ + if isinstance(value, ConstantVariable): + return value + assert isinstance( + value, ConstTypes + ), f"value: {value},type: {type(value)}" + return ConstantVariable(value, graph, ConstTracker(value)) + + +class PrintStmtVariable(VariableBase): + def __init__(self, value: Any, graph: FunctionGraph): + # TODO: graph should be not None + super().__init__(None, DanglingTracker()) + self.args, self.kwargs = value + self.graph = graph + + def _reconstruct(self, codegen: PyCodeGen): + # do we need ? may be too strict. + for var in self.args: + self.graph.add_global_guarded_variable(var) + for var in self.kwargs.values(): + self.graph.add_global_guarded_variable(var) + # currently dont' consider kwargs + codegen.gen_load_global("print", push_null=True) + for var in self.args: + var.reconstruct(codegen) + codegen.gen_call_function(len(self.args)) + codegen.gen_pop_top() + + def flatten_items(self): + return self.args + + +IMPLEMENTED_TENSOR_PROPERTIES = set() + + +def tensor_property(func): + IMPLEMENTED_TENSOR_PROPERTIES.add(func.__name__) + return property(func) + + +class DataVariable(VariableBase): + """ + A value only object. + If it's all magic method don't change the function_graph state, [tensor op, guard, side_effect] + we will call it a ValueObjectVariable, we directy call python operator on it. + """ + + def __init__( + self, + value: Any, + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(graph, tracker) + self.value = value + + def get_py_value(self, allow_tensor=False): + return self.value + + +class TensorDtypeVariable(DataVariable): + def __init__(self, value, graph, tracker): + super().__init__(value, graph, tracker) + + @check_guard + def make_stringify_guard(self) -> list[StringifyExpression]: + if isinstance(self.tracker, GetAttrTracker) and isinstance( + self.tracker.obj, TensorVariable + ): + tensor_value_tracer = ( + self.tracker.obj.tracker.trace_value_from_frame() + ) + return [ + StringifyExpression( + f"str(MetaInfo.from_tensor({{}}).dtype) == '{str(self.value)}'", + [tensor_value_tracer], + {"MetaInfo": MetaInfo}, + ) + ] + else: + return object_equal_stringify_guard(self) + + @property + def main_info(self) -> dict[str, Any]: + return { + "dtype": self.value, + } + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if isinstance(value, paddle.dtype): + return TensorDtypeVariable(value, graph, tracker) + + +class TensorVariable(VariableBase): + """ + TensorVariable is a subclass of VariableBase used to wrap a Variable of the tensor type. + + Args: + tensor (paddle.Tensor | MetaInfo): The tensor to be wrapped. + graph (FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker (Tracker): The Tracker object that tracks the information of this variable. + """ + + var_name_generator = NameGenerator("var_") + mutable_attrs = ["meta"] + + def __init__( + self, + tensor: paddle.Tensor | MetaInfo, + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(graph, tracker) + if isinstance(tensor, paddle.Tensor): + self.value = None + self.meta = MetaInfo.from_tensor(tensor) + elif isinstance(tensor, MetaInfo): + self.value = None + self.meta = tensor + else: + raise InnerError( + "Required type(tensor) is paddle.Tensor or ProxyTensor, but received {}.".format( + type(tensor).__name__ + ) + ) + self.origin_meta = self.meta + self.var_name = TensorVariable.var_name_generator.next() + self.graph.side_effects.record_mutable_variable(self) + + def __len__(self): + if self.meta.shape[0] == -1: + raise BreakGraphError( + "length of tensor variable with first dimension == -1" + ) + return self.meta.shape[0] + + def get_py_value(self, allow_tensor=False): + if allow_tensor: + + class SotTensor: + def __init__(self, id_): + self.id = id_ + + def __eq__(self, var): + if not hasattr(var, "id"): + return False + else: + return self.id == var.id + + return SotTensor(self.id) + + raise BreakGraphError( + "Called TensorVariable.get_py_value. Should not use Tensor's value in simulating." + ) + + def get_py_type(self): + return paddle.Tensor + + def get_symbol(self) -> Symbol: + return Symbol(self.var_name) + + @property + def out_var_name(self): + return f"{self.graph.OUT_VAR_PREFIX}{self.var_name}" + + def _reconstruct(self, codegen: PyCodeGen): + codegen.gen_load_fast(self.out_var_name) + + @check_guard + def make_stringify_guard(self) -> list[StringifyExpression]: + frame_value_tracer = self.tracker.trace_value_from_frame() + + return [ + StringifyExpression( + f"MetaInfo.from_tensor({{}}).guard_str() == '{self.origin_meta.guard_str()}'", + [frame_value_tracer], + union_free_vars( + {"MetaInfo": MetaInfo}, + frame_value_tracer.free_vars, + ), + ) + ] + + def get_iter(self): + from .iter import SequenceIterVariable + + return SequenceIterVariable(self, self.graph, GetIterTracker(self)) + + @property + def main_info(self) -> dict[str, Any]: + return { + "shape": self.meta.shape, + "dtype": DTYPE_ABBRS[self.meta.dtype], + "stop_gradient": self.meta.stop_gradient, + "var_name": self.var_name, + } + + def getitem(self, key): + return self.graph.call_tensor_method('__getitem__', self, key) + + def setitem(self, key, value): + self.graph.add_global_guarded_variable(value) + + key_var = VariableFactory.from_value( + key, self.graph, tracker=ConstTracker(key) + ) + new_tensor = self.graph.call_paddle_api( + paddle.static.setitem, + self, + key_var, + value, + ) + + self.meta = new_tensor.meta + self.graph.add_inplace_tensors(self) + + @tensor_property + def T(self): + """ + Return a new TensorVariable object that wraps the result of calling the transpose method on the wrapped value of this TensorVariable. + """ + from .container import ListVariable + + perm = list(range(len(self.meta.shape) - 1, -1, -1)) + perm_var = ListVariable(perm, self.graph, tracker=ConstTracker(perm)) + assert perm_var is not None + out = self.graph.call_paddle_api(paddle.transpose, self, perm_var) + return out + + @tensor_property + def ndim(self): + """ + Return a ConstantVariable object that represents the number of dimensions of the wrapped value of this TensorVariable. + """ + return ConstantVariable( + len(self.meta.shape), self.graph, DummyTracker([self]) + ) + + @tensor_property + def size(self): + """ + Return a ConstantVariable object that represents the total number of elements in the wrapped value of this TensorVariable. + """ + # TODO: maybe break graph. + if self.meta.is_dynamic_shape(): + raise BreakGraphError( + f"Getting size for a dynamic shape tensor causes graph break. shape = {self.meta.shape}" + ) + elements = reduce(operator.mul, self.meta.shape, 1) + return ConstantVariable(elements, self.graph, DummyTracker([self])) + + @tensor_property + def shape(self): + if self.meta.is_dynamic_shape(): + raise BreakGraphError( + f"Getting shape for a dynamic shape tensor causes graph break. shape = {self.meta.shape}" + ) + from .container import ListVariable + + return ListVariable( + self.meta.shape, self.graph, tracker=DummyTracker([self]) + ) + + def numel(self): + return self.size + + def len(self): + if len(self.meta.shape) == 0: + raise InnerError("len() of a 0-D tensor is wrong") + first_dim = self.meta.shape[0] + if first_dim == -1: + raise BreakGraphError( + "Getting len() for a dynamic shape tensor causes graph break." + ) + + return ConstantVariable(first_dim, self.graph, DummyTracker([self])) + + def is_tensor(self): + return ConstantVariable(True, self.graph, DummyTracker([self])) + + def is_complex(self): + dtype = self.meta.dtype + is_cp_dtype = dtype in CP_DTYPE_ABBRS + return ConstantVariable(is_cp_dtype, self.graph, DummyTracker([self])) + + def is_integer(self): + dtype = self.meta.dtype + is_int_dtype = dtype in INT_DTYPE_ABBRS + return ConstantVariable(is_int_dtype, self.graph, DummyTracker([self])) + + def is_floating_point(self): + dtype = self.meta.dtype + is_fp_dtype = dtype in FP_DTYPE_ABBRS + return ConstantVariable(is_fp_dtype, self.graph, DummyTracker([self])) + + def getattr(self, name: str, default=None): + if default is not None: + raise FallbackError( + "default argument for getattr is not implemented" + ) + method_name_to_builtin_fn = { + "dim": paddle.rank, + "numel": tensor_numel, + "ndimension": paddle.rank, + "is_tensor": paddle.is_tensor, + "is_complex": paddle.is_complex, + "is_integer": paddle.is_integer, + "is_floating_point": paddle.is_floating_point, + } + if name in ["dtype", "type", "name", "persistable", "stop_gradient"]: + if name == "name" and self.meta.name.startswith( + "infer_meta_variable_tmp" + ): + raise BreakGraphError(f"{self.meta.name} is a middle tensor.") + return VariableFactory.from_value( + getattr(self.meta, name), + self.graph, + tracker=GetAttrTracker(self, name), + ) + elif name in IMPLEMENTED_TENSOR_PROPERTIES: + return getattr(self, name) + elif name in method_name_to_builtin_fn: + # TODO: backward, gradient + from .callable import BuiltinVariable + + builtin_fn = method_name_to_builtin_fn[name] + + return BuiltinVariable( + builtin_fn, self.graph, DanglingTracker() + ).bind(self, name) + elif name in paddle_tensor_methods: + from .callable import TensorFunctionVariable + + fn_var = TensorFunctionVariable( + name, graph=self.graph, tracker=DanglingTracker() + ) + return fn_var.bind(self, name) + else: + raise HasNoAttributeError(f"Unknown Tensor attribute: {name}") + + def setattr(self, key, val): + # support tensor variable store attr, like: + # t.stop_gradient = True + self.graph.call_tensor_method( + "__setattr__", + self, + VariableFactory().from_value(key, self.graph, ConstTracker(key)), + val, + ) + + def delattr(self, key): + raise BreakGraphError("Don't support TensorVariable delattr") + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if isinstance(value, (paddle.Tensor, MetaInfo)): + return TensorVariable(value, graph, tracker) + return None + + +class ObjectVariable(VariableBase): + """ + ObjectVariable is a subclass of VariableBase used to wrap a Variable of the object type. + + Args: + obj(Any): The object to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + make_stringify_guard = object_equal_stringify_guard + + def __init__(self, obj, graph, tracker): + super().__init__(graph, tracker) + self.value = obj + + @property + def main_info(self) -> dict[str, Any]: + return {"value": self.value} + + def get_py_value(self, allow_tensor=False) -> Any: + return self.value + + +class SliceVariable(VariableBase): + """ + SliceVariable is a subclass of VariableBase used to wrap a Variable of the slice type. + + Args: + slice_(slice): The slice to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__(self, slice_: slice, graph, tracker): + super().__init__(graph, tracker) + self.value = slice_ + + @property + def debug_name(self) -> str: + return ":".join( + [ + str(self.value.start) if self.value.start is not None else "", + str(self.value.stop) if self.value.stop is not None else "", + str(self.value.step) if self.value.step is not None else "", + ] + ) + + @debug_name.setter + def debug_name(self, name): + pass + + @cached_property + def attr_proxy(self): + return self.graph.side_effects.get_proxy( + MutableDictLikeData, self.value, self.attr_proxy_getter + ) + + @property + def main_info(self) -> dict[str, Any]: + return {"value": self.value} + + def get_py_value(self, allow_tensor=False): + return slice( + self.getattr("start").get_py_value(), + self.getattr("stop").get_py_value(), + self.getattr("step").get_py_value(), + ) + + @check_guard + def make_stringify_guard(self) -> list[StringifyExpression]: + frame_value_tracer = self.tracker.trace_value_from_frame() + result = ( + [ + StringifyExpression( + "isinstance({}, slice)", + [frame_value_tracer], + frame_value_tracer.free_vars, + ), + ] + + self.getattr("start").make_stringify_guard() + + self.getattr("stop").make_stringify_guard() + + self.getattr("step").make_stringify_guard() + ) + return result + + def _reconstruct(self, codegen: PyCodeGen): + if all( + isinstance(x, ConstantVariable) + for x in [ + self.getattr("start"), + self.getattr("stop"), + self.getattr("step"), + ] + ): + self.graph.add_global_guarded_variable(self) + self.getattr("start").reconstruct(codegen) + self.getattr("stop").reconstruct(codegen) + self.getattr("step").reconstruct(codegen) + codegen.gen_build_slice(3) + else: + super()._reconstruct(codegen) + + def setattr(self, key, val): + raise BreakGraphError("Don't support SliceVariable setattr") + + def delattr(self, key): + raise BreakGraphError("Don't support SliceVariable delattr") + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if isinstance(value, slice): + return SliceVariable(value, graph, tracker) + return None + + +class ModuleVariable(VariableBase): + """ + ModuleVariable is a subclass of VariableBase used to wrap a Variable of the module type. + + Args: + func: The module to be wrapped. + graph: The FunctionGraph object that this variable is associated with. + tracker: The Tracker object that tracks the information of this variable. + """ + + def __init__(self, func, graph, tracker): + super().__init__(graph, tracker) + self.value = func + + def get_py_value(self, allow_tensor=False): + return self.value + + @property + def main_info(self) -> dict[str, Any]: + return {"value": self.value} + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if isinstance(value, types.ModuleType): + return ModuleVariable(value, graph, tracker) + return None + + # Happened in a inline import statement. + make_stringify_guard = object_equal_stringify_guard + + +class DygraphTracerVariable(VariableBase): + # TODO(SigureMo): Remove this trick after we add CompareTracker + def __init__(self, value, graph, tracker): + super().__init__(graph, tracker) + self.value = value + + def get_py_value(self, allow_tensor=False): + return self.value + + @check_guard + def make_stringify_guard(self) -> list[StringifyExpression]: + return [] + + @property + def main_info(self) -> dict[str, Any]: + return { + "is_none": self.value is None, + } + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if isinstance(value, paddle.base.dygraph.tracer.Tracer): + return DygraphTracerVariable(value, graph, tracker) + return None + + +class NumpyVariable(VariableBase): + """ + NumpyVariable is a subclass of VariableBase used to wrap a Variable of the numpy type. + + Args: + value: The numpy value to be wrapped. + graph: The FunctionGraph object that this variable is associated with. + tracker: The Tracker object that tracks the information of this variable. + """ + + def __init__(self, value, graph, tracker): + super().__init__(graph, tracker) + self.value = value + + @property + def main_info(self) -> dict[str, Any]: + return {"value": self.value} + + def get_py_value(self, allow_tensor=False) -> Any: + return self.value + + @check_guard + def make_stringify_guard(self) -> list[StringifyExpression]: + if isinstance(self.get_py_value(), np.number): + frame_value_tracer = self.tracker.trace_value_from_frame() + + def format_dtype(dtype: np.dtype): + return f"np.{str(dtype)}" + + def format_number(number: np.number): + return f"{format_dtype(number.dtype)}({str(number.item())})" + + return [ + StringifyExpression( + f"{{}} == {format_number(self.get_py_value())}", + [frame_value_tracer], + union_free_vars(frame_value_tracer.free_vars, {"np": np}), + ), + StringifyExpression( + f"{{}}.dtype == {format_dtype(self.get_py_value().dtype)}", + [frame_value_tracer], + union_free_vars(frame_value_tracer.free_vars, {"np": np}), + ), + ] + else: + return object_equal_stringify_guard(self) + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if isinstance(value, (np.ndarray, np.number)): + return NumpyVariable(value, graph, tracker) + return None + + +class NullVariable(VariableBase): + """ + NullVariable is a subclass of VariableBase used to represent a placeholder variable that has no value or reference associated with it. + """ + + def __init__(self): + # TODO: graph should be not None + super().__init__(None, DanglingTracker()) + + def reconstruct(self, codegen: PyCodeGen): + codegen.gen_push_null() + + +class CellVariable(VariableBase): + def __init__(self, value=None): + # TODO: graph should be not None + super().__init__( + None, DanglingTracker() + ) # should reconstruct cell variable + assert isinstance(value, (VariableBase, type(None))) + self.set_value(value) + + def reconstruct( + self, + codegen: PyCodeGen, + *, + use_tracker: bool = True, + add_to_global_guarded_vars: bool = True, + ): + raise FallbackError("Break graph in closure is not support.") + + def cell_content(self): + return self.value + + def set_value(self, value): + self.value = value + + def empty(self): + return self.value is None + + +class GlobalVariable(VariableBase): + def __init__( + self, + val_dict, + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(graph, tracker) + self.proxy = self.graph.side_effects.get_proxy( + MutableDictLikeData, val_dict, self.proxy_getter + ) + + def proxy_getter(self, proxy: MutableDictLikeData, key: Any): + if key not in proxy.original_data: + return MutableDictLikeData.Empty() + return VariableFactory.from_value( + proxy.original_data[key], + self.graph, + tracker=GlobalTracker(key), + ) + + def get_value(self): + return dict(self.proxy.get_all().items()) + + def keys(self): + return self.proxy.get_all().keys() + + def get(self, key): + if isinstance(key, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} to get value." + ) + return self.proxy.get(key) + + def set(self, key, value): + if isinstance(key, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} as key." + ) + if not isinstance(value, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {value} to set value." + ) + self.proxy.set(key, value) + self.graph.side_effects.record_proxy_variable(self) + + def delete(self, key): + self.proxy.delete(key) + self.graph.side_effects.record_proxy_variable(self) + + +class FunctionGlobalVariable(GlobalVariable): + def __init__( + self, + fn: FunctionVariable, + val_dict: dict[str, Any], + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(val_dict, graph, tracker) + self.fn = fn + + def proxy_getter(self, proxy: MutableDictLikeData, key: Any): + from ..opcode_inline_executor import FunctionGlobalTracker + + if key not in proxy.original_data: + return MutableDictLikeData.Empty() + return VariableFactory.from_value( + proxy.original_data[key], + self.graph, + tracker=FunctionGlobalTracker(self.fn, key), + ) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py b/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py new file mode 100644 index 0000000000000..819580710beba --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py @@ -0,0 +1,759 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import operator +import types +from functools import reduce +from typing import TYPE_CHECKING, Any, Callable + +import paddle + +from .... import psdb +from ....profiler import EventGuard +from ....utils import ( + is_break_graph_api, + is_break_graph_tensor_methods, + is_builtin_fn, + is_paddle_api, + magic_method_builtin_dispatch, +) +from ....utils.exceptions import BreakGraphError, FallbackError, SotErrorBase +from ..dispatcher import Dispatcher +from ..guard import ( + StringifyExpression, + check_guard, + object_equal_stringify_guard, + union_free_vars, +) +from ..tracker import ( + ConstTracker, + CreateLayerTracker, + DanglingTracker, + DummyTracker, + GetAttrTracker, + GetItemTracker, + GetIterTracker, + Tracker, +) +from .base import VariableBase, VariableFactory +from .basic import ConstantVariable, PrintStmtVariable, SliceVariable + +if TYPE_CHECKING: + from ..function_graph import FunctionGraph + + +PD_ALL_CONTAINERS = (paddle.nn.Sequential, paddle.nn.LayerList) +PD_SEQ_CONTAINERS = (paddle.nn.Sequential, paddle.nn.LayerList) + + +class CallableVariable(VariableBase): + """ + CallableVariable is a subclass of VariableBase used to wrap a callable variable. + + Args: + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__(self, graph: FunctionGraph, tracker: Tracker): + super().__init__(graph, tracker) + + def __call__(self, /, *args, **kwargs) -> VariableBase: + """Why we need '/' to make self positional only? + + If kwargs have {'self': xxx}, this function call raise a error. + See: test_str_format.py for details. + """ + with EventGuard(f"call_function: {self.__class__.__name__}"): + return self.call_function(*args, **kwargs) + + def call_function(self, /, *args, **kwargs): + raise NotImplementedError("call_function is not implemented.") + + +class FunctionVariable(CallableVariable): + """ + FunctionVariable is a subclass of CallableVariable used to wrap a function variable. + + Args: + fn (Callable[..., Any]): The function to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker + ): + super().__init__(graph, tracker) + self.value = fn + + def get_py_value(self, allow_tensor=False): + return self.value + + def get_code(self) -> types.CodeType: + return self.value.__code__ + + def bind(self, instance: VariableBase, name: str): + method_var = MethodVariable( + instance, + self, + graph=self.graph, + tracker=GetAttrTracker(instance, name), + ) + class_var = VariableFactory.from_value( + instance.get_py_type(), + graph=self.graph, + tracker=GetAttrTracker(instance, "__class__"), + ) + assert class_var is not None + self.tracker = GetAttrTracker(class_var, name) + return method_var + + make_stringify_guard = object_equal_stringify_guard + + +class UserDefinedFunctionVariable(FunctionVariable): + """ + UserDefinedFunctionVariable is a subclass of FunctionVariable used to wrap a user-defined function. + + Args: + fn (Callable[..., Any]): The user-defined function to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker + ): + super().__init__(fn, graph, tracker) + + def handle_psdb_function(self, /, *args, **kwargs): + # special function for inner debug. + if self.value is psdb.assert_true: + return ConstantVariable.wrap_literal( + self.value(args[0].value), self.graph + ) + elif self.value is psdb.print: + sot_prefix = ConstantVariable.wrap_literal("[SOT]", self.graph) + self.graph.add_print_variables( + PrintStmtVariable(([sot_prefix, *args], kwargs), self.graph) + ) + return ConstantVariable.wrap_literal(None, self.graph) + elif self.value is psdb.breakpoint: + # do nothing. just return None. + from ...breakpoint import BM + + BM.locate(BM.executors[-1]) + BM.add(BM.cur_exe._code.co_filename, BM.cur_exe._current_line) + return ConstantVariable.wrap_literal(None, self.graph) + elif self.value is psdb.breakgraph: + raise BreakGraphError("breakgraph by psdb.breakgraph") + elif self.value is psdb.fallback: + raise FallbackError("fallback by psdb.fallback") + elif self.value is psdb.in_sot: + return ConstantVariable.wrap_literal(True, self.graph) + return None + + def call_function(self, /, *args, **kwargs) -> VariableBase: + from ..opcode_inline_executor import OpcodeInlineExecutor + + result = self.handle_psdb_function(*args, **kwargs) + if result is not None: + return result + + checkpoint = self.graph.save_memo() + try: + inline_executor = OpcodeInlineExecutor(self, *args, **kwargs) + with EventGuard( + f"Inline Call: {inline_executor._code.co_name.replace('<', '(').replace('>', ')')}, file {inline_executor._code.co_filename}, line {int(inline_executor._code.co_firstlineno)}" + ): + output = inline_executor.inline_call() + except SotErrorBase as e: + self.graph.restore_memo(checkpoint) + raise BreakGraphError( + f"({e}) raised while inline call {self.value.__code__}." + ) + return output + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if isinstance(value, (types.FunctionType)): + return UserDefinedFunctionVariable(value, graph, tracker) + if isinstance( + value, paddle.jit.dy2static.program_translator.StaticFunction + ): + return UserDefinedFunctionVariable( + value.dygraph_function, graph, tracker + ) + return None + + @property + def main_info(self) -> dict[str, Any]: + return { + "name": self.value.__name__, + } + + +class PaddleApiVariable(FunctionVariable): + """ + PaddleApiVariable is a subclass of FunctionVariable used to wrap a paddlepaddle API function. + + Args: + fn (Callable[..., Any]): The paddlepaddle API to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker + ): + super().__init__(fn, graph, tracker) + + def call_function(self, /, *args, **kwargs): + if is_break_graph_api(self.value): + raise BreakGraphError( + f"breakgraph by unsupport function: {self.value.__name__}" + ) + return self.graph.call_paddle_api(self.value, *args, **kwargs) + + @VariableFactory.register_from_value( + successor="UserDefinedFunctionVariable" + ) + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if callable(value) and is_paddle_api(value): + return PaddleApiVariable(value, graph, tracker) + return None + + @property + def main_info(self) -> dict[str, Any]: + return { + "name": self.value.__name__, + } + + make_stringify_guard = object_equal_stringify_guard + + +class TensorFunctionVariable(FunctionVariable): + """ + TensorFunctionVariable is a subclass of FunctionVariable used to wrap a method of a tensor. + + Args: + method_name (str): The name of the tensor method to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, method_name: str, graph: FunctionGraph, tracker: Tracker + ): + fn = getattr(paddle.static.Variable, method_name) + super().__init__(fn, graph, tracker) + self.method_name = method_name + + def call_function(self, /, *args, **kwargs): + if is_break_graph_tensor_methods(self.method_name): + raise BreakGraphError() + return self.graph.call_tensor_method(self.method_name, *args, **kwargs) + + def bind(self, instance: VariableBase, name: str): + method_var = MethodVariable( + instance, + self, + graph=self.graph, + tracker=GetAttrTracker(instance, name), + ) + class_var = VariableFactory.from_value( + instance.get_py_type(), + graph=self.graph, + tracker=ConstTracker(instance.get_py_type()), + ) + assert class_var is not None + self.tracker = GetAttrTracker(class_var, name) + return method_var + + @property + def main_info(self) -> dict[str, Any]: + return { + "name": self.value.__name__, + } + + +class MethodVariable(CallableVariable): + """ + MethodVariable is a subclass of CallableVariable used to wrap a method variable. + + Args: + bound_instance (VariableBase): The instance of the method. + fn (VariableBase): The method to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + method_name (str): The name of the method to be wrapped. + """ + + def __init__( + self, + bound_instance: VariableBase, + fn: VariableBase, + graph: FunctionGraph, + tracker: Tracker, + *, + method_name: str | None = None, + ): + super().__init__(graph, tracker) + self.bound_instance = bound_instance + self.fn = fn + self.method_name = method_name + + def get_py_value(self, allow_tensor=False): + return self.fn.get_py_value().__get__( + self.bound_instance.get_py_value(allow_tensor), + self.bound_instance.get_py_value(allow_tensor).__class__, + ) + + def _reconstruct(self, pycode_gen): + assert self.method_name is not None + self.tensor.reconstruct(pycode_gen) + pycode_gen.gen_load_attr(self.method_name) + + def call_function(self, /, *args, **kwargs): + return self.fn(*(self.bound_instance, *args), **kwargs) + + @staticmethod + def wrap_method( + value: types.MethodType, + *, + graph: FunctionGraph, + tracker: Tracker, + instance: VariableBase | None = None, + fn: VariableBase | None = None, + method_name: str | None = None, + ): + # NOTE(SigureMo): Since the method_self need method_var as the obj + # of the tracker, we need to temporarily set the tracker of method_self + # to DummyTracker, and set it to GetAttrTracker after method_var is created. + instance_var = ( + VariableFactory.from_value(value.__self__, graph, DanglingTracker()) + if instance is None + else instance + ) + + fn_var = ( + VariableFactory.from_value(value.__func__, graph, DanglingTracker()) + if fn is None + else fn + ) + + method_var = MethodVariable( + instance_var, + fn_var, + method_name=method_name, + graph=graph, + tracker=tracker, + ) + if instance is None: + instance_var.tracker = GetAttrTracker(method_var, "__self__") + if fn is None: + fn_var.tracker = GetAttrTracker(method_var, "__func__") + return method_var + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if inspect.ismethod(value): + return MethodVariable.wrap_method( + value=value, tracker=tracker, graph=graph + ) + return None + + @property + def main_info(self) -> dict[str, Any]: + return { + "method": self.method_name, + } + + +class LayerVariable(CallableVariable): + """ + LayerVariable is a subclass of CallableVariable used to wrap a layer. + + Args: + layer (paddle.nn.Layer): The layer to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker + ): + super().__init__(graph, tracker) + self.value = layer + + def get_py_value(self, allow_tensor=False): + return self.value + + def call_function(self, /, *args, **kwargs): + fn_var = UserDefinedFunctionVariable( + self.value.__class__.__call__, + self.graph, + GetAttrTracker(self, "__call__"), + ) + + return fn_var(*(self, *args), **kwargs) + + @check_guard + def make_stringify_guard(self) -> list[StringifyExpression]: + frame_value_tracer = self.tracker.trace_value_from_frame() + return [ + StringifyExpression( + f"id({{}}) == {id(self.get_py_value())}", + [frame_value_tracer], + union_free_vars(frame_value_tracer.free_vars), + ), + StringifyExpression( + f"{{}}.training == {self.get_py_value().training}", + [frame_value_tracer], + union_free_vars(frame_value_tracer.free_vars), + ), + ] + + +class ContainerLayerVariable(LayerVariable): + def __init__( + self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker + ): + super().__init__(layer, graph, tracker) + + def __len__(self): + return len(self.value) + + def len(self): + return ConstantVariable(len(self), self.graph, DummyTracker([self])) + + def getitem(self, key): + if isinstance(self.value, PD_SEQ_CONTAINERS) and isinstance( + key, SliceVariable + ): + try: + slice_py_value = key.get_py_value() + new_layer_list = self.value[slice_py_value] + self.graph.add_global_guarded_variable(key) + return VariableFactory.from_value( + new_layer_list, + self.graph, + GetItemTracker(self, slice_py_value), + ) + except Exception as e: + raise BreakGraphError( + f"call {self.value.__class__.__name__}.__getitem__ with slice as key, and slice with py value failed: {e}." + ) + + else: + return super().getitem(key) + + def get_iter(self): + if isinstance(self.value, PD_SEQ_CONTAINERS): + from .iter import SequenceIterVariable + + return SequenceIterVariable(self, self.graph, GetIterTracker(self)) + else: + return super().get_iter() + + def make_stringify_guard(self) -> list[StringifyExpression]: + if isinstance(self.value, PD_SEQ_CONTAINERS): + frame_value_tracer = self.tracker.trace_value_from_frame() + + len_guard = StringifyExpression( + f"len({{}}) == {len(self.value)}", + [frame_value_tracer], + frame_value_tracer.free_vars, + ) + + guards = [len_guard] + for idx, layer in enumerate(self.value): + layer_variable = VariableFactory.from_value( + layer, self.graph, GetItemTracker(self, idx) + ) + guards.extend(layer_variable.make_stringify_guard()) + + return guards + else: + return super().make_stringify_guard() + + @property + def main_info(self) -> dict[str, Any]: + return { + "name": self.value.__class__.__name__, + } + + @VariableFactory.register_from_value(successor="PaddleLayerVariable") + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if isinstance(value, PD_ALL_CONTAINERS): + return ContainerLayerVariable(value, graph, tracker) + return None + + +class PaddleLayerVariable(LayerVariable): + """ + PaddleLayerVariable is a subclass of LayerVariable used to wrap a paddlepaddle layer. + + Args: + layer (paddle.nn.Layer): The paddle built-in layer to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker + ): + super().__init__(layer, graph, tracker) + + def call_function(self, /, *args, **kwargs): + self.graph.add_global_guarded_variable(self) + return self.graph.call_layer(self, *args, **kwargs) + + def make_stringify_guard(self) -> list[StringifyExpression]: + if isinstance(self.tracker, CreateLayerTracker): + return reduce( + operator.add, + [var.make_stringify_guard() for var in self.tracker.inputs], + ) + else: + return super().make_stringify_guard() + + @property + def main_info(self) -> dict[str, Any]: + return { + "name": self.value.__class__.__name__, + } + + @VariableFactory.register_from_value(successor="UserDefinedLayerVariable") + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + # TODO(SigureMo): Add a more common way to check if a value is a paddle builtin layer. + if isinstance(value, paddle.nn.Layer): + # If there is a user-defined behavior, such as a container class layer + # or a hook on the layer, it needs to be converted to UserDefinedLayerVariable, + # otherwise converted to PaddleLayerVariable + if ( + hasattr(value, "_forward_pre_hooks") + and value._forward_pre_hooks + or hasattr(value, "_forward_post_hooks") + and value._forward_post_hooks + ): + return None + if value.__module__.startswith("paddle.nn."): + return PaddleLayerVariable(value, graph, tracker) + return None + + +class UserDefinedLayerVariable(LayerVariable): + """ + UserDefinedLayerVariable is a subclass of LayerVariable used to wrap a user-defined layer. + + Args: + layer (paddle.nn.Layer): The user-defined layer to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker + ): + super().__init__(layer, graph, tracker) + + @property + def main_info(self) -> dict[str, Any]: + return { + "name": self.value.__class__.__name__, + } + + @VariableFactory.register_from_value(successor="PaddleApiVariable") + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if isinstance(value, paddle.nn.Layer): + return UserDefinedLayerVariable(value, graph, tracker) + return None + + +class BuiltinVariable(FunctionVariable): + """ + BuiltinVariable is a subclass of FunctionVariable used to wrap a built-in function. + Args: + fn (Callable[..., Any]): The built-in function to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker + ): + super().__init__(fn, graph, tracker) + self.value = fn + + def call_function(self, /, *args, **kwargs): + # Lookup the handler from dispatcher + handler = Dispatcher.dispatch(self.value, *args, **kwargs) + if handler is not None: + return handler(*args, **kwargs) + + # Try to inline call the magic function + magic_methods = magic_method_builtin_dispatch(self.value) + for magic_method in magic_methods: + sorted_args = args + if magic_method.is_reverse: + sorted_args = sorted_args[::-1] + arg_type = sorted_args[0].get_py_type() + if hasattr(arg_type, magic_method.name): + class_fn = getattr(arg_type, magic_method.name) + class_var = VariableFactory.from_value( + arg_type, + self.graph, + GetAttrTracker(args[0], "__class__"), + ) + assert isinstance(class_var, VariableBase) + fn_var = VariableFactory.from_value( + class_fn, + self.graph, + GetAttrTracker(class_var, class_fn.__name__), + ) + assert isinstance(fn_var, VariableBase) + return fn_var(*args) + + # Break graph if neither of the above conditions is met + arg_types = ", ".join([type(arg).__name__ for arg in args]) + fn_name = ( + self.value.__name__ + if hasattr(self.value, '__name__') + else self.value + ) + raise BreakGraphError( + f"Not support builtin function: {fn_name} with args: Args({arg_types})" + ) + + @VariableFactory.register_from_value(successor="ClassVariable") + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if is_builtin_fn(value): + return BuiltinVariable(value, graph, tracker) + return None + + @property + def main_info(self) -> dict[str, Any]: + return { + "name": self.value.__name__, + } + + +class UserDefinedGeneratorVariable(FunctionVariable): + """ + UserDefinedGeneratorVariable is a subclass of FunctionVariable used to wrap a user-defined generator. + Args: + fn (Callable[..., Any]): The user-defined generator to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker + ): + super().__init__(fn, graph, tracker) + + def call_function(self, /, *args, **kwargs): + iter_ = self.value(*args, **kwargs) + var = VariableFactory.from_value( + iter_, self.graph, DummyTracker([self]) + ) + return var + + @property + def main_info(self) -> dict[str, Any]: + return {"name": self.value.__name__} + + @VariableFactory.register_from_value( + successor="UserDefinedFunctionVariable" + ) + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if inspect.isgeneratorfunction(value): + return UserDefinedGeneratorVariable(value, graph, tracker) + return None + + +class ClassVariable(CallableVariable): + def __init__(self, class_: type, graph: FunctionGraph, tracker: Tracker): + super().__init__(graph, tracker) + self.value = class_ + + def get_py_value(self, allow_tensor=False): + return self.value + + def call_function(self, /, *args, **kwargs): + new_object = self.value.__new__(self.value) + + # do not have init function + if self.value.__init__ is object.__init__: + return VariableFactory.from_value( + new_object, self.graph, DummyTracker([self]) + ) + + if not hasattr(self.value.__init__, "__code__"): + fn_var = BuiltinVariable( + self.value.__init__, + self.graph, + GetAttrTracker(self, "__init__"), + ) + else: + fn_var = UserDefinedFunctionVariable( + self.value.__init__, + self.graph, + GetAttrTracker(self, "__init__"), + ) + + # need classify variable type here? + new_object_variable = VariableFactory.from_value( + new_object, + self.graph, + DummyTracker([self] + list(args) + list(kwargs.values())), + ) + fn_var(new_object_variable, *args, **kwargs) + return new_object_variable + + make_stringify_guard = object_equal_stringify_guard + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if inspect.isclass(value): + return ClassVariable(value, graph, tracker) + return None + + +class PaddleLayerClassVariable(ClassVariable): + def __init__(self, class_: type, graph: FunctionGraph, tracker: Tracker): + super().__init__(class_, graph, tracker) + + def call_function(self, /, *args, **kwargs): + input_py_args = [var.get_py_value() for var in args] + input_py_kwargs = {k: v.get_py_value() for k, v in kwargs.items()} + new_layer = self.value(*input_py_args, **input_py_kwargs) + return PaddleLayerVariable( + new_layer, self.graph, CreateLayerTracker(self, args, kwargs) + ) + + @VariableFactory.register_from_value(successor="ClassVariable") + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if ( + inspect.isclass(value) + and issubclass(value, paddle.nn.Layer) + and value.__module__.startswith("paddle.nn.") + ): + return PaddleLayerClassVariable(value, graph, tracker) + return None diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/container.py b/python/paddle/jit/sot/opcode_translator/executor/variables/container.py new file mode 100644 index 0000000000000..b1c318e9187bd --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/container.py @@ -0,0 +1,1011 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import operator +from collections import OrderedDict +from functools import reduce +from typing import TYPE_CHECKING, Any + +from ....utils.exceptions import FallbackError, InnerError +from ..dispatcher import Dispatcher +from ..guard import StringifyExpression, check_guard +from ..mutable_data import MutableDictLikeData, MutableListLikeData +from ..pycode_generator import PyCodeGen +from ..tracker import ( + ConstTracker, + DanglingTracker, + DummyTracker, + GetItemTracker, + GetIterTracker, + Tracker, +) +from .base import ConstTypes, VariableBase, VariableFactory +from .basic import ConstantVariable +from .callable import BuiltinVariable, UserDefinedFunctionVariable + +if TYPE_CHECKING: + from ..function_graph import FunctionGraph + + +class ContainerVariable(VariableBase): + """ + ContainerVariable is a wrapper for container types, such as range, list, tuple, dict. + """ + + @property + def init_value(self): + return self.value + + def get_items(self) -> list[VariableBase]: + raise FallbackError('ContainerVariable.get_items do not implement') + + def get_wrapped_items(self): + raise FallbackError( + "ContainerVariable.get_wrapped_items do not implement" + ) + + def __len__(self): + raise FallbackError('ContainerVariable.__len__ do not implement') + + def len(self): + return ConstantVariable(len(self), self.graph, DummyTracker([self])) + + def __bool__(self) -> bool: + return len(self) > 0 + + def bool(self): + return ConstantVariable(bool(self), self.graph, DummyTracker([self])) + + @check_guard + def make_stringify_guard(self) -> list[StringifyExpression]: + frame_value_tracer = self.tracker.trace_value_from_frame() + + type_guard = StringifyExpression( + f"isinstance({{}}, {self.get_py_type().__name__})", + [frame_value_tracer], + frame_value_tracer.free_vars, + ) + len_guard = StringifyExpression( + f"len({{}}) == {len(self.init_value)}", + [frame_value_tracer], + frame_value_tracer.free_vars, + ) + if isinstance(self, (ListVariable, TupleVariable)): + guard_variables = self.proxy.reproduce(0) + + elif isinstance(self, DictVariable): + guard_variables = filter( + lambda var: not isinstance(var, MutableDictLikeData.Empty), + self.proxy.reproduce(0).values(), + ) + else: + raise InnerError(f"Unsupported container type: {type(self)}") + return reduce( + operator.add, + [[type_guard, len_guard]] + + [item.make_stringify_guard() for item in guard_variables], + ) + + +class ListVariable(ContainerVariable): + """ + ListVariable is a wrapper for list and contains common APIs for list methods + + Args: + val_list(List[VariableBase]): the list to wrap + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, + val_list: list[VariableBase], + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(graph, tracker) + + # everything in stack is VariableBase, so just accept the input list is ok + self.proxy = self.graph.side_effects.get_proxy( + MutableListLikeData, val_list, self.proxy_getter + ) + self.value = val_list + + def proxy_getter(self, proxy: MutableListLikeData, key: Any): + if key < 0 or key >= len(proxy.original_data): + return MutableListLikeData.Empty() + return VariableFactory.from_value( + proxy.original_data[key], + self.graph, + tracker=GetItemTracker(self, key, changed=proxy.has_changed), + ) + + def get_py_value(self, allow_tensor=False): + items = self.proxy.get_all() + return [item.get_py_value(allow_tensor) for item in items] + + def get_py_type(self): + return list + + def _reconstruct(self, codegen: PyCodeGen): + size = len(self) + for idx in range(size): + Dispatcher.call(operator.getitem, self, idx).reconstruct(codegen) + codegen.gen_build_list(size) + + def get_items(self): + size = len(self) + return [ + Dispatcher.call(operator.getitem, self, idx) for idx in range(size) + ] + + def get_wrapped_items(self): + return self.get_items() + + def get_iter(self): + from .iter import SequenceIterVariable + + return SequenceIterVariable(self, self.graph, GetIterTracker(self)) + + @property + def main_info(self) -> dict[str, Any]: + return { + "len": len(self), + } + + def __len__(self): + return self.proxy.length + + def getitem(self, key): + self.graph.add_global_guarded_variable(key) + key = key.get_py_value() + if isinstance(key, int): + res = self.proxy.get(key) + if self.proxy.is_empty(res): + raise InnerError(f"List {self} out of range (index={key})") + return res + elif isinstance(key, slice): + items = self.proxy.get_all() + return VariableFactory.from_value( + items[key], + self.graph, + tracker=GetItemTracker( + self, key, changed=self.proxy.has_changed + ), + ) + else: + raise InnerError( + f"Unsupported key type {key.__class__.__name__} for ListVariable" + ) + + def setitem(self, key, value): + if not isinstance(value, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: received {value} to set value." + ) + if isinstance(key, int): + self.proxy.set(key, value) + elif isinstance(key, slice) and isinstance( + value, (ListVariable, TupleVariable) + ): + start, end, step = key.indices(self.proxy.length) + indices = list(range(start, end, step)) + if step == 1: + # replace a continuous range + for i, idx in enumerate(indices): + self.proxy.delete(idx - i) + for i, item in enumerate(value.get_wrapped_items()): + self.proxy.insert(start + i, item) + else: + # replace some elements + if len(indices) != len(value): + raise InnerError( + f"Attempt to replace {len(indices)} items with {len(value)}" + ) + for i, idx in enumerate(indices): + self.proxy.set(idx, value[i]) + else: + raise InnerError( + f"Unsupported key type {key.__class__.__name__} and value type {value.__class__.__name__} for ListVariable" + ) + + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def __delitem__(self, key): + return self.delitem(key) + + def delitem(self, key): + if isinstance(key, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: received {key} as key to delete." + ) + self.proxy.delete(key) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def insert(self, index: int, value: VariableBase): + self.proxy.insert(index, value) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def append(self, value: VariableBase): + self.insert(self.proxy.length, value) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def extend(self, data): + for item in data.proxy.get_all(): + self.append(item) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def concat(self, list_): + assert isinstance(list_, ListVariable) + return ListVariable( + self.proxy.get_all() + list_.proxy.get_all(), + self.graph, + DummyTracker([self, list_]), + ) + + def repeat(self, length): + assert isinstance(length, ConstantVariable) + return ListVariable( + self.proxy.get_all() * length.value, + self.graph, + DummyTracker([self, length]), + ) + + def pop(self, index: ConstantVariable | None = None): + if index is None: + index = ConstantVariable.wrap_literal(-1, self.graph) + res = self.proxy.get(index.get_py_value()) + self.proxy.delete(index.get_py_value()) + self.graph.side_effects.record_proxy_variable(self) + return res + + def copy(self): + return ListVariable( + self.proxy.get_all(), + self.graph, + DummyTracker([self]), + ) + + def clear(self): + for idx in range(self.proxy.length): + self.delitem(0) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def remove(self, value): + for idx in range(self.proxy.length): + if self[idx].get_py_value(allow_tensor=True) == value.get_py_value( + allow_tensor=True + ): + self.delitem(idx) + break + else: + raise InnerError(f"List {self} does not contain {value}") + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def sort(self, key=None, reverse=None): + if ( + key is None + or isinstance(key, ConstantVariable) + and key.get_py_value() is None + ): + key = UserDefinedFunctionVariable( + lambda x: x, self.graph, DanglingTracker() + ) + assert key is not None + if reverse is None: + reverse = ConstantVariable.wrap_literal(False, self.graph) + + permutation = list(range(self.proxy.length)) + permutation.sort( + key=lambda x: key.get_py_value()( + Dispatcher.call(operator.getitem, self, x).value + ), + reverse=reverse.get_py_value(), + ) + self.proxy.permutate(permutation) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def reverse(self): + permutation = list(range(self.proxy.length)) + permutation.reverse() + self.proxy.permutate(permutation) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def count(self, value: VariableBase): + count: int = 0 + getitem = BuiltinVariable( + operator.getitem, self.graph, DanglingTracker() + ) + for index in range(len(self)): + index_value = getitem(self, index) + if index_value.id == value.id: + count += 1 + continue + eq = BuiltinVariable(operator.eq, self.graph, DanglingTracker())( + index_value, value + ) + eq_bool = BuiltinVariable(bool, self.graph, DanglingTracker())(eq) + assert isinstance( + eq_bool, ConstantVariable + ), "bool should return ConstantVariable" + if eq.get_py_value() is True: + count += 1 + continue + + return ConstantVariable(count, self.graph, DummyTracker([self, value])) + + def index(self, value: VariableBase): + res = 0 + getitem = BuiltinVariable( + operator.getitem, self.graph, DanglingTracker() + ) + for index in range(len(self)): + index_value = getitem(self, index) + if index_value.id == value.id: + return ConstantVariable( + res, self.graph, DummyTracker([self, value]) + ) + eq = BuiltinVariable(operator.eq, self.graph, DanglingTracker())( + index_value, value + ) + eq_bool = BuiltinVariable(bool, self.graph, DanglingTracker())(eq) + assert isinstance( + eq_bool, ConstantVariable + ), "bool should return ConstantVariable" + if eq.get_py_value() is True: + return ConstantVariable( + res, self.graph, DummyTracker([self, value]) + ) + res += 1 + + return ConstantVariable(-1, self.graph, DummyTracker([self, value])) + + def max(self): + if len(self) == 0: + raise ValueError("max() arg is an empty sequence") + res = self[0] + getitem = BuiltinVariable( + operator.getitem, self.graph, DanglingTracker() + ) + for index in range(len(self)): + index_value = getitem(self, index) + gt = BuiltinVariable(operator.gt, self.graph, DanglingTracker())( + index_value, res + ) + if gt.get_py_value() is True: + res = index_value + return res + + def min(self): + if len(self) == 0: + raise ValueError("max() arg is an empty sequence") + res = self[0] + getitem = BuiltinVariable( + operator.getitem, self.graph, DanglingTracker() + ) + for index in range(len(self)): + index_value = getitem(self, index) + lt = BuiltinVariable(operator.lt, self.graph, DanglingTracker())( + index_value, res + ) + if lt.get_py_value() is True: + res = index_value + return res + + def getattr(self, name: str, default=None): + from .callable import BuiltinVariable + + if default is not None: + raise FallbackError( + "default argument for getattr is not implemented" + ) + + method_name_to_builtin_fn = { + "insert": list.insert, + "append": list.append, + "extend": list.extend, + "pop": list.pop, + "copy": list.copy, + "clear": list.clear, + "remove": list.remove, + "sort": list.sort, + "reverse": list.reverse, + "count": list.count, + "index": list.index, + } + + if name in method_name_to_builtin_fn: + builtin_fn = method_name_to_builtin_fn[name] + return BuiltinVariable( + builtin_fn, self.graph, DanglingTracker() + ).bind(self, name) + else: + raise FallbackError(f"attribute {name} for list is not implemented") + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + # Note(SigureMo): Why not use isinstance? + # Because user may define a class that inherit from list. + # We should convert it to ObjectVariable instead of ListVariable. + if type(value) is list: # noqa: E721 + return ListVariable(value, graph=graph, tracker=tracker) + return None + + +class TupleVariable(ContainerVariable): + """ + TupleVariable is a wrapper for tuple and contains common APIs for tuple methods. + + Args: + val_tuple(tuple[VariableBase, ...]): the tuple to wrap + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, + val_tuple: tuple[VariableBase, ...], + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(graph, tracker) + + self.proxy = self.graph.side_effects.get_proxy( + MutableListLikeData, list(val_tuple), self.proxy_getter + ) + self.value = val_tuple + + def getattr(self, name: str, default=None): + from .callable import BuiltinVariable + + if default is not None: + raise FallbackError( + "default argument for getattr is not implemented" + ) + + method_name_to_builtin_fn = { + "count": tuple.count, + "index": tuple.index, + } + if name in method_name_to_builtin_fn: + builtin_fn = method_name_to_builtin_fn[name] + return BuiltinVariable( + builtin_fn, self.graph, DanglingTracker() + ).bind(self, name) + else: + raise FallbackError( + f"attribute {name} for tuple is not implemented" + ) + + def proxy_getter(self, proxy: MutableListLikeData, key: Any): + if key < 0 or key >= len(proxy.original_data): + return MutableListLikeData.Empty() + return VariableFactory.from_value( + proxy.original_data[key], + self.graph, + tracker=GetItemTracker(self, key, changed=False), + ) + + def get_py_value(self, allow_tensor=False): + return tuple( + self[idx].get_py_value(allow_tensor) for idx in range(len(self)) + ) + + def get_py_type(self): + return tuple + + def _reconstruct(self, codegen: PyCodeGen): + size = len(self) + for idx in range(size): + Dispatcher.call(operator.getitem, self, idx).reconstruct(codegen) + codegen.gen_build_tuple(size) + + def get_items(self): + size = len(self) + return [ + Dispatcher.call(operator.getitem, self, idx) for idx in range(size) + ] + + def get_wrapped_items(self): + return tuple(self.get_items()) + + def get_iter(self): + from .iter import SequenceIterVariable + + return SequenceIterVariable(self, self.graph, GetIterTracker(self)) + + @property + def main_info(self) -> dict[str, Any]: + return { + "len": len(self), + } + + def __len__(self): + return self.proxy.length + + def getitem(self, key): + self.graph.add_global_guarded_variable(key) + key = key.get_py_value() + if isinstance(key, int): + res = self.proxy.get(key) + if self.proxy.is_empty(res): + raise InnerError(f"List {self} out of range (index={key})") + return res + elif isinstance(key, slice): + return TupleVariable( + tuple(self.proxy.get_all())[key], + self.graph, + tracker=GetItemTracker(self, key, changed=False), + ) + else: + raise InnerError( + f"Unsupported key type {key.__class__.__name__} for TupleVariable" + ) + + def setitem(self, key, value): + raise InnerError( + f"[{self.__class__.__name__}]: setitem is not allowed." + ) + + def __delitem__(self, key): + return self.delitem(key) + + def delitem(self, key): + raise InnerError( + f"[{self.__class__.__name__}]: delitem is not allowed." + ) + + def concat(self, tuple_): + assert isinstance(tuple_, TupleVariable) + new_tuple_variable = TupleVariable( + tuple(self.proxy.get_all() + tuple_.proxy.get_all()), + self.graph, + DummyTracker([self, tuple_]), + ) + return new_tuple_variable + + def repeat(self, length): + assert isinstance(length, ConstantVariable) + new_tuple_variable = TupleVariable( + tuple(self.proxy.get_all()) * length.value, + self.graph, + DummyTracker([self, length]), + ) + return new_tuple_variable + + def count(self, value: VariableBase): + count: int = 0 + getitem = BuiltinVariable( + operator.getitem, self.graph, DanglingTracker() + ) + for index in range(len(self)): + index_value = getitem(self, index) + if index_value.id == value.id: + count += 1 + continue + eq = BuiltinVariable(operator.eq, self.graph, DanglingTracker())( + index_value, value + ) + eq_bool = BuiltinVariable(bool, self.graph, DanglingTracker())(eq) + assert isinstance( + eq_bool, ConstantVariable + ), "bool should return ConstantVariable" + if eq.get_py_value() is True: + count += 1 + continue + + return ConstantVariable(count, self.graph, DummyTracker([self, value])) + + def index(self, value: VariableBase): + res = 0 + getitem = BuiltinVariable( + operator.getitem, self.graph, DanglingTracker() + ) + for index in range(len(self)): + index_value = getitem(self, index) + if index_value.id == value.id: + return ConstantVariable( + res, self.graph, DummyTracker([self, value]) + ) + eq = BuiltinVariable(operator.eq, self.graph, DanglingTracker())( + index_value, value + ) + eq_bool = BuiltinVariable(bool, self.graph, DanglingTracker())(eq) + assert isinstance( + eq_bool, ConstantVariable + ), "bool should return ConstantVariable" + if eq.get_py_value() is True: + return ConstantVariable( + res, self.graph, DummyTracker([self, value]) + ) + res += 1 + + return ConstantVariable(-1, self.graph, DummyTracker([self, value])) + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if type(value) is tuple: + return TupleVariable(value, graph, tracker) + return None + + +class RangeVariable(ContainerVariable): + """ + RangeVariable is a wrapper for range. + + Args: + val_range(range): the range to wrap + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, + val_range: range, + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(graph, tracker) + self.value = val_range + + def get_py_type(self): + return range + + def get_py_value(self, allow_tensor=False): + return self.value + + def getitem(self, key): + self.graph.add_global_guarded_variable(self) + self.graph.add_global_guarded_variable(key) + key = key.get_py_value() + retval = self.value[key] + return ConstantVariable.wrap_literal(retval, self.graph) + + def get_items(self): + size = len(self) + return [self[idx] for idx in range(size)] + + def get_wrapped_items(self): + return self.get_items() + + def get_iter(self): + from .iter import SequenceIterVariable + + return SequenceIterVariable(self, self.graph, GetIterTracker(self)) + + def __len__(self): + return len(self.value) + + def _reconstruct(self, codegen: PyCodeGen): + codegen.gen_load_global("range", push_null=True) + # The start default value is 0, step is 1 + # So we can always construct range with 3 args + codegen.gen_load_const(self.value.start) + codegen.gen_load_const(self.value.stop) + codegen.gen_load_const(self.value.step) + codegen.gen_call_function(3) + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if type(value) is range: + return RangeVariable(value, graph, tracker) + return None + + @check_guard + def make_stringify_guard(self) -> list[StringifyExpression]: + frame_value_tracer = self.tracker.trace_value_from_frame() + + return [ + StringifyExpression( + "isinstance({0}, range) and " + + f"{{0}}.start == {self.init_value.start} and " + + f"{{0}}.stop == {self.init_value.stop} and " + + f"{{0}}.step == {self.init_value.step}", + [frame_value_tracer], + frame_value_tracer.free_vars, + ) + ] + + @property + def debug_name(self) -> str: + return ":".join( + [ + str(self.value.start) if self.value.start is not None else "", + str(self.value.stop) if self.value.stop is not None else "", + str(self.value.step) if self.value.step is not None else "", + ] + ) + + @debug_name.setter + def debug_name(self, name): + pass + + @property + def main_info(self) -> dict[str, Any]: + return {"value": self.value} + + +class DictVariable(ContainerVariable): + """ + DictVariable is a wrapper for dict and contains common APIs for dict methods + + Args: + val_dict(dict[object, VariableBase]): the dict to wrap + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, + val_dict: dict[object, VariableBase], + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(graph, tracker) + + self.proxy = self.graph.side_effects.get_proxy( + MutableDictLikeData, val_dict, self.proxy_getter + ) + self.value = val_dict + + def proxy_getter(self, proxy: MutableDictLikeData, key: Any): + if key not in proxy.original_data: + return MutableDictLikeData.Empty() + return VariableFactory.from_value( + proxy.original_data[key], + self.graph, + tracker=GetItemTracker(self, key, changed=proxy.has_changed), + ) + + def get_py_value(self, allow_tensor=False): + return { + key: value.get_py_value(allow_tensor) + for key, value in self.proxy.get_all().items() + } + + def get_py_type(self): + return dict + + def _reconstruct(self, codegen: PyCodeGen): + from .basic import ConstantVariable + + size = len(self) + for key in self.proxy.get_all().keys(): + if not isinstance(key, ConstTypes): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} as key." + ) + key_var = ConstantVariable.wrap_literal(key, self.graph) + value_var = self[key] + key_var.reconstruct(codegen) + value_var.reconstruct(codegen) + codegen.gen_build_map(size) + + def get_items(self): + items = [] + for key in self.proxy.get_all().keys(): + if not isinstance(key, ConstTypes): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} as key." + ) + key_var = VariableFactory.from_value( + key, self.graph, tracker=ConstTracker(key) + ) + value_var = self[key] + items.extend([key_var, value_var]) + return items + + def get_wrapped_items(self): + items = {} + for key in self.proxy.get_all().keys(): + if not isinstance(key, ConstTypes): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} as key." + ) + items[key] = self[key] + return items + + def get_iter(self): + return self.keys() + + @property + def main_info(self) -> dict[str, Any]: + return { + "len": len(self), + } + + def __len__(self): + return len(self.proxy.get_all()) + + def get(self, key, default=None): + if isinstance(key, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} to get value." + ) + + if default is None: + return Dispatcher.call(operator.getitem, self, key) + + if isinstance(self.proxy.get(key), MutableDictLikeData.Empty): + assert isinstance(default, VariableBase) + return default + + return Dispatcher.call(operator.getitem, self, key) + + def getitem(self, key): + self.graph.add_global_guarded_variable(key) + key = key.get_py_value() + return self.proxy.get(key) + + def setitem(self, key, value): + if isinstance(key, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} as key." + ) + + if not isinstance(value, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {value} to set value." + ) + + self.proxy.set(key, value) + self.graph.side_effects.record_proxy_variable(self) + + return ConstantVariable.wrap_literal(None, self.graph) + + def clear(self): + # TODO: Replace with self.proxy.clear() + for key in self.value: + self.delitem(key) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def __delitem__(self, key): + return self.delitem(key) + + def delitem(self, key): + if isinstance(key, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} as key to delete." + ) + self.proxy.delete(key) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def keys(self): + from .iter import SequenceIterVariable + + raw_list = [ + ConstantVariable(x, self.graph, ConstTracker(x)) + for x in self.proxy.get_all().keys() + ] + key_list = ListVariable(raw_list, self.graph, DummyTracker(raw_list)) + assert key_list is not None + return SequenceIterVariable( + key_list, self.graph, DummyTracker([key_list]) + ) + + def values(self): + from .iter import SequenceIterVariable + + raw_list = list(self.get_wrapped_items().values()) + value_list = ListVariable(raw_list, self.graph, DummyTracker([self])) + assert value_list is not None + return SequenceIterVariable( + value_list, self.graph, DummyTracker([value_list]) + ) + + def items(self): + from .iter import SequenceIterVariable + + keys = [ + ConstantVariable(x, self.graph, ConstTracker(x)) + for x in self.proxy.get_all().keys() + ] + values = list(self.get_wrapped_items().values()) + raw_list = list(zip(keys, values)) + item_list = ListVariable(raw_list, self.graph, DummyTracker([self])) + assert item_list is not None + return SequenceIterVariable( + item_list, self.graph, DummyTracker([item_list]) + ) + + def update(self, data: DictVariable): + for key, value in data.proxy.get_all().items(): + self.setitem(key, value) + return ConstantVariable.wrap_literal(None, self.graph) + + def copy(self): + new_dict_variable = DictVariable( + self.get_wrapped_items(), self.graph, DummyTracker([self]) + ) + return new_dict_variable + + def setdefault(self, key, default=None): + if isinstance(self.proxy.get(key), MutableDictLikeData.Empty): + if default is None: + self.setitem( + key, ConstantVariable.wrap_literal(default, self.graph) + ) + else: + self.setitem(key, default) + + return Dispatcher.call(operator.getitem, self, key) + + def pop(self, key, default=None): + if isinstance(self.proxy.get(key), MutableDictLikeData.Empty): + assert isinstance(default, VariableBase) + return default + + # default is not None, or key is in dict + temp_value = Dispatcher.call(operator.getitem, self, key) + self.delitem(key) + return temp_value + + def popitem(self): + key = self.keys().hold.get_py_value()[-1] + value = Dispatcher.call(operator.getitem, self, key) + # TODO: key, value should be VariableBase but key maybe a int + # assert isinstance(key, VariableBase), key + # assert isinstance(value, VariableBase), value + new_tuple_variable = TupleVariable( + (key, value), self.graph, DummyTracker([self]) + ) + self.delitem(key) + return new_tuple_variable + + def getattr(self, name: str, default=None): + from .callable import BuiltinVariable + + if default is not None: + raise FallbackError( + "default argument for getattr is not implemented" + ) + + method_name_to_builtin_fn = { + "keys": dict.keys, + "values": dict.values, + "items": dict.items, + "update": dict.update, + "setdefault": dict.setdefault, + "get": dict.get, + "copy": dict.copy, + "clear": dict.clear, + "pop": dict.pop, + "popitem": dict.popitem, + } + + if name in method_name_to_builtin_fn: + builtin_fn = method_name_to_builtin_fn[name] + return BuiltinVariable( + builtin_fn, self.graph, DanglingTracker() + ).bind(self, name) + else: + raise FallbackError(f"attribute {name} for dict is not implemented") + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if type(value) in (dict, OrderedDict): + return DictVariable(value, graph=graph, tracker=tracker) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/iter.py b/python/paddle/jit/sot/opcode_translator/executor/variables/iter.py new file mode 100644 index 0000000000000..82ff8fe2534a7 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/iter.py @@ -0,0 +1,203 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from ....utils import BreakGraphError, FallbackError +from ..pycode_generator import PyCodeGen +from ..tracker import ConstTracker, DummyTracker +from .base import VariableBase +from .basic import ConstantVariable +from .container import ContainerVariable, TupleVariable + +if TYPE_CHECKING: + from ..function_graph import FunctionGraph + from ..tracker import Tracker + + +class IterVariable(VariableBase): + """ + This Variable (include subclasses) should be generated only when simulate GET_ITER opcode + """ + + def __init__( + self, obj: VariableBase, graph: FunctionGraph, tracker: Tracker + ): + super().__init__(graph, tracker) + self.hold = obj + + def make_stringify_guard(self): + return self.hold.make_stringify_guard() + + def next(self): + raise NotImplementedError(f"Can not simulate `next` for {type(self)}") + + def get_iter(self): + return self + + def get_hold(self): + return self.hold + + +class SequenceIterVariable(IterVariable): + """ + The basic SequenceIterVariable wraps iterators which can be simulated by call getitem + Currently includes: List | Tuple | Dict (keys) | Range | Tensor | nn.LayerList + """ + + mutable_attrs = ["idx"] + + def __init__(self, obj, graph: FunctionGraph, tracker: Tracker): + super().__init__(obj, graph, tracker) + self.idx = 0 + self.graph.side_effects.record_mutable_variable(self) + + def next(self): + # TODO: self.hold should have a __len__ method + if self.idx < len(self.hold): + val = self.hold[self.idx] + self.idx += 1 + return val + else: + raise StopIteration() + + def to_list(self) -> list: + if self.has_side_effect(): + raise FallbackError("Can not convert an used iterator into list") + self.idx = len(self.hold) + retval = [] + for i in range(len(self.hold)): + retval.append(self.hold[i]) + return retval + + def has_side_effect(self) -> bool: + return self.idx != 0 + + @property + def main_info(self) -> dict[str, Any]: + return { + "idx": self.idx, + } + + def _reconstruct(self, codegen: PyCodeGen): + if self.has_side_effect(): + super()._reconstruct(codegen) + else: + self.hold.reconstruct(codegen) + codegen.gen_get_iter() + + +class EnumerateVariable(SequenceIterVariable): + """ + EnumerateVariable holds a SequenceIterVariable and return additional index + """ + + def __init__(self, val_iterator, graph, tracker): + super().__init__(val_iterator, graph, tracker) + + def next(self): + val = self.hold.next() + idx_var = ConstantVariable(self.idx, self.graph, ConstTracker(self.idx)) + self.idx += 1 + return TupleVariable( + (idx_var, val), self.graph, DummyTracker([idx_var, val]) + ) + + def to_list(self): + values = self.hold.to_list() + idx = [ + ConstantVariable(i, self.graph, ConstTracker(i)) + for i in range(len(values)) + ] + return list(zip(idx, values)) + + def has_side_effect(self) -> bool: + return self.hold.has_side_effect() or self.idx != 0 + + def _reconstruct(self, codegen: PyCodeGen): + if self.has_side_effect(): + super()._reconstruct(codegen) + else: + codegen.gen_load_global("enumerate", push_null=True) + self.hold.reconstruct(codegen) + codegen.gen_call_function(1) + + def get_hold(self): + return self.hold.get_hold() + + @staticmethod + def from_iterator(value, graph: FunctionGraph | None, tracker: Tracker): + iter_variable = value.get_iter() + if isinstance(iter_variable, SequenceIterVariable): + return EnumerateVariable(iter_variable, graph, tracker) + else: + return UserDefinedIterVariable(value, graph, tracker) + + +class MapVariable(SequenceIterVariable): + """ + MapVariable holds a SequenceIterVariable and return a Iterable Variable after map function + """ + + def __init__(self, func, val_iterator, graph, tracker): + super().__init__(val_iterator, graph, tracker) + self.func = func + + def next(self): + return self.func(self.hold.next()) + + def to_list(self) -> list: + retval = [] + while True: + try: + retval.append(self.func(self.hold.next())) + except StopIteration: + break + return retval + + def has_side_effect(self) -> bool: + return self.hold.has_side_effect() + + def _reconstruct(self, codegen: PyCodeGen): + if self.has_side_effect(): + super()._reconstruct(codegen) + else: + codegen.gen_load_global("map", push_null=True) + self.func.reconstruct(codegen) + self.hold.reconstruct(codegen) + codegen.gen_call_function(2) + + @staticmethod + def from_iterator( + func, value, graph: FunctionGraph | None, tracker: Tracker + ): + iter_variable = ( + value.get_iter() if isinstance(value, ContainerVariable) else value + ) + + if isinstance(iter_variable, IterVariable): + return MapVariable(func, iter_variable, graph, tracker) + else: + return UserDefinedIterVariable(value, graph, tracker) + + +# what UserDefinedIterVariable holds doesn't matter, because use user defined iterator will trigger break graph +class UserDefinedIterVariable(IterVariable): + def __init__(self, obj, graph, tracker): + super().__init__(obj, graph, tracker) + + def next(self): + raise BreakGraphError("Break graph when using user defined iterator") diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/__init__.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/__init__.py new file mode 100644 index 0000000000000..5fc71359e9386 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .instruction_utils import ( # noqa: F401 + Instruction, + calc_offset_from_bytecode_offset, + calc_stack_effect, + convert_instruction, + gen_instr, + get_instructions, + instrs_info, + modify_extended_args, + modify_instrs, + modify_vars, + relocate_jump_target, + replace_instr, + reset_offset, +) +from .opcode_analysis import ( # noqa: F401 + Space, + analysis_inputs, + analysis_used_names_with_space, +) diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py new file mode 100644 index 0000000000000..182ba54279eef --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py @@ -0,0 +1,407 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import dataclasses +import dis +import sys +from typing import TYPE_CHECKING, Any + +from ...utils import InnerError +from .opcode_info import ABS_JUMP, ALL_JUMP, REL_BWD_JUMP, REL_JUMP + +if TYPE_CHECKING: + import types + + +@dataclasses.dataclass +class Instruction: + opcode: int + opname: str + arg: int | None + argval: Any + offset: int | None = None + starts_line: int | None = None + is_jump_target: bool = False + jump_to: Instruction | None = None + is_generated: bool = True + + # for analys EXTENDED_ARG + first_ex_arg: Instruction | None = None + ex_arg_for: Instruction | None = None + + # used in modify_extended_args + def __hash__(self): + return id(self) + + +def gen_instr(name, arg=None, argval=None, gened=True, jump_to=None): + return Instruction( + opcode=dis.opmap[name], + opname=name, + arg=arg, + argval=argval, + is_generated=gened, + jump_to=jump_to, + ) + + +def convert_instruction(instr: dis.Instruction) -> Instruction: + """ + Converts a disassembled instruction to a customized Instruction object. + + Args: + instr (dis.Instruction): The disassembled instruction. + + Returns: + Instruction: A customized Instruction object. + """ + return Instruction( + instr.opcode, + instr.opname, + instr.arg, + instr.argval, + instr.offset, + instr.starts_line, + instr.is_jump_target, + jump_to=None, + is_generated=False, + ) + + +def get_instructions(code: types.CodeType) -> list[Instruction]: + """ + Returns parsed instructions from the given code object and exclude + any opcodes that contain `EXTENDED_ARG`. + + Args: + code (types.CodeType): The code object to extract instructions from. + + Returns: + list[Instruction]: A list of Instruction objects representing the + bytecode instructions in the code object. + """ + # instrs do not contain EXTENDED_ARG + instrs = list(map(convert_instruction, dis.get_instructions(code))) + for instr in instrs: + if instr.opname in ALL_JUMP: + origin_jump_target = calc_offset_from_bytecode_offset( + instr.argval, instrs + ) + jump_offset = origin_jump_target + + while instrs[jump_offset].opname == "EXTENDED_ARG": + jump_offset += 1 + + if origin_jump_target != jump_offset: + # copy infos from EXETENDED_ARG to other opcode + + if instrs[origin_jump_target].is_jump_target: + instrs[jump_offset].is_jump_target = instrs[ + origin_jump_target + ].is_jump_target + if instrs[origin_jump_target].starts_line: + instrs[jump_offset].starts_line = instrs[ + origin_jump_target + ].starts_line + + instr.jump_to = instrs[jump_offset] + + # if the origin opcode contains EXTENDED_ARG, it should be like: + # >> EXTENDED_ARG 1 + # XX 388 <- 256 + 132 + # filter all EXTENDED_ARG here + instrs = [x for x in instrs if x.opname != "EXTENDED_ARG"] + return instrs + + +def modify_instrs(instructions: list[Instruction]) -> None: + """ + Modifies the given list of instructions. It contains three steps: + + 1. reset offset + 2. relocate jump target + 3. add EXTENDED_ARG instruction if needed + + Args: + instructions (list): The list of Instruction objects representing bytecode instructions. + + Returns: + None + """ + modify_completed = False + while not modify_completed: + reset_offset(instructions) + relocate_jump_target(instructions) + modify_completed = modify_extended_args(instructions) + + +def reset_offset(instructions: list[Instruction]) -> None: + """ + Resets the offset for each instruction in the list. + + Args: + instructions (list): The list of Instruction objects representing bytecode instructions. + + Returns: + None + """ + from ..executor.pycode_generator import get_instruction_size + + if sys.version_info >= (3, 11): + current_offset = 0 + for instr in instructions: + instr.offset = current_offset + current_offset += get_instruction_size(instr) + return + for idx, instr in enumerate(instructions): + instr.offset = idx * 2 + + +def correct_jump_direction(instr: Instruction, arg: int) -> Instruction: + """ + Corrects the jump direction of the given instruction. + NOTE(zrr1999): In Python 3.11, JUMP_ABSOLUTE is removed, so python generates JUMP_FORWARD or JUMP_BACKWARD instead, + but in for loop breakgraph, we reuse JUMP_BACKWARD to jump forward, so we need to change it to JUMP_FORWARD. + + Args: + instr (Instruction): The instruction to be corrected. + """ + if instr.opname in ABS_JUMP: + instr.arg = arg + return instr + elif instr.opname in REL_JUMP: + if arg < 0: + if instr.opname in REL_BWD_JUMP: + forward_op_name = instr.opname.replace("BACKWARD", "FORWARD") + if forward_op_name not in dis.opmap: + raise InnerError(f"Unknown jump type {instr.opname}") + instr.opname = forward_op_name + instr.opcode = dis.opmap[forward_op_name] + else: # instr.opname in REL_FWD_JUMP + backward_op_name = instr.opname.replace("FORWARD", "BACKWARD") + if backward_op_name not in dis.opmap: + raise InnerError(f"Unknown jump type {instr.opname}") + instr.opname = backward_op_name + instr.opcode = dis.opmap[backward_op_name] + instr.arg = -arg + else: + instr.arg = arg + return instr + else: + raise ValueError(f"unknown jump type: {instr.opname}") + + +def relocate_jump_target(instructions: list[Instruction]) -> None: + """ + If a jump instruction is found, this function will adjust the jump targets based on the presence of EXTENDED_ARG instructions. + If an EXTENDED_ARG instruction exists for the jump target, use its offset as the new target. + + Args: + instructions (list): The list of Instruction objects representing bytecode instructions. + + Returns: + None + """ + extended_arg = [] + for instr in instructions: + if instr.opname == "EXTENDED_ARG": + extended_arg.append(instr) + continue + + if instr.opname in ALL_JUMP: + assert instr.jump_to is not None + assert instr.offset is not None + # if jump target has extended_arg, should jump to the first extended_arg opcode + jump_target = ( + instr.jump_to.offset + if instr.jump_to.first_ex_arg is None + else instr.jump_to.first_ex_arg.offset + ) + assert jump_target is not None + + if instr.opname in ABS_JUMP: + new_arg = jump_target + else: # instr.opname in REL_JUMP + new_arg = jump_target - instr.offset - 2 + if instr.opname in REL_BWD_JUMP: + new_arg = -new_arg + + if sys.version_info >= (3, 10): + new_arg //= 2 + correct_jump_direction(instr, new_arg) + assert instr.arg is not None + if extended_arg: + instr.arg &= 0xFF + new_arg = new_arg >> 8 + for ex in reversed(extended_arg): + ex.arg = new_arg & 0xFF + new_arg = new_arg >> 8 + + # need more extended_args instr + # set arg in the first extended_arg + if new_arg > 0: + extended_arg[0].arg += new_arg << 8 + extended_arg.clear() + + +def modify_extended_args(instructions: list[Instruction]) -> bool: + """ + This function replaces any instruction with an argument greater than or equal to 256 with one or more EXTENDED_ARG instructions. + + Args: + instructions (list): The list of Instruction objects representing bytecode instructions. + + Returns: + bool: True if the modification is completed, False otherwise. + """ + + modify_completed = True + extend_args_record = {} + for instr in instructions: + if instr.arg and instr.arg >= 256: # more than one byte + _instrs = [ + instr + ] # replace instr with _instrs later (it is a set of instrs), all operations will be recorded in extend_args_record + val = instr.arg + instr.arg = val & 0xFF + val = val >> 8 + while val > 0: + _instrs.append(gen_instr("EXTENDED_ARG", arg=val & 0xFF)) + val = val >> 8 + + extend_args_record.update({instr: list(reversed(_instrs))}) + + if extend_args_record: + # if new EXTENDED_ARG inserted, we need update offset and jump target + modify_completed = False + + def bind_ex_arg_with_instr(ex_arg, instr): + # move opcode info to EXTENDED_ARG + ex_arg.starts_line = instr.starts_line + instr.starts_line = None + ex_arg.is_jump_target = instr.is_jump_target + instr.is_jump_target = False + + if instr.ex_arg_for is not None: + # instr is also an ex_arg for another instr + instr.ex_arg_for.first_ex_arg = ex_arg + ex_arg.ex_arg_for = instr.ex_arg_for + instr.ex_arg_for = None + else: + instr.first_ex_arg = ex_arg + ex_arg.ex_arg_for = instr + + for key, val in extend_args_record.items(): + bind_ex_arg_with_instr(val[0], key) + replace_instr(instructions, instr=key, new_instr=val) + + return modify_completed + + +def modify_vars(instructions, code_options): + co_names = code_options['co_names'] + co_varnames = code_options['co_varnames'] + co_freevars = code_options['co_freevars'] + for instrs in instructions: + if instrs.opname == 'LOAD_FAST' or instrs.opname == 'STORE_FAST': + assert ( + instrs.argval in co_varnames + ), f"`{instrs.argval}` not in {co_varnames}" + instrs.arg = co_varnames.index(instrs.argval) + elif instrs.opname == "LOAD_DEREF" or instrs.opname == "STORE_DEREF": + if sys.version_info >= (3, 11): + namemap = co_varnames + co_freevars + assert ( + instrs.argval in namemap + ), f"`{instrs.argval}` not in {namemap}" + instrs.arg = namemap.index(instrs.argval) + + +def calc_offset_from_bytecode_offset( + bytecode_offset: int, + instructions: list[dis.Instruction] | list[Instruction], +) -> int: + """ + Calculate the index from bytecode offset, because it have 2 bytes per instruction (for Python <= 3.10). + + Args: + bytecode_offset (int): The bytecode offset of the instruction. + + Returns: + int: The index of the instruction in the instruction list. + """ + + if sys.version_info >= (3, 11): + instruction_offsets = [x.offset for x in instructions] + return instruction_offsets.index(bytecode_offset) + return bytecode_offset // 2 + + +def replace_instr(instructions, instr, new_instr): + idx = instructions.index(instr) + instructions[idx : idx + 1] = new_instr + + +def instrs_info(instrs, mark=None, range=None): + ret = [] + start = -1 + end = 1000000 + if mark is not None and range is not None: + start = mark - range + end = mark + range + 1 + for idx, instr in enumerate(instrs): + if idx < start or idx >= end: + continue + if instr.starts_line is not None: + ret.append("") + ret.append( + "{line:<8s}{is_jump_target:>2s}{offset:>4d} {opname:<30s}{arg:<4s}{argval:<40s}{mark}".format( + line=str(instr.starts_line) if instr.starts_line else "", + is_jump_target=">>" if instr.is_jump_target else " ", + offset=instr.offset + if instr.offset or instr.offset == 0 + else -1, + opname=instr.opname, + arg=str(instr.arg) if instr.arg is not None else "", + argval=f"({instr.argval})" if instr.argval else "", + mark="", + ) + ) + if idx == mark: + ret[-1] = "\033[31m" + ret[-1] + "\033[0m" + return ret + + +def calc_stack_effect(instr: Instruction, *, jump: bool | None = None) -> int: + """ + Gets the stack effect of the given instruction. In Python 3.11, the stack effect of `CALL` is -1, + refer to https://github.com/python/cpython/blob/3.11/Python/compile.c#L1123-L1124. + + Args: + instr: The instruction. + + Returns: + The stack effect of the instruction. + + """ + if sys.version_info[:2] == (3, 11): + if instr.opname == "PRECALL": + return 0 + elif instr.opname == "CALL": + # NOTE(zrr1999): push_n = 1, pop_n = oparg + 2, stack_effect = push_n - pop_n = -oparg-1 + assert instr.arg is not None + return -instr.arg - 1 + return dis.stack_effect(instr.opcode, instr.arg, jump=jump) diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_analysis.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_analysis.py new file mode 100644 index 0000000000000..dcda7558e5a39 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_analysis.py @@ -0,0 +1,217 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import dataclasses +from enum import Enum + +from ...utils import InnerError, OrderedSet +from .instruction_utils import Instruction +from .opcode_info import ALL_JUMP, HAS_FREE, HAS_LOCAL, UNCONDITIONAL_JUMP + + +@dataclasses.dataclass +class State: + reads: OrderedSet[str] + writes: OrderedSet[str] + visited: OrderedSet[int] + + +def is_read_opcode(opname): + if opname in [ + "LOAD_FAST", + "LOAD_DEREF", + "LOAD_NAME", + "LOAD_GLOBAL", + "LOAD_CLOSURE", + ]: + return True + if opname in ( + "DELETE_FAST", + "DELETE_DEREF", + "DELETE_NAME", + "DELETE_GLOBAL", + ): + return True + return False + + +def is_write_opcode(opname): + if opname in ["STORE_FAST", "STORE_NAME", "STORE_DEREF", "STORE_GLOBAL"]: + return True + if opname in ( + "DELETE_FAST", + "DELETE_DEREF", + "DELETE_NAME", + "DELETE_GLOBAL", + ): + return True + return False + + +def analysis_inputs( + instructions: list[Instruction], + current_instr_idx: int, + stop_instr_idx: int | None = None, +) -> OrderedSet[str]: + """ + Analyze the inputs of the instructions from current_instr_idx to stop_instr_idx. + + Args: + instructions (list[Instruction]): The instructions to analyze. + current_instr_idx (int): The index of the current instruction. + stop_instr_idx (int | None, optional): The index of the instruction to stop. Defaults to None. + If None, the analysis will stop at the end of the instructions. + + Returns: + set[str]: The analysis result. + """ + root_state = State(OrderedSet(), OrderedSet(), OrderedSet()) + + def fork( + state: State, start: int, jump: bool, jump_target: int + ) -> OrderedSet[str]: + new_start = start + 1 if not jump else jump_target + new_state = State( + OrderedSet(state.reads), + OrderedSet(state.writes), + OrderedSet(state.visited), + ) + return walk(new_state, new_start) + + def walk(state: State, start: int) -> OrderedSet[str]: + end = len(instructions) if stop_instr_idx is None else stop_instr_idx + for i in range(start, end): + if i in state.visited: + return state.reads + state.visited.add(i) + + instr = instructions[i] + if instr.opname in HAS_LOCAL | HAS_FREE: + if is_read_opcode(instr.opname) and instr.argval not in ( + state.writes + ): + state.reads.add(instr.argval) + elif is_write_opcode(instr.opname): + state.writes.add(instr.argval) + elif instr.opname in ALL_JUMP: + assert instr.jump_to is not None + target_idx = instructions.index(instr.jump_to) + # Fork to two branches, jump or not + jump_branch = fork(state, i, True, target_idx) + not_jump_branch = ( + fork(state, i, False, target_idx) + if instr.opname not in UNCONDITIONAL_JUMP + else OrderedSet() + ) + return jump_branch | not_jump_branch + elif instr.opname == "RETURN_VALUE": + return state.reads + return state.reads + + return walk(root_state, current_instr_idx) + + +@dataclasses.dataclass +class SpaceState: + reads: dict[str, Space] + writes: dict[str, Space] + visited: OrderedSet[int] + + def __or__(self, other): + reads = {} + reads.update(other.reads) + reads.update(self.reads) + writes = {} + writes.update(other.writes) + writes.update(self.writes) + return SpaceState(reads, writes, OrderedSet()) + + +class Space(Enum): + locals = 1 + globals = 2 + cells = 3 + all = 4 + + +def get_space(opname: str): + if "FAST" in opname: + return Space.locals + elif "GLOBAL" in opname: + return Space.globals + elif "DEREF" in opname or "CLOSURE" in opname: + return Space.cells + elif "NAME" in opname: + return Space.all + else: + raise InnerError(f"Unknown space for {opname}") + + +def analysis_used_names_with_space( + instructions: list[Instruction], + start_instr_idx: int, + stop_instr_idx: int | None = None, +): + root_state = SpaceState({}, {}, OrderedSet()) + + def fork( + state: SpaceState, start: int, jump: bool, jump_target: int + ) -> SpaceState: + new_start = start + 1 if not jump else jump_target + new_state = SpaceState( + dict(state.reads), + dict(state.writes), + OrderedSet(state.visited), + ) + return walk(new_state, new_start) + + def walk(state: SpaceState, start: int) -> SpaceState: + end = len(instructions) if stop_instr_idx is None else stop_instr_idx + for i in range(start, end): + if i in state.visited: + return state + state.visited.add(i) + + instr = instructions[i] + if instr.opname in HAS_LOCAL | HAS_FREE: + if is_read_opcode(instr.opname) and instr.argval not in ( + state.writes + ): + space = get_space(instr.opname) + state.reads[instr.argval] = space + elif is_write_opcode(instr.opname): + space = get_space(instr.opname) + state.writes[instr.argval] = space + elif instr.opname in ALL_JUMP: + assert instr.jump_to is not None + target_idx = instructions.index(instr.jump_to) + # Fork to two branches, jump or not + jump_branch = fork(state, i, True, target_idx) + not_jump_branch = ( + fork(state, i, False, target_idx) + if instr.opname not in UNCONDITIONAL_JUMP + else SpaceState({}, {}, OrderedSet()) + ) + return jump_branch | not_jump_branch + elif instr.opname == "RETURN_VALUE": + return state + return state + + state = walk(root_state, start_instr_idx) + all_used_vars = {} + all_used_vars.update(state.writes) + all_used_vars.update(state.reads) + return all_used_vars diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_info.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_info.py new file mode 100644 index 0000000000000..cc63d5ecde967 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_info.py @@ -0,0 +1,58 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from enum import Enum + +import opcode + +REL_JUMP = {opcode.opname[x] for x in opcode.hasjrel} +REL_BWD_JUMP = {opname for opname in REL_JUMP if "BACKWARD" in opname} +REL_FWD_JUMP = REL_JUMP - REL_BWD_JUMP +ABS_JUMP = {opcode.opname[x] for x in opcode.hasjabs} +HAS_LOCAL = {opcode.opname[x] for x in opcode.haslocal} +HAS_FREE = {opcode.opname[x] for x in opcode.hasfree} +ALL_JUMP = REL_JUMP | ABS_JUMP +UNCONDITIONAL_JUMP = {"JUMP_ABSOLUTE", "JUMP_FORWARD"} +if sys.version_info >= (3, 11): + UNCONDITIONAL_JUMP.add("JUMP_BACKWARD") + + +class JumpDirection(Enum): + FORWARD = "FORWARD" + BACKWARD = "BACKWARD" + + +class PopJumpCond(Enum): + FALSE = "FALSE" + TRUE = "TRUE" + NONE = "NONE" + NOT_NONE = "NOT_NONE" + + +# Cache for some opcodes, it's for Python 3.11+ +# https://github.com/python/cpython/blob/3.11/Include/internal/pycore_opcode.h#L41-L53 +PYOPCODE_CACHE_SIZE = { + "BINARY_SUBSCR": 4, + "STORE_SUBSCR": 1, + "UNPACK_SEQUENCE": 1, + "STORE_ATTR": 4, + "LOAD_ATTR": 4, + "COMPARE_OP": 2, + "LOAD_GLOBAL": 5, + "BINARY_OP": 1, + "LOAD_METHOD": 10, + "PRECALL": 1, + "CALL": 4, +} diff --git a/python/paddle/jit/sot/opcode_translator/skip_files.py b/python/paddle/jit/sot/opcode_translator/skip_files.py new file mode 100644 index 0000000000000..7753309debce9 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/skip_files.py @@ -0,0 +1,177 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import codecs +import collections +import contextlib +import copy +import copyreg +import dataclasses +import distutils +import enum +import functools +import importlib +import inspect +import linecache +import logging +import multiprocessing +import operator +import os +import posixpath +import random +import re +import selectors +import signal +import sys +import tempfile +import threading +import tokenize +import traceback +import types +import typing +import unittest +import uuid +import warnings +import weakref + +import _collections_abc +import _weakrefset +import decorator +import google.protobuf +import numpy +import setuptools + +import paddle + +from ..utils import log + +NEED_SKIP_THIRD_PARTIY_MODULES = { + abc, + collections, + contextlib, + copy, + copyreg, + dataclasses, + enum, + functools, + google.protobuf, + importlib, + inspect, + linecache, + logging, + multiprocessing, + numpy, + operator, + os, + posixpath, + random, + re, + selectors, + signal, + tempfile, + threading, + tokenize, + traceback, + types, + typing, + unittest, + weakref, + _collections_abc, + _weakrefset, + decorator, + codecs, + uuid, + setuptools, + distutils, + warnings, +} + +if sys.version_info < (3, 11): + import sre_compile + import sre_parse + + NEED_SKIP_THIRD_PARTIY_MODULES.add(sre_compile) + NEED_SKIP_THIRD_PARTIY_MODULES.add(sre_parse) + + +def _strip_init_py(s): + return re.sub(r"__init__.py$", "", s) + + +def _module_dir(m: types.ModuleType): + return _strip_init_py(m.__file__) + + +skip_file_names = {_module_dir(m) for m in NEED_SKIP_THIRD_PARTIY_MODULES} + + +sot_path = os.path.dirname(__file__).rpartition(os.sep)[0] + os.sep +paddle_path = sys.modules["paddle"].__file__.rpartition(os.sep)[0] + os.sep + +skip_file_names.add(sot_path) +skip_file_names.add(paddle_path) +skip_file_names.add( + "") + +skip_file_name_re = re.compile( + f"^({'|'.join(map(re.escape, skip_file_names))})" +) + +customed_skip_code = set() + +no_skip_code = {paddle.nn.Sequential.forward.__code__} + + +def need_skip_path(filepath: str) -> bool: + """ + Check if the file should be skipped and not transcribed. + + Args: + filepath: The path of the file to check. + + Returns: + bool: True if the file should be skipped. + """ + if not filepath.startswith("<"): + filepath = os.path.abspath(filepath) + return bool(skip_file_name_re.match(filepath)) + + +def skip_function(function): + customed_skip_code.add(function.__code__) + return function + + +def need_skip(frame): + pycode = frame.f_code + if pycode in no_skip_code: + return False + if pycode in customed_skip_code: + log(3, f"Skip frame by code: {pycode}\n") + return True + filename = pycode.co_filename + if sys.version_info >= (3, 11) and filename.startswith(" CustomCode: + with EventGuard( + f"eval_frame_callback: {frame.f_code.co_name}", event_level=2 + ): + # is generator + if frame.f_code.co_flags & 0x20 > 0: + return CustomCode(None, True) + + # NOTE(SigureMo): Temporary fallback when code has exception handling. + if sys.version_info >= (3, 11) and frame.f_code.co_exceptiontable: + log( + 3, + f"[eval_frame_callback] {frame.f_code} has co_exceptiontable\n", + ) + return CustomCode(None, False) + + if need_skip(frame): + log(3, f"[eval_frame_callback] skip {frame.f_code}\n") + custom_code = CustomCode(None, False) + new_code = frame.f_code + else: + log( + 2, f"[eval_frame_callback] start to translate: {frame.f_code}\n" + ) + log_do(4, partial(print_locals, frame)) + + log(3, f"[transform] OriginCode: {frame.f_code.co_name}\n") + log_do(3, lambda: dis.dis(frame.f_code)) + + custom_code = OpcodeExecutorCache()(frame, **kwargs) + + if custom_code.code is None: + log( + 3, + "[transform] NewCode (same as origin code): " + + frame.f_code.co_name + + "\n", + ) + new_code = frame.f_code + else: + log( + 3, + "[transform] NewCode: " + custom_code.code.co_name + "\n", + ) + log_do(3, lambda: dis.dis(custom_code.code)) + new_code = custom_code.code + + # just check those codes which need open eval_frame + if ( + custom_code.disable_eval_frame is False + and CodeStatus().is_code_without_graph(new_code) + ): + log( + 3, + "[eval_frame_callback] Code has no graph, block it.\n", + ) + return CustomCode(None, True) + + return custom_code diff --git a/python/paddle/jit/sot/profiler.py b/python/paddle/jit/sot/profiler.py new file mode 100644 index 0000000000000..8315e03dd37f5 --- /dev/null +++ b/python/paddle/jit/sot/profiler.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from contextlib import contextmanager +from functools import wraps + +from paddle.framework import core + +_event_level = int(os.environ.get("EVENT_LEVEL", "-1")) + + +class SotProfiler: + def __enter__(self): + self.enable() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.disable() + + def enable(self, tag=None): + core.nvprof_start() + core.nvprof_enable_record_event() + + def disable(self): + core.nvprof_stop() + + +@contextmanager +def EventGuard(event_name, event_level=0): + try: + global _event_level + need_pop = False + if _event_level >= event_level: + core.nvprof_nvtx_push(event_name) + need_pop = True + yield + finally: + if need_pop: + core.nvprof_nvtx_pop() + + +if _event_level == -1: + + @contextmanager + def _EmptyEventGuard(event_name, event_level=0): + yield + + EventGuard = _EmptyEventGuard # noqa: F811 + + +def event_register(event_name, event_level=0): + def event_wrapper(func): + @wraps(func) + def call_with_event(*args, **kwargs): + with EventGuard(event_name, event_level=0): + return func(*args, **kwargs) + + return call_with_event + + def do_nothing(func): + return func + + global _event_level + if _event_level >= event_level: + return event_wrapper + else: + return do_nothing diff --git a/python/paddle/jit/sot/psdb.py b/python/paddle/jit/sot/psdb.py new file mode 100644 index 0000000000000..38fa4d7479e16 --- /dev/null +++ b/python/paddle/jit/sot/psdb.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import builtins +import types +from typing import TYPE_CHECKING, Callable + +if TYPE_CHECKING: + from typing import TypeVar + + from typing_extensions import ParamSpec + + T = TypeVar("T") + P = ParamSpec("P") + +NO_BREAKGRAPH_CODES: set[types.CodeType] = set() +NO_FALLBACK_CODES: set[types.CodeType] = set() + + +def assert_true(input: bool): + assert input + + +def print(*args, **kwargs): + builtins.print("[Dygraph]", *args, **kwargs) + + +def breakpoint(): + import paddle + + old = paddle.framework.core.set_eval_frame(None) + builtins.breakpoint() + paddle.framework.core.set_eval_frame(old) + + +def check_no_breakgraph(fn: Callable[P, T]) -> Callable[P, T]: + NO_BREAKGRAPH_CODES.add(fn.__code__) + return fn + + +def breakgraph(): + pass + + +def check_no_fallback(fn: Callable[P, T]) -> Callable[P, T]: + NO_FALLBACK_CODES.add(fn.__code__) + return fn + + +def fallback(): + pass + + +def in_sot(): + return False diff --git a/python/paddle/jit/sot/symbolic/compile_cache.py b/python/paddle/jit/sot/symbolic/compile_cache.py new file mode 100644 index 0000000000000..8fa7444ff0684 --- /dev/null +++ b/python/paddle/jit/sot/symbolic/compile_cache.py @@ -0,0 +1,143 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import paddle + +from ..profiler import EventGuard +from ..utils import ( + Cache, + CodeStatus, + GraphLogger, + Singleton, + StepInfoManager, + log_do, +) +from .interpreter import compile_sir + +if TYPE_CHECKING: + from .symbolic_context import SymbolicTraceContext + + +def clear_eager_tensor_name(output_tensors): + for output_tensor in output_tensors: + output_tensor.name = "" + + +class FallbackWrapper: + """ + Used to store and call static graph methods generated by paddle.jit.to_static + """ + + def __init__(self, compiled_fn, SIR): + self.compiled_fn = compiled_fn + self.partial_program = None + self.concrete_program = None + self.SIR = SIR # for debug + + def __call__(self, *args, **kwargs): + with EventGuard(f"FallbackWrapper: {self.SIR.name}"): + if StepInfoManager().need_back_trace: + CodeStatus().trace_back_frames() + + log_do( + 2, + lambda: print("[FallbackWrapper] start run SIR: \n", self.SIR), + ) + log_do( + 4, + lambda: print( + self.compiled_fn.get_concrete_program(*args, **kwargs)[ + 1 + ].train_program + ), + ) + if self.partial_program is None: + with EventGuard("FallbackWrapper: call compiled_fn"): + outputs = self.compiled_fn(*args, **kwargs) + ( + self.concrete_program, + self.partial_program, + ) = self.compiled_fn.get_concrete_program(*args, **kwargs) + else: + # Speed up Resnet from 0.0068 --> 0.0057 + with EventGuard("FallbackWrapper: call partial_program"): + outputs = self.partial_program(*args, **kwargs) + + clear_eager_tensor_name(outputs) + log_do( + 1, + lambda: GraphLogger().add_subgraph( + self.concrete_program.main_program + ), + ) + log_do( + 4, + lambda: print("[CompileCache] run sir forward success."), + ) + return outputs + + +@Singleton +class CompileSIRCache(Cache): + """ + Cache the compiled function of SIR + """ + + def __init__(self): + super().__init__(weak=False) + + def key_fn(self, context: SymbolicTraceContext, sir_name: str, **kwargs): + """ + generate a hash key for a SIR + + Args: + context: The context to compile + sir_name: The name of the sir to compile + build_strategy: The build strategy to compile + + Returns: + The hash key of the SIR + """ + sir = context.get_sir(sir_name) + # NOTE(dev): Is str(sir) a heavy opearation ? + hash_key = hash(str(sir)) + return hash_key + + def value_fn(self, context: SymbolicTraceContext, sir_name: str, **kwargs): + """ + Generate static graph function + + Args: + context: The context to compile + sir_name: The name of the sir to compile + build_strategy: The build strategy to compile + + Returns: + The static graph function + """ + build_strategy = kwargs.get("build_strategy", None) + backend = kwargs.get("backend", None) + return FallbackWrapper( + paddle.jit.to_static( + compile_sir(context, sir_name), + build_strategy=build_strategy, + backend=backend, + enable_fallback=False, + ), + context.get_sir(sir_name), + ) diff --git a/python/paddle/jit/sot/symbolic/interpreter.py b/python/paddle/jit/sot/symbolic/interpreter.py new file mode 100644 index 0000000000000..13265bbab4e38 --- /dev/null +++ b/python/paddle/jit/sot/symbolic/interpreter.py @@ -0,0 +1,194 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import paddle +from paddle.utils import to_sequence + +from ..utils import InnerError, map_if, map_if_extend +from .statement_ir import SIRRuntimeCache, Symbol + +if TYPE_CHECKING: + from .statement_ir import Statement, StatementIR + from .symbolic_context import SymbolicTraceContext + + +def replace_symbol( + values: list[Symbol] | list[object], state: dict[str, Symbol] +): + """ + Replaces Symbol objects with their corresponding values. + + Args: + values: A list of values that may contain Symbol objects. + state: A dict mapping Symbol names to their corresponding values. + + Returns: + A new list with Symbol objects replaced by their corresponding values in the state dict. + """ + # deal with list / map etc. + values = map_if_extend( + values, + pred=lambda x: isinstance(x, Symbol), + true_fn=lambda x: state[x.name], + false_fn=lambda x: x, + ) + return values + + +def _append_opstack_between(start, end, stack): + # NOTE(xiongkun): we don't sync for speed. careful!! + # [start, end) + from paddle.framework import core + + op_maker = core.op_proto_and_checker_maker + callstack_attr_name = op_maker.kOpCreationCallstackAttrName() + for op in for_each_ops_between(start, end): + op._set_attr(callstack_attr_name, stack) + + +def for_each_ops_between(start, end): + # NOTE(xiongkun): we don't sync for speed. careful!! + # [start, end) + program = paddle.static.default_main_program() + ops = program.current_block().ops[start:end] + yield from ops + + +def opnum_in_program(): + # NOTE(xiongkun): we don't sync for speed. careful!! + program = paddle.static.default_main_program() + return len(program.current_block().ops) + + +class Interpreter: + """ + Interpreter is used to interpret and execute SIR. + """ + + def __init__(self, symbolic_context: SymbolicTraceContext): + self._context = symbolic_context + + def get_sir(self, name: str) -> StatementIR: + """ + Returns the StatementIR object by given name. + + Args: + name: The name of the StatementIR. + + Returns: + The StatementIR object with the given name. + """ + return self._context.get_sir(name) + + def run_sir(self, name: str, state: dict[str, Symbol]): + """ + Runs the StatementIR with the given name using the provided state. + + Args: + name: The name of the given StatementIR to run. + state: A dict mapping Symbol names to their corresponding values. + + Returns: + A list of the Symbol of the StatementIR after execution. + """ + SIR = self.get_sir(name) + for stmt in SIR.statements: + stmt: Statement + before_stmt_opnum = opnum_in_program() + inputs = replace_symbol(stmt.inputs, state) + outs = getattr(self, stmt.type)(stmt, inputs) + + def _set(v, s): + state[s.name] = v + + if len(to_sequence(outs)) != len(to_sequence(stmt.outputs)): + raise InnerError("Number output mismatch, some error happen.") + + _append_opstack_between( + before_stmt_opnum, opnum_in_program() + 1, stmt.stmt_stack + ) + + map_if( + outs, + stmt.outputs, + pred=lambda v, s: isinstance(s, Symbol), + true_fn=lambda v, s: _set(v, s), + false_fn=lambda v, s: None, + ) + # fetch outputs + return replace_symbol(SIR.outputs, state) + + def call(self, stmt: Statement, inputs): + SIR = self.get_sir(stmt.sir_name) + state = prepare_state(SIR, inputs) + return self.run_sir(stmt.sir_name, state) + + def api(self, stmt, inputs): + args, kwargs = inputs + return stmt.api(*args, **kwargs) + + def method(self, stmt, inputs): + args, kwargs = inputs + var = args[0] + return getattr(var, stmt.method)(*args[1:], **kwargs) + + def layer(self, stmt, inputs): + args, kwargs = inputs + layer = stmt.layer() + assert layer is not None, "SIR bound layer is None." + return layer(*args, **kwargs) + + +def compile_sir(context: SymbolicTraceContext, name: str): + """ + Compile a SIR to a new function + + Args: + context: The context to compile + name: The name of the sir to compile + + """ + + @paddle.jit.not_to_static + def wrapper(args): + """ + This function will be decorated by paddle.to_static. + so the args is variables, not eager tensors. + """ + interpreter = Interpreter(context) + SIR = interpreter.get_sir(name) + state = prepare_state(SIR, args) + return interpreter.run_sir(name, state) + + return wrapper + + +def prepare_state(SIR, inputs): + state = {} + + # update free vars if exsits + if SIRRuntimeCache().has_key(SIR.name): # noqa: W601 + free_var_seeker = SIRRuntimeCache().get_free_vars(SIR.name) + if free_var_seeker: + state = free_var_seeker() + + # bind inputs + for sir_inp, inp in zip(SIR.inputs, inputs): + state[sir_inp.name] = inp + + return state diff --git a/python/paddle/jit/sot/symbolic/statement_ir.py b/python/paddle/jit/sot/symbolic/statement_ir.py new file mode 100644 index 0000000000000..11a08f36acd9d --- /dev/null +++ b/python/paddle/jit/sot/symbolic/statement_ir.py @@ -0,0 +1,338 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +THIS FILE IS PRIVATE !! + +use interface in symbolic_context.py first. +""" +from __future__ import annotations + +import weakref +from typing import Any, Callable + +import paddle +from paddle.utils import is_sequence, map_structure + +from ..utils import NameGenerator, OrderedSet, Singleton, flatten_extend + + +class Symbol: + """ + Symbol is used to distinguish a string and a `math variable`. + """ + + def __init__(self, name: str): + self.name = name + + def __str__(self): + return self.name + + def __repr__(self): + return str(self) + + def __eq__(self, other): + if isinstance(other, str): + return self.name == other + return self.name == other.name + + def __hash__(self): + return hash(self.name) + + def __deepcopy__(self, memo=None): + return Symbol(self.name) + + +class Statement: + """ + Statement is used to represent a sentence of code for building the neural network model, + which has four types: "call", "api", "method", and "layer". + + Note: + Statement temporarily does not support control flow. + """ + + def __init__( + self, + type: str, + name: str, + inputs: list[Symbol], + outputs: list[Symbol], + stacks: list[str], + ): + assert type in ["call", "api", "method", "layer"] + self.name = name + self.inputs = inputs # (list of Symbols, dict of Symbols) + self.outputs = outputs # list of Symbol | PythonObj + self.stmt_stack = ( + stacks # a list of string to record the source code callstack. + ) + self.type = type + + def __str__(self): + def to_string(inps): + if isinstance(inps, str) or not is_sequence(inps): + return inps.__str__() + inps = (x.__str__() for x in inps) + return ", ".join(inps) + + return "{} || {} = {} ({}) ".format( + self.type + " " * (10 - len(self.type)), + to_string(self.outputs), + self.name, + to_string(self.inputs), + ) + + def __repr__(self): + return self.__str__() + + +class CallStatement(Statement): + def __init__( + self, + name: str, + inputs: list[Symbol], + outputs: list[Symbol], + stacks: list[str], + ): + super().__init__("call", name, inputs, outputs, stacks) + self.sir_name = name + + +class ApiStatement(Statement): + def __init__( + self, + api: Callable, + inputs: list[Symbol], + outputs: list[Symbol], + stacks: list[str], + ): + super().__init__( + "api", "paddle." + api.__name__, inputs, outputs, stacks + ) + self.api = api + + +class MethodStatement(Statement): + def __init__( + self, + name: str, + inputs: list[Symbol], + outputs: list[Symbol], + stacks: list[str], + ): + super().__init__("method", name, inputs, outputs, stacks) + self.method = name + + +class LayerStatement(Statement): + def __init__( + self, + layer: paddle.nn.Layer, + inputs: list[Symbol], + outputs: list[Symbol], + stacks: list[str], + ): + super().__init__( + "layer", layer.__class__.__name__, inputs, outputs, stacks + ) + self.layer = weakref.ref(layer) + + +class StatementIR: + """ + StatementIR is the carrier that records the code for building the neural network model.It is + a representation of a purely computational structure, and does not care about specific values. + The function converted from StatementIR can ensure that it can be turned into a static state. + In this way, we can reuse the original `to_static` function to realize the execution of the static graph. + + Note: + Don't create by yourself, just use the StatementIRCache.get() + """ + + def __init__(self, name: str): + self.name = name + self.inputs = [] # list of Symbol | PythonObj + self.outputs = [] # list of Symbol | PythonObj + self.statements = [] # list of Statement + + def __len__(self): + return len(self.statements) + + def __deepcopy__(self, memo=None): + new_sir = StatementIR(self.name) + new_sir.inputs = list(self.inputs) + new_sir.outputs = list(self.outputs) + new_sir.statements = list(self.statements) + return new_sir + + def add_input(self, input): + self.inputs.append(input) + + def add_output(self, output): + self.outputs.append(output) + + def add_statement(self, statement): + assert isinstance(statement, Statement) + self.statements.append(statement) + + def analyse_inputs(self): + used_symbols = OrderedSet() + generated_symbols = OrderedSet() + for stmt in self.statements: + for inp in flatten_extend(stmt.inputs): + if isinstance(inp, Symbol) and inp not in generated_symbols: + used_symbols.add(inp) + for out in flatten_extend(stmt.outputs): + if isinstance(out, Symbol): + generated_symbols.add(out) + + input_symbols = sorted(used_symbols, key=lambda x: x.name) + return input_symbols + + def __str__(self): + strs = [] + strs.append("StatmentIR: %s" % self.name) + strs.append(f" inputs: {map_structure(lambda x: x.name, self.inputs)}") + strs.append( + f" outputs: {map_structure(lambda x: x.name, self.outputs)}" + ) + strs.append(" statements: ") + for stmt in self.statements: + strs.append(f" {stmt}") + return "\n".join(strs) + + def __repr__(self): + return self.__str__() + + def graph_size(self): + call_layers = [x for x in self.statements if x.type == "layer"] + return len(self.statements) + len(call_layers) + + +@Singleton +class StatementIRFactory: + """ + It is used to create a StatementIR. + """ + + def __init__(self): + self.cache = {} + self.name_generator = NameGenerator("SIR_") + + def __getitem__(self, key): + return self.cache[key] + + def create(self, input_name=None): + if input_name: + name = input_name + else: + name = self.name_generator.next() + + sir = StatementIR(name) + self.cache[name] = sir + return sir + + def update(self, stmt_ir): + name = stmt_ir.name + self.cache[name] = stmt_ir + + def clear(self): + want_clear = [ + key + for key in self.cache.keys() + if self.name_generator.match_name(key) + ] + for key in want_clear: + del self.cache[key] + + +@Singleton +class SIRRuntimeCache: + """ + It is used to cache the runtime information of the StatementIR. + """ + + def __init__(self): + self.cache = {} + # { name : (inputs, outputs, free_vars) } + # inputs : can be used when call_SIR, if free_vars exist + # outputs : used for generator new ProxyTensor output before fallback + # free_vars: (name, function) + + def __getitem__(self, key): + return self.cache[key] + + def has_key(self, key: str) -> bool: + """ + has_key is used to check whether the key is in the cache. + """ + return key in self.cache.keys() + + def set_origin_inputs(self, key: str, inputs: Any): + """ + Set Cache origin Inputs of the StatementIR + """ + if key in self.cache.keys(): + val = self.cache[key] + self.cache[key] = (inputs, val[1], val[2]) + else: + self.cache[key] = (inputs, None, None) + + def set_origin_outputs(self, key: str, outputs: Any): + """ + Set Cache origin outputs of the StatementIR + """ + if key in self.cache.keys(): + val = self.cache[key] + self.cache[key] = (val[0], outputs, val[2]) + else: + self.cache[key] = (None, outputs, None) + + def set_free_vars(self, key: str, free_vars: Any): + """ + Set Cache free variables of the StatementIR + """ + if key in self.cache.keys(): + val = self.cache[key] + self.cache[key] = (val[0], val[1], free_vars) + else: + self.cache[key] = (None, None, free_vars) + + def get_origin_inputs(self, key: str): + """ + Get the origin inputs of the StatementIR. + """ + if key in self.cache.keys(): + return self.cache[key][0] + else: + return None + + def get_origin_outputs(self, key: str): + """ + Get the origin outputs of the StatementIR. + """ + if key in self.cache.keys(): + return self.cache[key][1] + else: + return None + + def get_free_vars(self, key: str): + """ + Get the free variables of the StatementIR. + """ + if key in self.cache.keys(): + return self.cache[key][2] + else: + return None diff --git a/python/paddle/jit/sot/symbolic/symbolic_context.py b/python/paddle/jit/sot/symbolic/symbolic_context.py new file mode 100644 index 0000000000000..47f40bbcc9ec7 --- /dev/null +++ b/python/paddle/jit/sot/symbolic/symbolic_context.py @@ -0,0 +1,161 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from ..utils import log +from .compile_cache import CompileSIRCache +from .statement_ir import ( + ApiStatement, + CallStatement, + LayerStatement, + MethodStatement, + StatementIR, + StatementIRFactory, + Symbol, +) + + +class SymbolicTraceContext: + """ + SymbolicTraceContext is a context manager, which is used to record the symbolic trace. + + """ + + def __init__(self): + self.reset() + + def reset(self): + """ + Reset the context. + """ + + # TODO(dev): StatementIRFactory is a singleton, but SymbolicTraceContext is not. + # whether will two different SymbolicTraceContext objects be conflict ? + self.statement_factory = StatementIRFactory() + self.sir_stack = [self.statement_factory.create()] + + @property + def TOS(self): + """ + The top SIR of sir_stack. + + Returns: + StatementIR: the top of stack. + """ + + return self.sir_stack[-1] + + def call_SIR(self, sirname, inputs, outputs, stacks): + """ + Call a SIR, which is a subgraph. + """ + + stmt = CallStatement(sirname, inputs, outputs, stacks) + self.TOS.add_statement(stmt) + + def call_API(self, api, inputs, outputs, stacks): + """ + Call a paddle api. + """ + + assert callable(api), "call_API must receive a paddle api." + stmt = ApiStatement(api, inputs, outputs, stacks) + self.TOS.add_statement(stmt) + + def call_METHOD(self, method_name, inputs, outputs, stacks): + """ + Call a method of a api. The API here can be python or Paddle + """ + assert isinstance( + method_name, str + ), "call_METHOD must method api name. string." + assert isinstance( + inputs[0][0], Symbol + ), "call_METHOD must first augument must be Symbol Variable." + stmt = MethodStatement(method_name, inputs, outputs, stacks) + self.TOS.add_statement(stmt) + + def call_LAYER(self, layer, inputs, outputs, stacks): + """ + Call a layer of a api. + """ + stmt = LayerStatement(layer, inputs, outputs, stacks) + self.TOS.add_statement(stmt) + + def get_sir(self, name: str): + """ + Get a SIR from statement_factory. + + Args: + name (str): the name of SIR. + + Returns: + StatementIR: the SIR. + """ + return self.statement_factory[name] + + def reset_TOS(self): + """ + Reset the TOS. + """ + self.sir_stack.pop() + self.sir_stack.append(self.statement_factory.create()) + + def replace_TOS(self, sir): + """ + Use deepcopyed sir to replace the TOS. + This function will update statment_factory. + """ + self.sir_stack.pop() + self.sir_stack.append(sir) + self.statement_factory.update(sir) + + def compile_do_nothing(self, ret_vals): + """ + Return a dummy function, which will return an empty list. + + Args: + ret_vals (list[Symbol]): the return values of the function. + """ + + def dummy_func(*args, **kwargs): + return [] + + # return None function + dummy_stmt_ir = StatementIR("dummy_func") + dummy_stmt_ir.outputs = [] + dummy_stmt_ir.inputs = [] + return dummy_func, dummy_stmt_ir + + def compile_fn(self, ret_vals, **kwargs): + """ + start compile and return the python function, which must can be to_static without errors. + """ + cur_sir: StatementIR = self.TOS + # step0: if no statement, return a dummy function + if len(cur_sir.statements) == 0: + return self.compile_do_nothing(ret_vals) + # step1: analyse sir inputs and outputs + cur_sir.inputs = cur_sir.analyse_inputs() + # TODO: output analysis + cur_sir.outputs = ret_vals + log(2, "start subgraph compile and execution.\n") + log(2, self.TOS, "\n") + # step2: call compile_sir and get python function, third cache is triggered here. + static_func = CompileSIRCache()(self, cur_sir.name, **kwargs) + # step3: GC and reset TOS + # self.reset_TOS() + + return static_func, cur_sir diff --git a/python/paddle/jit/sot/translate.py b/python/paddle/jit/sot/translate.py new file mode 100644 index 0000000000000..88f569460a5ca --- /dev/null +++ b/python/paddle/jit/sot/translate.py @@ -0,0 +1,125 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, TypeVar + +import paddle + +from .opcode_translator import eval_frame_callback +from .utils import GraphLogger, StepInfoManager, StepState, log_do + +if TYPE_CHECKING: + from typing_extensions import ParamSpec + + P = ParamSpec("P") + R = TypeVar("R") + + +def symbolic_translate(fn: Callable[P, R], **kwargs) -> Callable[P, R]: + """ + This function is the entry point of PaddleSOT. It sets eval_frame_callback before input + function to achieve Opcode-level translation. The translation process depends on the + simulation execution, in which information will be collected, especially the network + code. After the simulation execution is completed, the network code will be compiled + into a static graph Program to improve performance. + + Args: + fn: The input function. + + Returns: + Callable, The wrapped function. + + Examples: + >>> # doctest: +SKIP("Cound not get source code of function foo."") + >>> import paddle + >>> import numpy as np + >>> from sot.translate import symbolic_translate + >>> def foo(cond: paddle.Tensor, x: paddle.Tensor): + ... x += 1 + ... if cond: + ... x += 1 + ... else: + ... x -= 1 + ... return x + >>> symbolic_translate_foo = symbolic_translate(foo) + >>> # For the true branch, the output is 2. + >>> cond = paddle.to_tensor(True) + >>> x = paddle.to_tensor(0) + >>> dygraph_out = foo(cond, x) + >>> symbolic_translate_out = symbolic_translate_foo(cond, x) + >>> dygraph_out + Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True, + 2) + >>> symbolic_translate_out + Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True, + 2) + >>> np.testing.assert_allclose( + ... dygraph_out.numpy(), symbolic_translate_out.numpy() + ... ) + >>> # For the false branch, the output is 0. + >>> cond = paddle.to_tensor(False) + >>> dygraph_out = foo(cond, x) + >>> symbolic_translate_out = symbolic_translate_foo(cond, x) + >>> dygraph_out + Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True, + 0) + >>> symbolic_translate_out + Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True, + 0) + >>> np.testing.assert_allclose( + ... dygraph_out.numpy(), symbolic_translate_out.numpy() + ... ) + + """ + + def callback(frame): + return eval_frame_callback(frame, **kwargs) + + def impl_sot(*args: P.args, **kwargs: P.kwargs) -> R: + assert hasattr( + fn, "__code__" + ), "Target function doesn't have code for simulating." + StepInfoManager().sot_step() + GraphLogger().clear() + paddle.framework.core.set_eval_frame(callback) + try: + outs = fn(*args, **kwargs) + except Exception as e: + raise e + finally: + paddle.framework.core.set_eval_frame(None) + + log_do(1, lambda: GraphLogger().print_info()) + return outs + + def impl_dynamic(*args: P.args, **kwargs: P.kwargs) -> R: + outs = fn(*args, **kwargs) + return outs + + def impl(*args: P.args, **kwargs: P.kwargs) -> R: + with StepInfoManager().step_guard(fn.__code__): + state = StepInfoManager().current_state + + if state == StepState.RUN_SOT: + return impl_sot(*args, **kwargs) + elif state == StepState.RUN_DYN: + return impl_dynamic(*args, **kwargs) + elif state == StepState.COLLECT_INFO: + return StepInfoManager().collect_info( + impl_dynamic, impl_sot, *args, **kwargs + ) + + return impl diff --git a/python/paddle/jit/sot/utils/__init__.py b/python/paddle/jit/sot/utils/__init__.py new file mode 100644 index 0000000000000..a1f26ea622772 --- /dev/null +++ b/python/paddle/jit/sot/utils/__init__.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .code_status import CodeStatus # noqa: F401 +from .exceptions import ( # noqa: F401 + BreakGraphError, + FallbackError, + InnerError, + inner_error_default_handler, +) +from .magic_methods import magic_method_builtin_dispatch # noqa: F401 +from .paddle_api_config import ( # noqa: F401 + is_break_graph_tensor_methods, + is_inplace_api, + paddle_tensor_methods, +) +from .utils import ( # noqa: F401 + Cache, + GraphLogger, + NameGenerator, + OrderedSet, + ResumeFnNameFactory, + Singleton, + SotUndefinedVar, + StepInfoManager, + StepState, + cost_model, + count_if, + current_tmp_name_records, + execute_time, + flatten_extend, + get_unbound_method, + hashable, + in_paddle_module, + is_break_graph_api, + is_builtin_fn, + is_clean_code, + is_paddle_api, + is_strict_mode, + list_contain_by_id, + list_find_index_by_id, + log, + log_do, + map_if, + map_if_extend, + meta_str, + min_graph_size, + no_eval_frame, + show_trackers, + tmp_name_guard, +) diff --git a/python/paddle/jit/sot/utils/code_status.py b/python/paddle/jit/sot/utils/code_status.py new file mode 100644 index 0000000000000..007e77f634004 --- /dev/null +++ b/python/paddle/jit/sot/utils/code_status.py @@ -0,0 +1,90 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from enum import Enum + +import paddle + +from .utils import Singleton, log + + +class CodeState(Enum): + UNKNOW = 1 + WITH_GRAPH = 2 + WITHOUT_GRAPH = 3 + + +class CodeInfo: + def __init__(self): + self.state = CodeState.UNKNOW + self.counter = 0 + + def __repr__(self): + return f"state: {self.state}, counter: {self.counter}" + + +@Singleton +class CodeStatus: + WITH_GRAPH_API = [ + paddle.nn.Layer.__call__.__code__, + paddle.nn.Layer._dygraph_call_func.__code__, + ] + + def __init__(self): + self.code_map = {} + self.setup_code_map() + + def setup_code_map(self): + for code in self.WITH_GRAPH_API: + info = CodeInfo() + info.state = CodeState.WITH_GRAPH + self.code_map[code] = info + + def clear(self): + self.code_map.clear() + self.setup_code_map() + + def is_code_without_graph(self, code): + if code not in self.code_map: + info = CodeInfo() + self.code_map[code] = info + else: + info = self.code_map[code] + + if info.state == CodeState.WITHOUT_GRAPH: + return True + if info.state == CodeState.UNKNOW: + info.counter += 1 + if info.counter >= 10: + log( + 3, + f"[CodeStatus] Switch state to WITHOUT_GRAPH for {code}\n", + ) + info.state = CodeState.WITHOUT_GRAPH + return False + + def trace_back_frames(self): + frame = inspect.currentframe() + while frame.f_back is not None: + frame = frame.f_back + code = frame.f_code + if code in self.code_map: + info = self.code_map[code] + if info.state != CodeState.WITH_GRAPH: + log( + 3, + f"[CodeStatus] Switch state to WITH_GRAPH for {code}\n", + ) + info.state = CodeState.WITH_GRAPH diff --git a/python/paddle/jit/sot/utils/exceptions.py b/python/paddle/jit/sot/utils/exceptions.py new file mode 100644 index 0000000000000..ff26f4ee2ba10 --- /dev/null +++ b/python/paddle/jit/sot/utils/exceptions.py @@ -0,0 +1,64 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import traceback + + +class SotErrorBase(Exception): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + from ..opcode_translator.breakpoint import BreakpointManager + + BreakpointManager().on_event(f"{self.__class__.__name__}") + + def print(self): + lines = traceback.format_tb(self.__traceback__) + print("".join(lines)) + + +class InnerError(SotErrorBase): + pass + + +class HasNoAttributeError(InnerError): + pass + + +class FallbackError(SotErrorBase): + def __init__(self, msg, disable_eval_frame=False): + super().__init__(msg) + self.disable_eval_frame = disable_eval_frame + + +# raise in inline function call strategy. +class BreakGraphError(SotErrorBase): + pass + + +def inner_error_default_handler(func, message_fn): + """Wrap function and an error handling function and throw an InnerError.""" + + def impl(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + message = message_fn(*args, **kwargs) + origin_exception_message = "\n".join( + traceback.format_exception(type(e), e, e.__traceback__) + ) + raise InnerError( + f"{message}.\nOrigin Exception is: \n {origin_exception_message}" + ) from e + + return impl diff --git a/python/paddle/jit/sot/utils/magic_methods.py b/python/paddle/jit/sot/utils/magic_methods.py new file mode 100644 index 0000000000000..56b20abdb0541 --- /dev/null +++ b/python/paddle/jit/sot/utils/magic_methods.py @@ -0,0 +1,130 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import operator +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable + +from .utils import hashable + +if TYPE_CHECKING: + BinaryOp = Callable[[Any, Any], Any] + UnaryOp = Callable[[Any], Any] + + +INPLACE_BINARY_OPS_TO_MAGIC_NAMES: dict[BinaryOp, tuple[str, BinaryOp]] = { + # inplace op fn: (magic name, non-inplace op fn) + operator.iadd: ("__iadd__", operator.add), + operator.iand: ("__iand__", operator.and_), + operator.iconcat: ("__iconcat__", operator.concat), + operator.ifloordiv: ("__ifloordiv__", operator.floordiv), + operator.ilshift: ("__ilshift__", operator.lshift), + operator.imatmul: ("__imatmul__", operator.matmul), + operator.imod: ("__imod__", operator.mod), + operator.imul: ("__imul__", operator.mul), + operator.ior: ("__ior__", operator.or_), + operator.ipow: ("__ipow__", operator.pow), + operator.irshift: ("__irshift__", operator.rshift), + operator.isub: ("__isub__", operator.sub), + operator.itruediv: ("__itruediv__", operator.truediv), + operator.ixor: ("__ixor__", operator.xor), +} + +NON_INPLACE_BINARY_OPS_TO_MAGIC_NAMES: dict[ + BinaryOp, tuple[str, str | None] +] = { + # op fn: (magic name, reverse magic name) + operator.add: ("__add__", "__radd__"), + operator.and_: ("__and__", "__rand__"), + operator.contains: ("__contains__", None), + operator.delitem: ("__delitem__", None), + operator.eq: ("__eq__", "__eq__"), + operator.floordiv: ("__floordiv__", "__rfloordiv__"), + operator.ge: ("__ge__", "__le__"), + operator.getitem: ("__getitem__", None), + operator.gt: ("__gt__", "__lt__"), + operator.le: ("__le__", "__ge__"), + operator.lshift: ("__lshift__", "__rlshift__"), + operator.lt: ("__lt__", "__gt__"), + operator.matmul: ("__matmul__", "__rmatmul__"), + operator.mod: ("__mod__", "__rmod__"), + operator.mul: ("__mul__", "__rmul__"), + operator.ne: ("__ne__", "__ne__"), + operator.or_: ("__or__", "__ror__"), + operator.pow: ("__pow__", "__rpow__"), + operator.rshift: ("__rshift__", "__rrshift__"), + operator.sub: ("__sub__", "__rsub__"), + operator.truediv: ("__truediv__", "__rtruediv__"), + operator.xor: ("__xor__", "__rxor__"), +} + +UNARY_OPS_TO_MAGIC_NAMES: dict[UnaryOp, str] = { + operator.neg: "__neg__", + operator.invert: "__invert__", + operator.pos: "__pos__", + operator.abs: "__abs__", + operator.index: "__index__", + operator.inv: "__inv__", + operator.invert: "__invert__", + operator.not_: "__not__", + operator.pos: "__pos__", + operator.truth: "__bool__", + bool: "__bool__", + abs: "__abs__", + float: "__float__", + len: "__len__", + int: "__int__", +} +# TODO(SigureMo): support any, all, sum + + +INPLACE_BINARY_OPS = set(INPLACE_BINARY_OPS_TO_MAGIC_NAMES.keys()) +NON_INPLACE_BINARY_OPS = set(NON_INPLACE_BINARY_OPS_TO_MAGIC_NAMES.keys()) +BINARY_OPS = INPLACE_BINARY_OPS | NON_INPLACE_BINARY_OPS +UNARY_OPS = set(UNARY_OPS_TO_MAGIC_NAMES.keys()) + + +@dataclass +class MagicMethod: + name: str + is_inplace: bool = False + is_reverse: bool = False + + +def magic_method_builtin_dispatch(fn: BinaryOp | UnaryOp) -> list[MagicMethod]: + if not hashable(fn): + return [] + if fn in INPLACE_BINARY_OPS: + inplace_magic_name, non_inplace_op = INPLACE_BINARY_OPS_TO_MAGIC_NAMES[ + fn + ] + return [ + MagicMethod(inplace_magic_name, is_inplace=True) + ] + magic_method_builtin_dispatch(non_inplace_op) + elif fn in NON_INPLACE_BINARY_OPS: + magic_name, reverse_magic_name = NON_INPLACE_BINARY_OPS_TO_MAGIC_NAMES[ + fn + ] + magic_methods = [MagicMethod(magic_name)] + if reverse_magic_name is not None: + magic_methods.append( + MagicMethod(reverse_magic_name, is_reverse=True) + ) + return magic_methods + elif fn in UNARY_OPS: + magic_name = UNARY_OPS_TO_MAGIC_NAMES[fn] + return [MagicMethod(magic_name)] + return [] diff --git a/python/paddle/jit/sot/utils/paddle_api_config.py b/python/paddle/jit/sot/utils/paddle_api_config.py new file mode 100644 index 0000000000000..06852d186a76c --- /dev/null +++ b/python/paddle/jit/sot/utils/paddle_api_config.py @@ -0,0 +1,102 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import paddle + + +def is_inplace_api(func): + inplace_apis = {paddle.static.setitem} + return func in inplace_apis + + +def get_tensor_methods(): + return [ + member_name + for member_name, member in inspect.getmembers(paddle.static.Variable) + if inspect.isfunction(member) + ] + + +def get_paddle_api(): + modules = [ + paddle, + paddle.nn.functional, + paddle.linalg, + paddle.signal, + paddle.fft, + paddle.vision.ops, + ] + special_paddle_apis = [paddle.tensor.fill_constant] + non_operator_related_apis = [ + paddle.in_dynamic_mode, + paddle.save, + paddle.load, + paddle.get_cuda_rng_state, + paddle.set_rng_state, + paddle.set_cuda_rng_state, + paddle.get_rng_state, + paddle.set_default_dtype, + paddle.check_shape, + paddle.summary, + paddle.finfo, + paddle.iinfo, + paddle.enable_static, + paddle.disable_static, + paddle.is_grad_enabled, + ] + # TODO: users should not call static_apis, but we need to use, so add static_apis here temporary + static_apis = [paddle.static.setitem, paddle.static.accuracy] + paddle_api_list = [] + for module in modules: + for fn_name in getattr(module, "__all__", []): + fn = getattr(module, fn_name) + if inspect.isfunction(fn): + paddle_api_list.append(fn) + return list( + set(special_paddle_apis) + | set(static_apis) + | set(paddle_api_list) - set(non_operator_related_apis) + ) + + +paddle_tensor_methods = get_tensor_methods() +paddle_api_list = get_paddle_api() + +# TODO(Aurelius84): It seems that we use it to judge 'in_paddle_module()'. +# Bug what does 'is_paddle_module' really means? Is all paddle.xx sub module +# considered as paddle module? +paddle_api_module_prefix = { + "paddle.nn.functional", + "paddle.nn.layer.activation", +} + +break_graph_set = set() + + +break_graph_tensor_method = { + 'register_hook', + 'numpy', + 'clear_gradient', + # TODO: Browse all possible functions and make prior judgments. +} + + +def is_break_graph_tensor_methods(method_name): + return method_name in break_graph_tensor_method + + +def add_break_graph_apis(apis: list): + break_graph_set.update(apis) diff --git a/python/paddle/jit/sot/utils/utils.py b/python/paddle/jit/sot/utils/utils.py new file mode 100644 index 0000000000000..ad4ff3faaa4dc --- /dev/null +++ b/python/paddle/jit/sot/utils/utils.py @@ -0,0 +1,730 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import builtins +import inspect +import os +import time +import types +import weakref +from collections import OrderedDict +from contextlib import contextmanager +from enum import Enum +from typing import Any, Generic, Iterable, Iterator, TypeVar +from weakref import WeakValueDictionary + +import numpy as np + +import paddle +from paddle.framework import Program +from paddle.utils import flatten, map_structure + +from .paddle_api_config import ( + break_graph_set, + paddle_api_list, + paddle_api_module_prefix, +) + +T = TypeVar("T") + + +def cost_model(): + return os.environ.get("COST_MODEL", "True") == "True" + + +def min_graph_size(): + return int(os.environ.get("MIN_GRAPH_SIZE", 10)) + + +class Singleton(Generic[T]): + def __init__(self, cls: type[T]): + self._cls = cls + self._instance = {} + + def __call__(self) -> T: + if self._cls not in self._instance: + self._instance[self._cls] = self._cls() + return self._instance[self._cls] + + +class NameGenerator: + def __init__(self, prefix): + self.counter = 0 + self.prefix = prefix + + def next(self): + name = self.prefix + str(self.counter) + self.counter += 1 + return name + + def match_name(self, name: str) -> bool: + return name.startswith(self.prefix) + + +_tmp_name_records = None + + +class TmpNameRecords: + def __init__(self): + self.name_generator = NameGenerator(prefix="_sot_tmp_") + self.tmp_names_record = OrderedDict() + + def next_name(self): + return self.name_generator.next() + + def add_tmp_var(self, expr): + if expr in self.tmp_names_record: + return self.tmp_names_record[expr] + else: + tmp_name = self.next_name() + self.tmp_names_record[expr] = tmp_name + return tmp_name + + +@contextmanager +def tmp_name_guard(): + global _tmp_name_records + old = _tmp_name_records + _tmp_name_records = TmpNameRecords() + yield + _tmp_name_records = old + + +def current_tmp_name_records(): + global _tmp_name_records + return _tmp_name_records + + +@Singleton +class ResumeFnNameFactory: + def __init__(self) -> None: + self.gen = NameGenerator('resume_') + + def next(self): + name = self.gen.next() + return name + + +def log(level, *args): + cur_level = int(os.environ.get("SOT_LOG_LEVEL", "0")) + if level <= cur_level: + print(*args, end="") + + +def log_do(level, fn): + cur_level = int(os.environ.get("SOT_LOG_LEVEL", "0")) + if level <= cur_level: + fn() + + +def no_eval_frame(func): + def no_eval_frame_func(*args, **kwargs): + old_cb = paddle.framework.core.set_eval_frame(None) + try: + retval = func(*args, **kwargs) + except: + raise + finally: + paddle.framework.core.set_eval_frame(old_cb) + return retval + + return no_eval_frame_func + + +def is_paddle_api(func): + if isinstance(func, paddle.nn.Layer): # ignore all the classes + return False + if hasattr(func, "__self__"): # ignore all the methods + return False + if inspect.isclass( + func + ): # paddle.Tensor should not be wrapped, but how about other situations? + return False + return in_paddle_module(func) or func in paddle_api_list + + +def is_builtin_fn(fn): + special_builtin_fns = [weakref.ref] + if fn in special_builtin_fns: + return True + if isinstance(fn, types.BuiltinFunctionType): + return True + for member_name, member in inspect.getmembers(builtins): + if member is fn and isinstance(member, type): + return True + return False + + +def in_paddle_module(func): + if hasattr(func, "__module__"): + module_str = func.__module__ + if module_str is None: + return False + log(5, "find paddle function with __module__: ", module_str, "\n") + if hasattr(func, "__name__"): + log( + 5, " with __name__ : ", func.__name__, "\n" + ) + log(5, " with results : ") + for prefix in paddle_api_module_prefix: + if module_str.startswith(prefix): + log(5, " True\n") + return True + log(5, " False\n") + return False + + +def is_break_graph_api(func): + return func in break_graph_set + + +def map_if(*structures, pred, true_fn, false_fn): + def replace(*args): + if pred(*args): + return true_fn(*args) + return false_fn(*args) + + return map_structure(replace, *structures) + + +def flatten_extend(structure): + for item in flatten(structure): + if isinstance(item, slice): + yield item.start + yield item.stop + yield item.step + else: + yield item + + +def map_if_extend(structure, pred, true_fn, false_fn): + """support extended structures like slice and SliceVariable""" + + def wrapped_pred(x): + if isinstance(x, slice): + return True + return pred(x) + + def wrapped_true_fn(x): + if isinstance(x, (slice)): + l = [x.start, x.stop, x.step] + l = map_if_extend(l, pred, true_fn, false_fn) + return slice(*l) + return true_fn(x) + + return map_if( + structure, pred=wrapped_pred, true_fn=wrapped_true_fn, false_fn=false_fn + ) + + +def count_if(*structures, pred): + def is_true(*args): + if pred(*args): + return 1 + return 0 + + return sum(flatten(map_structure(is_true, *structures))) + + +class Cache: + def __init__(self, weak=False): + if not weak: + self.cache = {} + else: + self.cache = WeakValueDictionary() + self.hit_num = 0 + + def __call__(self, *args, **kwargs): + cache_key = self.key_fn(*args, **kwargs) + if cache_key is None: + return self.value_fn(*args, **kwargs) + if cache_key in self.cache: + log(5, "cache hit: ", cache_key, "\n") + self.hit_num += 1 + return self.cache[cache_key] + value = self.value_fn(*args, **kwargs) + self.cache[cache_key] = value + return value + + def clear(self): + self.cache.clear() + self.hit_num = 0 + + def key_fn(self, *args, **kwargs): + raise NotImplementedError() + + def value_fn(self, *args, **kwargs): + raise NotImplementedError() + + +def execute_time(func): + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + execution_time = end_time - start_time + print("Execute time:", execution_time) + return result + + return wrapper + + +def meta_str(shape, dtype, stop_gradient): + return f"(shape: {shape}, dtype: {dtype}, stop_gradient: {stop_gradient})" + + +def is_strict_mode(): + return os.environ.get("STRICT_MODE", "0") == "1" + + +def show_trackers() -> str | None: + return os.environ.get("SHOW_TRACKERS", None) + + +def is_clean_code() -> bool: + return os.environ.get('CLEAN_CODE', "False") == "True" + + +def list_find_index_by_id(li: list[Any], item: Any) -> int: + return [id(it) for it in li].index(id(item)) + + +def list_contain_by_id(li: list[Any], item: Any) -> int: + return id(item) in [id(it) for it in li] + + +def get_unbound_method(obj, name): + # TODO(dev): Consider the case of patching methods to instances + return getattr(obj.__class__, name) + + +@Singleton +class GraphLogger: + graph_num: int + op_num: int + graphs: list[Program] + ops: list[paddle.base.framework.Operator] + + def __init__(self): + self.clear() + + def clear(self): + self.graph_num = 0 + self.op_num = 0 + self.graphs = [] + self.ops = [] + + def get_graph_num(self): + return self.graph_num + + def get_op_num(self): + return self.op_num + + def add_subgraph(self, program: Program): + self.graph_num += 1 + self.graphs.append(program) + + for block in program.blocks: + sub_op = [] + for op in block.ops: + self.op_num += 1 + sub_op.append(op) + self.ops.append(sub_op) + + def add_subgprah_info(self, strs): + for i in range(len(self.graphs)): + strs.append( + "------------------------------------------------------" + ) + + strs.append(f"subgraph {i}, OpNum: {len(self.ops[i])}") + strs.append(f"{self.graphs[i]}") + + def __str__(self): + strs = [] + strs.append("---------------- PaddleSOT graph info ----------------") + strs.append(f"SubgraphNum: {self.get_graph_num()}") + strs.append(f"OpNum: {self.get_op_num()}") + + # We can display every subgraph info + log_do(5, lambda: self.add_subgprah_info(strs)) + + strs.append("---------------- PaddleSOT graph info ----------------") + return "\n".join(strs) + + def __repr__(self): + return self.__str__() + + def print_info(self): + print(self) + + +@Singleton +class SotUndefinedVar: + pass + + +def hashable(obj): + try: + hash(obj) + return True + except TypeError as e: + return False + + +class OrderedSet(Generic[T]): + """ + A set that preserves the order of insertion. + """ + + _data: dict[T, None] + + def __init__(self, items: Iterable[T] | None = None): + """ + Examples: + >>> s = OrderedSet([1, 2, 3]) + >>> s + OrderedSet(1, 2, 3) + >>> s = OrderedSet() + >>> s + OrderedSet() + """ + self._data = dict.fromkeys(items) if items is not None else {} + + def __iter__(self) -> Iterator[T]: + """ + Examples: + >>> s = OrderedSet([1, 2, 3]) + >>> for item in s: + ... print(item) + 1 + 2 + 3 + """ + return iter(self._data) + + def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]: + """ + Union two sets. + + Args: + other: Another set to be unioned. + + Returns: + The union of two sets. + + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([2, 3, 4]) + >>> s1 | s2 + OrderedSet(1, 2, 3, 4) + """ + return OrderedSet(list(self) + list(other)) + + def __ior__(self, other: OrderedSet[T]): + """ + Union two sets in place. + + Args: + other: Another set to be unioned. + + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([2, 3, 4]) + >>> s1 |= s2 + >>> s1 + OrderedSet(1, 2, 3, 4) + """ + self._data.update(dict.fromkeys(other)) + return self + + def __and__(self, other: OrderedSet[T]) -> OrderedSet[T]: + """ + Intersect two sets. + + Args: + other: Another set to be intersected. + + Returns: + The intersection of two sets. + + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([2, 3, 4]) + >>> s1 & s2 + OrderedSet(2, 3) + """ + return OrderedSet([item for item in self if item in other]) + + def __iand__(self, other: OrderedSet[T]): + """ + Intersect two sets in place. + + Args: + other: Another set to be intersected. + + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([2, 3, 4]) + >>> s1 &= s2 + >>> s1 + OrderedSet(2, 3) + """ + self._data = {item: None for item in self if item in other} + return self + + def __sub__(self, other: OrderedSet[T]) -> OrderedSet[T]: + """ + Subtract two sets. + + Args: + other: Another set to be subtracted. + + Returns: + The subtraction of two sets. + + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([2, 3, 4]) + >>> s1 - s2 + OrderedSet(1) + """ + return OrderedSet([item for item in self if item not in other]) + + def __isub__(self, other: OrderedSet[T]): + """ + Subtract two sets in place. + + Args: + other: Another set to be subtracted. + + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([2, 3, 4]) + >>> s1 -= s2 + >>> s1 + OrderedSet(1) + """ + self._data = {item: None for item in self if item not in other} + return self + + def add(self, item: T): + """ + Add an item to the set. + + Args: + item: The item to be added. + + Examples: + >>> s = OrderedSet([1, 2, 3]) + >>> s.add(4) + >>> s + OrderedSet(1, 2, 3, 4) + """ + self._data.setdefault(item) + + def remove(self, item: T): + """ + Remove an item from the set. + + Args: + item: The item to be removed. + + Examples: + >>> s = OrderedSet([1, 2, 3]) + >>> s.remove(2) + >>> s + OrderedSet(1, 3) + """ + del self._data[item] + + def __contains__(self, item: T) -> bool: + """ + Examples: + >>> s = OrderedSet([1, 2, 3]) + >>> 1 in s + True + >>> 4 in s + False + """ + return item in self._data + + def __len__(self) -> int: + """ + Examples: + >>> s = OrderedSet([1, 2, 3]) + >>> len(s) + 3 + """ + return len(self._data) + + def __bool__(self) -> bool: + """ + Examples: + >>> s = OrderedSet([1, 2, 3]) + >>> bool(s) + True + >>> s = OrderedSet() + >>> bool(s) + False + """ + return bool(self._data) + + def __eq__(self, other: object) -> bool: + """ + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([1, 2, 3]) + >>> s1 == s2 + True + >>> s3 = OrderedSet([3, 2, 1]) + >>> s1 == s3 + False + """ + if not isinstance(other, OrderedSet): + return NotImplemented + return list(self) == list(other) + + def __repr__(self) -> str: + data_repr = ", ".join(map(repr, self._data)) + return f"OrderedSet({data_repr})" + + +class StepState(Enum): + COLLECT_INFO = 1 + RUN_SOT = 2 + RUN_DYN = 3 + + +class StepInfo: + REQUIRED_DYN_INFOS = 10 + REQUIRED_SOT_INFOS = 10 + + USED_DYN_INFOS = 5 + + COLLECT_INFO_MAX_STEP = 50 + CV_BOUNDARY = 0.1 + + BACK_TRACE_STEPS = 20 + + def __init__(self): + self.step_count = -1 + self.state = ( + StepState.COLLECT_INFO if cost_model() else StepState.RUN_SOT + ) + self.dyn_time_costs = [] + self.avg_dyn_time = 0 + self.sot_time_costs = [] + self.sot_step = -1 + + def add_dynamic_time_info(self, time_cost): + self.dyn_time_costs.append(time_cost) + if len(self.dyn_time_costs) == self.REQUIRED_DYN_INFOS: + self.avg_dyn_time = np.mean( + self.dyn_time_costs[-self.USED_DYN_INFOS :] + ) + + def add_sot_time_info(self, time_cost, current_code): + self.sot_time_costs.append(time_cost) + if len(self.sot_time_costs) == self.REQUIRED_SOT_INFOS: + avg_sot_time = np.mean(self.sot_time_costs) + log( + 1, + f"[Cost Model] sot: {avg_sot_time}, dyn: {self.avg_dyn_time}\n", + ) + if avg_sot_time < self.avg_dyn_time: + log(1, f"[Cost Model] Switch to RUN_SOT: {current_code} \n") + self.state = StepState.RUN_SOT + elif ( + self.step_count > self.COLLECT_INFO_MAX_STEP + or np.std(self.sot_time_costs) / avg_sot_time < self.CV_BOUNDARY + ): + log(1, f"[Cost Model] Switch to RUN_DYN: {current_code}\n") + self.state = StepState.RUN_DYN + else: + log(1, f"[Cost Model] Decision delayed: {current_code}\n") + self.sot_time_costs.clear() + + def need_back_trace(self): + return self.step_count < self.BACK_TRACE_STEPS + + def need_dynamic_info(self): + return len(self.dyn_time_costs) < self.REQUIRED_DYN_INFOS + + +@Singleton +class StepInfoManager: + def __init__(self): + self.step_record = {} + self.current_code = None + self.current_step_info = None + + @contextmanager + def step_guard(self, code): + try: + old_code = self.current_code + old_info = self.current_step_info + + self.current_code = code + if code not in self.step_record: + self.step_record[code] = StepInfo() + self.current_step_info = self.step_record[code] + + self.current_step_info.step_count += 1 + + log( + 2, + f"[Cost Model] New step start, current state is {self.current_state}\n", + ) + yield + finally: + self.current_code = old_code + self.current_step_info = old_info + + def sot_step(self): + self.current_step_info.sot_step += 1 + + def collect_info(self, impl_dynamic, impl_sot, /, *args, **kwargs): + if self.current_step_info.need_dynamic_info(): + start_time = time.perf_counter() + outs = impl_dynamic(*args, **kwargs) + time_cost = time.perf_counter() - start_time + self.current_step_info.add_dynamic_time_info(time_cost) + else: + start_time = time.perf_counter() + outs = impl_sot(*args, **kwargs) + time_cost = time.perf_counter() - start_time + self.current_step_info.add_sot_time_info( + time_cost, self.current_code + ) + return outs + + @property + def need_back_trace(self): + return self.current_step_info.need_back_trace() + + @property + def current_step(self): + return self.current_step_info.step_count + + @property + def current_state(self): + return self.current_step_info.state + + def clear(self): + self.step_record.clear() + self.current_code = None + self.current_step = -1 diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 108bfa5762804..d8dfcfb80bf8d 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -692,6 +692,9 @@ def _to_tensor_static(data, dtype=None, stop_gradient=None): # fix numpy default dtype if data.dtype in ['float16', 'float32', 'float64']: data = data.astype(paddle.get_default_dtype()) + # Windows default type is 'int32', while Linux/Mac is 'int64'. Unify they. + elif data.dtype in ['int32']: + data = data.astype("int64") if dtype: target_dtype = dtype diff --git a/python/setup.py.in b/python/setup.py.in index 4f2bce4bfbaad..10cbd7d54a86d 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -430,6 +430,13 @@ packages=['paddle', 'paddle.jit', 'paddle.jit.dy2static', 'paddle.jit.newir_dy2static', + 'paddle.jit.sot', + 'paddle.jit.sot.opcode_translator', + 'paddle.jit.sot.opcode_translator.executor', + 'paddle.jit.sot.opcode_translator.executor.variables', + 'paddle.jit.sot.opcode_translator.instruction_utils', + 'paddle.jit.sot.symbolic', + 'paddle.jit.sot.utils', 'paddle.inference', 'paddle.inference.contrib', 'paddle.inference.contrib.utils', diff --git a/setup.py b/setup.py index 221e0a0770e06..e12d676cb8a5f 100644 --- a/setup.py +++ b/setup.py @@ -1425,6 +1425,13 @@ def get_setup_parameters(): 'paddle.jit', 'paddle.jit.dy2static', 'paddle.jit.newir_dy2static', + 'paddle.jit.sot', + 'paddle.jit.sot.opcode_translator', + 'paddle.jit.sot.opcode_translator.executor', + 'paddle.jit.sot.opcode_translator.executor.variables', + 'paddle.jit.sot.opcode_translator.instruction_utils', + 'paddle.jit.sot.symbolic', + 'paddle.jit.sot.utils', 'paddle.inference', 'paddle.inference.contrib', 'paddle.inference.contrib.utils', diff --git a/test/dygraph_to_static/CMakeLists.txt b/test/dygraph_to_static/CMakeLists.txt index 4231938cf1ee6..1beadd642a66e 100644 --- a/test/dygraph_to_static/CMakeLists.txt +++ b/test/dygraph_to_static/CMakeLists.txt @@ -3,34 +3,9 @@ file( RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") +set(SOT_ENVS SOT_LOG_LEVEL=0 COST_MODEL=False MIN_GRAPH_SIZE=0 STRICT_MODE=0) set(GC_ENVS FLAGS_eager_delete_tensor_gb=0.0) -set(DY2ST_EAGER_TEST_ENVS ${GC_ENVS}) -set(TEST_EAGER_OPS - test_bmn - test_break_continue - test_ifelse - test_loop - test_mnist_amp - test_mnist_pure_fp16 - test_mobile_net - test_program_translator - test_ptb_lm - test_reinforcement_learning - test_resnet - test_resnet_amp - test_resnet_pure_fp16 - test_se_resnet - test_sentiment - test_seq2seq - test_tsm - test_word2vec - test_yolov3 - test_bert - test_cycle_gan - test_lstm - test_simnet - test_transformer) list(REMOVE_ITEM TEST_OPS test_lac) # NOTE(Aurelius84): In case of Windows CI, if open ON_INFER, RWLOCK of Scope # will be removed and will cause some random failed in multi-thread. @@ -52,12 +27,7 @@ if(NOT WITH_GPU) endif() foreach(TEST_OP ${TEST_OPS}) - list(FIND TEST_EAGER_OPS ${TEST_OP} WAS_FOUND) - if(NOT WAS_FOUND EQUAL -1) - py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${DY2ST_EAGER_TEST_ENVS}) - else() - py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) - endif() + py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS} ${SOT_ENVS}) endforeach() set_tests_properties(test_se_resnet PROPERTIES TIMEOUT 900) @@ -67,10 +37,11 @@ set_tests_properties(test_mobile_net PROPERTIES TIMEOUT 120) set_tests_properties(test_seq2seq PROPERTIES TIMEOUT 150) set_tests_properties(test_cycle_gan PROPERTIES TIMEOUT 150) set_tests_properties(test_bert PROPERTIES TIMEOUT 180) -set_tests_properties(test_basic_api_transformation PROPERTIES TIMEOUT 120) +set_tests_properties(test_basic_api_transformation PROPERTIES TIMEOUT 240) set_tests_properties(test_reinforcement_learning PROPERTIES TIMEOUT 120) set_tests_properties(test_transformer PROPERTIES TIMEOUT 200) -set_tests_properties(test_bmn PROPERTIES TIMEOUT 120) +set_tests_properties(test_bmn PROPERTIES TIMEOUT 300) +set_tests_properties(test_bert PROPERTIES TIMEOUT 240) #set_tests_properties(test_mnist PROPERTIES TIMEOUT 120) set_tests_properties(test_build_strategy PROPERTIES TIMEOUT 120) diff --git a/test/dygraph_to_static/dygraph_to_static_util.py b/test/dygraph_to_static/dygraph_to_static_util.py index 3202621228710..9a5b9bf22d92a 100644 --- a/test/dygraph_to_static/dygraph_to_static_util.py +++ b/test/dygraph_to_static/dygraph_to_static_util.py @@ -49,7 +49,8 @@ def to_sot(func): """ convert run fall_back to ast """ - enable_sot = os.environ.get("ENABLE_SOT", "False") == "True" + # TODO(SigureMo): ENABLE_SOT should always be True, remove this + enable_sot = os.environ.get("ENABLE_SOT", "True") == "True" def impl(*args, **kwargs): if enable_sot: diff --git a/test/dygraph_to_static/dygraph_to_static_utils_new.py b/test/dygraph_to_static/dygraph_to_static_utils_new.py new file mode 100644 index 0000000000000..5e0ebacd8e1e3 --- /dev/null +++ b/test/dygraph_to_static/dygraph_to_static_utils_new.py @@ -0,0 +1,320 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import inspect +import logging +import os +import unittest +from enum import Flag, auto +from functools import wraps + +import numpy as np + +from paddle import set_flags, static +from paddle.base import core + +""" +# Usage: +class MyTest(Dy2StTestBase): + @set_to_static_mode( + ToStaticMode.LEGACY_AST | ToStaticMode.SOT | ToStaticMode.PIR_AST + ) + @set_ir_mode(IrMode.LEGACY_PROGRAM | IrMode.PIR) + def test_case1(self): + raise ValueError("MyTest 1") + + def test_case2(self): + raise ValueError("MyTest 2") + + +class MyTest2(MyTest): + def test_case1(self): + raise ValueError("MyTest2 1") +""" + +logger = logging.getLogger("Dygraph to static utils") +logger.setLevel(logging.WARNING) + + +class ToStaticMode(Flag): + LEGACY_AST = auto() + PIR_AST = auto() + SOT = auto() + + def lower_case_name(self): + return self.name.lower() + + +class IrMode(Flag): + LEGACY_PROGRAM = auto() + PIR = auto() + + def lower_case_name(self): + return self.name.lower() + + +DEFAULT_TO_STATIC_MODE = ToStaticMode.LEGACY_AST | ToStaticMode.SOT +DEFAULT_IR_MODE = IrMode.LEGACY_PROGRAM + + +def in_sot_mode(): + return os.getenv("ENABLE_FALL_BACK", "False") == "True" + + +@contextlib.contextmanager +def enable_fallback_guard(enable): + flag = os.environ.get("ENABLE_FALL_BACK", None) + os.environ["ENABLE_FALL_BACK"] = enable + yield + if flag is not None: + os.environ["ENABLE_FALL_BACK"] = flag + else: + del os.environ["ENABLE_FALL_BACK"] + + +def to_legacy_ast_test(fn): + """ + convert run fall_back to ast + """ + + @wraps(fn) + def impl(*args, **kwargs): + logger.info("[AST] running AST") + with enable_fallback_guard("False"): + fn(*args, **kwargs) + + return impl + + +def to_sot_test(fn): + """ + convert run fall_back to ast + """ + + @wraps(fn) + def impl(*args, **kwargs): + logger.info("[SOT] running SOT") + with enable_fallback_guard("True"): + fn(*args, **kwargs) + + return impl + + +def to_pir_ast_test(fn): + raise TypeError("Don't enable PIR AST mode now!") + + +def to_legacy_program_test(fn): + def impl(*args, **kwargs): + logger.info("[Program] running legacy program") + return fn(*args, **kwargs) + + return impl + + +def to_pir_test(fn): + @wraps(fn) + def impl(*args, **kwargs): + logger.info("[PIR] running pir") + ir_outs = None + if os.environ.get('FLAGS_use_stride_kernel', False): + return + with static.scope_guard(static.Scope()): + with static.program_guard(static.Program()): + try: + new_ir_flag = 'FLAGS_enable_new_ir_in_executor' + os.environ[new_ir_flag] = 'True' + set_flags({new_ir_flag: True}) + ir_outs = fn(*args, **kwargs) + finally: + del os.environ[new_ir_flag] + set_flags({new_ir_flag: False}) + return ir_outs + + return impl + + +# Metaclass and BaseClass +class Dy2StTestMeta(type): + TO_STATIC_HANDLER_MAP = { + ToStaticMode.SOT: to_sot_test, + ToStaticMode.LEGACY_AST: to_legacy_ast_test, + ToStaticMode.PIR_AST: to_pir_ast_test, + } + + IR_HANDLER_MAP = { + IrMode.LEGACY_PROGRAM: to_legacy_program_test, + IrMode.PIR: to_pir_test, + } + + def __new__(cls, name, bases, attrs): + new_attrs = {} + original_test_cases = { + key: value + for key, value in attrs.items() + if key.startswith("test") and inspect.isfunction(value) + } + logger.info(f"[creating {name}]") + new_attrs.update( + { + key: value + for key, value in attrs.items() + if key not in original_test_cases + } + ) + for fn_name, fn in original_test_cases.items(): + logger.info(f"Generating {fn_name}") + # Disable inherited test cases + for base in bases: + for attr in dir(base): + if attr.startswith(fn_name): + new_attrs[attr] = None + fn_to_static_modes = getattr( + fn, "to_static_mode", DEFAULT_TO_STATIC_MODE + ) + fn_ir_modes = getattr(fn, "ir_mode", DEFAULT_IR_MODE) + fn_disabled_test_cases = getattr(fn, "disabled_test_cases", []) + logger.info(f"fn_to_static_modes: {fn_to_static_modes}") + logger.info(f"fn_ir_modes: {fn_ir_modes}") + logger.info(f"fn_disabled_test_cases: {fn_disabled_test_cases}") + # Get all valid test cases with to_static_mode and ir_mode + to_static_with_ir_modes = [ + (to_static_mode, ir_mode) + for to_static_mode in ToStaticMode + for ir_mode in IrMode + if to_static_mode & fn_to_static_modes and ir_mode & fn_ir_modes + ] + # Filter out disabled test cases and test cases already in compare groups + to_static_with_ir_modes = list( + filter( + lambda flags: (flags not in fn_disabled_test_cases), + to_static_with_ir_modes, + ) + ) + # Generate all test cases + for to_static_mode, ir_mode in to_static_with_ir_modes: + if ( + to_static_mode == ToStaticMode.PIR_AST + and ir_mode == IrMode.LEGACY_PROGRAM + ): + # PIR with LEGACY_PROGRAM is not a valid combination + continue + new_attrs[ + Dy2StTestMeta.test_case_name( + fn_name, to_static_mode, ir_mode + ) + ] = Dy2StTestMeta.convert_test_case(fn, to_static_mode, ir_mode) + return type.__new__(cls, name, bases, new_attrs) + + @staticmethod + def test_case_name(original_name: str, to_static_mode, ir_mode): + return f"{original_name}__{to_static_mode.lower_case_name()}_{ir_mode.lower_case_name()}" + + @staticmethod + def convert_test_case(fn, to_static_mode, ir_mode): + fn = Dy2StTestMeta.IR_HANDLER_MAP[ir_mode](fn) + fn = Dy2StTestMeta.TO_STATIC_HANDLER_MAP[to_static_mode](fn) + return fn + + +class Dy2StTestBase(unittest.TestCase, metaclass=Dy2StTestMeta): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +# Base decorators +def set_to_static_mode(mode: ToStaticMode): + def decorator(fn): + fn.to_static_mode = mode + return fn + + return decorator + + +def set_ir_mode(mode: IrMode): + def decorator(fn): + fn.ir_mode = mode + return fn + + return decorator + + +def disable_test_case(flags): + def decorator(fn): + disabled_test_cases = getattr(fn, "disabled_test_cases", []) + disabled_test_cases.append(flags) + fn.disabled_test_cases = disabled_test_cases + return fn + + return decorator + + +# Suger decorators +# These decorators can be simply composed by base decorators +def ast_only_test(fn): + fn = set_to_static_mode(ToStaticMode.LEGACY_AST)(fn) + return fn + + +def sot_only_test(fn): + fn = set_to_static_mode(ToStaticMode.SOT)(fn) + return fn + + +def test_with_new_ir(fn): + fn = set_ir_mode(IrMode.PIR)(fn) + return fn + + +def _test_and_compare_with_new_ir(fn): + @wraps(fn) + def impl(*args, **kwargs): + outs = fn(*args, **kwargs) + if core._is_bwd_prim_enabled() or core._is_fwd_prim_enabled(): + return outs + # Disable SOT + PIR test temprorily + if in_sot_mode(): + return outs + ir_outs = to_pir_test(fn)(*args, **kwargs) + np.testing.assert_equal( + outs, + ir_outs, + err_msg=f'Dy2St Unittest Check ({fn.__name__}) has diff \n' + + f'Expect {outs}\n' + + f'But Got {ir_outs}', + ) + return outs + + return impl + + +def test_and_compare_with_new_ir(need_check_output: bool = True): + def decorator(fn): + fn = set_ir_mode(IrMode.LEGACY_PROGRAM | IrMode.PIR)(fn) + if need_check_output: + logger.info(f"[need_check_output] {fn.__name__}") + fn = _test_and_compare_with_new_ir(fn) + return fn + + return decorator + + +# For debug +def show_all_test_cases(test_class): + logger.info(f"[showing {test_class.__name__}]") + for attr in dir(test_class): + if attr.startswith("test"): + fn = getattr(test_class, attr) + logger.info(f"{attr}: {fn}") diff --git a/test/dygraph_to_static/test_assert.py b/test/dygraph_to_static/test_assert.py index dc01413d0c8be..210e904454fd9 100644 --- a/test/dygraph_to_static/test_assert.py +++ b/test/dygraph_to_static/test_assert.py @@ -15,7 +15,11 @@ import unittest import numpy -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + ast_only_test, + test_and_compare_with_new_ir, +) import paddle from paddle import base @@ -33,7 +37,8 @@ def dyfunc_assert_non_variable(x=True): assert x -class TestAssertVariable(unittest.TestCase): +# @dy2static_unittest +class TestAssertVariable(Dy2StTestBase): def _run(self, func, x, with_exception, to_static): paddle.jit.enable_to_static(to_static) if with_exception: @@ -49,6 +54,7 @@ def _run_dy_static(self, func, x, with_exception): self._run(func, x, with_exception, False) @test_and_compare_with_new_ir(False) + @ast_only_test def test_non_variable(self): self._run_dy_static( dyfunc_assert_non_variable, x=False, with_exception=True @@ -58,6 +64,7 @@ def test_non_variable(self): ) @test_and_compare_with_new_ir(False) + @ast_only_test def test_bool_variable(self): self._run_dy_static( dyfunc_assert_variable, x=numpy.array([False]), with_exception=True @@ -67,6 +74,7 @@ def test_bool_variable(self): ) @test_and_compare_with_new_ir(False) + @ast_only_test def test_int_variable(self): self._run_dy_static( dyfunc_assert_variable, x=numpy.array([0]), with_exception=True diff --git a/test/dygraph_to_static/test_ast_util.py b/test/dygraph_to_static/test_ast_util.py index 52920d81433c6..c2468765e3438 100644 --- a/test/dygraph_to_static/test_ast_util.py +++ b/test/dygraph_to_static/test_ast_util.py @@ -17,7 +17,11 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + ast_only_test, + test_and_compare_with_new_ir, +) from ifelse_simple_func import ( dyfunc_with_if_else, dyfunc_with_if_else2, @@ -31,7 +35,8 @@ from paddle.utils import gast -class TestAST2Func(unittest.TestCase): +# @dy2static_unittest +class TestAST2Func(Dy2StTestBase): """ TestCase for the transformation from ast.AST into python callable function. """ @@ -43,6 +48,7 @@ def _ast2func(self, func): transformed_func, _ = ast_to_func(ast_root, func) return transformed_func + @ast_only_test def test_ast2func(self): def func(x, y): return x + y @@ -50,6 +56,7 @@ def func(x, y): x, y = 10, 20 self.assertEqual(func(x, y), self._ast2func(func)(x, y)) + @ast_only_test def test_ast2func_dygraph(self): paddle.disable_static() funcs = [dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else] @@ -62,6 +69,7 @@ def test_ast2func_dygraph(self): self.assertTrue((true_ret == test_ret).all()) @test_and_compare_with_new_ir(False) + @ast_only_test def test_ast2func_static(self): paddle.enable_static() @@ -80,6 +88,7 @@ def func(x): ret = exe.run(main_program, fetch_list=[true_ret, test_ret]) self.assertTrue((ret[0] == ret[1]).all()) + @ast_only_test def test_ast2func_error(self): with self.assertRaises(Exception) as e: self.assertRaises(TypeError, ast_to_func("x = a + b", 'foo')) diff --git a/test/dygraph_to_static/test_backward_without_params.py b/test/dygraph_to_static/test_backward_without_params.py index af70b9e7a2f95..e233259dc514e 100644 --- a/test/dygraph_to_static/test_backward_without_params.py +++ b/test/dygraph_to_static/test_backward_without_params.py @@ -15,7 +15,13 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + IrMode, + ToStaticMode, + disable_test_case, + test_and_compare_with_new_ir, +) import paddle @@ -24,16 +30,17 @@ class Net(paddle.nn.Layer): def __init__(self): super().__init__() - @paddle.jit.to_static def forward(self, x): out = x + 1 return out -class TestBackwardWithoutParams(unittest.TestCase): +# @dy2static_unittest +class TestBackwardWithoutParams(Dy2StTestBase): @test_and_compare_with_new_ir(False) + @disable_test_case((ToStaticMode.SOT, IrMode.PIR)) def test_run(self): - net = Net() + net = paddle.jit.to_static(Net()) x = paddle.ones([2, 2]) x.stop_gradient = False @@ -47,7 +54,6 @@ class ZeroSizeNet(paddle.nn.Layer): def __init__(self): super().__init__() - @paddle.jit.to_static def forward(self, x): y = paddle.randn((0,)) out = paddle.nn.functional.relu(x) @@ -55,10 +61,12 @@ def forward(self, x): return y, out -class TestZeroSizeNet(unittest.TestCase): +# @dy2static_unittest +class TestZeroSizeNet(Dy2StTestBase): @test_and_compare_with_new_ir(False) + @disable_test_case((ToStaticMode.SOT, IrMode.PIR)) def test_run(self): - net = ZeroSizeNet() + net = paddle.jit.to_static(ZeroSizeNet()) x = paddle.ones([2, 2]) x.stop_gradient = False _, out = net(x) diff --git a/test/dygraph_to_static/test_basic_api_transformation.py b/test/dygraph_to_static/test_basic_api_transformation.py index efa9caa17dd51..e0998b8fe1e67 100644 --- a/test/dygraph_to_static/test_basic_api_transformation.py +++ b/test/dygraph_to_static/test_basic_api_transformation.py @@ -16,7 +16,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle import base, to_tensor @@ -69,6 +72,7 @@ def dyfunc_bool_to_tensor(x): return paddle.to_tensor(True) +@dy2static_unittest class TestDygraphBasicApi_ToVariable(unittest.TestCase): def setUp(self): self.input = np.ones(5).astype("int32") @@ -230,6 +234,7 @@ def dyfunc_Prelu(input): return res +@dy2static_unittest class TestDygraphBasicApi(unittest.TestCase): # Compare results of dynamic graph and transformed static graph function which only # includes basic Api. @@ -396,6 +401,7 @@ def dyfunc_PolynomialDecay(): return paddle.to_tensor(lr) +@dy2static_unittest class TestDygraphBasicApi_CosineDecay(unittest.TestCase): def setUp(self): self.dygraph_func = dyfunc_CosineDecay @@ -539,6 +545,7 @@ def _dygraph_fn(): np.random.random(1) +@dy2static_unittest class TestDygraphApiRecognition(unittest.TestCase): def setUp(self): self.src = inspect.getsource(_dygraph_fn) diff --git a/test/dygraph_to_static/test_bert.py b/test/dygraph_to_static/test_bert.py index c7b5272ff4765..ba8e2350794aa 100644 --- a/test/dygraph_to_static/test_bert.py +++ b/test/dygraph_to_static/test_bert.py @@ -20,7 +20,11 @@ import numpy as np from bert_dygraph_model import PretrainModelLayer from bert_utils import get_bert_config, get_feed_data_reader -from dygraph_to_static_util import ast_only_test, test_with_new_ir +from dygraph_to_static_util import ( + ast_only_test, + dy2static_unittest, + test_with_new_ir, +) from predictor_utils import PredictorTools import paddle @@ -74,6 +78,7 @@ def __len__(self): return len(self.src_ids) +@dy2static_unittest class TestBert(unittest.TestCase): def setUp(self): self.bert_config = get_bert_config() diff --git a/test/dygraph_to_static/test_break_continue.py b/test/dygraph_to_static/test_break_continue.py index d3a2162dc787e..a803c1d4bf49e 100644 --- a/test/dygraph_to_static/test_break_continue.py +++ b/test/dygraph_to_static/test_break_continue.py @@ -205,6 +205,7 @@ def test_optim_break_in_while(x): return x +@dy2static_unittest class TestContinueInFor(unittest.TestCase): def setUp(self): self.input = np.zeros(1).astype('int64') diff --git a/test/dygraph_to_static/test_build_strategy.py b/test/dygraph_to_static/test_build_strategy.py index 83ed8d56751dd..85e934afb020b 100644 --- a/test/dygraph_to_static/test_build_strategy.py +++ b/test/dygraph_to_static/test_build_strategy.py @@ -84,6 +84,7 @@ def test_in_static_mode_mkldnn(self): paddle.base.set_flags({'FLAGS_use_mkldnn': False}) +@dy2static_unittest class TestError(unittest.TestCase): def test_type_error(self): def foo(x): diff --git a/test/dygraph_to_static/test_cache_program.py b/test/dygraph_to_static/test_cache_program.py index 0602b15b3054b..199c3e980e20c 100644 --- a/test/dygraph_to_static/test_cache_program.py +++ b/test/dygraph_to_static/test_cache_program.py @@ -76,6 +76,7 @@ def setUp(self): self.data = np.random.random((4, 10)).astype('float32') +@dy2static_unittest class TestCacheProgramWithOptimizer(unittest.TestCase): def setUp(self): self.dygraph_class = Linear @@ -125,6 +126,7 @@ def simple_func(x): return mean +@dy2static_unittest class TestConvertWithCache(unittest.TestCase): def test_cache(self): static_func = convert_to_static(simple_func) @@ -155,6 +157,7 @@ def sum_under_while(limit): return ret_sum +@dy2static_unittest class TestToOutputWithCache(unittest.TestCase): def test_output(self): with base.dygraph.guard(): diff --git a/test/dygraph_to_static/test_cast.py b/test/dygraph_to_static/test_cast.py index 7e2b0914a5fff..a01f2712cc764 100644 --- a/test/dygraph_to_static/test_cast.py +++ b/test/dygraph_to_static/test_cast.py @@ -15,7 +15,11 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, test_and_compare_with_new_ir +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + ast_only_test, + test_and_compare_with_new_ir, +) from paddle import base from paddle.jit.api import to_static @@ -60,7 +64,8 @@ def test_mix_cast(x): return x -class TestCastBase(unittest.TestCase): +# @dy2static_unittest +class TestCastBase(Dy2StTestBase): def setUp(self): self.place = ( base.CUDAPlace(0) @@ -90,6 +95,7 @@ def do_test(self): @ast_only_test # TODO: add new symbolic only test. @test_and_compare_with_new_ir(False) + # @set_to_static_mode(ToStaticMode.LEGACY_AST) def test_cast_result(self): res = self.do_test().numpy() self.assertTrue( @@ -186,9 +192,11 @@ def prepare(self): def set_func(self): self.func = test_not_var_cast - @ast_only_test # TODO: add new symbolic only test. + @ast_only_test @test_and_compare_with_new_ir(False) def test_cast_result(self): + # breakpoint() + # print("run once!!!") res = self.do_test() self.assertTrue(type(res) == int, msg='The casted dtype is not int.') ref_val = int(self.input) diff --git a/test/dygraph_to_static/test_cinn.py b/test/dygraph_to_static/test_cinn.py index 59a114d0aae58..84e619149c800 100644 --- a/test/dygraph_to_static/test_cinn.py +++ b/test/dygraph_to_static/test_cinn.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle @@ -42,6 +45,7 @@ def apply_to_static(net, use_cinn): return paddle.jit.to_static(net, build_strategy=build_strategy) +@dy2static_unittest class TestCINN(unittest.TestCase): def setUp(self): self.x = paddle.randn([2, 4]) diff --git a/test/dygraph_to_static/test_cinn_prim.py b/test/dygraph_to_static/test_cinn_prim.py index 0bf905ec846f9..2ed5326f7b9d0 100644 --- a/test/dygraph_to_static/test_cinn_prim.py +++ b/test/dygraph_to_static/test_cinn_prim.py @@ -172,6 +172,7 @@ def test_cinn_prim(self): ) +@dy2static_unittest class TestBackend(unittest.TestCase): @test_and_compare_with_new_ir(False) def test_backend(self): diff --git a/test/dygraph_to_static/test_cinn_prim_layer_norm.py b/test/dygraph_to_static/test_cinn_prim_layer_norm.py index 18c48883d75a6..42bf36d731eca 100644 --- a/test/dygraph_to_static/test_cinn_prim_layer_norm.py +++ b/test/dygraph_to_static/test_cinn_prim_layer_norm.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test +from dygraph_to_static_util import ast_only_test, dy2static_unittest import paddle import paddle.nn.functional as F @@ -52,6 +52,7 @@ def forward(self, x, w, b): return out[0] +@dy2static_unittest class TestPrimForward(unittest.TestCase): """ This case only tests prim_forward + to_static + cinn. Thus we need to @@ -124,6 +125,7 @@ def test_cinn_prim_forward(self): ) +@dy2static_unittest class TestPrimForwardAndBackward(unittest.TestCase): """ Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph diff --git a/test/dygraph_to_static/test_closure_analysis.py b/test/dygraph_to_static/test_closure_analysis.py index 95234565a6922..de1d1e12d6502 100644 --- a/test/dygraph_to_static/test_closure_analysis.py +++ b/test/dygraph_to_static/test_closure_analysis.py @@ -13,10 +13,12 @@ # limitations under the License. import inspect -import os import unittest -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_and_compare_with_new_ir, +) from numpy import append import paddle @@ -161,7 +163,7 @@ def test_push_pop_4(x, *args, **kargs): return l, k -class TestClosureAnalysis(unittest.TestCase): +class TestClosureAnalysis(Dy2StTestBase): def setUp(self): self.judge_type = "var and w_vars" self.init_dygraph_func() @@ -260,7 +262,7 @@ def init_dygraph_func(self): ] -class TestPushPopTrans(unittest.TestCase): +class TestPushPopTrans(Dy2StTestBase): @test_and_compare_with_new_ir(False) def test(self): def vlist_of_dict(x): @@ -270,7 +272,6 @@ def vlist_of_dict(x): return ma x = paddle.to_tensor([3]) - print(paddle.jit.to_static(vlist_of_dict).code) print(paddle.jit.to_static(vlist_of_dict)(x)) @test_and_compare_with_new_ir(False) @@ -284,7 +285,6 @@ def vlist_of_dict(x): return a x = paddle.to_tensor([3]) - print(paddle.jit.to_static(vlist_of_dict).code) print(paddle.jit.to_static(vlist_of_dict)(x)) @test_and_compare_with_new_ir(False) @@ -298,7 +298,6 @@ def vlist_of_dict(x): return a x = paddle.to_tensor([3]) - print(paddle.jit.to_static(vlist_of_dict).code) print(paddle.jit.to_static(vlist_of_dict)(x)) @test_and_compare_with_new_ir(False) @@ -312,7 +311,6 @@ def vlist_of_dict(x): return a x = paddle.to_tensor([3]) - print(paddle.jit.to_static(vlist_of_dict).code) print(paddle.jit.to_static(vlist_of_dict)(x)) @test_and_compare_with_new_ir(False) @@ -326,10 +324,8 @@ def vlist_of_dict(x): return a x = paddle.to_tensor([3]) - print(paddle.jit.to_static(vlist_of_dict).code) print(paddle.jit.to_static(vlist_of_dict)(x)) if __name__ == '__main__': - os.environ['ENABLE_FALL_BACK'] = "False" unittest.main() diff --git a/test/dygraph_to_static/test_convert_call.py b/test/dygraph_to_static/test_convert_call.py index 77ca5a88f012b..723d3f910debd 100644 --- a/test/dygraph_to_static/test_convert_call.py +++ b/test/dygraph_to_static/test_convert_call.py @@ -77,6 +77,7 @@ def dyfunc_with_staticmethod(x_v): return a.add(x_v, x_v) +@dy2static_unittest class TestRecursiveCall1(unittest.TestCase): def setUp(self): self.input = np.random.random([10, 16]).astype('float32') @@ -168,6 +169,7 @@ def forward(self, inputs): return self.act(out) +@dy2static_unittest class TestRecursiveCall2(unittest.TestCase): def setUp(self): self.input = np.random.random((1, 3, 3, 5)).astype('float32') diff --git a/test/dygraph_to_static/test_convert_call_generator.py b/test/dygraph_to_static/test_convert_call_generator.py index b33a41576498d..dd9d93c907c55 100644 --- a/test/dygraph_to_static/test_convert_call_generator.py +++ b/test/dygraph_to_static/test_convert_call_generator.py @@ -14,7 +14,11 @@ import unittest -from dygraph_to_static_util import ast_only_test, test_and_compare_with_new_ir +from dygraph_to_static_util import ( + ast_only_test, + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle.jit import to_static @@ -32,6 +36,7 @@ def main_func(): print(i) +@dy2static_unittest class TestConvertGenerator(unittest.TestCase): # fallback will ok. @ast_only_test diff --git a/test/dygraph_to_static/test_convert_operators.py b/test/dygraph_to_static/test_convert_operators.py index 420e7d8b1e887..02d0c09a70857 100644 --- a/test/dygraph_to_static/test_convert_operators.py +++ b/test/dygraph_to_static/test_convert_operators.py @@ -15,7 +15,11 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, test_and_compare_with_new_ir +from dygraph_to_static_util import ( + ast_only_test, + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle @@ -40,6 +44,7 @@ def forward(self): net.forward = "A string so that convert forward will fail" +@dy2static_unittest class TestConvertCall(unittest.TestCase): # fallback mode will raise a InnerError, it's ok. @ast_only_test @@ -68,6 +73,7 @@ def callable_list(x, y): self.assertEqual(callable_list(1, 2), 3) +@dy2static_unittest class TestConvertShapeCompare(unittest.TestCase): def test_non_variable(self): self.assertEqual( @@ -204,6 +210,7 @@ def forward(self, x): return out +@dy2static_unittest class TestChooseShapeAttrOrApiWithLayer(unittest.TestCase): @test_and_compare_with_new_ir(False) def test_tensor_shape(self): @@ -214,6 +221,7 @@ def test_tensor_shape(self): np.testing.assert_array_equal(out.numpy(), x.numpy()) +@dy2static_unittest class TestIfElseNoValue(unittest.TestCase): @test_and_compare_with_new_ir(False) def test_else_ret_none(self): diff --git a/test/dygraph_to_static/test_cpu_cuda_to_tensor.py b/test/dygraph_to_static/test_cpu_cuda_to_tensor.py index f5d6c833d16c1..b6e55b8900c1e 100644 --- a/test/dygraph_to_static/test_cpu_cuda_to_tensor.py +++ b/test/dygraph_to_static/test_cpu_cuda_to_tensor.py @@ -25,6 +25,7 @@ import paddle +@dy2static_unittest class TestCpuCuda(unittest.TestCase): def test_cpu_cuda(self): def func(x): @@ -38,6 +39,7 @@ def func(x): # print(paddle.jit.to_static(func)(x)) +@dy2static_unittest class TestToTensor(unittest.TestCase): @test_and_compare_with_new_ir(False) def test_to_tensor_with_variable_list(self): diff --git a/test/dygraph_to_static/test_cycle_gan.py b/test/dygraph_to_static/test_cycle_gan.py index 3484b27d5fac5..fb06a52407ec6 100644 --- a/test/dygraph_to_static/test_cycle_gan.py +++ b/test/dygraph_to_static/test_cycle_gan.py @@ -12,16 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License import os import random @@ -36,7 +26,10 @@ # Use GPU:0 to elimate the influence of other tasks. os.environ["CUDA_VISIBLE_DEVICES"] = "1" -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle.base.dygraph import to_variable @@ -686,6 +679,7 @@ def train(args, to_static): return np.array(loss_data) +@dy2static_unittest class TestCycleGANModel(unittest.TestCase): def setUp(self): self.args = Args() diff --git a/test/dygraph_to_static/test_declarative.py b/test/dygraph_to_static/test_declarative.py index 9d3e1e54b0ebb..12b098cc10ac5 100644 --- a/test/dygraph_to_static/test_declarative.py +++ b/test/dygraph_to_static/test_declarative.py @@ -17,7 +17,11 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + ast_only_test, + test_and_compare_with_new_ir, +) from test_basic_api_transformation import dyfunc_to_variable import paddle @@ -31,8 +35,6 @@ from paddle.nn import Layer from paddle.static import InputSpec -os.environ['ENABLE_FALL_BACK'] = "False" # NOTE: ast only - class SimpleNet(Layer): def __init__(self): @@ -89,7 +91,7 @@ def func_with_list_dict(self, dl): return z -class TestStaticFunctionInstance(unittest.TestCase): +class TestStaticFunctionInstance(Dy2StTestBase): def test_instance_same_class(self): with base.dygraph.guard(base.CPUPlace()): net_1 = SimpleNet() @@ -106,7 +108,7 @@ def test_instance_same_class(self): self.assertTrue(len(net_2.forward.program_cache) == 0) -class TestInputSpec(unittest.TestCase): +class TestInputSpec(Dy2StTestBase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() self.model_path = os.path.join(self.temp_dir.name, 'simple_net') @@ -115,6 +117,7 @@ def tearDown(self): self.temp_dir.cleanup() @test_and_compare_with_new_ir(False) + @ast_only_test def test_with_input_spec(self): with base.dygraph.guard(base.CPUPlace()): x = to_variable(np.ones([4, 10]).astype('float32')) @@ -175,6 +178,7 @@ def test_with_error(self): ) net.add_func(x, y) + @ast_only_test def test_concrete_program(self): with base.dygraph.guard(base.CPUPlace()): x = to_variable(np.ones([4, 10]).astype('float32')) @@ -210,11 +214,12 @@ def foo_func(a, b, c=1, d=2): return z -class TestDifferentInputSpecCacheProgram(unittest.TestCase): +class TestDifferentInputSpecCacheProgram(Dy2StTestBase): def setUp(self): paddle.jit.enable_to_static(True) @test_and_compare_with_new_ir(False) + @ast_only_test def test_with_different_input(self): with base.dygraph.guard(base.CPUPlace()): x_data = np.ones([16, 10]).astype('float32') @@ -260,6 +265,7 @@ def test_with_different_input(self): recent_program = foo.program_cache.last() self.assertTrue(first_program == recent_program) + @ast_only_test def test_get_concrete_program(self): foo = to_static(foo_func) @@ -301,6 +307,7 @@ def test_get_concrete_program(self): ) @test_and_compare_with_new_ir(False) + @ast_only_test def test_concrete_program(self): with base.dygraph.guard(base.CPUPlace()): # usage 1 @@ -324,7 +331,7 @@ def test_concrete_program(self): foo_3.concrete_program # noqa: B018 -class TestInputDefaultName(unittest.TestCase): +class TestInputDefaultName(Dy2StTestBase): def setUp(self): paddle.disable_static() self.net = SimpleNet() @@ -348,7 +355,8 @@ def test_nest_input(self): self.assert_default_name('func_with_list_dict', ['dl_0', 'x', 'y']) -class TestDeclarativeAPI(unittest.TestCase): +class TestDeclarativeAPI(Dy2StTestBase): + @ast_only_test def test_error(self): func = to_static(dyfunc_to_variable) @@ -366,19 +374,21 @@ def test_error(self): func(np.ones(5).astype("int32")) -class TestDecorateModelDirectly(unittest.TestCase): +class TestDecorateModelDirectly(Dy2StTestBase): def setUp(self): paddle.disable_static() paddle.jit.enable_to_static(True) self.x = to_variable(np.ones([4, 10]).astype('float32')) @test_and_compare_with_new_ir(False) + @ast_only_test def test_fake_input(self): net = SimpleNet() net = to_static(net) y = net(self.x) self.assertTrue(len(net.forward.program_cache) == 1) + @ast_only_test def test_input_spec(self): net = SimpleNet() net = to_static(net, input_spec=[InputSpec([None, 8, 10])]) @@ -393,7 +403,7 @@ def test_input_spec(self): self.assertListEqual(list(input_shape), [-1, 16, 10]) -class TestErrorWithInitFromStaticMode(unittest.TestCase): +class TestErrorWithInitFromStaticMode(Dy2StTestBase): def test_raise_error(self): # disable imperative paddle.enable_static() @@ -435,7 +445,7 @@ def func(self): return x -class TestCallNonForwardFunc(unittest.TestCase): +class TestCallNonForwardFunc(Dy2StTestBase): @test_and_compare_with_new_ir(False) def test_call_non_forward(self): paddle.disable_static() @@ -468,7 +478,7 @@ def forward(self): return self.b -class TestSetBuffers(unittest.TestCase): +class TestSetBuffers(Dy2StTestBase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() self.model_path = os.path.join(self.temp_dir.name, 'SetBuffersNet1') @@ -485,6 +495,7 @@ def test_set_buffers1(self): paddle.jit.save(net, self.model_path) paddle.enable_static() + @ast_only_test def test_set_buffers2(self): paddle.disable_static() net = SetBuffersNet2() @@ -498,7 +509,7 @@ def func(self, x): return x + 1 -class TestClassNoInheritLayer(unittest.TestCase): +class TestClassNoInheritLayer(Dy2StTestBase): def test_to_static(self): paddle.disable_static() net = ClassNoInheritLayer() diff --git a/test/dygraph_to_static/test_decorator_transform.py b/test/dygraph_to_static/test_decorator_transform.py index d0ddffdd40cbe..4f4096d607dc8 100644 --- a/test/dygraph_to_static/test_decorator_transform.py +++ b/test/dygraph_to_static/test_decorator_transform.py @@ -19,9 +19,9 @@ import decos import numpy as np -from dygraph_to_static_util import ( +from dygraph_to_static_utils_new import ( + Dy2StTestBase, ast_only_test, - dy2static_unittest, test_and_compare_with_new_ir, ) @@ -185,8 +185,7 @@ def deco_with_paddle_api(): return fun10() -@dy2static_unittest -class TestDecoratorTransform(unittest.TestCase): +class TestDecoratorTransform(Dy2StTestBase): @test_and_compare_with_new_ir(False) def test_deco_transform(self): outs = paddle.jit.to_static(forward)() diff --git a/test/dygraph_to_static/test_deepcopy.py b/test/dygraph_to_static/test_deepcopy.py index 0959d74dbc1fb..d291927b73ddd 100644 --- a/test/dygraph_to_static/test_deepcopy.py +++ b/test/dygraph_to_static/test_deepcopy.py @@ -16,15 +16,23 @@ from copy import deepcopy import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + IrMode, + ToStaticMode, + disable_test_case, + test_and_compare_with_new_ir, +) from test_rollback import Net, foo import paddle from paddle.jit.dy2static.program_translator import StaticFunction -class TestDeepCopy(unittest.TestCase): +# @dy2static_unittest +class TestDeepCopy(Dy2StTestBase): @test_and_compare_with_new_ir(False) + @disable_test_case((ToStaticMode.SOT, IrMode.PIR)) def test_net(self): net = Net() net = paddle.jit.to_static(net) diff --git a/test/dygraph_to_static/test_dict.py b/test/dygraph_to_static/test_dict.py index 80180b522cf54..99364c1343a7d 100644 --- a/test/dygraph_to_static/test_dict.py +++ b/test/dygraph_to_static/test_dict.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle import base @@ -116,6 +119,7 @@ def update_cache(cache): return cache +@dy2static_unittest class TestNetWithDict(unittest.TestCase): """ TestCase for the transformation from control flow `if/else` @@ -169,6 +173,7 @@ def test_dic_pop_2(x): return out +@dy2static_unittest class TestDictPop(unittest.TestCase): def setUp(self): self.input = np.random.random(3).astype('int32') @@ -249,6 +254,7 @@ def test_ast_to_func(self): ) +@dy2static_unittest class TestDictCmpInFor(unittest.TestCase): def test_with_for(self): def func(): diff --git a/test/dygraph_to_static/test_drop_path.py b/test/dygraph_to_static/test_drop_path.py index a9ea20be04c38..aad752007ceb0 100644 --- a/test/dygraph_to_static/test_drop_path.py +++ b/test/dygraph_to_static/test_drop_path.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle @@ -36,6 +39,7 @@ def forward(self, x): return drop_path(x, self.training) +@dy2static_unittest class TestTrainEval(unittest.TestCase): def setUp(self): self.model = DropPath() diff --git a/test/dygraph_to_static/test_duplicate_output.py b/test/dygraph_to_static/test_duplicate_output.py index 7e4220899d5ef..add3a7262446a 100644 --- a/test/dygraph_to_static/test_duplicate_output.py +++ b/test/dygraph_to_static/test_duplicate_output.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle @@ -38,6 +41,7 @@ def forward(self, x): return x, x +@dy2static_unittest class TestDuplicateOutput(unittest.TestCase): """ TestCase for the transformation from control flow `if/else` diff --git a/test/dygraph_to_static/test_error.py b/test/dygraph_to_static/test_error.py index 8c6f74d75c4e0..c12dc3887f23d 100644 --- a/test/dygraph_to_static/test_error.py +++ b/test/dygraph_to_static/test_error.py @@ -23,8 +23,6 @@ from paddle.jit.dy2static import error from paddle.jit.dy2static.origin_info import unwrap -os.environ['ENABLE_FALL_BACK'] = "False" # NOTE: ast only - def inner_func(): paddle.tensor.fill_constant(shape=[1, 2], value=9, dtype="int") @@ -257,9 +255,9 @@ def set_exception_type(self): def set_message(self): self.expected_message = [ - f'File "{self.filepath}", line 37, in func_error_in_compile_time', + f'File "{self.filepath}", line 35, in func_error_in_compile_time', 'inner_func()', - f'File "{self.filepath}", line 30, in inner_func', + f'File "{self.filepath}", line 28, in inner_func', 'def inner_func():', 'paddle.tensor.fill_constant(shape=[1, 2], value=9, dtype="int")', '<--- HERE', @@ -286,7 +284,7 @@ def set_exception_type(self): def set_message(self): self.expected_message = [ - f'File "{self.filepath}", line 48, in func_error_in_compile_time_2', + f'File "{self.filepath}", line 46, in func_error_in_compile_time_2', 'def func_error_in_compile_time_2(x):', 'x = base.dygraph.to_variable(x)', 'x = paddle.reshape(x, shape=[1, 2])', @@ -310,7 +308,7 @@ def set_exception_type(self): def set_message(self): self.expected_message = [ - f'File "{self.filepath}", line 93, in forward', + f'File "{self.filepath}", line 91, in forward', '@paddle.jit.to_static', 'def forward(self):', 'self.test_func()', @@ -334,7 +332,7 @@ def set_exception_type(self): def set_message(self): self.expected_message = [ - f'File "{self.filepath}", line 56, in func_error_in_runtime', + f'File "{self.filepath}", line 54, in func_error_in_runtime', 'x = base.dygraph.to_variable(x)', 'two = paddle.tensor.fill_constant(shape=[1], value=2, dtype="int32")', 'x = paddle.reshape(x, shape=[1, two])', @@ -349,7 +347,7 @@ def set_func(self): def set_message(self): self.expected_message = [ - 'File "{}", line 108, in func_error_in_runtime_with_empty_line'.format( + 'File "{}", line 106, in func_error_in_runtime_with_empty_line'.format( self.filepath ), 'two = paddle.tensor.fill_constant(shape=[1], value=2, dtype="int32")', @@ -372,7 +370,7 @@ def set_exception_type(self): def set_message(self): self.expected_message = [ - f'File "{self.filepath}", line 82, in forward', + f'File "{self.filepath}", line 80, in forward', 'def forward(self, x):', 'y = self._linear(x)', 'z = paddle.tensor.fill_constant(shape=[1, 2], value=9, dtype="int")', diff --git a/test/dygraph_to_static/test_fallback.py b/test/dygraph_to_static/test_fallback.py index b641f8b22233a..58394feda2a68 100644 --- a/test/dygraph_to_static/test_fallback.py +++ b/test/dygraph_to_static/test_fallback.py @@ -16,7 +16,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test +from dygraph_to_static_util import ast_only_test, dy2static_unittest import paddle @@ -51,6 +51,7 @@ def forward(self, x): return unsupport_func(x - 1) +@dy2static_unittest class TestFallback(unittest.TestCase): def setUp(self): self.x = paddle.to_tensor([2]).astype('int') diff --git a/test/dygraph_to_static/test_fetch_feed.py b/test/dygraph_to_static/test_fetch_feed.py index 0834f2ec4a315..b44578fad2c9e 100644 --- a/test/dygraph_to_static/test_fetch_feed.py +++ b/test/dygraph_to_static/test_fetch_feed.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle import base @@ -62,6 +65,7 @@ def forward(self, x): return pre, loss +@dy2static_unittest class TestPool2D(unittest.TestCase): def setUp(self): self.dygraph_class = Pool2D diff --git a/test/dygraph_to_static/test_for_enumerate.py b/test/dygraph_to_static/test_for_enumerate.py index bbb64e8756ea3..dc9505a5cf6fc 100644 --- a/test/dygraph_to_static/test_for_enumerate.py +++ b/test/dygraph_to_static/test_for_enumerate.py @@ -17,6 +17,7 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest import paddle from paddle import base @@ -353,6 +354,7 @@ def tensor_array_slice_in_enumerate(): return feat_n2 +@dy2static_unittest class TestTransformBase(unittest.TestCase): def setUp(self): self.place = ( @@ -556,6 +558,7 @@ def test_transformed_result_compare(self): self.transformed_result_compare() +@dy2static_unittest class TestForZip(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() diff --git a/test/dygraph_to_static/test_full_name_usage.py b/test/dygraph_to_static/test_full_name_usage.py index 0332480891e16..39a80acb566ea 100644 --- a/test/dygraph_to_static/test_full_name_usage.py +++ b/test/dygraph_to_static/test_full_name_usage.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test +from dygraph_to_static_util import ast_only_test, dy2static_unittest import paddle from paddle import base @@ -58,6 +58,7 @@ def double_decorated_func2(self, x): return jit_decorated_func(x) +@dy2static_unittest class TestFullNameDecorator(unittest.TestCase): @ast_only_test def test_run_success(self): diff --git a/test/dygraph_to_static/test_grad.py b/test/dygraph_to_static/test_grad.py index e542d87efc90c..ceca09e789548 100644 --- a/test/dygraph_to_static/test_grad.py +++ b/test/dygraph_to_static/test_grad.py @@ -65,6 +65,7 @@ def forward(self, x): return out +@dy2static_unittest class TestGrad(unittest.TestCase): def setUp(self): self.func = paddle.jit.to_static(GradLayer()) diff --git a/test/dygraph_to_static/test_gradient_aggregation.py b/test/dygraph_to_static/test_gradient_aggregation.py index ab7effba5b16c..4172fb87197df 100644 --- a/test/dygraph_to_static/test_gradient_aggregation.py +++ b/test/dygraph_to_static/test_gradient_aggregation.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle @@ -37,6 +40,7 @@ def forward(self, x): # return [out2, out1] # 梯度正常 +@dy2static_unittest class TestGradientAggregationInDy2Static(unittest.TestCase): @test_and_compare_with_new_ir(False) def test_to_static(self): diff --git a/test/dygraph_to_static/test_grid_generator.py b/test/dygraph_to_static/test_grid_generator.py index ea1eafb5c1fa9..7c1a9189366e0 100644 --- a/test/dygraph_to_static/test_grid_generator.py +++ b/test/dygraph_to_static/test_grid_generator.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_and_compare_with_new_ir, +) import paddle from paddle import ParamAttr, nn @@ -126,7 +129,7 @@ def get_expand_tensor(self, batch_C_prime): return batch_C_ex_part_tensor -class TestGridGenerator(unittest.TestCase): +class TestGridGenerator(Dy2StTestBase): def setUp(self): self.x = paddle.uniform(shape=[1, 20, 2], dtype='float32') diff --git a/test/dygraph_to_static/test_ifelse.py b/test/dygraph_to_static/test_ifelse.py index 6e2dc6f8ffe6d..6c141aed8ff13 100644 --- a/test/dygraph_to_static/test_ifelse.py +++ b/test/dygraph_to_static/test_ifelse.py @@ -72,6 +72,7 @@ def test_error(self): paddle.jit.enable_to_static(False) +@dy2static_unittest class TestDygraphIfElse(unittest.TestCase): """ TestCase for the transformation from control flow `if/else` @@ -238,6 +239,7 @@ def setUp(self): self.dyfunc = if_tensor_case +@dy2static_unittest class TestDygraphIfElseNet(unittest.TestCase): """ TestCase for the transformation from control flow `if/else` @@ -350,6 +352,7 @@ def forward(self, x, y): raise ValueError('Illegal mode') +@dy2static_unittest class TestDiffModeNet(unittest.TestCase): """ TestCase for the net with different modes @@ -392,6 +395,7 @@ def init_net(self): self.Net = DiffModeNet2 +@dy2static_unittest class TestNewVarCreateInOneBranch(unittest.TestCase): def test_var_used_in_another_for(self): def case_func(training): @@ -497,6 +501,7 @@ def forward(self, a, b, c): return b +@dy2static_unittest class TestDy2StIfElseBackward(unittest.TestCase): def test_run_backward(self): a = paddle.randn((4, 3), dtype='float32') diff --git a/test/dygraph_to_static/test_isinstance.py b/test/dygraph_to_static/test_isinstance.py index e3557dc32658f..7dfd05989dabe 100644 --- a/test/dygraph_to_static/test_isinstance.py +++ b/test/dygraph_to_static/test_isinstance.py @@ -26,7 +26,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle import nn @@ -85,6 +88,7 @@ def train(model, to_static): return out.numpy() +@dy2static_unittest class TestIsinstance(unittest.TestCase): def test_isinstance_simple_return_layer(self): model = IsInstanceLayer(SimpleReturnLayer()) diff --git a/test/dygraph_to_static/test_jit_property_save.py b/test/dygraph_to_static/test_jit_property_save.py index f25c128e265d7..965168dedc6ea 100644 --- a/test/dygraph_to_static/test_jit_property_save.py +++ b/test/dygraph_to_static/test_jit_property_save.py @@ -14,9 +14,12 @@ import unittest +from dygraph_to_static_util import dy2static_unittest + import paddle +@dy2static_unittest class TestPropertySave(unittest.TestCase): """test jit property save""" diff --git a/test/dygraph_to_static/test_jit_setitem.py b/test/dygraph_to_static/test_jit_setitem.py index 59841ed431f08..219e6a6c9de74 100644 --- a/test/dygraph_to_static/test_jit_setitem.py +++ b/test/dygraph_to_static/test_jit_setitem.py @@ -16,11 +16,13 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest import paddle import paddle.nn.functional as F +@dy2static_unittest class TestSetItemBase(unittest.TestCase): def setUp(self) -> None: pass diff --git a/test/dygraph_to_static/test_lac.py b/test/dygraph_to_static/test_lac.py index 522eb81cf5a7a..461b03fe7a5ed 100644 --- a/test/dygraph_to_static/test_lac.py +++ b/test/dygraph_to_static/test_lac.py @@ -22,6 +22,8 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "2" +from dygraph_to_static_util import dy2static_unittest + import paddle from paddle import _legacy_C_ops, base from paddle.base.dygraph import to_variable @@ -513,6 +515,7 @@ def create_dataloader(reader, place): return data_loader +@dy2static_unittest class TestLACModel(unittest.TestCase): def setUp(self): self.args = Args() diff --git a/test/dygraph_to_static/test_lambda.py b/test/dygraph_to_static/test_lambda.py index c1ff57147564c..add572cb6dfcf 100644 --- a/test/dygraph_to_static/test_lambda.py +++ b/test/dygraph_to_static/test_lambda.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest import paddle import paddle.nn.functional as F @@ -79,6 +80,7 @@ def call_lambda_with_ifExpr2(x): return out +@dy2static_unittest class TestLambda(unittest.TestCase): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') diff --git a/test/dygraph_to_static/test_layer_hook.py b/test/dygraph_to_static/test_layer_hook.py index bf679cf8dcc2e..d19b9ea9abfc9 100644 --- a/test/dygraph_to_static/test_layer_hook.py +++ b/test/dygraph_to_static/test_layer_hook.py @@ -17,7 +17,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle @@ -56,6 +59,7 @@ def forward(self, x): return out +@dy2static_unittest class TestNestLayerHook(unittest.TestCase): def setUp(self): paddle.seed(2022) diff --git a/test/dygraph_to_static/test_len.py b/test/dygraph_to_static/test_len.py index e2cee7c4dc8b4..340ba86ff50c2 100644 --- a/test/dygraph_to_static/test_len.py +++ b/test/dygraph_to_static/test_len.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest import paddle from paddle import base @@ -42,6 +43,7 @@ def len_with_lod_tensor_array(x): return arr_len +@dy2static_unittest class TestLen(unittest.TestCase): def setUp(self): self.place = ( @@ -113,6 +115,7 @@ def len_with_selected_rows(place): return result +@dy2static_unittest class TestLenWithSelectedRows(unittest.TestCase): def setUp(self): self.place = ( diff --git a/test/dygraph_to_static/test_list.py b/test/dygraph_to_static/test_list.py index 9ad646de8818c..51b28ce3fe38a 100644 --- a/test/dygraph_to_static/test_list.py +++ b/test/dygraph_to_static/test_list.py @@ -16,6 +16,7 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest import paddle from paddle import base @@ -207,6 +208,7 @@ def test_list_pop_in_while_loop(x, iter_num): return a[0], b[2] +@dy2static_unittest class TestListWithoutControlFlow(unittest.TestCase): def setUp(self): self.place = ( @@ -354,6 +356,7 @@ def forward(self, x, index, *args): return z +@dy2static_unittest class TestListWithCondGradInferVarType(unittest.TestCase): def test_to_static(self): net = ListWithCondNet() diff --git a/test/dygraph_to_static/test_load_transformer.py b/test/dygraph_to_static/test_load_transformer.py index 95e06a51f3c69..1e36145537f43 100644 --- a/test/dygraph_to_static/test_load_transformer.py +++ b/test/dygraph_to_static/test_load_transformer.py @@ -16,7 +16,13 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + IrMode, + ToStaticMode, + disable_test_case, + test_and_compare_with_new_ir, +) import paddle @@ -41,11 +47,12 @@ def forward(self, x): return t -class TestFallback(unittest.TestCase): +class TestFallback(Dy2StTestBase): def setUp(self): self.x = paddle.to_tensor(1.0).astype('int') @test_and_compare_with_new_ir(False) + @disable_test_case((ToStaticMode.SOT, IrMode.PIR)) def test_name_load(self): net_dy = Net() net_st = Net() @@ -54,8 +61,9 @@ def test_name_load(self): np.testing.assert_allclose(output_dy.numpy(), output_st.numpy()) -class TestLoad2(unittest.TestCase): +class TestLoad2(Dy2StTestBase): @test_and_compare_with_new_ir(False) + @disable_test_case((ToStaticMode.SOT, IrMode.PIR)) def test_name_load_nograd(self): @paddle.no_grad() def func(x): diff --git a/test/dygraph_to_static/test_logical.py b/test/dygraph_to_static/test_logical.py index 9e0f1d12bd9b4..a05f91b7c0493 100644 --- a/test/dygraph_to_static/test_logical.py +++ b/test/dygraph_to_static/test_logical.py @@ -18,6 +18,7 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest import paddle from paddle import base @@ -167,6 +168,7 @@ def test_shape_not_equal(x): return paddle.ones([1, 2, 3]) +@dy2static_unittest class TestLogicalBase(unittest.TestCase): def setUp(self): self.input = np.array([3]).astype('int32') @@ -262,6 +264,7 @@ def _set_test_func(self): self.dygraph_func = test_shape_not_equal +@dy2static_unittest class TestCmpopNodeToStr(unittest.TestCase): def test_exception(self): with self.assertRaises(KeyError): diff --git a/test/dygraph_to_static/test_loop.py b/test/dygraph_to_static/test_loop.py index 77f568e2c5eec..422508d6cd97e 100644 --- a/test/dygraph_to_static/test_loop.py +++ b/test/dygraph_to_static/test_loop.py @@ -16,6 +16,7 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest import paddle import paddle.nn.functional as F @@ -229,6 +230,7 @@ def for_loop_dufunc_with_listcomp(array): return res +@dy2static_unittest class TestNameVisitor(unittest.TestCase): def setUp(self): self.loop_funcs = [ @@ -299,6 +301,7 @@ def test_nested_loop_vars(self): i += 1 +@dy2static_unittest class TestTransformWhileLoop(unittest.TestCase): def setUp(self): self.place = ( @@ -378,6 +381,7 @@ def _init_dyfunc(self): self.dyfunc = loop_var_contains_property +@dy2static_unittest class TestTransformForLoop(unittest.TestCase): def setUp(self): self.place = ( @@ -460,6 +464,7 @@ def forward(self, x): return out +@dy2static_unittest class TestForLoopMeetDict(unittest.TestCase): def test_start(self): net = Net() diff --git a/test/dygraph_to_static/test_mnist.py b/test/dygraph_to_static/test_mnist.py index 9641a9225cee7..984176a83afe0 100644 --- a/test/dygraph_to_static/test_mnist.py +++ b/test/dygraph_to_static/test_mnist.py @@ -18,7 +18,11 @@ from time import time import numpy as np -from dygraph_to_static_util import ast_only_test, test_and_compare_with_new_ir +from dygraph_to_static_util import ( + ast_only_test, + dy2static_unittest, + test_and_compare_with_new_ir, +) from predictor_utils import PredictorTools import paddle @@ -126,6 +130,7 @@ def inference(self, inputs): return x +@dy2static_unittest class TestMNIST(unittest.TestCase): def setUp(self): self.epoch_num = 1 diff --git a/test/dygraph_to_static/test_mobile_net.py b/test/dygraph_to_static/test_mobile_net.py index 5536a14e695c4..cca77999d5e7d 100644 --- a/test/dygraph_to_static/test_mobile_net.py +++ b/test/dygraph_to_static/test_mobile_net.py @@ -19,7 +19,7 @@ import unittest import numpy as np -from dygraph_to_static_util import test_with_new_ir +from dygraph_to_static_util import dy2static_unittest, test_with_new_ir from predictor_utils import PredictorTools import paddle @@ -656,6 +656,7 @@ def predict_analysis_inference(args, data): return out +@dy2static_unittest class TestMobileNet(unittest.TestCase): def setUp(self): self.args = Args() diff --git a/test/dygraph_to_static/test_multi_forward.py b/test/dygraph_to_static/test_multi_forward.py index 039db089b5c86..2cf8e592f3fa0 100644 --- a/test/dygraph_to_static/test_multi_forward.py +++ b/test/dygraph_to_static/test_multi_forward.py @@ -14,7 +14,10 @@ import unittest -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle @@ -33,6 +36,7 @@ def forward(self, x): return self.linear(x) +@dy2static_unittest class TestBackward(unittest.TestCase): @test_and_compare_with_new_ir(False) def test_order_0(self): diff --git a/test/dygraph_to_static/test_new_ir_selectedrows.py b/test/dygraph_to_static/test_new_ir_selectedrows.py index 7d87a48fe7858..e403cbd6089a1 100644 --- a/test/dygraph_to_static/test_new_ir_selectedrows.py +++ b/test/dygraph_to_static/test_new_ir_selectedrows.py @@ -15,10 +15,7 @@ import random import unittest -from dygraph_to_static_util import ( - enable_fallback_guard, - test_and_compare_with_new_ir, -) +from dygraph_to_static_util import test_and_compare_with_new_ir import paddle from paddle.jit.api import to_static @@ -104,5 +101,4 @@ def test_dygraph_static_same_loss(self): if __name__ == '__main__': - with enable_fallback_guard("False"): - unittest.main() + unittest.main() diff --git a/test/dygraph_to_static/test_op_attr.py b/test/dygraph_to_static/test_op_attr.py index 17394df88dd07..6aaf1cdbf2138 100644 --- a/test/dygraph_to_static/test_op_attr.py +++ b/test/dygraph_to_static/test_op_attr.py @@ -14,7 +14,7 @@ import unittest -from dygraph_to_static_util import ast_only_test +from dygraph_to_static_util import ast_only_test, dy2static_unittest import paddle from paddle.static import InputSpec @@ -52,6 +52,7 @@ def with_cond(self, x): return out +@dy2static_unittest class CheckOpAttr(unittest.TestCase): def setUp(self): self.in_num = 16 diff --git a/test/dygraph_to_static/test_origin_info.py b/test/dygraph_to_static/test_origin_info.py index e2925d4fa1a4b..be38650b750c2 100644 --- a/test/dygraph_to_static/test_origin_info.py +++ b/test/dygraph_to_static/test_origin_info.py @@ -16,6 +16,8 @@ import sys import unittest +from dygraph_to_static_util import dy2static_unittest + from paddle.jit.api import to_static from paddle.jit.dy2static import DygraphToStaticAst from paddle.jit.dy2static.origin_info import ( @@ -54,6 +56,7 @@ def decorated_func2(x): return x +@dy2static_unittest class TestOriginInfo(unittest.TestCase): def setUp(self): self.set_test_func() diff --git a/test/dygraph_to_static/test_param_guard.py b/test/dygraph_to_static/test_param_guard.py index b8edaf50dfced..c6787db58fc89 100644 --- a/test/dygraph_to_static/test_param_guard.py +++ b/test/dygraph_to_static/test_param_guard.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle.jit import to_static @@ -50,6 +53,7 @@ def forward(self, x): return out +@dy2static_unittest class TestParameterList(unittest.TestCase): def setUp(self): self.seed = 2021 @@ -102,6 +106,7 @@ def forward(self, x): return out +@dy2static_unittest class TestRawParameterList(unittest.TestCase): def setUp(self): self.seed = 2021 diff --git a/test/dygraph_to_static/test_params_no_grad.py b/test/dygraph_to_static/test_params_no_grad.py index f7bf87888f49c..3b3f3949fad57 100644 --- a/test/dygraph_to_static/test_params_no_grad.py +++ b/test/dygraph_to_static/test_params_no_grad.py @@ -14,6 +14,8 @@ import unittest +from dygraph_to_static_util import dy2static_unittest + import paddle import paddle.distributed as dist from paddle import nn @@ -52,6 +54,7 @@ def train(): print(loss) +@dy2static_unittest class TestParamsNoGrad(unittest.TestCase): def test_two_card(self): if ( diff --git a/test/dygraph_to_static/test_partial_program.py b/test/dygraph_to_static/test_partial_program.py index db4a7c21e4010..f742a0cdb5337 100644 --- a/test/dygraph_to_static/test_partial_program.py +++ b/test/dygraph_to_static/test_partial_program.py @@ -15,9 +15,12 @@ import unittest import numpy as np -from dygraph_to_static_util import ( +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + IrMode, + ToStaticMode, ast_only_test, - dy2static_unittest, + disable_test_case, test_and_compare_with_new_ir, ) from test_fetch_feed import Linear @@ -57,8 +60,7 @@ def fake_data(shape): return base.dygraph.to_variable(x_data) -@dy2static_unittest -class TestWithNestedInput(unittest.TestCase): +class TestWithNestedInput(Dy2StTestBase): def setUp(self): self.x = None self.y = None @@ -89,14 +91,14 @@ def _run(self, to_static): return out.numpy() @test_and_compare_with_new_ir(False) + @disable_test_case((ToStaticMode.SOT, IrMode.PIR)) def test_nest(self): dygraph_res = self._run(to_static=False) static_res = self._run(to_static=True) np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05) -@dy2static_unittest -class TestWithNestedOutput(unittest.TestCase): +class TestWithNestedOutput(Dy2StTestBase): def setUp(self): self.x = None self.y = None @@ -133,8 +135,7 @@ def test_nest(self): self.assertTrue(dy_var, st_var) -@dy2static_unittest -class TestWithTrainAndEval(unittest.TestCase): +class TestWithTrainAndEval(Dy2StTestBase): @ast_only_test @test_and_compare_with_new_ir(False) def test_switch_eval_and_train(self): @@ -167,8 +168,7 @@ def test_switch_eval_and_train(self): ) -@dy2static_unittest -class TestWithNoGrad(unittest.TestCase): +class TestWithNoGrad(Dy2StTestBase): @ast_only_test @test_and_compare_with_new_ir(False) def test_with_no_grad(self): @@ -204,8 +204,7 @@ def forward(self, x): return x1 -@dy2static_unittest -class TestPruneUnusedParamInProgram(unittest.TestCase): +class TestPruneUnusedParamInProgram(Dy2StTestBase): @test_and_compare_with_new_ir(False) def test_prune(self): input_ids = np.array([[15, 11, 6, 3, 18, 13]]).astype("float32") diff --git a/test/dygraph_to_static/test_partial_program_hook.py b/test/dygraph_to_static/test_partial_program_hook.py index cb177862692d3..c10194f6187ad 100644 --- a/test/dygraph_to_static/test_partial_program_hook.py +++ b/test/dygraph_to_static/test_partial_program_hook.py @@ -15,11 +15,14 @@ import os import unittest +from dygraph_to_static_util import dy2static_unittest + import paddle from paddle.base import core from paddle.jit.dy2static import partial_program, program_translator +@dy2static_unittest class TestPartiaProgramLayerHook(unittest.TestCase): def setUp(self): os.environ["ENABLE_FALL_BACK"] = "False" @@ -35,6 +38,7 @@ def test_after_infer(self): self.assertIsNone(self._hook.after_infer(None)) +@dy2static_unittest class TestPrimHook(unittest.TestCase): def setUp(self): os.environ["ENABLE_FALL_BACK"] = "False" diff --git a/test/dygraph_to_static/test_place.py b/test/dygraph_to_static/test_place.py index 2ed904a0b5490..f1cb7e80589a3 100644 --- a/test/dygraph_to_static/test_place.py +++ b/test/dygraph_to_static/test_place.py @@ -14,9 +14,12 @@ import unittest +from dygraph_to_static_util import dy2static_unittest + import paddle +@dy2static_unittest class TestPlace(unittest.TestCase): def test_place(self): paddle.enable_static() diff --git a/test/dygraph_to_static/test_print.py b/test/dygraph_to_static/test_print.py index d7fe1f5a882c0..251bca776e700 100644 --- a/test/dygraph_to_static/test_print.py +++ b/test/dygraph_to_static/test_print.py @@ -15,7 +15,10 @@ import unittest import numpy -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle import base @@ -84,6 +87,7 @@ def dyfunc_print_with_kwargs(x): print("Tensor", x_t, end='\n\n', sep=': ') +@dy2static_unittest class TestPrintBase(unittest.TestCase): def setUp(self): self.input = numpy.ones(5).astype("int32") diff --git a/test/dygraph_to_static/test_program_translator.py b/test/dygraph_to_static/test_program_translator.py index 25cf316dd7e91..d2909d07a50b2 100644 --- a/test/dygraph_to_static/test_program_translator.py +++ b/test/dygraph_to_static/test_program_translator.py @@ -18,7 +18,7 @@ import astor import numpy as np -from dygraph_to_static_util import ast_only_test +from dygraph_to_static_util import ast_only_test, dy2static_unittest from ifelse_simple_func import ( dyfunc_with_if_else_early_return1, dyfunc_with_if_else_early_return2, @@ -212,6 +212,7 @@ def forward(self, x): return y +@dy2static_unittest class TestEnableDeclarative(unittest.TestCase): def setUp(self): self.x = np.random.randn(30, 10, 32).astype('float32') @@ -267,6 +268,7 @@ def switch_mode_function(): return True +@dy2static_unittest class TestFunctionTrainEvalMode(unittest.TestCase): @ast_only_test def test_switch_mode(self): @@ -297,6 +299,7 @@ def test_raise_error(self): net.foo.train() +@dy2static_unittest class TestIfElseEarlyReturn(unittest.TestCase): def test_ifelse_early_return1(self): answer = np.zeros([2, 2]) + 1 @@ -311,6 +314,7 @@ def test_ifelse_early_return2(self): np.testing.assert_allclose(answer, out[0].numpy(), rtol=1e-05) +@dy2static_unittest class TestRemoveCommentInDy2St(unittest.TestCase): def func_with_comment(self): # Comment1 @@ -352,6 +356,7 @@ def func1(x): return func1(data) +@dy2static_unittest class TestParameterRecorder(unittest.TestCase): def test_recorder(self): """function calls nn.Layer case.""" diff --git a/test/dygraph_to_static/test_ptb_lm.py b/test/dygraph_to_static/test_ptb_lm.py index 2c94d6b343d3a..76a35d57ac9ba 100644 --- a/test/dygraph_to_static/test_ptb_lm.py +++ b/test/dygraph_to_static/test_ptb_lm.py @@ -17,7 +17,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle import base @@ -321,6 +324,7 @@ def train_static(place): return train(place) +@dy2static_unittest class TestPtb(unittest.TestCase): def setUp(self): self.place = ( diff --git a/test/dygraph_to_static/test_ptb_lm_v2.py b/test/dygraph_to_static/test_ptb_lm_v2.py index 3694d50396536..92d4d43d9d4ea 100644 --- a/test/dygraph_to_static/test_ptb_lm_v2.py +++ b/test/dygraph_to_static/test_ptb_lm_v2.py @@ -17,6 +17,7 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest import paddle @@ -322,6 +323,7 @@ def train_static(place): return train(place) +@dy2static_unittest class TestPtb(unittest.TestCase): def setUp(self): self.place = ( diff --git a/test/dygraph_to_static/test_pylayer.py b/test/dygraph_to_static/test_pylayer.py index c36bc1a14d5d1..0e083a67b0e94 100644 --- a/test/dygraph_to_static/test_pylayer.py +++ b/test/dygraph_to_static/test_pylayer.py @@ -26,6 +26,7 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest from test_jit_save_load import train import paddle @@ -262,6 +263,7 @@ def forward(self, x): return out +@dy2static_unittest class TestPyLayerBase(unittest.TestCase): def setUp(self): self.place = "gpu" if paddle.is_compiled_with_cuda() else "cpu" @@ -512,6 +514,7 @@ def test_pylayer_net_with_no_grad(self): self._run_and_compare(input1, input2) +@dy2static_unittest class PyLayerTrainHelper(unittest.TestCase): def setUp(self): self.place = "gpu" if paddle.is_compiled_with_cuda() else "cpu" @@ -583,6 +586,7 @@ def test_pylayer_net_no_grad(self): ) +@dy2static_unittest class TestPyLayerJitSaveLoad(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() diff --git a/test/dygraph_to_static/test_reinforcement_learning.py b/test/dygraph_to_static/test_reinforcement_learning.py index 2a792ebcda733..ffbd0e315229d 100644 --- a/test/dygraph_to_static/test_reinforcement_learning.py +++ b/test/dygraph_to_static/test_reinforcement_learning.py @@ -18,7 +18,10 @@ import gym import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle import paddle.nn.functional as F @@ -203,6 +206,7 @@ def finish_episode(): return np.array(loss_data) +@dy2static_unittest class TestDeclarative(unittest.TestCase): def setUp(self): self.place = ( diff --git a/test/dygraph_to_static/test_resnet.py b/test/dygraph_to_static/test_resnet.py index a99999c4e7447..cb57ce234b263 100644 --- a/test/dygraph_to_static/test_resnet.py +++ b/test/dygraph_to_static/test_resnet.py @@ -19,7 +19,7 @@ import unittest import numpy as np -from dygraph_to_static_util import test_with_new_ir +from dygraph_to_static_util import dy2static_unittest, test_with_new_ir from predictor_utils import PredictorTools import paddle @@ -386,6 +386,7 @@ def predict_analysis_inference(self, data): return out +@dy2static_unittest class TestResnet(unittest.TestCase): def setUp(self): self.resnet_helper = ResNetHelper() diff --git a/test/dygraph_to_static/test_resnet_amp.py b/test/dygraph_to_static/test_resnet_amp.py index 60a30db707be4..0255c0c00db3b 100644 --- a/test/dygraph_to_static/test_resnet_amp.py +++ b/test/dygraph_to_static/test_resnet_amp.py @@ -16,7 +16,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) from test_resnet import SEED, ResNet, optimizer_setting import paddle @@ -111,6 +114,7 @@ def train(to_static, build_strategy=None): return total_loss.numpy() +@dy2static_unittest class TestResnet(unittest.TestCase): def train(self, to_static): paddle.jit.enable_to_static(to_static) diff --git a/test/dygraph_to_static/test_resnet_pure_fp16.py b/test/dygraph_to_static/test_resnet_pure_fp16.py index 1eb6a8ac9b3a5..771f9033f99d7 100644 --- a/test/dygraph_to_static/test_resnet_pure_fp16.py +++ b/test/dygraph_to_static/test_resnet_pure_fp16.py @@ -16,7 +16,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) from test_resnet import SEED, ResNet, optimizer_setting import paddle @@ -112,6 +115,7 @@ def train(to_static, build_strategy=None): return loss_data +@dy2static_unittest class TestResnet(unittest.TestCase): def train(self, to_static): paddle.jit.enable_to_static(to_static) diff --git a/test/dygraph_to_static/test_resnet_v2.py b/test/dygraph_to_static/test_resnet_v2.py index cf941effd2c28..0f5d804427ca6 100644 --- a/test/dygraph_to_static/test_resnet_v2.py +++ b/test/dygraph_to_static/test_resnet_v2.py @@ -19,7 +19,7 @@ import unittest import numpy as np -from dygraph_to_static_util import test_with_new_ir +from dygraph_to_static_util import dy2static_unittest, test_with_new_ir from predictor_utils import PredictorTools import paddle @@ -242,6 +242,7 @@ def __len__(self): return len(self.img) +@dy2static_unittest class TestResnet(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() diff --git a/test/dygraph_to_static/test_return.py b/test/dygraph_to_static/test_return.py index 41c622e9ed03a..0cd14b94267cd 100644 --- a/test/dygraph_to_static/test_return.py +++ b/test/dygraph_to_static/test_return.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test +from dygraph_to_static_util import ast_only_test, dy2static_unittest from ifelse_simple_func import dyfunc_with_if_else import paddle @@ -264,6 +264,7 @@ def func(): return func() +@dy2static_unittest class TestReturnBase(unittest.TestCase): def setUp(self): self.input = np.ones(1).astype('int32') diff --git a/test/dygraph_to_static/test_rollback.py b/test/dygraph_to_static/test_rollback.py index 0efb2147f2076..7ee3456747b51 100644 --- a/test/dygraph_to_static/test_rollback.py +++ b/test/dygraph_to_static/test_rollback.py @@ -71,6 +71,7 @@ def foo(x, flag=False): return out +@dy2static_unittest class TestRollBackPlainFunction(unittest.TestCase): def setUp(self): paddle.set_device("cpu") diff --git a/test/dygraph_to_static/test_save_inference_model.py b/test/dygraph_to_static/test_save_inference_model.py index c6a01d38e7d86..468541cfde39e 100644 --- a/test/dygraph_to_static/test_save_inference_model.py +++ b/test/dygraph_to_static/test_save_inference_model.py @@ -17,7 +17,11 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, test_and_compare_with_new_ir +from dygraph_to_static_util import ( + ast_only_test, + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle import base @@ -73,6 +77,7 @@ def forward(self, x): return loss, out +@dy2static_unittest class TestDyToStaticSaveInferenceModel(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() @@ -223,6 +228,7 @@ def load_and_run_inference( return np.array(results[0]) +@dy2static_unittest class TestPartialProgramRaiseError(unittest.TestCase): @ast_only_test @test_and_compare_with_new_ir(False) diff --git a/test/dygraph_to_static/test_save_load.py b/test/dygraph_to_static/test_save_load.py index 1c7b34435d7ac..92965aea2ccc2 100644 --- a/test/dygraph_to_static/test_save_load.py +++ b/test/dygraph_to_static/test_save_load.py @@ -17,7 +17,11 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, test_and_compare_with_new_ir +from dygraph_to_static_util import ( + ast_only_test, + dy2static_unittest, + test_and_compare_with_new_ir, +) from test_fetch_feed import Linear import paddle @@ -55,6 +59,7 @@ def forward_post_hook_for_prim_net(layer, input, output): return output * 2 +@dy2static_unittest class TestDyToStaticSaveLoad(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() diff --git a/test/dygraph_to_static/test_se_resnet.py b/test/dygraph_to_static/test_se_resnet.py index c12990b53659d..15c021d29ad87 100644 --- a/test/dygraph_to_static/test_se_resnet.py +++ b/test/dygraph_to_static/test_se_resnet.py @@ -20,7 +20,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test +from dygraph_to_static_util import ast_only_test, dy2static_unittest from predictor_utils import PredictorTools import paddle @@ -346,6 +346,7 @@ def forward(self, inputs, label): return out, avg_loss, acc_top1, acc_top5 +@dy2static_unittest class TestSeResnet(unittest.TestCase): def setUp(self): self.train_reader = paddle.batch( diff --git a/test/dygraph_to_static/test_sentiment.py b/test/dygraph_to_static/test_sentiment.py index 22bb980cd437f..60d3678a5a72b 100644 --- a/test/dygraph_to_static/test_sentiment.py +++ b/test/dygraph_to_static/test_sentiment.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) from test_lac import DynamicGRU import paddle @@ -369,6 +372,7 @@ def train(args, to_static): return loss_data +@dy2static_unittest class TestSentiment(unittest.TestCase): def setUp(self): self.args = Args() diff --git a/test/dygraph_to_static/test_seq2seq.py b/test/dygraph_to_static/test_seq2seq.py index 85de170c3f06c..b97752d4c57cb 100644 --- a/test/dygraph_to_static/test_seq2seq.py +++ b/test/dygraph_to_static/test_seq2seq.py @@ -18,6 +18,7 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest from seq2seq_dygraph_model import AttentionModel, BaseModel from seq2seq_utils import Seq2SeqModelHyperParams, get_data_iter @@ -174,6 +175,7 @@ def infer(args, attn_model=False): return outputs.numpy() +@dy2static_unittest class TestSeq2seq(unittest.TestCase): def setUp(self): self.args = Seq2SeqModelHyperParams diff --git a/test/dygraph_to_static/test_simnet.py b/test/dygraph_to_static/test_simnet.py index 7d6cad6d03381..90dce27f87eef 100644 --- a/test/dygraph_to_static/test_simnet.py +++ b/test/dygraph_to_static/test_simnet.py @@ -17,7 +17,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) from simnet_dygraph_model import BOW, HingeLoss import paddle @@ -176,8 +179,9 @@ def train(conf_dict, to_static): return losses +@dy2static_unittest class TestSimnet(unittest.TestCase): - @test_and_compare_with_new_ir(True) + @test_and_compare_with_new_ir(False) def test_dygraph_static_same_loss(self): if base.is_compiled_with_cuda(): base.set_flags({"FLAGS_cudnn_deterministic": True}) diff --git a/test/dygraph_to_static/test_simnet_v2.py b/test/dygraph_to_static/test_simnet_v2.py index a54cfe14dcbf8..16fccfd731be0 100644 --- a/test/dygraph_to_static/test_simnet_v2.py +++ b/test/dygraph_to_static/test_simnet_v2.py @@ -17,7 +17,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) from simnet_dygraph_model_v2 import BOW, HingeLoss import paddle @@ -176,8 +179,9 @@ def train(conf_dict, to_static): return losses +@dy2static_unittest class TestSimnet(unittest.TestCase): - @test_and_compare_with_new_ir(True) + @test_and_compare_with_new_ir(False) def test_dygraph_static_same_loss(self): if paddle.is_compiled_with_cuda(): paddle.base.set_flags({"FLAGS_cudnn_deterministic": True}) diff --git a/test/dygraph_to_static/test_slice.py b/test/dygraph_to_static/test_slice.py index e66080a2c687f..3bd4c5f8a2c83 100644 --- a/test/dygraph_to_static/test_slice.py +++ b/test/dygraph_to_static/test_slice.py @@ -17,7 +17,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test +from dygraph_to_static_util import ast_only_test, dy2static_unittest import paddle from paddle.static import InputSpec @@ -108,6 +108,7 @@ def forward(self, x): return x +@dy2static_unittest class TestSliceWithoutControlFlow(unittest.TestCase): def setUp(self): self.init_input() @@ -169,6 +170,7 @@ def init_dygraph_func(self): self.dygraph_func = test_set_value +@dy2static_unittest class TestSetValueWithLayerAndSave(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() @@ -189,6 +191,7 @@ def test_set_value_with_save(self): ) +@dy2static_unittest class TestSliceSupplementSpecialCase(unittest.TestCase): # unittest for slice index which abs(step)>0. eg: x[::2] def test_static_slice_step(self): @@ -232,6 +235,7 @@ def func(inps): ) +@dy2static_unittest class TestPaddleStridedSlice(unittest.TestCase): def test_compare_paddle_strided_slice_with_numpy(self): paddle.disable_static() @@ -293,6 +297,7 @@ def slice_zero_shape_tensor(x): return y +@dy2static_unittest class TestSliceZeroShapeTensor(unittest.TestCase): def test_slice(self): paddle.disable_static() diff --git a/test/dygraph_to_static/test_spec_names.py b/test/dygraph_to_static/test_spec_names.py index 86fe69c507631..72ffdc845134a 100644 --- a/test/dygraph_to_static/test_spec_names.py +++ b/test/dygraph_to_static/test_spec_names.py @@ -14,8 +14,9 @@ import unittest -from dygraph_to_static_util import ( - enable_fallback_guard, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + ast_only_test, test_and_compare_with_new_ir, ) @@ -40,7 +41,7 @@ def forward(self, x, y, m, n): return paddle.sum(out) -class TestArgsSpecName(unittest.TestCase): +class TestArgsSpecName(Dy2StTestBase): def read_from_dataset(self): self.x = paddle.randn([4, 2, 8]) self.y = paddle.randn([4, 2, 8]) @@ -48,6 +49,7 @@ def read_from_dataset(self): self.n = paddle.randn([4, 2, 8]) @test_and_compare_with_new_ir(False) + @ast_only_test def test_spec_name_hash(self): net = Net() net = paddle.jit.to_static(net) @@ -90,5 +92,4 @@ def run_test(self, net, inputs, trace_count, mode): if __name__ == '__main__': - with enable_fallback_guard("False"): - unittest.main() + unittest.main() diff --git a/test/dygraph_to_static/test_tensor_hook.py b/test/dygraph_to_static/test_tensor_hook.py index fc53fefc95ae6..06b1b288ad899 100644 --- a/test/dygraph_to_static/test_tensor_hook.py +++ b/test/dygraph_to_static/test_tensor_hook.py @@ -15,12 +15,14 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest import paddle from paddle import nn from paddle.jit import to_static +@dy2static_unittest class TestStaticAnalysis(unittest.TestCase): def test_hook_for_different_parameter(self): def f(x): diff --git a/test/dygraph_to_static/test_tensor_methods.py b/test/dygraph_to_static/test_tensor_methods.py index 6e1ae1a3ffc0e..65981d65825a4 100644 --- a/test/dygraph_to_static/test_tensor_methods.py +++ b/test/dygraph_to_static/test_tensor_methods.py @@ -15,7 +15,11 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, test_and_compare_with_new_ir +from dygraph_to_static_util import ( + ast_only_test, + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle @@ -27,6 +31,7 @@ def tensor_clone(x): return y +@dy2static_unittest class TestTensorClone(unittest.TestCase): def _run(self, to_static): paddle.jit.enable_to_static(to_static) @@ -48,6 +53,7 @@ def tensor_numpy(x): return x +@dy2static_unittest class TestTensorDygraphOnlyMethodError(unittest.TestCase): def _run(self, to_static): paddle.jit.enable_to_static(to_static) @@ -71,6 +77,7 @@ def tensor_item(x): return y.item() +@dy2static_unittest class TestTensorItem(unittest.TestCase): def _run(self, to_static): paddle.jit.enable_to_static(to_static) @@ -95,6 +102,7 @@ def tensor_size(x): return y +@dy2static_unittest class TestTensorSize(unittest.TestCase): def _run(self, to_static): paddle.jit.enable_to_static(to_static) @@ -120,6 +128,7 @@ def true_div(x, y): return z +@dy2static_unittest class TestTrueDiv(unittest.TestCase): def _run(self, to_static): paddle.jit.enable_to_static(to_static) diff --git a/test/dygraph_to_static/test_tensor_shape.py b/test/dygraph_to_static/test_tensor_shape.py index ad85daf7b0f78..d8c13cff35193 100644 --- a/test/dygraph_to_static/test_tensor_shape.py +++ b/test/dygraph_to_static/test_tensor_shape.py @@ -15,9 +15,9 @@ import unittest import numpy as np -from dygraph_to_static_util import ( +from dygraph_to_static_utils_new import ( + Dy2StTestBase, ast_only_test, - dy2static_unittest, test_and_compare_with_new_ir, ) @@ -235,8 +235,7 @@ def dyfunc_dict_assign_shape(): # 1. Basic tests without control flow -@dy2static_unittest -class TestTensorShapeBasic(unittest.TestCase): +class TestTensorShapeBasic(Dy2StTestBase): def setUp(self): self.input = np.ones(5).astype("int32") self.place = ( @@ -495,7 +494,7 @@ def _set_expected_op_num(self): # 5. Test op num for negative dim -class TestOpNumBasicWithTensorShape(unittest.TestCase): +class TestOpNumBasicWithTensorShape(Dy2StTestBase): def setUp(self): self._set_input_spec() self._set_test_func() @@ -617,7 +616,7 @@ def dyfunc_with_static_convert_var_shape(x): return res -class TestFindStatiConvertVarShapeSuffixVar(unittest.TestCase): +class TestFindStatiConvertVarShapeSuffixVar(Dy2StTestBase): @ast_only_test def test(self): x_spec = paddle.static.InputSpec(shape=[None, 10]) diff --git a/test/dygraph_to_static/test_to_tensor.py b/test/dygraph_to_static/test_to_tensor.py index ee33d56187efa..b211e09254ede 100644 --- a/test/dygraph_to_static/test_to_tensor.py +++ b/test/dygraph_to_static/test_to_tensor.py @@ -96,6 +96,10 @@ def case8(x): return a +def case_to_tensor_default_dtype(): + return paddle.to_tensor(1) + + @dy2static_unittest class TestToTensorReturnVal(unittest.TestCase): def test_to_tensor_badreturn(self): @@ -150,6 +154,13 @@ def test_to_tensor_badreturn(self): self.assertTrue(a.stop_gradient == b.stop_gradient) self.assertTrue(a.place._equals(b.place)) + def test_to_tensor_default_dtype(self): + a = paddle.jit.to_static(case_to_tensor_default_dtype)() + b = case_to_tensor_default_dtype() + self.assertTrue(a.dtype == b.dtype) + self.assertTrue(a.stop_gradient == b.stop_gradient) + self.assertTrue(a.place._equals(b.place)) + def test_to_tensor_err_log(self): paddle.disable_static() x = paddle.to_tensor([3]) @@ -162,6 +173,7 @@ def test_to_tensor_err_log(self): ) +@dy2static_unittest class TestStatic(unittest.TestCase): def test_static(self): paddle.enable_static() diff --git a/test/dygraph_to_static/test_transformer.py b/test/dygraph_to_static/test_transformer.py index 073535371ccde..29dda3916f3ab 100644 --- a/test/dygraph_to_static/test_transformer.py +++ b/test/dygraph_to_static/test_transformer.py @@ -20,7 +20,10 @@ import numpy as np import transformer_util as util -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) from transformer_dygraph_model import ( CrossEntropyCriterion, Transformer, @@ -527,6 +530,7 @@ def predict_static(args, batch_generator): return seq_ids, seq_scores +@dy2static_unittest class TestTransformer(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() diff --git a/test/dygraph_to_static/test_tsm.py b/test/dygraph_to_static/test_tsm.py index e68406bd4c9ab..2cef9e7df4ded 100644 --- a/test/dygraph_to_static/test_tsm.py +++ b/test/dygraph_to_static/test_tsm.py @@ -19,7 +19,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) from tsm_config_utils import merge_configs, parse_config, print_configs import paddle @@ -384,6 +387,7 @@ def train(args, fake_data_reader, to_static): return ret +@dy2static_unittest class TestTsm(unittest.TestCase): @test_and_compare_with_new_ir(False) def test_dygraph_static_same_loss(self): diff --git a/test/dygraph_to_static/test_typehint.py b/test/dygraph_to_static/test_typehint.py index b37a3539e2254..563db1d7a1df0 100644 --- a/test/dygraph_to_static/test_typehint.py +++ b/test/dygraph_to_static/test_typehint.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle import base @@ -33,6 +36,7 @@ def function(x: A) -> A: return 2 * x +@dy2static_unittest class TestTransformWhileLoop(unittest.TestCase): def setUp(self): self.place = ( diff --git a/test/dygraph_to_static/test_unuseful_inputs.py b/test/dygraph_to_static/test_unuseful_inputs.py index 603ffe9eba12d..8f83f015db431 100644 --- a/test/dygraph_to_static/test_unuseful_inputs.py +++ b/test/dygraph_to_static/test_unuseful_inputs.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle import nn @@ -62,6 +65,7 @@ def forward(self, x): return val +@dy2static_unittest class TestDuplicateOutput(unittest.TestCase): """ TestCase for the transformation from control flow `if/else` diff --git a/test/dygraph_to_static/test_utils.py b/test/dygraph_to_static/test_utils.py index 3361a866feb54..180078c144829 100644 --- a/test/dygraph_to_static/test_utils.py +++ b/test/dygraph_to_static/test_utils.py @@ -15,9 +15,12 @@ import types import unittest +from dygraph_to_static_util import dy2static_unittest + from paddle.jit.dy2static.utils import index_in_list, is_paddle_func +@dy2static_unittest class TestIndexInList(unittest.TestCase): def test_index_in_list(self): list_to_test = [1, 2, 3, 4, 5] @@ -49,6 +52,7 @@ def dyfunc_assign(input): y = n +@dy2static_unittest class TestIsPaddle(unittest.TestCase): def fake_module(self): return types.ModuleType('paddlenlp') diff --git a/test/dygraph_to_static/test_variable_trans_func.py b/test/dygraph_to_static/test_variable_trans_func.py index f2395fa517793..0ca73fbf9dd75 100644 --- a/test/dygraph_to_static/test_variable_trans_func.py +++ b/test/dygraph_to_static/test_variable_trans_func.py @@ -14,10 +14,13 @@ import unittest +from dygraph_to_static_util import dy2static_unittest + from paddle.jit.dy2static.utils import ast_to_source_code from paddle.jit.dy2static.variable_trans_func import create_fill_constant_node +@dy2static_unittest class TestVariableTransFunc(unittest.TestCase): def test_create_fill_constant_node(self): node = create_fill_constant_node("a", 1.0) diff --git a/test/dygraph_to_static/test_word2vec.py b/test/dygraph_to_static/test_word2vec.py index 85edea2093d82..0f16f5b2a9d23 100644 --- a/test/dygraph_to_static/test_word2vec.py +++ b/test/dygraph_to_static/test_word2vec.py @@ -17,7 +17,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle import base @@ -318,6 +321,7 @@ def train(to_static): return np.array(ret) +@dy2static_unittest class TestWord2Vec(unittest.TestCase): @test_and_compare_with_new_ir(False) def test_dygraph_static_same_loss(self): diff --git a/test/dygraph_to_static/test_yolov3.py b/test/dygraph_to_static/test_yolov3.py index 3f31b666c7f31..12830ca7bce55 100644 --- a/test/dygraph_to_static/test_yolov3.py +++ b/test/dygraph_to_static/test_yolov3.py @@ -17,7 +17,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) from yolov3 import YOLOv3, cfg import paddle @@ -165,6 +168,7 @@ def train(to_static): return np.array(ret) +@dy2static_unittest class TestYolov3(unittest.TestCase): @test_and_compare_with_new_ir(False) def test_dygraph_static_same_loss(self): diff --git a/test/sot/extract_errors.py b/test/sot/extract_errors.py new file mode 100644 index 0000000000000..b9d9e505724ef --- /dev/null +++ b/test/sot/extract_errors.py @@ -0,0 +1,30 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import sys + +runtime_error_msg = sys.stdin.read() + +pattern = r'File "?(.*?)"?, line (\d+),.*\n(.*?)\n(.*?)$' +for match in re.finditer(pattern, runtime_error_msg, re.MULTILINE): + file = match.group(1) + if file.startswith("./"): + file = f"tests/{file[2:]}" + line = match.group(2) + error_info = match.group(4) + if "AssertionError" not in error_info: + # error_info = match.group(3) + '\n' + match.group(4) + output = f"::error file={file},line={line}::Error" + print(output) diff --git a/test/sot/test_01_basic.py b/test/sot/test_01_basic.py new file mode 100644 index 0000000000000..8a03ea9fd3ae5 --- /dev/null +++ b/test/sot/test_01_basic.py @@ -0,0 +1,55 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase, strict_mode_guard + +import paddle + + +def foo(x: int, y: paddle.Tensor): + return x + y + + +class TestExecutor(TestCaseBase): + def test_simple(self): + self.assert_results(foo, 1, paddle.to_tensor(2)) + + +def numpy_add(x, y): + out = paddle.to_tensor(x.numpy() + y.numpy()) + return out + + +class TestNumpyAdd(TestCaseBase): + @strict_mode_guard(0) + def test_numpy_add(self): + x = paddle.to_tensor([2]) + y = paddle.to_tensor([3]) + self.assert_results(numpy_add, x, y) + + +if __name__ == "__main__": + unittest.main() + + +# Instructions: +# LOAD_FAST +# BINARY_ADD +# RETURN_VALUE + +# Variables: +# ConstantVariable +# TensorVariable diff --git a/test/sot/test_02_store_inplace.py b/test/sot/test_02_store_inplace.py new file mode 100644 index 0000000000000..3c9b4df4602a0 --- /dev/null +++ b/test/sot/test_02_store_inplace.py @@ -0,0 +1,47 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def foo(x: int, y: paddle.Tensor): + x = x + 1 + y = y + 1 + x += y + return x + + +class TestStoreInplace(TestCaseBase): + def test_simple(self): + self.assert_results(foo, 1, paddle.to_tensor(2)) + + +if __name__ == "__main__": + unittest.main() + + +# Instructions: +# LOAD_FAST +# BINARY_ADD +# STORE_FAST (new) +# INPLACE_ADD (new) +# RETURN_VALUE + +# Variables: +# ConstantVariable +# TensorVariable diff --git a/test/sot/test_03_tuple.py b/test/sot/test_03_tuple.py new file mode 100644 index 0000000000000..797d54384714d --- /dev/null +++ b/test/sot/test_03_tuple.py @@ -0,0 +1,91 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# New Supported Instructions: +# BUILD_TUPLE +# BINARY_SUBSCR + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot.psdb import check_no_breakgraph + + +@check_no_breakgraph +def build_tuple(x: int, y: paddle.Tensor): + x = (x, y) + return x[1] + 1 + + +@check_no_breakgraph +def build_tuple_with_slice_subscript(x: int, y: paddle.Tensor): + z = (x, y, 3, 4) + return z[0:5:1] + + +@check_no_breakgraph +def build_tuple_with_int_subscript(x: int, y: paddle.Tensor): + z = (x, y) + return z[0] + + +@check_no_breakgraph +def tuple_count_int(x: int, y: paddle.Tensor): + z = (x, x, 2, 1) + return z.count(x) + + +def tuple_count_tensor(x: paddle.Tensor, y: tuple[paddle.Tensor]): + return y.count(x) + + +@check_no_breakgraph +def tuple_index_int(x: int, y: paddle.Tensor): + z = (x, y, x, y, y) + return z.index(x) + + +def tuple_index_tensor(x: paddle.Tensor, y: tuple[paddle.Tensor]): + return y.index(x) + + +class TestBuildTuple(TestCaseBase): + def test_build_tuple(self): + self.assert_results(build_tuple, 1, paddle.to_tensor(2)) + self.assert_results( + build_tuple_with_slice_subscript, 1, paddle.to_tensor(2) + ) + self.assert_results( + build_tuple_with_int_subscript, 1, paddle.to_tensor(2) + ) + + +class TestTupleMethods(TestCaseBase): + def test_tuple_methods_int(self): + self.assert_results(tuple_count_int, 1, paddle.to_tensor(2)) + self.assert_results(tuple_index_int, 1, paddle.to_tensor(2)) + + def test_tuple_methods_tensor(self): + a = paddle.to_tensor(1) + b = paddle.to_tensor(2) + self.assert_results(tuple_count_tensor, a, (a, b, a, b)) + self.assert_results(tuple_index_tensor, b, (b, b, b, a)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_04_list.py b/test/sot/test_04_list.py new file mode 100644 index 0000000000000..d8b0823a279c2 --- /dev/null +++ b/test/sot/test_04_list.py @@ -0,0 +1,327 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# New Supported Instructions: +# BUILD_LIST (new) +# BINARY_SUBSCR +# DELETE_SUBSCR + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot.psdb import check_no_breakgraph + + +@check_no_breakgraph +def list_getitem_int(x: int, y: paddle.Tensor): + x = [x, y] + return x[0] + 1 + + +@check_no_breakgraph +def list_getitem_tensor(x: int, y: paddle.Tensor): + x = [x, y] + return x[1] + 1 + + +@check_no_breakgraph +def list_setitem_int(x: int, y: paddle.Tensor): + z = [x, y] + z[0] = 3 + return z + + +def list_setitem_tensor(x: int, y: paddle.Tensor): + z = [x, y] + z[1] = paddle.to_tensor(3) + return z + + +@check_no_breakgraph +def list_delitem_int(x: int, y: paddle.Tensor): + z = [x, y] + del z[0] + return z + + +@check_no_breakgraph +def list_delitem_tensor(x: int, y: paddle.Tensor): + z = [x, y] + del z[1] + return z + + +@check_no_breakgraph +def list_construct_from_list(x: int, y: paddle.Tensor): + z = [x, y] + return z + + +@check_no_breakgraph +def list_append_int(x: int, y: paddle.Tensor): + z = [x, y] + z.append(3) + return z + + +@check_no_breakgraph +def list_append_tensor(x: int, y: paddle.Tensor): + z = [x, y] + z.append(y) + return z + + +@check_no_breakgraph +def list_clear(x: int, y: paddle.Tensor): + z = [x, y] + z.clear() + return z + + +@check_no_breakgraph +def list_copy(x: int, y: paddle.Tensor): + z = [x, y] + a = z.copy() + z[0] = 3 + z[1] = y + 1 + return (a, z) + + +@check_no_breakgraph +def list_count_int(x: int, y: paddle.Tensor): + z = [x, x, 2, 3, 1] + return z.count(x) + + +def list_count_tensor(x: paddle.Tensor, y: list[paddle.Tensor]): + return y.count(x) + + +@check_no_breakgraph +def list_extend(x: int, y: paddle.Tensor): + z = [x, y] + a = [y, x] + b = (x, y) + z.extend(a) + z.extend(b) + return z + + +@check_no_breakgraph +def list_index_int(x: int, y: paddle.Tensor): + z = [x, x, 1, 2] + return z.index(x) + + +def list_index_tensor(x: paddle.Tensor, y: list[paddle.Tensor]): + return y.index(x) + + +@check_no_breakgraph +def list_insert(x: int, y: paddle.Tensor): + z = [x, y] + z.insert(0, x) + z.insert(3, y) + return z + + +@check_no_breakgraph +def list_pop(x: int, y: paddle.Tensor): + z = [x, y] + a = z.pop() + b = z.pop() + return (z, a, b) + + +@check_no_breakgraph +def list_remove(x: int, y: paddle.Tensor): + z = [x, x, y, y] + z.remove(x) + z.remove(y) + return z + + +@check_no_breakgraph +def list_reverse(x: int, y: paddle.Tensor): + z = [x, x, y, y] + z.reverse() + return z + + +@check_no_breakgraph +def list_default_sort(x: int, y: paddle.Tensor): + z = [x + 2, x, x + 1] + z.sort() + return z + + +@check_no_breakgraph +def list_key_sort(x: int, y: paddle.Tensor): + z = [x + 2, x, x + 1] + z.sort(lambda x: x) + return z + + +@check_no_breakgraph +def list_reverse_sort(x: int, y: paddle.Tensor): + z = [x + 2, x, x + 1] + z.sort(reverse=True) + return z + + +@check_no_breakgraph +def list_tensor_sort(x: int, y: paddle.Tensor): + z = [y + 2, y, y + 1] + z.sort() + return z + + +@check_no_breakgraph +def list_max(x: paddle.Tensor | int, y: paddle.Tensor | int): + z = [x, x, y] + return max(z) + + +@check_no_breakgraph +def list_tensor_max_api(x: paddle.Tensor): + return x.max() + + +@check_no_breakgraph +def list_min(x: paddle.Tensor | int, y: paddle.Tensor | int): + z = [x, x, y] + return min(z) + + +@check_no_breakgraph +def list_tensor_min_api(x: paddle.Tensor): + return x.min() + + +@check_no_breakgraph +def list_no_arguments(): + l1 = list() # noqa: C408 + l1.append(1) + l2 = list() # noqa: C408 + l2.append(2) + return l1[0] + l2[0] + + +class TestListBasic(TestCaseBase): + def test_list_basic(self): + self.assert_results(list_getitem_int, 1, paddle.to_tensor(2)) + self.assert_results(list_getitem_tensor, 1, paddle.to_tensor(2)) + self.assert_results_with_side_effects( + list_setitem_int, 1, paddle.to_tensor(2) + ) + + +class TestListMethods(TestCaseBase): + def test_list_setitem(self): + self.assert_results_with_side_effects( + list_setitem_tensor, 1, paddle.to_tensor(2) + ) + + def test_list_count_and_index(self): + self.assert_results(list_count_int, 1, paddle.to_tensor(2)) + self.assert_results(list_index_int, 1, paddle.to_tensor(2)) + a = paddle.to_tensor(1) + b = paddle.to_tensor(2) + self.assert_results(list_count_tensor, a, [a, b, a, b, a, b]) + self.assert_results(list_index_tensor, b, [a, b, a, b, a, b]) + + def test_list_delitem(self): + self.assert_results_with_side_effects( + list_delitem_int, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + list_delitem_tensor, 1, paddle.to_tensor(2) + ) + + def test_list_append(self): + self.assert_results_with_side_effects( + list_append_int, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + list_append_tensor, 1, paddle.to_tensor(2) + ) + + def test_list_clear(self): + self.assert_results_with_side_effects( + list_clear, 1, paddle.to_tensor(2) + ) + + def test_list_copy(self): + self.assert_results_with_side_effects(list_copy, 1, paddle.to_tensor(2)) + + def test_list_extend(self): + self.assert_results_with_side_effects( + list_extend, 1, paddle.to_tensor(2) + ) + + def test_list_insert(self): + self.assert_results_with_side_effects( + list_insert, 1, paddle.to_tensor(2) + ) + + def test_list_pop(self): + self.assert_results_with_side_effects(list_pop, 1, paddle.to_tensor(2)) + + def test_list_remove(self): + self.assert_results_with_side_effects( + list_remove, 1, paddle.to_tensor(2) + ) + + def test_list_reverse(self): + self.assert_results_with_side_effects( + list_reverse, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + list_reverse, 1, paddle.to_tensor(2) + ) + + def test_list_sort(self): + self.assert_results_with_side_effects( + list_default_sort, 1, paddle.to_tensor(2) + ) + # TODO: Not currently supported + # self.assert_results_with_side_effects( + # list_tensor_sort, 1, paddle.to_tensor(2) + # ) + # self.assert_results_with_side_effects( + # list_key_sort, 1, paddle.to_tensor(2) + # ) + # self.assert_results_with_side_effects( + # list_reverse_sort, 1, paddle.to_tensor(2) + # ) + + def test_list_construct_from_list(self): + self.assert_results(list_construct_from_list, 1, paddle.to_tensor(2)) + + def test_list_max_min(self): + self.assert_results(list_max, 1, 2) + self.assert_results(list_min, 1, 2) + self.assert_results(list_tensor_max_api, paddle.to_tensor([1, 2, 3])) + self.assert_results(list_tensor_min_api, paddle.to_tensor([1, 2, 3])) + + def test_list_noargs(self): + self.assert_results(list_no_arguments) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_05_dict.py b/test/sot/test_05_dict.py new file mode 100644 index 0000000000000..7014a71746798 --- /dev/null +++ b/test/sot/test_05_dict.py @@ -0,0 +1,264 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# New Supported Instructions: +# BUILD_MAP (new) +# BUILD_CONST_KEY_MAP (new) + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot.psdb import check_no_breakgraph + + +@check_no_breakgraph +def build_map(x: int, y: paddle.Tensor): + z = {x: y} + return z[x] + 1 + + +@check_no_breakgraph +def build_const_key_map(x: int, y: paddle.Tensor): + z = {1: y, 2: y + 1} + return z[x] + 1 + + +@check_no_breakgraph +def dict_get_item(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + return (z.get(1), z.get(2)) + + +@check_no_breakgraph +def dict_get_item_default(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + return (z.get(3, 2), z.get(4, y)) + + +@check_no_breakgraph +def dict_set_item_int(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + z[1] = x * 2 + return z[1] + + +@check_no_breakgraph +def dict_set_item_tensor(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + z[2] = y + return z[1] + + +@check_no_breakgraph +def dict_update_item1(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + z.update({1: x * 2, 2: y, 3: y + 2}) + return z + + +@check_no_breakgraph +def dict_update_item2(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + z.update({1: x * 2, 2: y, 3: z[2] + 2}) + return z + + +@check_no_breakgraph +def dict_del_item_int(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + del z[1] + return z + + +@check_no_breakgraph +def dict_del_item_tensor(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + del z[2] + return z + + +@check_no_breakgraph +def dict_clear(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + z.clear() + return z + + +@check_no_breakgraph +def dict_copy(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + z2 = z.copy() + z[1] = 2 + return z2 + + +@check_no_breakgraph +def dict_setdefault_int(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + a = z.setdefault(4) + b = z.setdefault(1, 2) + c = z.setdefault(3, 4) + return (z, a, b, c) + + +@check_no_breakgraph +def dict_pop(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1, 3: y} + a = z.pop(1) + b = z.pop(2, 3) + c = z.pop(4, 3) + d = z.pop(5, y) + return (z, a, b, c, d) + + +@check_no_breakgraph +def dict_popitem(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1, 3: y} + a = z.popitem() + return (z, a) + + +@check_no_breakgraph +def dict_construct_from_dict(): + x = {1: 2, 3: 4} + d = dict(x) + return d + + +@check_no_breakgraph +def dict_construct_from_list(): + x = [[1, 2], [3, 4]] + d = dict(x) + return d + + +@check_no_breakgraph +def dict_construct_from_tuple(): + x = ((1, 2), (3, 4)) + d = dict(x) + return d + + +@check_no_breakgraph +def dict_construct_from_comprehension(): + z = {1: 2, 3: 4} + d = {k: v + 1 for k, v in z.items()} + return d + + +@check_no_breakgraph +def dict_no_arguments(): + d1 = dict() # noqa: C408 + d1.update({1: 2}) + d2 = dict() # noqa: C408 + d2.update({3: 4}) + return d1[1] + d2[3] + + +@check_no_breakgraph +def dict_test_fromkeys(x): + d = dict.fromkeys(x) + return d + + +@check_no_breakgraph +def dict_test_fromkeys_defalut(x, y): + d = dict.fromkeys(x, y) + return d + + +class TestBuildDict(TestCaseBase): + def test_build_map(self): + self.assert_results(build_map, 1, paddle.to_tensor(2)) + + def test_build_const_key_map(self): + self.assert_results(build_const_key_map, 1, paddle.to_tensor(2)) + + +class TestDictMethods(TestCaseBase): + def test_dict_get_item(self): + self.assert_results(dict_get_item, 1, paddle.to_tensor(2)) + self.assert_results(dict_get_item_default, 1, paddle.to_tensor(2)) + + def test_dict_set_item(self): + self.assert_results_with_side_effects( + dict_set_item_int, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + dict_set_item_tensor, 1, paddle.to_tensor(2) + ) + + def test_dict_copy(self): + self.assert_results_with_side_effects(dict_copy, 1, paddle.to_tensor(2)) + + def test_dict_update(self): + self.assert_results_with_side_effects( + dict_update_item1, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + dict_update_item2, 1, paddle.to_tensor(2) + ) + + def test_dict_setdefault(self): + self.assert_results_with_side_effects( + dict_setdefault_int, 1, paddle.to_tensor(2) + ) + + def test_dict_del_item(self): + self.assert_results_with_side_effects( + dict_del_item_int, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + dict_del_item_tensor, 1, paddle.to_tensor(2) + ) + + def test_dict_clear(self): + self.assert_results_with_side_effects( + dict_clear, 1, paddle.to_tensor(2) + ) + + def test_dict_pop(self): + self.assert_results_with_side_effects(dict_pop, 1, paddle.to_tensor(2)) + + def test_dict_popitem(self): + self.assert_results_with_side_effects( + dict_popitem, 1, paddle.to_tensor(2) + ) + + def test_construct(self): + self.assert_results(dict_construct_from_dict) + self.assert_results(dict_construct_from_list) + self.assert_results(dict_construct_from_tuple) + self.assert_results(dict_construct_from_comprehension) + + def test_dict_noargs(self): + self.assert_results(dict_no_arguments) + + def test_dict_fromkeys(self): + self.assert_results(dict_test_fromkeys, (1, 2, 3, 4)) + self.assert_results(dict_test_fromkeys, [1, 2, 3, 4]) + self.assert_results(dict_test_fromkeys_defalut, (1, 2, 3, 4), 1) + self.assert_results( + dict_test_fromkeys_defalut, (1, 2, 3, 4), paddle.to_tensor(1) + ) + self.assert_results(dict_test_fromkeys_defalut, [1, 2, 3, 4], 1) + self.assert_results( + dict_test_fromkeys_defalut, [1, 2, 3, 4], paddle.to_tensor(1) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_06_call_function.py b/test/sot/test_06_call_function.py new file mode 100644 index 0000000000000..4358afe6ca985 --- /dev/null +++ b/test/sot/test_06_call_function.py @@ -0,0 +1,153 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def add(x, y): + return x + y + + +def sub(x, y): + return x - y + + +def foo_1(x: paddle.Tensor): + m = x + 1 + y = add(m * 3, m * 2) + return y + + +def foo_2(x: paddle.Tensor): + m = x + 1 + y = sub(m * 3, m * 2) + return y + + +def foo_3(x: paddle.Tensor): + m = x + 1 + y = sub(m * 3, m * 2) + y = sub(y, y) + y = sub(y, y) + return y + + +def nest_2(x): + return x + 1 + + +def nest_1(x): + return (x - 1) * 2 + + +def foo_4(x: paddle.Tensor): + m = x + 1 + m = nest_1(m) + return m + + +def fn_with_varargs_and_kwargs(x, *args, **kwargs): + return ( + x + + args[0] + + args[1] + - args[2] + + kwargs['a'] * kwargs['b'] / kwargs['c'] + ) + + +def foo_5(x: paddle.Tensor): + m = x + 1 + m = fn_with_varargs_and_kwargs( + m, x + 1, x + 2, x + 3, a=x + 4, b=x + 5, c=x + 6 + ) + return m + + +def fn_with_default_value(x, y=1, z=2): + return x + y + z + + +def foo_6(x: paddle.Tensor): + m = x + 1 + m = fn_with_default_value(m, m + 10) + m = fn_with_default_value(m + 42) + return m + + +def fn_with_default_value_and_varargs_kwargs(x, y=1, *args, **kwargs): + return x + y + args[0] + kwargs['a'] + + +def foo_7(x: paddle.Tensor): + m = x + 1 + m = fn_with_default_value_and_varargs_kwargs(m, m + 1, m + 2, a=m + 3) + return m + + +def fn_with_default_value_and_varargs_kwargs_kwonly_1( + x, y=1, *args, z, **kwargs +): + return x + y + args[0] + kwargs['a'] + z + + +def fn_with_default_value_and_varargs_kwargs_kwonly_2( + x, y=1, *args, z=10, **kwargs +): + return x + y + args[0] + kwargs['a'] + z + + +def foo_8(x: paddle.Tensor): + m = x + 1 + m = fn_with_default_value_and_varargs_kwargs_kwonly_1( + m, m + 1, m + 2, a=m + 3, z=m + 4 + ) + m = fn_with_default_value_and_varargs_kwargs_kwonly_2( + m, m + 1, m + 2, a=m + 3 + ) + return m + + +class TestCall(TestCaseBase): + def test_call1(self): + self.assert_results(foo_1, paddle.to_tensor(2)) + + def test_call2(self): + self.assert_results(foo_2, paddle.to_tensor(3)) + + def test_call3(self): + self.assert_results(foo_3, paddle.to_tensor(4)) + + def test_call4(self): + self.assert_results(foo_4, paddle.to_tensor(5)) + + def test_call5(self): + self.assert_results(foo_5, paddle.to_tensor(6)) + + def test_call6(self): + self.assert_results(foo_6, paddle.to_tensor(7)) + + def test_call7(self): + self.assert_results(foo_7, paddle.to_tensor(8)) + + def test_call8(self): + self.assert_results(foo_8, paddle.to_tensor(9)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_07_unpack.py b/test/sot/test_07_unpack.py new file mode 100644 index 0000000000000..f04a185294b6f --- /dev/null +++ b/test/sot/test_07_unpack.py @@ -0,0 +1,70 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# New Supported Instructions: +# UNPACK_SEQUENCE (new) + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def unpack_tuple(x: tuple[int, paddle.Tensor]): + y, z = x + return z + 1 + + +def unpack_tensor(x: paddle.Tensor): + a, b = x + return (a, b) + + +def unpack_ex_tuple(x: tuple[int, int, paddle.Tensor]): + *y, z = x + return z + 1 + + +def unpack_ex_tensor(x: paddle.Tensor): + a, b, *c = x + return (a, b) + + +def unpack_ex_tensor_2(x: paddle.Tensor): + a, *b, c, d = x + return (a, c) + + +class TestUnpack(TestCaseBase): + def test_unpack_tuple(self): + self.assert_results(unpack_tuple, (1, paddle.to_tensor(2))) + + def test_unpack_tensor(self): + self.assert_results(unpack_tensor, paddle.to_tensor([2, 3])) + + def test_unpack_ex_tuple(self): + self.assert_results(unpack_ex_tuple, (1, 1, paddle.to_tensor(2))) + + def test_unpack_ex_tensor(self): + self.assert_results(unpack_ex_tensor, paddle.to_tensor([2, 3, 3, 3])) + + def test_unpack_ex_tensor_2(self): + self.assert_results(unpack_ex_tensor_2, paddle.to_tensor([2, 3, 3, 3])) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_08_rot.py b/test/sot/test_08_rot.py new file mode 100644 index 0000000000000..2d9146e3ff3ba --- /dev/null +++ b/test/sot/test_08_rot.py @@ -0,0 +1,97 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def rot_two_return_a(a: paddle.Tensor, b: paddle.Tensor): + b, a = a, b + return a + 1 + + +def rot_two_return_b(a: paddle.Tensor, b: paddle.Tensor): + b, a = a, b + return b + 2 + + +def rot_three_return_a(a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor): + a, b, c = c, b, a + return a + 1 + + +def rot_three_return_b(a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor): + a, b, c = c, b, a + return b + 1 + + +def rot_three_return_c(a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor): + a, b, c = c, b, a + return c + 1 + + +def rot_four_return_a( + a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor +): + a, b, c, d = d, c, b, a + return a + 1 + + +def rot_four_return_b( + a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor +): + a, b, c, d = d, c, b, a + return b + 1 + + +def rot_four_return_c( + a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor +): + a, b, c, d = d, c, b, a + return c + 1 + + +def rot_four_return_d( + a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor +): + a, b, c, d = d, c, b, a + return d + 1 + + +class TestExecutor(TestCaseBase): + def test_simple(self): + a = paddle.to_tensor(1) + b = paddle.to_tensor(2) + c = paddle.to_tensor(3) + d = paddle.to_tensor(4) + self.assert_results(rot_two_return_a, a, b) + self.assert_results(rot_two_return_b, a, b) + + self.assert_results(rot_three_return_a, a, b, c) + self.assert_results(rot_three_return_b, a, b, c) + self.assert_results(rot_three_return_c, a, b, c) + + self.assert_results(rot_four_return_a, a, b, c, d) + self.assert_results(rot_four_return_b, a, b, c, d) + self.assert_results(rot_four_return_c, a, b, c, d) + self.assert_results(rot_four_return_d, a, b, c, d) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_09_f_string.py b/test/sot/test_09_f_string.py new file mode 100644 index 0000000000000..c2a3b8144605b --- /dev/null +++ b/test/sot/test_09_f_string.py @@ -0,0 +1,41 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# FORMAT_VALUE (new) +# BUILD_STRING (new) +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot.psdb import assert_true + + +def foo(x: paddle.Tensor): + whilespace = 123 + hello_world = f"Hello {whilespace} World" + z = assert_true(hello_world == "Hello 123 World") + x = x + 1 + return x + + +class TestFString(TestCaseBase): + def test_fstring(self): + self.assert_results(foo, paddle.to_tensor(1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_10_build_unpack.py b/test/sot/test_10_build_unpack.py new file mode 100644 index 0000000000000..0b35c46901863 --- /dev/null +++ b/test/sot/test_10_build_unpack.py @@ -0,0 +1,97 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# BUILD_TUPLE_UNPACK (new) +# BUILD_LIST_UNPACK (new) +# BUILD_TUPLE_UNPACK_WITH_CALL (new) +# CALL_FUNCTION_EX (new) +# BUILD_MAP_UNPACK (new) +# LIST_EXTEND (new) +# LIST_TO_TUPLE (new) +# DICT_UPDATE (new) +# DICT_MERGE (new) + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def build_tuple_unpack(x: tuple[paddle.Tensor], y: tuple[paddle.Tensor]): + z = (*x, *y) + + return z[0] + 1 + + +def build_list_unpack(x: list[paddle.Tensor], y: list[paddle.Tensor]): + z = [*x, *y] + return z[0] + 1 + + +def build_tuple_unpack_with_call( + x: tuple[paddle.Tensor], y: tuple[paddle.Tensor] +): + z = build_tuple_unpack_with_call_inner(*x, *y) + return z[0] + 1 + + +def build_tuple_unpack_with_call_inner( + a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor +): + z = (a, b, c, d) + return z + + +def build_map_unpack(x: dict[str, paddle.Tensor], y: dict[str, paddle.Tensor]): + z = {**x, **y} + return z["a"] + 1 + + +def build_map_unpack_with_call_inner( + a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor +): + z = {"a": a, "b": b, "c": c, "d": d} + return z + + +def build_map_unpack_with_call( + x: dict[str, paddle.Tensor], y: dict[str, paddle.Tensor] +): + z = build_map_unpack_with_call_inner(**x, **y) + return z["a"] + 1 + + +class TestExecutor(TestCaseBase): + def test_simple(self): + a = paddle.to_tensor(1) + b = paddle.to_tensor(2) + c = paddle.to_tensor(3) + d = paddle.to_tensor(4) + + self.assert_results(build_tuple_unpack, (a, b), (c, d)) + self.assert_results(build_list_unpack, [a, b], [c, d]) + self.assert_results(build_tuple_unpack_with_call, (a, b), (c, d)) + self.assert_results( + build_map_unpack, {"a": a, "b": b}, {"c": c, "d": d} + ) + self.assert_results( + build_map_unpack_with_call, {"a": a, "b": b}, {"c": c, "d": d} + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_11_jumps.py b/test/sot/test_11_jumps.py new file mode 100644 index 0000000000000..80fa1f4a4eb02 --- /dev/null +++ b/test/sot/test_11_jumps.py @@ -0,0 +1,118 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot.psdb import check_no_breakgraph + + +@check_no_breakgraph +def pop_jump_if_false(x: bool, y: paddle.Tensor): + if x: + y += 1 + else: + y -= 1 + return y + + +@check_no_breakgraph +def pop_jump_if_true(x: bool, y: bool, z: paddle.Tensor): + return (x or y) and z + + +@check_no_breakgraph +def jump_if_false_or_pop(x: bool, y: paddle.Tensor): + return x and (y + 1) + + +@check_no_breakgraph +def jump_if_true_or_pop(x: bool, y: paddle.Tensor): + return x or (y + 1) + + +@check_no_breakgraph +def jump_absolute(x: int, y: paddle.Tensor): + while x > 0: + y += 1 + x -= 1 + return y + + +@check_no_breakgraph +def pop_jump_if_none(x: bool, y: paddle.Tensor): + if x is not None: + y += 1 + else: + y -= 1 + return y + + +@check_no_breakgraph +def pop_jump_if_not_none(x: bool, y: paddle.Tensor): + if x is None: + y += 1 + else: + y -= 1 + return y + + +a = paddle.to_tensor(1) +b = paddle.to_tensor(2) +c = paddle.to_tensor(3) +d = paddle.to_tensor(4) + +true_tensor = paddle.to_tensor(True) +false_tensor = paddle.to_tensor(False) + + +class TestExecutor(TestCaseBase): + def test_simple(self): + self.assert_results(jump_absolute, 5, a) + + self.assert_results(pop_jump_if_false, True, a) + self.assert_results(pop_jump_if_false, False, a) + self.assert_results(jump_if_false_or_pop, True, a) + self.assert_results(jump_if_false_or_pop, False, a) + self.assert_results(jump_if_true_or_pop, True, a) + self.assert_results(jump_if_true_or_pop, False, a) + self.assert_results(pop_jump_if_true, True, False, a) + self.assert_results(pop_jump_if_true, False, False, a) + + self.assert_results(pop_jump_if_none, None, a) + self.assert_results(pop_jump_if_none, True, a) + self.assert_results(pop_jump_if_not_none, None, a) + self.assert_results(pop_jump_if_not_none, True, a) + + def test_breakgraph(self): + self.assert_results(pop_jump_if_false, true_tensor, a) + self.assert_results(jump_if_false_or_pop, true_tensor, a) + self.assert_results(jump_if_true_or_pop, false_tensor, a) + self.assert_results(pop_jump_if_true, true_tensor, false_tensor, a) + self.assert_results(jump_absolute, 5, a) + self.assert_results(pop_jump_if_false, false_tensor, a) + self.assert_results(jump_if_false_or_pop, false_tensor, a) + self.assert_results(jump_if_true_or_pop, false_tensor, a) + self.assert_results(pop_jump_if_true, true_tensor, false_tensor, a) + + self.assert_results(pop_jump_if_none, true_tensor, a) + self.assert_results(pop_jump_if_not_none, true_tensor, a) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_12_for_loop.py b/test/sot/test_12_for_loop.py new file mode 100644 index 0000000000000..63e3fedace4bf --- /dev/null +++ b/test/sot/test_12_for_loop.py @@ -0,0 +1,298 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# GET_ITER (new) +# FOR_ITER (new) + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase, strict_mode_guard + +import paddle +from paddle.jit import sot +from paddle.jit.sot import symbolic_translate +from paddle.jit.sot.opcode_translator.executor.executor_cache import ( + OpcodeExecutorCache, +) + + +def gener(): + yield 1 + yield 2 + yield 3 + + +def for_list_1(x: paddle.Tensor): + for i in [1, 2, 3]: + x += i + + if x > 2: + x += 1 + else: + x -= 1 + return x + + +def for_list_2(x: paddle.Tensor): + for i in [1, 2, 3]: + x += i + + if i > 2: + x += 1 + else: + x -= 1 + return x + + +def for_dict(x: paddle.Tensor): + map = {1: 2, 3: 4} + for k in map.keys(): + x += k + + for v in map.values(): + x += v + + for k, v in map.items(): + x += k + x += v + + return x + + +def for_iter(x, it): + for item in it: + x += item + return x + + +def for_for_fallback(x, it): + for i in [1, 2, 3]: + for item in it: + x += item + return x + + +def for_break(x: paddle.Tensor, it): + for i in [1, 2, 3]: + x += i + if i == 2: + break + for i in it: + x += i + if i == 2: + break + return x + + +def for_continue(x: paddle.Tensor, it): + for i in [1, 2, 3]: + if i == 2: + continue + x += i + + for i in it: + if i == 2: + continue + x += i + return x + + +def for_enumerate_var_with_nested_range(x_array): + x = paddle.tensor.fill_constant([1], 'int32', 0) + x_array = paddle.to_tensor(x_array) + for i, num in enumerate(x_array): + for idx in range(num): + x = x + num + return x + + +def for_create_tmp_in_loop(x, it): + s = x + for i in it: + tmp = i + s += tmp + return s, tmp + + +def for_without_zero_iter(self_res_dict, output): + res_dict = {"logits": output} + for res_key in list(self_res_dict): + res_dict[res_key] = self_res_dict.pop(res_key) + return res_dict + + +@sot.psdb.check_no_fallback +def for_reconstruct_range_iter(): + for i in range(3): + sot.psdb.breakgraph() + + +global_var_name = None + + +def for_tmp_var_with_same_name_as_global_var(): + total = 0 + for i in range(3): + global_var_name = i + 3 + sot.psdb.breakgraph() + total += global_var_name + return total + + +def for_layer_list(layer_list, x): + for net in layer_list: + x = net(x) + return x + + +class TestForLoop(TestCaseBase): + def test_list(self): + a = paddle.to_tensor(1) + self.assert_results(for_list_1, a) + + def test_list_with_fallback(self): + a = paddle.to_tensor(1) + self.assert_results(for_list_2, a) + + def test_dict(self): + a = paddle.to_tensor(1) + self.assert_results(for_dict, a) + + def test_fallback(self): + a = paddle.to_tensor(1) + + sym_output = symbolic_translate(for_iter)(a, gener()) + paddle_output = for_iter(a, gener()) + self.assert_nest_match(sym_output, paddle_output) + + def test_for_for_fallback(self): + a = paddle.to_tensor(1) + + sym_output = symbolic_translate(for_iter)(a, gener()) + paddle_output = for_iter(a, gener()) + self.assert_nest_match(sym_output, paddle_output) + + def test_for_break(self): + a = paddle.to_tensor(1) + sym_output = symbolic_translate(for_break)(a, gener()) + paddle_output = for_break(a, gener()) + self.assert_nest_match(sym_output, paddle_output) + + def test_for_continue(self): + a = paddle.to_tensor(1) + sym_output = symbolic_translate(for_continue)(a, gener()) + paddle_output = for_continue(a, gener()) + self.assert_nest_match(sym_output, paddle_output) + + # TODO(zmh): support range for tensor + # def test_resume_stack(self): + # a = [1, 2, 3] + # self.assert_results(for_enumerate_var_with_nested_range, a) + + def test_create_var_in_loop(self): + x = paddle.to_tensor(1, dtype="float32") + a = [1, 2, 3] + self.assert_results(for_create_tmp_in_loop, x, a) + + sym_output = symbolic_translate(for_create_tmp_in_loop)(x, iter(a)) + paddle_output = for_create_tmp_in_loop(x, iter(a)) + self.assert_nest_match(sym_output, paddle_output) + + def test_create_var_in_loop_with_same_name_as_global(self): + self.assert_results(for_tmp_var_with_same_name_as_global_var) + + def test_for_without_zero_iter(self): + self_res_dict = {} + output = paddle.to_tensor(2) + self.assert_results(for_without_zero_iter, self_res_dict, output) + + def test_reconstruct_range_iter(self): + self.assert_results(for_reconstruct_range_iter) + + def test_layer_list(self): + layers = paddle.nn.LayerList() + for i in range(5): + layers.append(paddle.nn.Linear(5, 5)) + x = paddle.rand([5], dtype="float32") + self.assert_results(for_layer_list, layers, x) + + +def run_list_comp(x): + out = [s.chunk(2, axis=1) for s in x] + return out + + +class TestListComp(TestCaseBase): + def test_list_comp(self): + x = [paddle.randn([1, 4]), paddle.randn([1, 4])] + self.assert_results(run_list_comp, x) + + +def for_enumerate_cache(func_list, x): + out = None + for idx, func in enumerate(func_list): + out = func(x[idx]) + return out + + +class TestEnumerateCache(TestCaseBase): + def test_run(self): + func_list = [ + paddle.nn.Linear(10, 10), + ] + x = [ + paddle.randn([5, 10]), + ] + + out = symbolic_translate(for_enumerate_cache)(func_list, x) + out = symbolic_translate(for_enumerate_cache)(func_list, x) + self.assert_nest_match(OpcodeExecutorCache().translate_count, 1) + + +# after_loop_fn need zzz, and zzz is created as UndefinedVar when generating loop body +# do not set zzz as UndefinedVar again +def undefined_var_case_0(): + for i in [1, 2]: + sot.psdb.breakgraph() + zzz = i + + zzz = zzz + 1 + return zzz + + +# after_loop_fn need create zzz as UndefinedVar +def undefined_var_case_1(): + for i in [1, 2]: + sot.psdb.breakgraph() + aaa = i + + for i in [1, 3]: + zzz = i + zzz = zzz + 1 + return zzz + + +class TestUndefinedVarInRiskyCodes(TestCaseBase): + def test_undefined_var_case_0(self): + self.assert_results(undefined_var_case_0) + + def test_undefined_var_case_1(self): + self.assert_results(undefined_var_case_1) + + +if __name__ == "__main__": + with strict_mode_guard(0): + unittest.main() diff --git a/test/sot/test_13_make_function.py b/test/sot/test_13_make_function.py new file mode 100644 index 0000000000000..9784d7ffad385 --- /dev/null +++ b/test/sot/test_13_make_function.py @@ -0,0 +1,39 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MAKE_FUNCTION +# CALL_FUNCTION_KW +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def make_fn(x: paddle.Tensor): + def fn(a, b=2, c=3, d=4): + return a + b + c + d + + return fn(1) + fn(2, c=5) + x + + +class TestExecutor(TestCaseBase): + def test_simple(self): + self.assert_results(make_fn, paddle.to_tensor(1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_14_operators.py b/test/sot/test_14_operators.py new file mode 100644 index 0000000000000..fc403ae3ef665 --- /dev/null +++ b/test/sot/test_14_operators.py @@ -0,0 +1,387 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import operator +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def unary_positive(x: int): + y = +x + return y + + +def unary_negative(x: paddle.Tensor): + y = -x + return y + + +def unary_not(x: paddle.Tensor): + y = not x + return y + + +def unary_invert(x: paddle.Tensor): + y = ~x + return y + + +def binary_power(x: paddle.Tensor, y: paddle.Tensor): + z = x**y + return z + + +def binary_multiply(x: paddle.Tensor, y: paddle.Tensor): + z = x * y + return z + + +def binary_matrix_multiply(x: paddle.Tensor, y: paddle.Tensor): + z = x @ y + return z + + +def binary_floor_divide(x: paddle.Tensor, y: paddle.Tensor): + z = x // y + return z + + +def binary_true_divide(x: paddle.Tensor, y: paddle.Tensor): + z = x / y + return z + + +def binary_modulo(x: paddle.Tensor, y: paddle.Tensor): + z = x % y + return z + + +def binary_add(x: paddle.Tensor, y: paddle.Tensor): + z = x + y + return z + + +def binary_subtract(x: paddle.Tensor, y: paddle.Tensor): + z = x - y + return z + + +def binary_lshift(x: int, y: int): + z = x << y + return z + + +def binary_rshift(x: int, y: int): + z = x >> y + return z + + +def binary_and(x: paddle.Tensor, y: paddle.Tensor): + z = x & y + return z + + +def binary_or(x: paddle.Tensor, y: paddle.Tensor): + z = x | y + return z + + +def binary_xor(x: paddle.Tensor, y: paddle.Tensor): + z = x ^ y + return z + + +def inplace_power(x: paddle.Tensor, y: paddle.Tensor): + x **= y + return x + + +def inplace_multiply(x: paddle.Tensor, y: paddle.Tensor): + x *= y + return x + + +def inplace_matrix_multiply(x: paddle.Tensor, y: paddle.Tensor): + x @= y + return x + + +def inplace_floor_divide(x: paddle.Tensor, y: paddle.Tensor): + x //= y + return x + + +def inplace_true_divide(x: paddle.Tensor, y: paddle.Tensor): + x /= y + return x + + +def inplace_modulo(x: paddle.Tensor, y: paddle.Tensor): + x %= y + return x + + +def inplace_add(x: paddle.Tensor, y: paddle.Tensor): + x += y + return x + + +def inplace_subtract(x: paddle.Tensor, y: paddle.Tensor): + x -= y + return x + + +def inplace_lshift(x: paddle.Tensor, y: int): + x <<= y + return x + + +def inplace_rshift(x: paddle.Tensor, y: int): + x >>= y + return x + + +def inplace_and(x: paddle.Tensor, y: paddle.Tensor): + x &= y + return x + + +def inplace_or(x: paddle.Tensor, y: paddle.Tensor): + x |= y + return x + + +def inplace_xor(x: paddle.Tensor, y: paddle.Tensor): + x ^= y + return x + + +def list_getitem(x: int, y: paddle.Tensor): + z = [x, y] + return operator.getitem(z, 1) + 1 + + +def list_getitem_slice(x: int, y: paddle.Tensor): + z = [x, y] + return operator.getitem(z, slice(0, 2)) + + +def list_setitem_int(x: int, y: paddle.Tensor): + z = [x, y] + operator.setitem(z, 0, 3) + return z + + +def list_setitem_tensor(x: int, y: paddle.Tensor): + z = [x, y] + operator.setitem(z, 1, paddle.to_tensor(3)) + return z + + +def list_delitem_int(x: int, y: paddle.Tensor): + z = [x, y] + operator.delitem(z, 0) + return z + + +def list_delitem_tensor(x: int, y: paddle.Tensor): + z = [x, y] + operator.delitem(z, 1) + return z + + +def dict_getitem_int(x: int, y: paddle.Tensor): + z = {1: y, 2: y + 1} + return operator.getitem(z, 1) + + +def dict_getitem_tensor(x: int, y: paddle.Tensor): + z = {1: y, 2: y + 1} + return operator.getitem(z, 2) + + +def dict_setitem_int(x: int, y: paddle.Tensor): + z = {'x': x, 'y': y} + operator.setitem(z, 'x', 2) + return z + + +def dict_setitem_tensor(x: int, y: paddle.Tensor): + z = {'x': x, 'y': y} + operator.setitem(z, 'y', paddle.to_tensor(3)) + return z + + +def dict_delitem_int(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + operator.delitem(z, 1) + return z + + +def dict_delitem_tensor(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + operator.delitem(z, 2) + return z + + +def tuple_getitem_int(x: int, y: paddle.Tensor): + x = (x, y) + return operator.getitem(x, 0) + + +def tuple_getitem_tensor(x: int, y: paddle.Tensor): + x = (x, y) + return operator.getitem(x, 1) + + +def tuple_getitem_slice(x: int, y: paddle.Tensor): + x = (x, y, 1) + return operator.getitem(x, slice(0, 2)) + + +def operator_add(x: int, y: paddle.Tensor): + return operator.add(x, y) + + +def operator_mul(x: int, y: paddle.Tensor): + return operator.mul(x, y) + + +def operator_truth(y: paddle.Tensor): + return operator.truth(y) + + +def operator_is_(x: paddle.Tensor, y: paddle.Tensor): + return (operator.is_(x, x), operator.is_(x, y)) + + +def operator_in_(x: int, y: list): + return x in y + + +def operator_not_in_(x: int, y: list): + return x not in y + + +def operator_is_not(x: paddle.Tensor, y: paddle.Tensor): + return (operator.is_not(x, x), operator.is_not(x, y)) + + +def operator_pos(y: int): + return operator.pos(+y) + + +class TestExecutor(TestCaseBase): + def test_simple(self): + a = paddle.to_tensor(1) + b = paddle.to_tensor(True) + c = paddle.to_tensor(3) + d = paddle.to_tensor(4) + e = paddle.to_tensor([[1, 2], [3, 4], [5, 6]], dtype='float32') + f = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype='float32') + g = paddle.to_tensor(False) + + self.assert_results(unary_positive, 1) + self.assert_results(unary_negative, a) + self.assert_results(unary_not, b) + self.assert_results(unary_invert, b) + + self.assert_results(binary_power, c, d) + self.assert_results(binary_multiply, c, d) + self.assert_results(binary_matrix_multiply, e, f) + self.assert_results(binary_floor_divide, c, d) + self.assert_results(binary_true_divide, c, d) + self.assert_results(binary_modulo, c, d) + self.assert_results(binary_add, c, d) + self.assert_results(binary_subtract, c, d) + self.assert_results(binary_lshift, 10, 2) + self.assert_results(binary_rshift, 10, 1) + self.assert_results(binary_and, b, g) + self.assert_results(binary_or, b, g) + self.assert_results(binary_xor, b, g) + + self.assert_results(inplace_power, c, d) + self.assert_results(inplace_multiply, c, d) + self.assert_results(inplace_matrix_multiply, e, f) + self.assert_results(inplace_floor_divide, c, d) + self.assert_results(inplace_true_divide, c, d) + self.assert_results(inplace_modulo, c, d) + self.assert_results(inplace_add, c, d) + self.assert_results(inplace_subtract, c, d) + self.assert_results(inplace_lshift, 10, 2) + self.assert_results(inplace_rshift, 10, 1) + self.assert_results(inplace_and, b, g) + self.assert_results(inplace_or, b, g) + self.assert_results(inplace_xor, b, g) + + def test_operator_simple(self): + self.assert_results(operator_add, 1, paddle.to_tensor(2)) + self.assert_results(operator_mul, 1, paddle.to_tensor(2)) + self.assert_results(operator_truth, paddle.to_tensor(2)) + self.assert_results( + operator_is_, paddle.to_tensor(2), paddle.to_tensor(3) + ) + self.assert_results( + operator_is_not, paddle.to_tensor(2), paddle.to_tensor(3) + ) + self.assert_results(operator_pos, 1) + self.assert_results(operator_in_, 12, [1, 2, 12]) + self.assert_results(operator_in_, 12, [1, 2, 3]) + self.assert_results(operator_not_in_, 12, [1, 2, 3]) + self.assert_results(operator_not_in_, 12, [1, 2, 3]) + + def test_operator_list(self): + self.assert_results(list_getitem, 1, paddle.to_tensor(2)) + self.assert_results(list_getitem_slice, 1, paddle.to_tensor(2)) + self.assert_results(list_setitem_int, 1, paddle.to_tensor(2)) + self.assert_results_with_side_effects( + list_setitem_tensor, 1, paddle.to_tensor(2) + ) + self.assert_results(list_delitem_int, 1, paddle.to_tensor(2)) + self.assert_results(list_delitem_tensor, 1, paddle.to_tensor(2)) + + def test_operator_dict(self): + self.assert_results(dict_getitem_int, 1, paddle.to_tensor(2)) + self.assert_results(dict_getitem_tensor, 1, paddle.to_tensor(2)) + self.assert_results(dict_setitem_int, 1, paddle.to_tensor(2)) + self.assert_results_with_side_effects( + dict_setitem_tensor, 1, paddle.to_tensor(2) + ) + self.assert_results(dict_delitem_int, 1, paddle.to_tensor(2)) + self.assert_results(dict_delitem_tensor, 1, paddle.to_tensor(2)) + + def test_operator_tuple(self): + self.assert_results(tuple_getitem_int, 1, paddle.to_tensor(2)) + self.assert_results(tuple_getitem_tensor, 1, paddle.to_tensor(2)) + self.assert_results(tuple_getitem_slice, 1, paddle.to_tensor(2)) + + +def run_not_eq(x: paddle.Tensor, y: int): + out = paddle.reshape(x, [1, -1]) != y + out = out.astype('float32') + return out + + +class TestNotEq(TestCaseBase): + def test_not_eq(self): + x = paddle.to_tensor([2]) + y = 3 + self.assert_results(run_not_eq, x, y) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_15_slice.py b/test/sot/test_15_slice.py new file mode 100644 index 0000000000000..b2ee00526f25b --- /dev/null +++ b/test/sot/test_15_slice.py @@ -0,0 +1,137 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# BUILD_SLICE (new) + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot.psdb import check_no_breakgraph + + +def build_list_slice(x: list, y: paddle.Tensor): + x[2:4] = [0, 1] + return x[0] + y + + +def build_list_slice_with_step(x: list, y: paddle.Tensor): + x[1:5:2] = [0, 1] + return x[0] + y + + +def build_tuple_slice(x: list, y: paddle.Tensor): + x[2:4] = (0, 1) + return x[0] + y + + +def build_tuple_slice_with_step(x: list, y: paddle.Tensor): + x[1:5:2] = (0, 1) + return x[0] + y + + +def tensor_subscript_ellipsis(x: paddle.Tensor, y: paddle.Tensor): + return x[...] + y[...] + + +@check_no_breakgraph +def tensor_subscript_tensor(x: paddle.Tensor): + d0, d1 = paddle.shape(x) + return x[: d0 // 2, d1 // 2 : d1] + + +class TestSlice(TestCaseBase): + def test_simple(self): + x = list(range(10)) + y = paddle.arange(10) + self.assert_results_with_side_effects(build_list_slice, x, y) + self.assert_results_with_side_effects(build_list_slice_with_step, x, y) + self.assert_results_with_side_effects(build_tuple_slice, x, y) + self.assert_results_with_side_effects(build_tuple_slice_with_step, x, y) + + +class MyLayer(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.linears = paddle.nn.LayerList( + [paddle.nn.Linear(10, 10) for i in range(10)] + ) + + def forward(self, x): + for i, l in enumerate(self.linears): + x = self.linears[i // 2](x) + l(x) + return x + + +def layer_list_slice(layer, x): + out = layer(x) + return out + + +class TestLayerList(TestCaseBase): + def test_layer_list_slice(self): + layer = MyLayer() + x = paddle.randn([5, 10]) + self.assert_results(layer_list_slice, layer, x) + + +def tensor_slice(x: paddle.Tensor): + return x[1, 1, 1] + 1 + + +class TestTensorSlice(TestCaseBase): + def test_tensor_slice(self): + x = paddle.randn([4, 3, 10]) + self.assert_results(tensor_slice, x) + + +class TestTensorEllipsis(TestCaseBase): + def test_tensor_subscript_ellipsis(self): + x = paddle.rand((10,)) + y = paddle.rand((10, 10)) + self.assert_results(tensor_subscript_ellipsis, x, y) + + +class TestTensorSubscriptTensor(TestCaseBase): + def test_tensor_subscript_tensor(self): + x = paddle.rand((10, 10)) + self.assert_results(tensor_subscript_tensor, x) + + +class LayerListNet(paddle.nn.Layer): + def __init__(self) -> None: + super().__init__() + self.layer_list = paddle.nn.LayerList( + [paddle.nn.Linear(5, 5), paddle.nn.Linear(5, 5)] + ) + + def forward(self, x): + out = self.layer_list[0](x) + for layer in self.layer_list[1:]: + out = layer(out) + return out + + +class TestLayerListSlice(TestCaseBase): + def test_layer_list_slice(self): + x = paddle.randn([2, 5]) + net = LayerListNet() + self.assert_results(layer_list_slice, net, x) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_16_paddle_api.py b/test/sot/test_16_paddle_api.py new file mode 100644 index 0000000000000..9f6e05fa48b2f --- /dev/null +++ b/test/sot/test_16_paddle_api.py @@ -0,0 +1,60 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.nn.functional import relu + + +def paddle_api_method_call(x: paddle.Tensor): + m = x + 2 + m = paddle.nn.functional.relu(m) + return m + + +def paddle_api_function_call(x: paddle.Tensor): + m = x + 2 + m = relu(m) + return m + + +def paddle_api_function_call_concat( + x: paddle.Tensor, y: paddle.Tensor, axis: int +): + return paddle.concat([x, y], axis=axis) + + +class TestPaddleApiCall(TestCaseBase): + def test_paddle_api_method_call(self): + self.assert_results(paddle_api_method_call, paddle.to_tensor(2.0)) + self.assert_results(paddle_api_method_call, paddle.to_tensor(-5.0)) + self.assert_results(paddle_api_method_call, paddle.to_tensor(0.0)) + + def test_paddle_api_function_call(self): + self.assert_results(paddle_api_function_call, paddle.to_tensor(2.0)) + self.assert_results(paddle_api_function_call, paddle.to_tensor(-5.0)) + self.assert_results(paddle_api_function_call, paddle.to_tensor(0.0)) + + def test_paddle_api_function_call_concat(self): + a = paddle.to_tensor([[1, 2], [3, 4]]) + b = paddle.to_tensor([[5, 6], [7, 8]]) + self.assert_results(paddle_api_function_call_concat, a, b, 0) + self.assert_results(paddle_api_function_call_concat, a, b, 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_17_paddle_layer.py b/test/sot/test_17_paddle_layer.py new file mode 100644 index 0000000000000..58b7dfb9fa301 --- /dev/null +++ b/test/sot/test_17_paddle_layer.py @@ -0,0 +1,94 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +class SimpleNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.linear1 = paddle.nn.Linear(10, 1) + + def forward(self, x): + out1 = self.linear1(x) + return out1 + + +class SimpleNet_bound(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.linear1 = paddle.nn.Linear(10, 1) + + def add(self, x): + return x + 1 + + def forward(self, x): + x = self.add(x) + out1 = self.linear1(x) + return out1 + + +def net_call(x: paddle.Tensor, net): + return net(x) + + +def net_call_passed_by_user(x: paddle.Tensor, net_forward): + return net_forward(x) + + +class SimpleNetWithSequenital(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.seq = paddle.nn.Sequential( + paddle.nn.Linear(10, 10), + paddle.nn.Linear(10, 10), + paddle.nn.Linear(10, 1), + ) + + def forward(self, x): + out1 = self.seq(x) + return out1 + + +class TestLayer(TestCaseBase): + def test_layer(self): + x = paddle.rand((10,)) + y = paddle.rand((10, 10)) + net = SimpleNet() + self.assert_results(net_call, x, net) + self.assert_results(net_call, y, net) + self.assert_results(net_call_passed_by_user, x, net.forward) + + def test_layer_with_sequential(self): + x = paddle.rand((10,)) + y = paddle.rand((10, 10)) + net = SimpleNetWithSequenital() + self.assert_results(net_call, x, net) + self.assert_results(net_call, y, net) + self.assert_results(net_call_passed_by_user, x, net.forward) + + def test_bound(self): + x = paddle.rand((10,)) + y = paddle.rand((10, 10)) + net = SimpleNet_bound() + self.assert_results(net_call, x, net) + self.assert_results(net_call, y, net) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_18_tensor_method.py b/test/sot/test_18_tensor_method.py new file mode 100644 index 0000000000000..2591db1f748d9 --- /dev/null +++ b/test/sot/test_18_tensor_method.py @@ -0,0 +1,90 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def tensor_method_call_1(x: paddle.Tensor): + y = x + 1 + return y.mean() + + +def tensor_method_call_2(a: paddle.Tensor, b: paddle.Tensor): + c = a.add(b) + d = c.multiply(a) + e = d.subtract(b) + f = e.divide(a) + g = f.pow(2) + f.abs().sqrt() + h = (g.abs() + 1).log() - (g / g.max()).exp() + i = h.sin() + h.cos() + return i + + +def tensor_method_passed_by_user(a: paddle.Tensor, func: paddle.Tensor): + return func(a) + + +def tensor_method_property(a: paddle.Tensor, b: paddle.Tensor): + return ( + a.name, + str(a.place), + a.persistable, + a.dtype, + a.type, + a.is_tensor(), + a.clear_gradient(), + a @ b.T + len(a.shape) + b.size + a.ndim + a.dim() + a.rank(), + ) + + +def middle_tensor_name(a: paddle.Tensor, b: paddle.Tensor): + c = a + b + return c.name + + +class TestTensorMethod(TestCaseBase): + def test_tensor_method_1(self): + x = paddle.rand([10]) + y = paddle.rand([2, 4, 6]) + self.assert_results(tensor_method_call_1, x) + self.assert_results(tensor_method_call_1, y) + + def test_tensor_method_2(self): + x = paddle.rand([42]) + y = paddle.rand([42]) + self.assert_results(tensor_method_call_2, x, y) + + def test_tensor_method_passed_by_user(self): + x = paddle.rand([42]) + y = paddle.rand([42]) + self.assert_results(tensor_method_passed_by_user, x, y.add) + + def test_tensor_method_property(self): + x = paddle.rand([42, 24], dtype='float64') + y = paddle.rand([42, 24], dtype='float32') + self.assert_results(tensor_method_property, x, y) + + @unittest.skip("TODO: dynamic tensor name is different") + def test_middle_tensor_name(self): + x = paddle.rand([42, 24]) + y = paddle.rand([42, 24]) + self.assert_results(middle_tensor_name, x, y) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_19_closure.py b/test/sot/test_19_closure.py new file mode 100644 index 0000000000000..6191141e07f39 --- /dev/null +++ b/test/sot/test_19_closure.py @@ -0,0 +1,260 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +from test_case_base import TestCaseBase, strict_mode_guard + +import paddle + + +def foo(x: int, y: paddle.Tensor): + z = 3 + + def local(a, b=5): + return a + x + z + b + y + + return local(4) + z + + +def foo2(y: paddle.Tensor, x=1): + """ + Test strip default value + """ + z = 3 + + def local(a, b=5): + return a + x + z + b + y + + return local(4) + + +def foo3(y: paddle.Tensor, x=1): + """ + Test Closure Band Default + """ + z = 3 + + def local(a, b=5): + nonlocal z + z = 4 + return a + x + z + b + y + + return local(4) + + +global_z = 3 + + +def test_global(y: paddle.Tensor): + """ + Test Global variable + """ + + def local(a, b=5): + global global_z + global_z += 1 + return a + global_z + b + y + + return local(1) + + +def multi(c): + return c + 2 + + +def wrapper_function(func): + a = 2 + + def inner(): + return func(a) + + return inner + + +wrapped_multi = wrapper_function(multi) + + +def foo5(y: paddle.Tensor): + """ + Test incoming closures + """ + a = wrapped_multi() + return a + + +def outwrapper(func): + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + +def foo6(y: paddle.Tensor): + """ + Test Decorator + """ + + @outwrapper + def load_1(a, b=5): + return a + b + + return load_1(1) + + +import numpy as np + + +def numpy_sum(m): + """ + Test loop call + + Example: a->b->c->a + """ + a = np.array([1, 2, 3]) + tmp = np.sum(a) + return m + 1 + + +def lambda_closure(x, m): + """ + lambda closure. + """ + + def break_graph_closure(): + print("yes") + return x + m + + return break_graph_closure() + + +# motivated by python builtin decorator +def kwargs_wrapper(func): + sig = inspect.signature(func) + + def inner(*args, **kwargs): + return func(*args, **kwargs) + + inner.__signature__ = sig + return inner + + +@kwargs_wrapper +def func7(a, b): + return a + b + + +def foo7(): + return func7(3, 5) + + +def create_closure(): + x = 1 + + def closure(): + return x + 1 + + return closure + + +class TestExecutor(TestCaseBase): + def test_closure(self): + self.assert_results(foo, 1, paddle.to_tensor(2)) + self.assert_results(foo2, paddle.to_tensor(2)) + self.assert_results(foo3, paddle.to_tensor(2)) + self.assert_results_with_global_check( + test_global, ["global_z"], paddle.to_tensor(2) + ) + self.assert_results(foo5, paddle.to_tensor(2)) + self.assert_results(foo6, paddle.to_tensor(2)) + self.assert_results(numpy_sum, paddle.to_tensor(1)) + with strict_mode_guard(0): + self.assert_results( + lambda_closure, paddle.to_tensor(2), paddle.to_tensor(1) + ) + + +class TestExecutor2(TestCaseBase): + def test_closure(self): + self.assert_results(foo7) + + +# Side Effect. +def test_slice_in_for_loop(x, iter_num=3): + x = paddle.to_tensor(x) + a = [] + # Use `paddle.full` so that static analysis can analyze the type of iter_num is Tensor + iter_num = paddle.full( + shape=[1], fill_value=iter_num, dtype="int32" + ) # TODO(liym27): Delete it if the type of parameter iter_num can be resolved + + for i in range(iter_num): + a.append(x) + + for i in range(iter_num): + a[i] = x + out = a[2] + return out + + +class TestExecutor3(TestCaseBase): + def test_closure(self): + tx = paddle.to_tensor([1.0, 2.0, 3.0]) + # need side effect of list. + # self.assert_results(test_slice_in_for_loop, tx) + + +def non_local_test(t: paddle.Tensor): + a = 1 + + def func1(): + nonlocal a + t = a + a = 2 + return t + + def func2(): + nonlocal a + a = 1 + return a + + t += func1() # add 2 + t += func2() # add 1 + t += a # add 1 + return t + + +class TestExecutor4(TestCaseBase): + def test_closure(self): + tx = paddle.to_tensor([1.0]) + self.assert_results(non_local_test, tx) + + +class TestCreateClosure(TestCaseBase): + def test_create_closure(self): + closure = create_closure() + self.assert_results(closure) + + +if __name__ == "__main__": + unittest.main() + +# Instructions: +# LOAD_CLOSURE +# LOAD_DEREF +# LOAD_CLASSDEREF +# STORE_DEREF +# DELETE_DEREF +# STORE_GLOBAL diff --git a/test/sot/test_20_string.py b/test/sot/test_20_string.py new file mode 100644 index 0000000000000..5e628b795afdd --- /dev/null +++ b/test/sot/test_20_string.py @@ -0,0 +1,83 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot.psdb import assert_true, check_no_breakgraph + + +def string_format(x: paddle.Tensor): + whilespace = 123 + hello_world = f"Hello {whilespace} World" + z = assert_true(hello_world == "Hello 123 World") + hello_world2 = f"Hello {whilespace}{whilespace} World" + z = assert_true(hello_world2 == "Hello 123123 World") + hello_world_lower = "Hello World".lower() + z = assert_true(hello_world_lower == "hello world") + return x + 1 + + +def string_lower(x: paddle.Tensor): + hello_world_lower = "Hello World".lower() + z = assert_true(hello_world_lower == "hello world") + return x + 1 + + +@check_no_breakgraph +def str_startswith(): + s = "Hello World" + a1 = s.startswith("Hello") + a2 = s.startswith("World") + a3 = s.startswith("Hello World") + a4 = s.startswith("Hello World!") + a5 = s.startswith("Hello", 5) + a6 = s.startswith("Hello", 1, 4) + a7 = s.startswith("Hello", 0, 11) + return (a1, a2, a3, a4, a5, a6, a7) + + +@check_no_breakgraph +def str_endswith(): + s = "Hello World" + a1 = s.endswith("Hello") + a2 = s.endswith("World") + a3 = s.endswith("Hello World") + a4 = s.endswith("Hello World!") + a5 = s.endswith("Hello", 5) + a6 = s.endswith("Hello", 0, 4) + a7 = s.endswith("Hello", 1, 11) + return (a1, a2, a3, a4, a5, a6, a7) + + +class TestExecutor(TestCaseBase): + def test_string_format(self): + self.assert_results(string_format, paddle.to_tensor(1)) + + def test_string_lower(self): + self.assert_results(string_lower, paddle.to_tensor(1)) + + def test_str_startswith(self): + self.assert_results(str_startswith) + + def test_str_endswith(self): + self.assert_results(str_endswith) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_21_global.py b/test/sot/test_21_global.py new file mode 100644 index 0000000000000..131f9c7e367f9 --- /dev/null +++ b/test/sot/test_21_global.py @@ -0,0 +1,175 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit import sot + +global_x = 1 +global_y = paddle.to_tensor(2) +global_z = None +global_del_val = 1 +global_dict = {} +global_list = [1, 2] +global_inline = 0 + + +def global_func_int(): + global global_x + global_x = global_x + 1 + return global_x + + +def global_func_int_add(): + global global_x + global_x = global_x + global_x + return global_x + global_x + + +def global_func_tensor_int_add(tensor_y: paddle.Tensor): + global global_x + global_x += 1 + return global_x + tensor_y + + +def global_multiple_update(): + global global_x + global_x = 999 + global_x = 888 + global_x = 777 + return global_x - 1 + + +def global_func_tensor(): + global global_y + global_y = global_y + global_y + return global_y + + +def global_func_tensor_add(): + global global_y + global_y = global_y + global_y + return global_y + global_y + + +def global_func(): + global global_x + global global_y + global global_z + + global_z = global_x + global_y + return global_z + + +def global_del_global(): + global global_del_val + + del global_del_val + + +def global_func_dict(): + global global_dict + global_dict["key"] = "value" + global_dict.update({"test_key1": "test_value2"}) + return global_dict + + +def global_func_control1(): + global global_dict + if "key" in global_dict: + del global_dict["key"] + return global_dict + + +def global_func_control2(): + global global_list + for i in range(len(global_list)): + global_list[i] = global_list[i] + 1 + return global_list + + +def global_func_inline_inner_1(): + global global_inline + global_func_inline_inner_2() + global_inline += 1 + + +def global_func_inline_inner_2(): + global global_inline + global_inline += 1 + + +def global_func_inline(): + global_func_inline_inner_1() + global global_inline + return global_inline + + +class TestGlobal(TestCaseBase): + def test_global_func_int(self): + global global_x + self.assert_results_with_global_check(global_func_int, ["global_x"]) + global_x += 1 + self.assert_results_with_global_check(global_func_int, ["global_x"]) + self.assert_results_with_global_check(global_func_int_add, ["global_x"]) + + def test_global_multiple_update(self): + self.assert_results_with_global_check( + global_multiple_update, ["global_x"] + ) + + def test_global_func_tensor_int_add(self): + self.assert_results_with_global_check( + global_func_tensor_int_add, ["global_x"], paddle.to_tensor(1) + ) + + def test_global_func_tensor(self): + self.assert_results_with_global_check(global_func_tensor, ["global_y"]) + self.assert_results_with_global_check( + global_func_tensor_add, ["global_y"] + ) + + def test_global_func(self): + self.assert_results_with_global_check(global_func, ["global_z"]) + self.assertIn("global_del_val", global_del_global.__globals__) + sot.symbolic_translate(global_del_global)() + self.assertNotIn("global_del_val", global_del_global.__globals__) + + def test_global_func_dict(self): + self.assert_results_with_global_check(global_func_dict, ["global_dict"]) + self.assert_results_with_global_check( + global_func_control1, ["global_dict"] + ) + + def test_global_func_list(self): + self.assert_results_with_global_check( + global_func_control2, ["global_list"] + ) + + def test_global_func_inline(self): + global global_inline + global_inline = 0 + sot.symbolic_translate(global_func_inline)() + self.assertEqual(global_inline, 2) + sot.symbolic_translate(global_func_inline)() + self.assertEqual(global_inline, 4) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_analysis_inputs.py b/test/sot/test_analysis_inputs.py new file mode 100644 index 0000000000000..20b32c2225324 --- /dev/null +++ b/test/sot/test_analysis_inputs.py @@ -0,0 +1,249 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import sys +import unittest + +import paddle +from paddle.jit.sot.opcode_translator.instruction_utils import ( + analysis_inputs, + calc_offset_from_bytecode_offset, + get_instructions, +) + + +def assert_inputs_equals(instruction_offset: int, expected_inputs: set[str]): + current_frame = inspect.currentframe() + assert current_frame is not None + test_frame = current_frame.f_back + assert test_frame is not None + + instructions = get_instructions(test_frame.f_code) + current_instr_idx = calc_offset_from_bytecode_offset( + test_frame.f_lasti + 2, instructions + ) + actual_inputs = analysis_inputs( + instructions, current_instr_idx + instruction_offset + ) + assert ( + set(actual_inputs) == expected_inputs + ), f"actual_inputs: {actual_inputs}, expected_inputs: {expected_inputs}" + + +def case1(x): + m = x + 1 + n = x + 2 + assert_inputs_equals(0, {"x", "n"}) + y = x + 2 + assert_inputs_equals(0, {"n"}) + return n + + +def case2(x): + x = x + 1 + assert_inputs_equals(0, {"x"}) + y = x + 3 + z = x + y + assert_inputs_equals(0, {"x"}) + x += 1 + m = x + 1 + n = x + m + assert_inputs_equals(0, set()) + return 1 + + +def case3(x): + y = x + 1 + + assert_inputs_equals(0, {"x"}) + if x: + z = 1 + else: + z = 2 + return z + + +def case4(x): + y = x + 1 + + assert_inputs_equals(0, {"x", "y"}) + if x: + z = y + else: + z = x + return z + + +def case5(x): + y = x + 1 + z = x + 2 + + assert_inputs_equals(0, {"z"}) + if z: + a = 1 + else: + b = 2 + return z + + +def case6(x): + y = x + 1 + z = x + 2 + + assert_inputs_equals(0, {"a", "z"}) + if z: + a = 1 + else: + a += 1 + return z + + +def case7(x): + y = x + 1 + z = x + 2 + + assert_inputs_equals(0, {"a", "z"}) + if not z: + a += 1 # noqa: F821 + else: + a = 1 + return z + + +def breakgraph_api(x): + return x + + +def normal_api(x): + return x + + +def case8(x): + x = normal_api(x) + assert_inputs_equals(0, {"x"}) + for i in range(10): + x += 1 + if i > 5: + continue + x += 10086 + x += i + return x + + +case9_offset = -9 if sys.version_info >= (3, 11) else -7 + + +def case9(x): + x = breakgraph_api(x) + assert_inputs_equals( + case9_offset, set() + ) # analysis when call breakgraph api (CALL_FUNCTION) + for i in range(10): + x += 1 + if i > 5: + continue + x += 10086 + x += i + return x + + +def case10(x): + assert_inputs_equals(0, {"x", "y"}) + # if x == 0, y will be read before assignment + for i in range(x): + y = i + z = y + + return y + 1 + + +def case11(x): + y = x + 1 + z = x + 2 + + assert_inputs_equals(0, {"a", "y", "z"}) + if z: + if not y: + a += 1 # noqa: F821 + else: + a = 2 + else: + if y: + a = 1 + else: + a += 1 + return z + + +def case12(x): + y = x + 1 + z = x + 2 + + assert_inputs_equals(0, {"a", "y", "z"}) + if z: + if y: + a = 2 + else: + a += 2 + else: + if y: + a += 1 + else: + a = 1 + return z + + +class TestAnalysisInputs(unittest.TestCase): + def test_case1(self): + case1(paddle.to_tensor([1])) + + def test_case2(self): + case2(paddle.to_tensor([2])) + + def test_case3(self): + case3(paddle.to_tensor([3])) + + def test_case4(self): + case4(paddle.to_tensor([4])) + + def test_case5(self): + case5(paddle.to_tensor([5])) + + def test_case6(self): + case6(paddle.to_tensor([6])) + + def test_case7(self): + case7(paddle.to_tensor([7])) + + def test_case8(self): + case8(paddle.to_tensor([8])) + + def test_case9(self): + case9(paddle.to_tensor([9])) + + def test_case10(self): + case10(paddle.to_tensor([10])) + + def test_case11(self): + case11(paddle.to_tensor([11])) + + def test_case12(self): + case12(paddle.to_tensor([12])) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_break_graph.py b/test/sot/test_break_graph.py new file mode 100644 index 0000000000000..532f1c7a4c497 --- /dev/null +++ b/test/sot/test_break_graph.py @@ -0,0 +1,157 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot.utils.paddle_api_config import add_break_graph_apis + + +def ifelse_func(x, y): + if x > 0: + y = y + 1 + else: + y = y + 2 + return y + + +class TestIfElse(TestCaseBase): + def test_simple(self): + x = paddle.to_tensor([1.0]) + y = paddle.to_tensor([2.0]) + self.assert_results(ifelse_func, x, y) + + +def multi_output(x: paddle.Tensor): + m = x + 1 + if x > 0: + return m + else: + return 2 * m + + +class TestExecutor(TestCaseBase): + def test_simple(self): + x = paddle.to_tensor(2) + self.assert_results(multi_output, x) + x = paddle.to_tensor(-2) + self.assert_results(multi_output, x) + + +def print_break_graph(x, y): + z = x + y + print(x, z) + out = y * z * 2 + return out + + +class TestPrint(TestCaseBase): + def test_simple(self): + x = paddle.to_tensor(2) + y = paddle.to_tensor(3) + self.assert_results(print_break_graph, x, y) + + +def to_tensor_break_graph(x, y): + z = x + y + out = y * paddle.to_tensor(2) * z + return out + + +class TestToTensor(TestCaseBase): + def test_simple(self): + add_break_graph_apis([paddle.to_tensor]) + x = paddle.to_tensor(2) + y = paddle.to_tensor(3) + self.assert_results(to_tensor_break_graph, x, y) + + +def tensor_clear_gradient(x): + x = paddle.to_tensor(x) + x.clear_gradient() + return x + + +class TestBreakGraphInResumeFn(TestCaseBase): + def test_simple(self): + x = paddle.to_tensor(2) + self.assert_results(tensor_clear_gradient, x) + + +def inner_fn(a, b, c, d): + return a + b * c - d + + +def multi_stack_args(a, b, c): + out = inner_fn(a, b, c, paddle.to_tensor(4)) + return out + + +class TestMultiStackArgs(TestCaseBase): + def test_simple(self): + a = paddle.to_tensor(1) + b = paddle.to_tensor(2) + c = paddle.to_tensor(3) + self.assert_results(multi_stack_args, a, b, c) + + +def break_graph_in_call_method(x): + out = paddle.nn.functional.relu(paddle.to_tensor([4.0])) + return x + out + + +def numpy_break_graph(): + a = paddle.to_tensor([1, 2]) + b = np.sum(a.numpy()) + print(b) + return b + + +class TestBreakGraphInCallMethod(TestCaseBase): + def test_simple(self): + x = paddle.to_tensor([1.0]) + break_graph_in_call_method(x) + x = paddle.to_tensor([2.0]) + break_graph_in_call_method(x) + + x = paddle.to_tensor([3.0]) + self.assert_results(break_graph_in_call_method, x) + + def test_numpy(self): + self.assert_results(numpy_break_graph) + + +def test_break_graph_repeat(x): + out = paddle.to_tensor( + paddle.to_tensor(paddle.to_tensor(paddle.to_tensor([1.0]))) + ) + return x + out + + +class TestBreakGraphRepeat(TestCaseBase): + def test_simple(self): + x = paddle.to_tensor([1.0]) + test_break_graph_repeat(x) + x = paddle.to_tensor([2.0]) + test_break_graph_repeat(x) + + x = paddle.to_tensor([3.0]) + self.assert_results(test_break_graph_repeat, x) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_builtin_dispatch.py b/test/sot/test_builtin_dispatch.py new file mode 100644 index 0000000000000..e4a1ee5fb2999 --- /dev/null +++ b/test/sot/test_builtin_dispatch.py @@ -0,0 +1,329 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +import operator +import unittest +import weakref + +from test_case_base import ( + TestCaseBase, + test_instruction_translator_cache_context, +) + +import paddle +from paddle.jit.sot.psdb import check_no_breakgraph + + +def dispatch_len(x: paddle.Tensor): + return len(x.shape) + + +def dispatch_tensor_len(x: paddle.Tensor): + return len(x) + + +def dispatch_reversed(x: paddle.Tensor | int, y: paddle.Tensor | int): + return list(reversed([x + 1, y - 1, x * 10, y + 1000])) + + +def dispatch_bool(x: paddle.Tensor): + return operator.truth(x.shape) and bool(x.shape) + + +def dispatch_ceil(x: paddle.Tensor | float): + return math.ceil(x) + 1 + + +def dispatch_floor(x: paddle.Tensor | float): + return math.floor(x) + 1 + + +def test_sum_tuple(x: paddle.Tensor | int, y: paddle.Tensor | int): + return sum((x, y)) + + +def test_sum_tuple2( + x: paddle.Tensor | int | list[int] | list[paddle.Tensor], + y: paddle.Tensor | int | list[int] | list[paddle.Tensor], +): + return sum((x, y), x) + + +def test_sum_tuple3(x): + return sum((), x) + + +def test_sum_list(x: paddle.Tensor | int, y: paddle.Tensor | int): + return sum([x, y]) + + +def test_sum_list2( + x: paddle.Tensor | int | list[int] | list[paddle.Tensor], + y: paddle.Tensor | int | list[int] | list[paddle.Tensor], +): + return sum([x, y], x) + + +def test_sum_list3(x): + return sum([], x) + + +def test_tensor_sum(x: paddle.Tensor): + return sum(x) + + +def test_tensor_sum_api(x: paddle.Tensor): + return x.sum() + + +def test_pow(x: paddle.Tensor | int, y: paddle.Tensor | int): + return pow(x, y) + + +def test_pow2(x: paddle.Tensor | int, y: paddle.Tensor | int): + return pow(x, y, 1) + + +def test_tensor_pow_api(x: paddle.Tensor, y: paddle.Tensor | int): + return x.pow(y) + + +def test_math_pow(x: int, y: int): + return math.pow(x, y) + + +def test_chr(x: int | hex | paddle.Tensor): + return chr(x) + + +def test_ord(x: str): + return ord(x) + + +@check_no_breakgraph +def test_sqrt(x: int): + return math.sqrt(x) + + +class TestBuiltinDispatch(TestCaseBase): + def test_dispatch_len(self): + self.assert_results(dispatch_len, paddle.to_tensor([1, 2, 3])) + + def test_dispatch_bool(self): + self.assert_results(dispatch_bool, paddle.to_tensor([1, 2, 3])) + + def test_dispatch_tensor_len(self): + with test_instruction_translator_cache_context() as ctx: + self.assert_results( + dispatch_tensor_len, paddle.to_tensor([1, 2, 3]) + ) + self.assertEqual(ctx.translate_count, 1) + self.assert_results( + dispatch_tensor_len, paddle.to_tensor([4, 5, 6]) + ) + self.assertEqual(ctx.translate_count, 1) + + def test_dispatch_list_reversed(self): + self.assert_results(dispatch_reversed, paddle.to_tensor(1), 2) + self.assert_results(dispatch_reversed, 2, paddle.to_tensor(1)) + + def test_dispatch_tensor_reversed(self): + self.assert_results( + dispatch_reversed, + paddle.to_tensor([1, 2]), + paddle.to_tensor([3, 4]), + ) + + def test_not_dispatch_tensor_ceil(self): + # ceil should break graph, since it returns a int rather than a tensor + self.assert_results(dispatch_ceil, paddle.to_tensor(1.2)) + + def test_dispatch_float_ceil(self): + self.assert_results(dispatch_ceil, 1.2) + + def test_not_dispatch_tensor_floor(self): + # floor should break graph, since it returns a int rather than a tensor + self.assert_results(dispatch_floor, paddle.to_tensor(1.2)) + + def test_dispatch_float_floor(self): + self.assert_results(dispatch_floor, 1.2) + + def test_dispatch_sum(self): + self.assert_results(test_sum_tuple, 1, 1) + self.assert_results(test_sum_tuple, paddle.to_tensor(1), 1) + self.assert_results( + test_sum_tuple, paddle.to_tensor(1), paddle.to_tensor(1) + ) + self.assert_results( + test_sum_tuple, paddle.to_tensor([1, 2]), paddle.to_tensor(1) + ) + self.assert_results( + test_sum_tuple, paddle.to_tensor([1, 2]), paddle.to_tensor([1, 3]) + ) + self.assert_results(test_sum_tuple2, 1, 1) + self.assert_results(test_sum_tuple2, [1, 2], [3, 4]) + self.assert_results(test_sum_tuple2, paddle.to_tensor(1), 1) + self.assert_results( + test_sum_tuple2, paddle.to_tensor(1), paddle.to_tensor(1) + ) + self.assert_results( + test_sum_tuple2, + [paddle.to_tensor(1), paddle.to_tensor(2)], + [paddle.to_tensor(3), paddle.to_tensor(4)], + ) + self.assert_results( + test_sum_tuple2, paddle.to_tensor([1, 2]), paddle.to_tensor(1) + ) + self.assert_results( + test_sum_tuple2, paddle.to_tensor([1, 2]), paddle.to_tensor([1, 3]) + ) + self.assert_results(test_sum_tuple3, 1) + self.assert_results(test_sum_tuple3, paddle.to_tensor(1)) + self.assert_results(test_sum_list, 1, 1) + self.assert_results(test_sum_list, paddle.to_tensor(1), 1) + self.assert_results( + test_sum_list, paddle.to_tensor(1), paddle.to_tensor(1) + ) + self.assert_results( + test_sum_list, paddle.to_tensor([1, 2]), paddle.to_tensor(1) + ) + self.assert_results( + test_sum_list, paddle.to_tensor([1, 2]), paddle.to_tensor([1, 3]) + ) + self.assert_results(test_sum_list2, 1, 1) + self.assert_results(test_sum_list2, [1, 2], [3, 4]) + self.assert_results(test_sum_list2, paddle.to_tensor(1), 1) + self.assert_results( + test_sum_list2, paddle.to_tensor(1), paddle.to_tensor(1) + ) + self.assert_results( + test_sum_list2, + [paddle.to_tensor(1), paddle.to_tensor(2)], + [paddle.to_tensor(3), paddle.to_tensor(4)], + ) + self.assert_results( + test_sum_list2, paddle.to_tensor([1, 2]), paddle.to_tensor(1) + ) + self.assert_results( + test_sum_list2, paddle.to_tensor([1, 2]), paddle.to_tensor([1, 3]) + ) + self.assert_results(test_sum_list3, 1) + self.assert_results(test_sum_list3, paddle.to_tensor(1)) + self.assert_results(test_tensor_sum, paddle.to_tensor([1, 2])) + self.assert_results(test_tensor_sum, paddle.to_tensor((1, 2))) + self.assert_results(test_tensor_sum_api, paddle.to_tensor([1, 2])) + self.assert_results(test_tensor_sum_api, paddle.to_tensor((1, 2))) + + def test_dispatch_pow(self): + self.assert_results(test_pow, 2, 3) + self.assert_results(test_pow, paddle.to_tensor(2), 3) + self.assert_results(test_pow, paddle.to_tensor(2), paddle.to_tensor(3)) + self.assert_results(test_pow2, 2, 3) + self.assert_results(test_math_pow, 2, 3) + self.assert_results(test_tensor_pow_api, paddle.to_tensor(2), 3) + self.assert_results( + test_tensor_pow_api, paddle.to_tensor(2), paddle.to_tensor(3) + ) + + def test_dispatch_chr(self): + self.assert_results(test_chr, 65) + self.assert_results(test_chr, 0x41) + self.assert_results(test_chr, paddle.to_tensor(65)) + self.assert_results(test_chr, paddle.to_tensor(0x41)) + + def test_dispatch_ord(self): + self.assert_results(test_ord, "a") + + def test_dispatch_sqrt(self): + self.assert_results(test_sqrt, 9) + + +def run_getattr(x: paddle.Tensor): + attr = 'dtype' + out = getattr(x, attr) + return out + + +class TestGetattr(TestCaseBase): + def test_getattr(self): + x = paddle.to_tensor(4) + self.assert_results(run_getattr, x) + + +def tensor_hasattr(x: paddle.Tensor): + return ( + hasattr(x, "dtype"), + hasattr(x, "stop_gradient"), + hasattr(x, "abs"), + hasattr(x, "non_tensor_attr"), + ) + + +class ObjectHasattr: + def __init__(self): + attr1 = 1 + attr2 = "2" + attr3 = [3] + + +def object_hasattr(x: ObjectHasattr): + return ( + hasattr(x, "attr1"), + hasattr(x, "attr2"), + hasattr(x, "attr3"), + hasattr(x, "non_obj_attr"), + ) + + +def layer_hasattr(layer: paddle.nn.Layer): + return ( + hasattr(layer, "parameters"), + hasattr(layer, "sublayers"), + hasattr(layer, "non_layer_attr"), + ) + + +class TestHasattr(TestCaseBase): + def test_tensor_hasattr(self): + x = paddle.to_tensor(4) + self.assert_results(tensor_hasattr, x) + + def test_object_hasattr(self): + x = ObjectHasattr() + self.assert_results(object_hasattr, x) + + def test_layer_hasattr(self): + x = paddle.nn.Layer() + self.assert_results(layer_hasattr, x) + + +class WeakrefableObject: + ... + + +def weakref_breakgraph(obj): + return weakref.ref(obj) + + +class TestWeakref(TestCaseBase): + def test_weakref_breakgraph(self): + obj = WeakrefableObject() + self.assert_results(weakref_breakgraph, obj) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_call_object.py b/test/sot/test_call_object.py new file mode 100644 index 0000000000000..486f3591f4326 --- /dev/null +++ b/test/sot/test_call_object.py @@ -0,0 +1,83 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle + +patched = lambda self, x: x * self.a + +patched2 = lambda self, x: x * self.a + 3 + + +class A: + def __init__(self, a): + self.a = a + + def __call__(self, x): + return self.add(x) + + def add(self, x): + return x + self.a + + multi = patched + + +class B: + def __init__(self, a): + self.a = A(a) + + def __call__(self, x, func): + return getattr(self.a, func)(x) + + def self_call(self, x, func): + return getattr(self.a, func)(self.a, x) + + +def foo_1(a, x): + return a(x) + + +def foo_2(a, x): + return a.multi(x) + + +def foo_3(b, x): + return b(x, "multi") + + +def foo_4(b, x): + return b(x, "add") + + +def foo_5(b, x): + return b.self_call(x, "multi") + + +class TestExecutor(TestCaseBase): + def test_simple(self): + c = B(13) + c.a.multi = patched2 + self.assert_results(foo_1, A(13), paddle.to_tensor(2)) + self.assert_results(foo_2, A(13), paddle.to_tensor(2)) + self.assert_results(foo_3, B(13), paddle.to_tensor(2)) + self.assert_results(foo_4, B(13), paddle.to_tensor(2)) + self.assert_results(foo_5, c, paddle.to_tensor(2)) + self.assert_results(foo_4, c, paddle.to_tensor(2)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_case_base.py b/test/sot/test_case_base.py new file mode 100644 index 0000000000000..03ce3c98227e8 --- /dev/null +++ b/test/sot/test_case_base.py @@ -0,0 +1,158 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import contextlib +import copy +import inspect +import os +import types +import unittest + +import numpy as np + +import paddle +from paddle.jit.sot import symbolic_translate +from paddle.jit.sot.opcode_translator.executor.executor_cache import ( + OpcodeExecutorCache, +) + + +@contextlib.contextmanager +def test_instruction_translator_cache_context(): + cache = OpcodeExecutorCache() + cache.clear() + yield cache + cache.clear() + + +def github_action_error_msg(msg: str): + if 'GITHUB_ACTIONS' in os.environ: + frame = inspect.currentframe() + if frame is not None: + # find the first frame that is in the test folder + while frame.f_back is not None: + filename = frame.f_code.co_filename + if filename.startswith("./"): + filename = f"tests/{filename[2:]}" + lineno = frame.f_lineno + output = f"\n::error file={filename},line={lineno}::{msg}" + return output + frame = frame.f_back + return None + + +class TestCaseBase(unittest.TestCase): + def assertIs(self, x, y, msg=None): + super().assertIs(x, y, msg=msg) + if msg is None: + msg = f"Assert Is, x is {x}, y is {y}" + msg = github_action_error_msg(msg) + if msg is not None: + print(msg) + + def assertEqual(self, x, y, msg=None): + super().assertEqual(x, y, msg=msg) + if msg is None: + msg = f"Assert Equal, x is {x}, y is {y}" + msg = github_action_error_msg(msg) + if msg is not None: + print(msg) + + def assert_nest_match(self, x, y): + cls_x = type(x) + cls_y = type(y) + msg = f"type mismatch, x is {cls_x}, y is {cls_y}" + self.assertIs(cls_x, cls_y, msg=msg) + + container_types = (tuple, list, dict, set) + if cls_x in container_types: + msg = f"length mismatch, x is {len(x)}, y is {len(y)}" + self.assertEqual( + len(x), + len(y), + msg=msg, + ) + if cls_x in (tuple, list): + for x_item, y_item in zip(x, y): + self.assert_nest_match(x_item, y_item) + elif cls_x is dict: + for x_key, y_key in zip(x.keys(), y.keys()): + self.assert_nest_match(x_key, y_key) + self.assert_nest_match(x[x_key], y[y_key]) + elif cls_x is set: + # TODO: Nested set is not supported yet + self.assertEqual(x, y) + elif cls_x in (np.ndarray, paddle.Tensor): + # TODO: support assert_allclose github error log + np.testing.assert_allclose(x, y) + else: + self.assertEqual(x, y) + + def assert_results(self, func, *inputs): + sym_output = symbolic_translate(func)(*inputs) + paddle_output = func(*inputs) + self.assert_nest_match(sym_output, paddle_output) + + def assert_results_with_side_effects(self, func, *inputs): + sym_inputs = copy.deepcopy(inputs) + sym_output = symbolic_translate(func)(*sym_inputs) + paddle_inputs = copy.deepcopy(inputs) + paddle_output = func(*paddle_inputs) + self.assert_nest_match(sym_inputs, paddle_inputs) + self.assert_nest_match(sym_output, paddle_output) + + def assert_results_with_global_check( + self, func, global_keys: list[str], *inputs + ): + def copy_fn(fn): + return types.FunctionType( + code=fn.__code__, + globals=copy.copy(fn.__globals__), + name=fn.__name__, + argdefs=fn.__defaults__, + closure=fn.__closure__, + ) + + sym_copied_fn = copy_fn(func) + sym_fn = symbolic_translate(sym_copied_fn) + paddle_fn = copy_fn(func) + sym_output = sym_fn(*inputs) + paddle_output = paddle_fn(*inputs) + for key in global_keys: + self.assert_nest_match( + sym_copied_fn.__globals__[key], paddle_fn.__globals__[key] + ) + self.assert_nest_match(sym_output, paddle_output) + + +@contextlib.contextmanager +def strict_mode_guard(value): + if "STRICT_MODE" not in os.environ: + os.environ["STRICT_MODE"] = "0" + old_value = os.environ["STRICT_MODE"] + os.environ["STRICT_MODE"] = str(value) + yield + os.environ["STRICT_MODE"] = old_value + + +@contextlib.contextmanager +def cost_model_guard(value): + if "COST_MODEL" not in os.environ: + os.environ["COST_MODEL"] = "True" + old_value = os.environ["COST_MODEL"] + os.environ["COST_MODEL"] = str(value) + yield + os.environ["COST_MODEL"] = old_value diff --git a/test/sot/test_code_status.py b/test/sot/test_code_status.py new file mode 100644 index 0000000000000..9fec5712c2293 --- /dev/null +++ b/test/sot/test_code_status.py @@ -0,0 +1,154 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase, strict_mode_guard + +import paddle +from paddle.jit import sot +from paddle.jit.sot.opcode_translator.skip_files import skip_function +from paddle.jit.sot.utils.code_status import CodeState, CodeStatus + + +class SimpleNet1(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.layers = paddle.nn.LayerList( + [paddle.nn.Linear(10, 10) for _ in range(30)] + ) + + def forward(self, x): + for i in range(len(self.layers)): + sot.psdb.breakgraph() + x = self.layers[i](x) + x = self.layers[i](x) + x = self.layers[i](x) + x = self.layers[i](x) + return x + + +class SimpleNet2(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.layers = paddle.nn.LayerList( + [paddle.nn.Linear(10, 10) for _ in range(30)] + ) + + def forward(self, x): + sot.psdb.fallback() + for i in range(len(self.layers)): + x = self.layers[i](x) + x = self.layers[i](x) + x = self.layers[i](x) + x = self.layers[i](x) + return x + + +def run_net(net, x): + for i in range(20): + x = net(x) + return x + + +class TestCodeInfo(TestCaseBase): + def test_case_1(self): + CodeStatus().clear() + net = SimpleNet1() + inp = paddle.rand((10, 10)) + self.assert_results(run_net, net, inp) + code_map = CodeStatus().code_map + states = [] + for k, v in code_map.items(): + if k.co_name.startswith("#") or k.co_name.startswith("$"): + states.append(v) + elif k in CodeStatus().WITH_GRAPH_API: + assert v.state == CodeState.WITH_GRAPH + else: + assert v.state == CodeState.WITHOUT_GRAPH + # run_net, forward, loop body, resumed part2 in loop body + assert len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 4 + # resumed part1 in loop body + assert ( + len([v for v in states if v.state == CodeState.WITHOUT_GRAPH]) == 1 + ) + + def test_case_2(self): + with strict_mode_guard(0): + CodeStatus().clear() + net = SimpleNet2() + inp = paddle.rand((10, 10)) + self.assert_results(run_net, net, inp) + code_map = CodeStatus().code_map + states = [] + for k, v in code_map.items(): + if k.co_name.startswith("#") or k.co_name.startswith("$"): + states.append(v) + elif k in CodeStatus().WITH_GRAPH_API: + assert v.state == CodeState.WITH_GRAPH + else: + assert v.state == CodeState.WITHOUT_GRAPH + # no graph found because fallback (paddle api will not enter simulate) + assert ( + len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 0 + ) + + +def no_skip_func_0(x): + return x + 1 + + +def skipped_func_0(): + pass + + +def skipped_func_1(x): + return x + 1 + + +def skipped_func_2(x): + return no_skip_func_0(x) + + +def call_skipped_func_0(x): + for i in range(15): + skipped_func_0() + x = skipped_func_1(x) + x = skipped_func_2(x) + return x + + +skip_function(skipped_func_0) +skip_function(skipped_func_1) +skip_function(skipped_func_2) +skip_function(call_skipped_func_0) + + +class TestDisableSkippedFrame(TestCaseBase): + def test_case_0(self): + CodeStatus().clear() + x = paddle.to_tensor([1]) + self.assert_results(call_skipped_func_0, x) + code_map = CodeStatus().code_map + assert ( + code_map[skipped_func_0.__code__].state == CodeState.WITHOUT_GRAPH + ) + assert ( + code_map[skipped_func_1.__code__].state == CodeState.WITHOUT_GRAPH + ) + assert code_map[skipped_func_2.__code__].state == CodeState.WITH_GRAPH + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_constant_graph.py b/test/sot/test_constant_graph.py new file mode 100644 index 0000000000000..970f9f4902413 --- /dev/null +++ b/test/sot/test_constant_graph.py @@ -0,0 +1,54 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# New Supported Instructions: +# BUILD_MAP (new) +# BUILD_CONST_KEY_MAP (new) + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def func_1(format_str, tensor): + str = format_str.format(xx=12) + a = "{xx} = 12".format + ttt = f"{10} = 12" + a(xx=12) + tensor = tensor + 1 + return str, tensor + + +def func_2(format_str, tensor): + str = format_str % 10 + tensor = tensor + 1 + return str, tensor + + +class TestConstantGraph(TestCaseBase): + def test_case_1(self): + x = "{xx} is xx" + tensor = paddle.to_tensor(1) + self.assert_results(func_1, x, tensor) + + def test_case_2(self): + x = "%s is xx" + tensor = paddle.to_tensor(1) + self.assert_results(func_2, x, tensor) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_cost_model.py b/test/sot/test_cost_model.py new file mode 100644 index 0000000000000..07899a03efbfd --- /dev/null +++ b/test/sot/test_cost_model.py @@ -0,0 +1,114 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import unittest + +from test_case_base import TestCaseBase, cost_model_guard + +import paddle +from paddle.jit.sot import psdb, symbolic_translate +from paddle.jit.sot.utils import StepInfoManager, StepState + + +def dyn_fast(x, net, iter_): + for i in iter_: + x = net(x) + return x + + +def sot_fast_with_single_graph(x, net): + if not psdb.in_sot(): + time.sleep(0.1) + return x + 1 + + +def sot_fast_with_multi_graph(x, net): + if not psdb.in_sot(): + time.sleep(0.1) + x = x + 1 + psdb.breakgraph() + x = x + 2 + return x + + +class Net(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.linear = paddle.nn.Linear(10, 10) + + def forward(self, x): + if not psdb.in_sot(): + time.sleep(0.1) + x = x / 3 + x = x + 5 + x = self.linear(x) + return x + + +class TestCostModel(TestCaseBase): + @cost_model_guard("True") + def test_dyn_fast(self): + x = paddle.rand([10]) + net = paddle.nn.Linear(10, 10) + sot_fn = symbolic_translate(dyn_fast) + for i in range(60): + sot_fn(x, net, iter(range(10))) + + state = StepInfoManager().step_record[dyn_fast.__code__].state + assert state == StepState.RUN_DYN + + @cost_model_guard("True") + def test_sot_fast_with_multi_graph(self): + x = paddle.rand([10]) + net = paddle.nn.Linear(10, 10) + sot_fn = symbolic_translate(sot_fast_with_multi_graph) + for i in range(30): + sot_fn(x, net) + + state = ( + StepInfoManager() + .step_record[sot_fast_with_multi_graph.__code__] + .state + ) + assert state == StepState.RUN_SOT + + @cost_model_guard("True") + def test_sot_fast_with_single_graph(self): + x = paddle.rand([10]) + net = paddle.nn.Linear(10, 10) + for i in range(30): + symbolic_translate(sot_fast_with_single_graph)(x, net) + + state = ( + StepInfoManager() + .step_record[sot_fast_with_single_graph.__code__] + .state + ) + assert state == StepState.RUN_SOT + + @cost_model_guard("True") + def test_net(self): + x = paddle.rand([10]) + net = Net() + net = paddle.jit.to_static(net, enable_fallback=True) + for i in range(30): + x = net(x) + + state = StepInfoManager().step_record[Net.forward.__code__].state + assert state == StepState.RUN_SOT + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_delete_fast.py b/test/sot/test_delete_fast.py new file mode 100644 index 0000000000000..9dca7d4ea1b14 --- /dev/null +++ b/test/sot/test_delete_fast.py @@ -0,0 +1,38 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def test_delete_fast(a): + a = a + 2 + t = a * 3 + del t + return a + + +class TestExecutor(TestCaseBase): + def test_simple(self): + a = paddle.to_tensor(1) + self.assert_results(test_delete_fast, a) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_dup_top.py b/test/sot/test_dup_top.py new file mode 100644 index 0000000000000..5cb28a2dc6cea --- /dev/null +++ b/test/sot/test_dup_top.py @@ -0,0 +1,49 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def func_dup_top_1(): + return True == True != False + + +def func_dup_top_2(x): + y = x + 1 + return True == True != False + + +def func_dup_top_two(x: list[paddle.Tensor]): + x[0] += x[1] + return x + + +class TestDupTop(TestCaseBase): + def test_dup_top(self): + self.assert_results(func_dup_top_1) + self.assert_results(func_dup_top_2, paddle.to_tensor(1.0)) + # TODO: fix this after we support side effect + # self.assert_results( + # func_dup_top_two, [paddle.to_tensor(1.0), paddle.to_tensor(2.0)] + # ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_enumerate.py b/test/sot/test_enumerate.py new file mode 100644 index 0000000000000..f81a451da55c9 --- /dev/null +++ b/test/sot/test_enumerate.py @@ -0,0 +1,116 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase, strict_mode_guard + +import paddle + + +def test_enumerate_1(x: int, y: int): + for id, val in enumerate(range(x)): + if id % 2 == 0: + y += val + return y + + +def test_enumerate_2(x: list): + return list(enumerate(x)) + + +def test_enumerate_3(x: list): + return tuple(enumerate(x)) + + +def test_enumerate_4(x: paddle.Tensor): + sum = 0 + for idx, val in enumerate(x): + sum += val + return sum + + +# TODO(zmh): support range for tensor +def test_enumerate_5(x: paddle.Tensor): + sum = 0 + + for idx, val in enumerate(x): + for i in range(val): + sum += val + return sum + + +def test_enumerate_6(x: paddle.Tensor): + sum = 0 + + for idx, val in enumerate(x): + for i in range(idx): + sum += val + return sum + + +def test_enumerate_7(x: paddle.Tensor): + sum = 0 + x = x.flatten() + for idx, val in enumerate(x): + sum += val + return sum + + +# TODO(zmh): support -1 +def test_enumerate_8(x: paddle.Tensor): + sum = 0 + x = paddle.nonzero(x, as_tuple=False) + for idx, val in enumerate(x): + sum += val + return sum + + +def test_enumerate_10(layer_list, x): + sum = 0 + for idx, layer in enumerate(layer_list): + sum += layer(x) + return sum + + +class TestExecutor(TestCaseBase): + def test_cases(self): + x = 8 + y = 5 + ty = paddle.randn((10, 10)) + layer_list = paddle.nn.LayerList( + [paddle.nn.Linear(10, 10) for _ in range(3)] + ) + + self.assert_results(test_enumerate_1, x, y) + self.assert_results(test_enumerate_2, [2, 4, 6, 8, 10]) + self.assert_results(test_enumerate_3, [2, 4, 6, 8, 10]) + + self.assert_results(test_enumerate_4, ty) + # TODO(zmh): support range for tensor + + with strict_mode_guard(0): + self.assert_results(test_enumerate_5, paddle.to_tensor([1, 2, 3])) + self.assert_results(test_enumerate_6, paddle.to_tensor([1, 2, 3])) + self.assert_results(test_enumerate_7, ty) + # TODO(zmh): support -1 + + with strict_mode_guard(0): + self.assert_results(test_enumerate_8, ty) + + self.assert_results(test_enumerate_10, layer_list, paddle.randn((10,))) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_error_handling.py b/test/sot/test_error_handling.py new file mode 100644 index 0000000000000..c74436f0d44f4 --- /dev/null +++ b/test/sot/test_error_handling.py @@ -0,0 +1,39 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase, strict_mode_guard + +from paddle.jit import sot + + +def fn_with_try_except(): + sot.psdb.breakgraph() + sot.psdb.fallback() + try: + raise ValueError("ValueError") + except ValueError: + print("catch ValueError") + return True + + +class TestErrorHandling(TestCaseBase): + @strict_mode_guard(0) + def test_fn_with_try_except(self): + self.assert_results(fn_with_try_except) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_exception.py b/test/sot/test_exception.py new file mode 100644 index 0000000000000..26e0f55044379 --- /dev/null +++ b/test/sot/test_exception.py @@ -0,0 +1,94 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import re +import unittest + +import paddle +from paddle.jit.sot import symbolic_translate + + +def case1(x): + return n # noqa: F821 + + +def case2(x): + x = x + 1 + return x @ x + + +def case3(x): + y = x.undefined_attr + return y + + +def case4_inner(x): + y = x * 2 + print() + y = y + 1 + return y.undefined_attr + + +def case4(x): + return case4_inner(x) + + +def case5_inner3(x): + x += 1 + print(x) + z = x + 1 + return z + + +def case5_inner2(x): + x += 1 + z = case5_inner3(1 / 0) + return z + 1 + + +def case5_inner1(x): + return case5_inner2(x) + + +def case5(x): + y = case5_inner3(x) + return case5_inner1(y) + 1 + + +class TestException(unittest.TestCase): + def catch_error(self, func, inputs, error_lines: int | list[int]): + if isinstance(error_lines, int): + error_lines = [error_lines] + try: + symbolic_translate(func)(inputs) + except Exception as e: + match_results = re.compile(r'File ".*", line (\d+)').findall(str(e)) + match_results = list(map(int, match_results)) + assert ( + match_results == error_lines + ), f"{match_results} is not equal {error_lines}" + + def test_all_case(self): + self.catch_error(case1, paddle.rand([2, 1]), 25) + # TODO: support runtime error, such as x[111], x@x + # self.catch_error(case2, paddle.rand([2, 1]), 30) + self.catch_error(case3, paddle.rand([2, 1]), 34) + self.catch_error(case4, paddle.rand([2, 1]), 42) + self.catch_error(case5, paddle.rand([3, 1]), [68, 63, 58]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_execution_base.py b/test/sot/test_execution_base.py new file mode 100644 index 0000000000000..8c16b89ec4cf1 --- /dev/null +++ b/test/sot/test_execution_base.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot import symbolic_translate +from paddle.static import BuildStrategy + + +def func(x, y): + ret = 2 * x + ret = paddle.nn.functional.relu(ret) + ret = ret + y + return ret + + +def simple(x): + ret = 2 * x + return ret + + +class TestExecutor(TestCaseBase): + def test_simple(self): + x = paddle.to_tensor([1.0]) + y = paddle.to_tensor([2.0]) + self.assert_results(simple, x) + self.assert_results(simple, y) + + +def foo(x): + out = x + 1 + out = out * 2 + out = paddle.nn.functional.relu(out) + return out + + +class TestBackend(TestCaseBase): + def test_backend(self): + x = paddle.randn([2, 3]) + dy_out = foo(x) + sot_out = symbolic_translate( + foo, build_strategy=BuildStrategy(), backend='CINN' + )(x) + self.assert_nest_match(dy_out, sot_out) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_guard_outputs.py b/test/sot/test_guard_outputs.py new file mode 100644 index 0000000000000..c717eb8190e5f --- /dev/null +++ b/test/sot/test_guard_outputs.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import ( + TestCaseBase, + test_instruction_translator_cache_context, +) + +import paddle + + +def non_operator_related_fn(x: int, y: int): + return x + y + + +def partial_non_operator_related_fn(x: paddle.Tensor, y: paddle.Tensor, z: int): + a = x + y + return [a, z + z] + + +def guard_inputs(x: int, y: int, z: int): + return x + y + z + + +class TestGuardOutputs(TestCaseBase): + def test_non_operator_related_fn(self): + with test_instruction_translator_cache_context() as ctx: + self.assert_results(non_operator_related_fn, 1, 2) + self.assertEqual(ctx.translate_count, 1) + self.assert_results(non_operator_related_fn, 3, 4) + self.assertEqual(ctx.translate_count, 2) + + def test_partial_non_operator_related_fn(self): + with test_instruction_translator_cache_context() as ctx: + self.assert_results( + partial_non_operator_related_fn, + paddle.to_tensor(1), + paddle.to_tensor(2), + 3, + ) + self.assertEqual(ctx.translate_count, 1) + self.assert_results( + partial_non_operator_related_fn, + paddle.to_tensor(4), + paddle.to_tensor(5), + 6, + ) + self.assertEqual(ctx.translate_count, 2) + + def test_guard_inputs(self): + with test_instruction_translator_cache_context() as ctx: + self.assert_results(guard_inputs, 1, 2, 3) + self.assertEqual(ctx.translate_count, 1) + self.assert_results(guard_inputs, 0, 2, 3) + self.assertEqual(ctx.translate_count, 2) + self.assert_results(guard_inputs, 1, 0, 3) + self.assertEqual(ctx.translate_count, 3) + self.assert_results(guard_inputs, 1, 2, 0) + self.assertEqual(ctx.translate_count, 4) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_guard_user_defined_fn.py b/test/sot/test_guard_user_defined_fn.py new file mode 100644 index 0000000000000..193164b06f58d --- /dev/null +++ b/test/sot/test_guard_user_defined_fn.py @@ -0,0 +1,88 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import ( + TestCaseBase, + test_instruction_translator_cache_context, +) + +import paddle + + +def test_guard_fn(fn, inp): + if fn is None: + return 0 + else: + return fn(inp) + + +class TestGuardOutputs(TestCaseBase): + def test_non_operator_related_fn(self): + with test_instruction_translator_cache_context() as ctx: + self.assert_results( + test_guard_fn, + paddle.nn.functional.relu, + paddle.to_tensor([1.0, -1.0]), + ) + self.assertEqual(ctx.translate_count, 1) + self.assert_results( + test_guard_fn, + paddle.nn.functional.gelu, + paddle.to_tensor([1.0, -1.0]), + ) + self.assertEqual(ctx.translate_count, 2) + self.assert_results( + test_guard_fn, + paddle.nn.functional.relu, + paddle.to_tensor([-1.0, -1.0]), + ) + self.assertEqual(ctx.translate_count, 2) + self.assert_results( + test_guard_fn, None, paddle.to_tensor([-1.0, -1.0]) + ) + self.assertEqual(ctx.translate_count, 3) + + deleted_cnt = 0 + + class Callable: + def __call__(self, var): + return paddle.nn.functional.relu(var) + + def __del__(self): + nonlocal deleted_cnt + deleted_cnt += 1 + + fn1 = Callable() + fn2 = Callable() + with test_instruction_translator_cache_context() as ctx: + self.assert_results( + test_guard_fn, fn1, paddle.to_tensor([1.0, -1.0]) + ) + self.assertEqual(ctx.translate_count, 1) + self.assert_results( + test_guard_fn, fn2, paddle.to_tensor([1.0, -1.0]) + ) + self.assertEqual(ctx.translate_count, 2) + self.assert_results( + test_guard_fn, fn2, paddle.to_tensor([1.0, -1.0]) + ) + self.assertEqual(ctx.translate_count, 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_inplace_api.py b/test/sot/test_inplace_api.py new file mode 100644 index 0000000000000..767368e9fe7dd --- /dev/null +++ b/test/sot/test_inplace_api.py @@ -0,0 +1,147 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot import symbolic_translate + + +def simple(x, y): + x[0] = 3.0 + z = [y] + y[1] = 5.0 + return x[0] + x[1] + z[0][1] + y[0] + y[1] + + +def inplace_in_if(x, y, z): + if z: + x[0] = 3.0 + z = [y] + y[1] = 5.0 + ret = x[0] + x[1] + z[0][1] + y[0] + y[1] + return ret + else: + return None + + +def inplace_in_if_fallback(x, y, z): + if z > 0: + x[0] = 3.0 + z = [y] + y[1] = 5.0 + ret = x[0] + x[1] + z[0][1] + y[0] + y[1] + return ret + else: + return None + + +def inplace_in_loop(x, y): + ret = 0 + for i in range(10): + x[0] = 1 + z = [y] + y[1] = 2 * i + 1 + ret += x[0] + x[1] + z[0][1] + y[0] + y[1] + return ret + + +def inplace_in_loop_fallback(x, y, it): + ret = 0 + for i in it: + x[0] = 1 + z = [y] + y[1] = 2 * i + 1 + ret += x[0] + x[1] + z[0][1] + y[0] + y[1] + return ret + + +def inplace_case_0(x): + x[:] = 1.0 + return x + + +def inplace_case_1(x): + x[0][0, 0::2] = 1.0 + return x + + +def inplace_case_2(x): + t = x[0] + t[:, 0::2] = t[:, 0::2] * 0 + t[:, 1::2] = t[:, 1::2] + 2 + return x + + +class TestExecutor(TestCaseBase): + def test_case(self): + self.assert_results(inplace_case_0, paddle.randn((1, 4))) + self.assert_results(inplace_case_1, [paddle.randn((1, 4))]) + self.assert_results(inplace_case_2, [paddle.randn((1, 4))]) + + def test_backward(self): + @symbolic_translate + def func(x): + m = x * 2 + n = x * 3 + y = m + y[:] = n + return y + + x = paddle.ones((1, 4)) * 4 + x.stop_gradient = False + y = func(x) + y.sum().backward() + assert (x.grad.numpy() == 3).all() + + def test_simple(self): + self.assert_results( + simple, paddle.to_tensor([1.0, 2.0]), paddle.to_tensor([3.0, 4.0]) + ) + + def test_if(self): + self.assert_results( + inplace_in_if, + paddle.to_tensor([1.0, 2.0]), + paddle.to_tensor([3.0, 4.0]), + True, + ) + self.assert_results( + inplace_in_if_fallback, + paddle.to_tensor([1.0, 2.0]), + paddle.to_tensor([3.0, 4.0]), + paddle.to_tensor(1), + ) + + def test_loop(self): + self.assert_results( + inplace_in_loop, + paddle.to_tensor([1.0, 2.0]), + paddle.to_tensor([3.0, 4.0]), + ) + + a = range(10) + sym_output = symbolic_translate(inplace_in_loop_fallback)( + paddle.to_tensor([1.0, 2.0]), paddle.to_tensor([3.0, 4.0]), iter(a) + ) + paddle_output = inplace_in_loop_fallback( + paddle.to_tensor([1.0, 2.0]), paddle.to_tensor([3.0, 4.0]), iter(a) + ) + self.assert_nest_match(sym_output, paddle_output) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_instruction_translator_cache.py b/test/sot/test_instruction_translator_cache.py new file mode 100644 index 0000000000000..6ee1b33ebbc15 --- /dev/null +++ b/test/sot/test_instruction_translator_cache.py @@ -0,0 +1,165 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import random +import types +import unittest +from unittest.mock import patch + +from test_case_base import ( + TestCaseBase, + test_instruction_translator_cache_context, +) + +from paddle.jit.sot.opcode_translator.custom_code import CustomCode +from paddle.jit.sot.opcode_translator.executor.executor_cache import ( + OpcodeExecutorCache, +) + + +def fake_frames() -> ( + tuple[ + types.FrameType, + types.FrameType, + types.FrameType, + types.FrameType, + types.FrameType, + ] +): + def fake_inner_fn_1(): + frame = inspect.currentframe() + assert frame is not None + return frame + + def fake_inner_fn_2(): + frame = inspect.currentframe() + assert frame is not None + return frame + + def fake_inner_fn_3(): + frame = inspect.currentframe() + assert frame is not None + return frame + + def fake_inner_fn_4(): + frame = inspect.currentframe() + assert frame is not None + return frame + + def fake_inner_fn_5(): + frame = inspect.currentframe() + assert frame is not None + return frame + + return ( + fake_inner_fn_1(), + fake_inner_fn_2(), + fake_inner_fn_3(), + fake_inner_fn_4(), + fake_inner_fn_5(), + ) + + +( + FRAME_1, + FRAME_2, + FRAME_3, + FRAME_4, + FRAME_5, +) = fake_frames() + + +def mock_start_translate(frame: types.FrameType, **kwargs): + translate_map = { + FRAME_1: (CustomCode(FRAME_2.f_code, False), lambda frame: True), + FRAME_3: ( + CustomCode(FRAME_4.f_code, False), + lambda frame: False, + ), # Always re-compile + FRAME_5: (CustomCode(None, False), lambda frame: True), + } + return translate_map[frame] + + +class TestOpcodeExecutorCache(unittest.TestCase): + def reset(self): + global translate_count + translate_count = 0 + OpcodeExecutorCache().clear() + + @patch( + "paddle.jit.sot.opcode_translator.executor.executor_cache.start_translate", + mock_start_translate, + ) + def test_cache_hit(self): + with test_instruction_translator_cache_context() as ctx: + translated_code_1 = OpcodeExecutorCache()(FRAME_1) + assert translated_code_1 is not None + self.assertEqual(translated_code_1.code, FRAME_2.f_code) + self.assertEqual(ctx.translate_count, 1) + # cache hit + translated_code_2 = OpcodeExecutorCache()(FRAME_1) + assert translated_code_2 is not None + self.assertEqual(translated_code_2.code, FRAME_2.f_code) + self.assertEqual(ctx.translate_count, 1) + + @patch( + "paddle.jit.sot.opcode_translator.executor.executor_cache.start_translate", + mock_start_translate, + ) + def test_cache_miss_due_to_unknown_code(self): + with test_instruction_translator_cache_context() as ctx: + translated_code_1 = OpcodeExecutorCache()(FRAME_1) + assert translated_code_1 is not None + self.assertEqual(translated_code_1.code, FRAME_2.f_code) + self.assertEqual(ctx.translate_count, 1) + # cache miss + translated_code_2 = OpcodeExecutorCache()(FRAME_3) + assert translated_code_2 is not None + self.assertEqual(translated_code_2.code, FRAME_4.f_code) + self.assertEqual(ctx.translate_count, 2) + + @patch( + "paddle.jit.sot.opcode_translator.executor.executor_cache.start_translate", + mock_start_translate, + ) + def test_cache_miss_due_to_check_failed(self): + with test_instruction_translator_cache_context() as ctx: + translated_code_1 = OpcodeExecutorCache()(FRAME_3) + assert translated_code_1 is not None + self.assertEqual(translated_code_1.code, FRAME_4.f_code) + self.assertEqual(ctx.translate_count, 1) + # cache miss + translated_code_2 = OpcodeExecutorCache()(FRAME_3) + assert translated_code_2 is not None + self.assertEqual(translated_code_2.code, FRAME_4.f_code) + self.assertEqual(ctx.translate_count, 2) + + +def foo(x): + return x + 1 + + +class TestCacheExceedLimit(TestCaseBase): + def test_cache_exceed_limit(self): + for _ in range(30): + input = random.random() + self.assert_results(foo, input) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/sot/test_map.py b/test/sot/test_map.py new file mode 100644 index 0000000000000..812ab36673be4 --- /dev/null +++ b/test/sot/test_map.py @@ -0,0 +1,124 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from typing import Iterable + +from test_case_base import TestCaseBase, strict_mode_guard + +from paddle.jit import sot +from paddle.jit.sot.psdb import check_no_breakgraph + + +def double_num(num: float | int): + return num * 2 + + +def double_num_with_breakgraph(num: float | int): + sot.psdb.breakgraph() + return num * 2 + + +@check_no_breakgraph +def test_map_list(x: list): + return list(map(double_num, x)) + + +@check_no_breakgraph +def test_map_list_comprehension(x: list): + return [i for i in map(double_num, x)] # noqa: C416 + + +@check_no_breakgraph +def test_map_tuple(x: tuple): + return tuple(map(double_num, x)) + + +@check_no_breakgraph +def test_map_tuple_comprehension(x: tuple): + return [i for i in map(double_num, x)] # noqa: C416 + + +@check_no_breakgraph +def test_map_range(x: Iterable): + return list(map(double_num, x)) + + +@check_no_breakgraph +def test_map_range_comprehension(x: Iterable): + return [i for i in map(double_num, x)] # noqa: C416 + + +def add_dict_prefix(key: str): + return f"dict_{key}" + + +@check_no_breakgraph +def test_map_dict(x: dict): + return list(map(add_dict_prefix, x)) + + +@check_no_breakgraph +def test_map_dict_comprehension(x: dict): + return [i for i in map(add_dict_prefix, x)] # noqa: C416 + + +def test_map_list_with_breakgraph(x: list): + return list(map(double_num_with_breakgraph, x)) + + +@check_no_breakgraph +def test_map_unpack(x: list): + a, b, c, d = map(double_num, x) + return a, b, c, d + + +@check_no_breakgraph +def test_map_for_loop(x: list): + res = 0 + for i in map(double_num, x): + res += i + return res + + +class TestMap(TestCaseBase): + def test_map(self): + self.assert_results(test_map_list, [1, 2, 3, 4]) + self.assert_results(test_map_tuple, (1, 2, 3, 4)) + self.assert_results(test_map_range, range(5)) + self.assert_results(test_map_dict, {"a": 1, "b": 2, "c": 3}) + + def test_map_comprehension(self): + self.assert_results(test_map_list_comprehension, [1, 2, 3, 4]) + self.assert_results(test_map_tuple_comprehension, (1, 2, 3, 4)) + self.assert_results(test_map_range_comprehension, range(5)) + self.assert_results( + test_map_dict_comprehension, {"a": 1, "b": 2, "c": 3} + ) + + def test_map_with_breakgraph(self): + with strict_mode_guard(0): + self.assert_results(test_map_list_with_breakgraph, [1, 2, 3, 4]) + + def test_map_unpack(self): + self.assert_results(test_map_unpack, [1, 2, 3, 4]) + + def test_map_for_loop(self): + self.assert_results(test_map_for_loop, [7, 8, 9, 10]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_multiple_args.py b/test/sot/test_multiple_args.py new file mode 100644 index 0000000000000..7d5bf6b59205c --- /dev/null +++ b/test/sot/test_multiple_args.py @@ -0,0 +1,35 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def foo(x, y): + ret = x + y + return ret + + +class TestMultipleArgs(TestCaseBase): + def test_multiple_args(self): + x = paddle.to_tensor([1.0]) + y = paddle.to_tensor([2.0]) + self.assert_results(foo, x, y) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_mutable_data.py b/test/sot/test_mutable_data.py new file mode 100644 index 0000000000000..2cedee2d8529f --- /dev/null +++ b/test/sot/test_mutable_data.py @@ -0,0 +1,354 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from paddle.jit.sot.opcode_translator.executor.mutable_data import ( + MutableData, + MutableDictLikeData, + MutableListLikeData, +) + + +class VariableBase: + def __init__(self): + ... + + +class ConstVariable(VariableBase): + def __init__(self, value): + self.value = value + + def __repr__(self): + return f"ConstVariable({self.value})" + + def __eq__(self, other): + if not isinstance(other, ConstVariable): + return False + return self.value == other.value + + +class DictVariable(VariableBase): + def __init__(self, data): + self.data = data + self.proxy = MutableDictLikeData(data, DictVariable.proxy_getter) + + @staticmethod + def proxy_getter(proxy, key): + if key not in proxy.original_data: + return MutableData.Empty() + return ConstVariable(proxy.original_data[key]) + + def getitem(self, key): + res = self.proxy.get(key) + if isinstance(res, MutableData.Empty): + raise KeyError(f"Key {key} not found") + return res + + def setitem(self, key, value): + self.proxy.set(key, value) + + def delitem(self, key): + self.proxy.delete(key) + + +class ListVariable(VariableBase): + def __init__(self, data): + self.data = data + self.proxy = MutableListLikeData(data, ListVariable.proxy_getter) + + @staticmethod + def proxy_getter(proxy, key): + if key < 0 or key >= len(proxy.original_data): + return MutableData.Empty() + return ConstVariable(proxy.original_data[key]) + + def getitem(self, key): + if isinstance(key, int): + res = self.proxy.get(key) + if isinstance(res, MutableData.Empty): + raise IndexError(f"Index {key} out of range") + return res + elif isinstance(key, slice): + return self.proxy.get_all()[key] + else: + raise TypeError(f"Invalid key type {type(key)}") + + def __getitem__(self, key): + return self.getitem(key) + + def setitem(self, key, value): + if isinstance(key, int): + self.proxy.set(key, value) + elif isinstance(key, slice): + start, end, step = key.indices(self.proxy.length) + indices = list(range(start, end, step)) + if step == 1: + # replace a continuous range + for i, idx in enumerate(indices): + self.proxy.delete(idx - i) + for i, item in enumerate(value): + self.proxy.insert(start + i, item) + else: + # replace some elements + if len(indices) != len(value): + raise ValueError( + f"Attempt to replace {len(indices)} items with {len(value)}" + ) + for i, idx in enumerate(indices): + self.proxy.set(idx, value[i]) + + def delitem(self, key): + self.proxy.delete(key) + + def insert(self, index, value): + self.proxy.insert(index, value) + + def append(self, value): + self.proxy.insert(self.proxy.length, value) + + def extend(self, value): + for item in value: + self.append(item) + + def pop(self, index=-1): + res = self.getitem(index) + self.delitem(index) + return res + + def clear(self): + for i in range(self.proxy.length): + self.delitem(0) + + def remove(self, value): + for i in range(self.proxy.length): + if self.getitem(i) == value: + self.delitem(i) + return + raise ValueError(f"Value {value} not found") + + def sort(self, key=None, reverse=False): + if key is None: + key = lambda x: x + permutation = list(range(self.proxy.length)) + permutation.sort( + key=lambda x: key(self.getitem(x).value), reverse=reverse + ) + self.proxy.permutate(permutation) + + def reverse(self): + permutation = list(range(self.proxy.length)) + permutation.reverse() + self.proxy.permutate(permutation) + + +class TestMutableDictLikeVariable(unittest.TestCase): + def test_getitem(self): + data = {"a": 1, "b": 2} + var = DictVariable(data) + self.assertEqual(var.getitem("a"), ConstVariable(1)) + self.assertEqual(var.getitem("b"), ConstVariable(2)) + + def test_setitem(self): + data = {"a": 1, "b": 2} + var = DictVariable(data) + var.setitem("a", ConstVariable(3)) + self.assertEqual(var.getitem("a"), ConstVariable(3)) + var.setitem("c", ConstVariable(4)) + self.assertEqual(var.getitem("c"), ConstVariable(4)) + + def test_delitem(self): + data = {"a": 1, "b": 2} + var = DictVariable(data) + var.delitem("a") + with self.assertRaises(KeyError): + var.getitem("a") + + def test_keys(self): + data = {"a": 1, "b": 2} + var = DictVariable(data) + self.assertEqual(list(var.proxy.get_all().keys()), ["a", "b"]) + + +class TestMutableListLikeVariable(unittest.TestCase): + def test_getitem(self): + data = [1, 2, 3] + var = ListVariable(data) + self.assertEqual(var.getitem(0), ConstVariable(1)) + self.assertEqual(var.getitem(1), ConstVariable(2)) + self.assertEqual(var.getitem(2), ConstVariable(3)) + + def test_getitem_slice_1(self): + data = [1, 2, 3, 4, 5, 6, 7] + var = ListVariable(data) + self.assertEqual( + var.getitem(slice(0, 3)), + [ConstVariable(1), ConstVariable(2), ConstVariable(3)], + ) + self.assertEqual( + var.getitem(slice(4, 1, -1)), + [ConstVariable(5), ConstVariable(4), ConstVariable(3)], + ) + self.assertEqual( + var.getitem(slice(1, 5, 2)), + [ConstVariable(2), ConstVariable(4)], + ) + + def test_getitem_slice_2(self): + data = [1, 2, 3, 4, 5, 6, 7] + var = ListVariable(data) + self.assertEqual( + var[0:3], + [ConstVariable(1), ConstVariable(2), ConstVariable(3)], + ) + self.assertEqual( + var[4:1:-1], + [ConstVariable(5), ConstVariable(4), ConstVariable(3)], + ) + self.assertEqual( + var[1:5:2], + [ConstVariable(2), ConstVariable(4)], + ) + + def test_setitem(self): + data = [1, 2, 3] + var = ListVariable(data) + var.setitem(0, ConstVariable(4)) + self.assertEqual(var.getitem(0), ConstVariable(4)) + var.append(ConstVariable(5)) + self.assertEqual(var.getitem(3), ConstVariable(5)) + + def test_setitem_slice_1(self): + data = [1, 2, 3, 4, 5, 6, 7] + var = ListVariable(data) + var.setitem(slice(0, 3), [ConstVariable(4), ConstVariable(5)]) + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [4, 5, 4, 5, 6, 7]], + ) + var.setitem( + slice(4, 1, -1), + [ConstVariable(8), ConstVariable(9), ConstVariable(10)], + ) + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [4, 5, 10, 9, 8, 7]], + ) + + def test_setitem_slice_2(self): + data = [1, 2, 3, 4, 5, 6, 7] + var = ListVariable(data) + var.setitem(slice(2, 5, 2), [ConstVariable(8), ConstVariable(9)]) + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [1, 2, 8, 4, 9, 6, 7]], + ) + + def test_delitem(self): + data = [1, 2, 3] + var = ListVariable(data) + var.delitem(0) + with self.assertRaises(IndexError): + var.getitem(2) + var.pop() + with self.assertRaises(IndexError): + var.getitem(1) + + def test_insert(self): + data = [1, 2, 3] + var = ListVariable(data) + var.insert(0, ConstVariable(4)) + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [4, 1, 2, 3]], + ) + var.insert(2, ConstVariable(5)) + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [4, 1, 5, 2, 3]], + ) + + def test_append(self): + data = [1, 2, 3] + var = ListVariable(data) + var.append(ConstVariable(4)) + self.assertEqual(var.getitem(3), ConstVariable(4)) + + def test_extend(self): + data = [1, 2, 3] + var = ListVariable(data) + var.extend([ConstVariable(4), ConstVariable(5)]) + self.assertEqual(var.getitem(3), ConstVariable(4)) + self.assertEqual(var.getitem(4), ConstVariable(5)) + + def test_pop(self): + data = [1, 2, 3] + var = ListVariable(data) + self.assertEqual(var.pop(), ConstVariable(3)) + self.assertEqual(var.pop(0), ConstVariable(1)) + + def test_clear(self): + data = [1, 2, 3] + var = ListVariable(data) + var.clear() + self.assertEqual(var.proxy.length, 0) + + def test_remove(self): + data = [1, 2, 3] + var = ListVariable(data) + var.remove(ConstVariable(2)) + self.assertEqual(var.getitem(0), ConstVariable(1)) + self.assertEqual(var.getitem(1), ConstVariable(3)) + with self.assertRaises(ValueError): + var.remove(ConstVariable(2)) + + def test_sort(self): + data = [2, 3, 0, 4, 1, 5] + var = ListVariable(data) + var.sort() + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [0, 1, 2, 3, 4, 5]], + ) + + def test_sort_with_key(self): + data = [-1, -4, 2, 0, 5, -3] + var = ListVariable(data) + var.sort(key=lambda x: x**2) + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [0, -1, 2, -3, -4, 5]], + ) + + def test_sort_reverse(self): + data = [2, 3, 0, 4, 1, 5] + var = ListVariable(data) + var.sort(reverse=True) + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [5, 4, 3, 2, 1, 0]], + ) + + def test_reverse(self): + data = [2, 3, 0, 4, 1, 5] + var = ListVariable(data) + var.reverse() + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [5, 1, 4, 0, 3, 2]], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_numpy.py b/test/sot/test_numpy.py new file mode 100644 index 0000000000000..3600d4df7cc45 --- /dev/null +++ b/test/sot/test_numpy.py @@ -0,0 +1,44 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from test_case_base import TestCaseBase, strict_mode_guard + +import paddle + + +def foo(x, y): + ret = x + y + return ret + + +class TestNumpy(TestCaseBase): + def test_tensor_add_numpy_number(self): + x = paddle.to_tensor([1.0]) + y = np.int64(2) + self.assert_results(foo, x, y) + self.assert_results(foo, y, x) + + @strict_mode_guard(0) + def test_tensor_add_numpy_array(self): + x = paddle.to_tensor([1.0]) + y = np.array(2.0) + self.assert_results(foo, x, y) + self.assert_results(foo, y, x) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_numpy_var_if.py b/test/sot/test_numpy_var_if.py new file mode 100644 index 0000000000000..9d7c4a7048e25 --- /dev/null +++ b/test/sot/test_numpy_var_if.py @@ -0,0 +1,53 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import numpy as np +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot.psdb import check_no_breakgraph, check_no_fallback + +os.environ['MIN_GRAPH_SIZE'] = '-1' + + +@check_no_breakgraph +@check_no_fallback +def forward(x, y): + if x == 0: + return y + 2 + else: + return y * 2 + + +@check_no_breakgraph +@check_no_fallback +def forward2(x, y): + if x == x: # numpy == numpy + return y + 2 + else: + return y * 2 + + +class TestJumpWithNumpy(TestCaseBase): + def test_jump(self): + self.assert_results(forward, np.array([1]), paddle.to_tensor(2)) + self.assert_results(forward, np.array([0]), paddle.to_tensor(2)) + self.assert_results(forward2, np.array([0]), paddle.to_tensor(2)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_output_restoration.py b/test/sot/test_output_restoration.py new file mode 100644 index 0000000000000..9c2cf268e9087 --- /dev/null +++ b/test/sot/test_output_restoration.py @@ -0,0 +1,95 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def output_identity(x): + return x + + +def output_const(): + return 42 + + +def output_list(x: paddle.Tensor, y: paddle.Tensor, z: int): + a = x + 1 + b = z + 1 + l = [1, a, b, y] + return l + + +def output_dict(x: paddle.Tensor, y: paddle.Tensor, z: int): + a = x + 1 + b = z + 1 + l = {1: a, b: y} + return l + + +def output_dict_const_key(x: paddle.Tensor, y: paddle.Tensor, z: int): + a = x + 1 + b = z + 1 + l = {1: a, 2: y} + return l + + +def output_nest_struct(x: paddle.Tensor, y: paddle.Tensor, z: int): + a = x + y + z + b = z + 1 + l = [1 + 1, (z, a), [b]] + return l + + +class TestOutputRestoration(TestCaseBase): + def test_output_identity(self): + self.assert_results(output_identity, 1) + self.assert_results(output_identity, 2) + self.assert_results(output_identity, paddle.to_tensor(1)) + + def test_output_const(self): + self.assert_results(output_const) + + def test_output_list(self): + a = paddle.to_tensor(1) + b = paddle.to_tensor(2) + + self.assert_results(output_list, a, b, 3) + + def test_output_dict(self): + a = paddle.to_tensor(1) + b = paddle.to_tensor(2) + + self.assert_results(output_dict, a, b, 3) + + def test_output_dict_const_key(self): + a = paddle.to_tensor(2) + b = paddle.to_tensor(3) + + self.assert_results(output_dict_const_key, a, b, 4) + + def test_output_nest_struct(self): + a = paddle.to_tensor(1) + b = paddle.to_tensor(2) + + self.assert_results(output_nest_struct, a, b, 3) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_range.py b/test/sot/test_range.py new file mode 100644 index 0000000000000..3a7e85fb0951d --- /dev/null +++ b/test/sot/test_range.py @@ -0,0 +1,92 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def test_range_1(stop: int): + return range(stop) + + +def test_range_2(start: int, stop: int): + return range(start, stop) + + +def test_range_3(start: int, stop: int, step: int): + return range(start, stop, step) + + +def test_range_4(stop: int, index: int): + return range(stop)[index] + + +def test_range_5(stop: int): + return list(range(stop)) + + +def test_range_6(stop: int, index: int): + return list(range(stop))[index] + + +def test_range_7(index: int, tensor: paddle.Tensor): + return list(range(len(tensor.shape)))[index] + + +def test_range_8(stop: int): + sum = 0 + for i in range(stop): + sum += i + return sum + + +def test_range_9(stop: int, tensor: paddle.Tensor): + for i in range(stop): + tensor += i + return tensor + + +def test_range_10(stop: int, tensor: paddle.Tensor): + for i in range(stop): + for j in range(stop + 1): + tensor += j + return tensor + + +class TestExecutor(TestCaseBase): + def test_cases(self): + start = 3 + stop = 10 + step = 2 + index = 1 + tensor = paddle.randn((10, 10)) + + self.assert_results(test_range_1, stop) + self.assert_results(test_range_2, start, stop) + self.assert_results(test_range_3, start, stop, step) + self.assert_results(test_range_4, stop, index) + self.assert_results(test_range_5, stop) + self.assert_results(test_range_6, stop, index) + self.assert_results(test_range_7, index, tensor) + self.assert_results(test_range_8, stop) + + self.assert_results(test_range_9, stop, paddle.randn((10,))) + self.assert_results(test_range_10, stop, paddle.randn((10,))) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_resnet.py b/test/sot/test_resnet.py new file mode 100644 index 0000000000000..cc9a47252c559 --- /dev/null +++ b/test/sot/test_resnet.py @@ -0,0 +1,59 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import ( + TestCaseBase, + test_instruction_translator_cache_context, +) + +import paddle +from paddle.vision.models.resnet import resnet18 + + +def resnet_call(x: paddle.Tensor, net: paddle.nn.Layer): + return net(x) + + +class TestResNet(TestCaseBase): + def test_resnet_eval(self): + x = paddle.rand((10, 3, 224, 224)) + net = resnet18(pretrained=False) + net.eval() + with test_instruction_translator_cache_context() as ctx: + self.assert_results(resnet_call, x, net) + self.assertEqual(ctx.translate_count, 1) + self.assert_results(resnet_call, x, net) # cache hit + self.assertEqual(ctx.translate_count, 1) + net.train() + self.assert_results(resnet_call, x, net) # cache miss + self.assertEqual(ctx.translate_count, 2) + + def test_resnet_train(self): + x = paddle.rand((10, 3, 224, 224)) + net = resnet18(pretrained=False) + net.train() + with test_instruction_translator_cache_context() as ctx: + self.assert_results(resnet_call, x, net) + self.assertEqual(ctx.translate_count, 1) + self.assert_results(resnet_call, x, net) # cache hit + self.assertEqual(ctx.translate_count, 1) + net.eval() + self.assert_results(resnet_call, x, net) # cache miss + self.assertEqual(ctx.translate_count, 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_resnet50_backward.py b/test/sot/test_resnet50_backward.py new file mode 100644 index 0000000000000..bd5aac0025e80 --- /dev/null +++ b/test/sot/test_resnet50_backward.py @@ -0,0 +1,107 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +os.environ["FLAGS_cudnn_deterministic"] = "True" + +import random +import unittest + +import numpy as np +from numpy.testing import assert_array_equal + +import paddle +from paddle.jit.sot import symbolic_translate +from paddle.jit.sot.utils.utils import execute_time +from paddle.vision import resnet50 + + +def resnet_call(net: paddle.nn.Layer, x: paddle.Tensor): + return net(x) + + +def run_dygraph_optimizer(inp): + """dygraph train + SGD optimizer""" + paddle.seed(2021) + np.random.seed(2021) + random.seed(2021) + net = resnet50() + optimizer = paddle.optimizer.SGD( + learning_rate=0.03, parameters=net.parameters() + ) + for i in range(5): + optimizer.clear_grad() + loss = execute_time(net)(inp) + loss.backward() + optimizer.step() + return loss + + +def run_symbolic_optimizer(inp): + """dygraph train + SGD optimizer""" + paddle.seed(2021) + np.random.seed(2021) + random.seed(2021) + net = resnet50() + net_wrapper = symbolic_translate(resnet_call) + optimizer = paddle.optimizer.SGD( + learning_rate=0.03, parameters=net.parameters() + ) + for i in range(5): + optimizer.clear_grad() + loss = execute_time(net_wrapper)(net, inp) + loss.backward() + optimizer.step() + return loss + + +def run_to_static_optimizer(inp): + """dygraph train + SGD optimizer""" + paddle.seed(2021) + np.random.seed(2021) + random.seed(2021) + net = resnet50() + net = paddle.jit.to_static(net, enable_fallback=False) + optimizer = paddle.optimizer.SGD( + learning_rate=0.03, parameters=net.parameters() + ) + for i in range(5): + optimizer.clear_grad() + loss = execute_time(net)(inp) + loss.backward() + optimizer.step() + return loss + + +class TestBackward(unittest.TestCase): + def test(self): + # TODO(xiongkun) add cache to speedup ! + paddle.seed(2021) + np.random.seed(2021) + random.seed(2021) + inp = paddle.rand((3, 3, 255, 255)) + print("Start Run SymbolicTranslate:") + out2 = run_symbolic_optimizer(inp)[0].numpy() + print("Start Run Dygraph:") + out1 = run_dygraph_optimizer(inp)[0].numpy() + print("Start Run To Static:") + out1 = run_to_static_optimizer(inp)[0].numpy() + assert_array_equal( + out1, out2, "Not Equal in dygraph and static graph", True + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_segment_linear.py b/test/sot/test_segment_linear.py new file mode 100644 index 0000000000000..ee3b7d70f8d36 --- /dev/null +++ b/test/sot/test_segment_linear.py @@ -0,0 +1,71 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle import nn +from paddle.jit import sot + + +class Head(nn.Layer): + def __init__(self): + super().__init__() + self.head = nn.Linear(10, 150) + + def forward(self, x, patch_embed_size): + masks = self.head(x) + # [b, (h w), c] -> [b, c, h, w] + h, w = patch_embed_size[0], patch_embed_size[1] + masks = masks.reshape((1, h, w, paddle.shape(masks)[-1])) + masks = masks.transpose((0, 3, 1, 2)) + return masks + + +class SimpleNet(nn.Layer): + def __init__(self): + super().__init__() + self.tmp = nn.Linear(1, 1024 * 10) + self.tmp2 = nn.Linear(1, 1 * 10 * 32 * 32) + self.head = Head() + + def getshape(self, x): + x = self.tmp2(x.mean().reshape([1])).reshape([1, 10, 32, 32]) + x = paddle.shape(x) + return x + + def forward(self, x): + shape = self.getshape(x) + feat = self.tmp(x.mean().reshape([1])).reshape([1, 1024, 10]) + logits = self.head(feat, shape[2:]) + return logits + + +class TestExecutor(TestCaseBase): + def test_simple(self): + sot.skip_function(SimpleNet.forward) + x = paddle.randn((1, 8, 8)) + net = SimpleNet() + net = paddle.jit.to_static( + net + ) # dont make effect. we need fetch sot PR in paddle. + loss = net(x) + loss = loss.sum() + loss.backward() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_side_effects.py b/test/sot/test_side_effects.py new file mode 100644 index 0000000000000..46bed6e8d3c4e --- /dev/null +++ b/test/sot/test_side_effects.py @@ -0,0 +1,333 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase, strict_mode_guard + +import paddle +from paddle.jit import sot +from paddle.jit.sot import symbolic_translate +from paddle.jit.sot.utils import InnerError + + +def dict_setitem(x): + x[0] = 1 + return x[0] + + +def dict_delitem(x): + del x[0] + return x + + +def dict_delitem_getitem(a): + b = a[0] + del a[0] + b[0] = 1 + return a, b + + +def dict_nested_1(x): + x[0][0] = 42 + x[1][0] = x[0][0] + x[0][1] + x[2] = {1: 2} + return x + + +def dict_nested_2(x): + a = x[0] + b = x[1] + del a[0] + a[1] = b[0] + a[2] = b[1] + x[1][0] = 42 + del a[1] + return a, b + + +def list_append_int(tensor_x, list_a): + tensor_x = tensor_x + 1 + list_a.append(12) + return tensor_x, list_a + + +def list_append_tensor(tensor_x, list_a): + tensor_x = tensor_x + 1 + list_a.append(tensor_x) + return tensor_x, list_a + + +def list_delitem(list_a): + del list_a[0] + return list_a[0] + + +def list_extend(list_a): + list_a.extend([1, 2, 3]) + return list_a[0] + + +def list_nested(list_a): + inner_list = [] + inner_list.append(list_a) + inner_list[-1].append(12) + return 12 + + +def list_insert(list_a): + list_a.insert(0, 1) + return list_a[0] + + +def list_remove(list_a): + list_a.remove(1) + return list_a[0] + + +def list_pop(list_a): + list_a.pop(0) + list_a.pop() + list_a.pop(1) + return list_a[0] + + +def list_clear(list_a): + list_a.clear() + return list_a + + +def list_sort(list_a): + list_a.sort() + return list_a + + +def list_reverse(list_a): + list_a.reverse() + return list_a + + +def slice_in_for_loop(x, iter_num=3): + x = paddle.to_tensor(x) + a = [] + + iter_num = paddle.full(shape=[1], fill_value=iter_num, dtype="int32") + + for i in range(iter_num): + a.append(x) + + for i in range(iter_num): + a[i] = x + out = a[2] + return out + + +# TODO: Object SideEffect +class CustomObject: + def __init__(self): + self.x = 2 + self.y = paddle.to_tensor(1) + + def object_attr_set2(self, x): + self.outputs = [] + self.outputs.append(x) + return self.outputs + + +@sot.psdb.check_no_breakgraph +def object_attr_set(cus_obj, t): + """object side effect.""" + t = t + 1 + cus_obj.x = t + return t, cus_obj.x + + +def object_attr_breakgraph(cus_obj, t): + t = t + 1 + sot.psdb.breakgraph() + cus_obj.x = t + sot.psdb.breakgraph() + return t, cus_obj.x + + +@sot.psdb.check_no_breakgraph +def object_attr_tensor_del(cus_obj): + del cus_obj.y + + +@sot.psdb.check_no_breakgraph +def object_attr_int_del(cus_obj): + del cus_obj.x + + +def slice_list_after_change(l): + l.reverse() + sum = 0 + for i, v in zip(range(2), l[2:]): + sum += v + return sum + + +class TestDictSideEffect(TestCaseBase): + def test_dict_setitem(self): + self.assert_results_with_side_effects( + dict_setitem, {0: paddle.to_tensor(0)} + ) + self.assert_results_with_side_effects( + dict_setitem, {0: paddle.to_tensor(1)} + ) + + def test_dict_delitem(self): + self.assert_results_with_side_effects( + dict_delitem, {0: paddle.to_tensor(0), 1: paddle.to_tensor(1)} + ) + self.assert_results_with_side_effects( + dict_delitem, {0: paddle.to_tensor(1), 2: paddle.to_tensor(2)} + ) + + def test_dict_delitem_getitem(self): + self.assert_results_with_side_effects( + dict_delitem_getitem, {0: {0: 1, 1: 2}} + ) + + def test_dict_nested_1(self): + self.assert_results_with_side_effects( + dict_nested_1, {0: {0: 1, 1: 2}, 1: {0: 1, 1: 2}} + ) + self.assert_results_with_side_effects( + dict_nested_1, {0: {0: 123, 1: 2}, 1: {0: 1, 1: 2}} + ) + + def test_dict_nested_2(self): + self.assert_results_with_side_effects( + dict_nested_2, {0: {0: 1, 1: 2}, 1: {0: 1, 1: 2}} + ) + self.assert_results_with_side_effects( + dict_nested_2, {0: {0: 123, 1: 2}, 1: {0: 1, 1: 2}} + ) + + +class TestListSideEffect(TestCaseBase): + def test_list_append(self): + self.assert_results_with_side_effects( + list_append_int, paddle.to_tensor(1), [1, 2, 3] + ) + self.assert_results_with_side_effects( + list_append_tensor, paddle.to_tensor(2), [1, 2, 3] + ) + + def test_list_delitem(self): + self.assert_results_with_side_effects(list_delitem, [1, 2, 3]) + + def test_list_extend(self): + self.assert_results_with_side_effects( + list_extend, [1, 2, 3, 4, 5, 6, 7, 8, 9] + ) + + def test_list_insert(self): + self.assert_results_with_side_effects(list_insert, [1, 2, 3]) + self.assert_results_with_side_effects( + list_insert, [-1, 2, -3, 4, -5, 6, -7, 8, -9] + ) + + def test_list_remove(self): + self.assert_results_with_side_effects(list_remove, [1, 1, 1]) + self.assert_results_with_side_effects(list_remove, [0, 1, 2]) + with self.assertRaises(InnerError): + symbolic_translate(list_remove)([0, 2, 4]) + + def test_list_pop(self): + self.assert_results_with_side_effects(list_pop, [1, 2, 3, 4, 5]) + self.assert_results_with_side_effects( + list_pop, [-1, 2, -3, 4, -5, 6, -7, 8, -9] + ) + + def test_list_clear(self): + self.assert_results_with_side_effects(list_clear, [1, 2, 3, 4, 5]) + self.assert_results_with_side_effects( + list_clear, [-1, 2, -3, 4, -5, 6, -7, 8, -9] + ) + + def test_list_sort(self): + self.assert_results_with_side_effects(list_sort, [2, 1, 7, 3, 4, 6]) + self.assert_results_with_side_effects( + list_sort, [-1, 2, -3, 4, -5, 6, -7, 8, -9] + ) + + def test_list_reverse(self): + self.assert_results_with_side_effects(list_reverse, [1, 2, 3, 4, 5]) + self.assert_results_with_side_effects( + list_reverse, [-1, 2, -3, 4, -5, 6, -7, 8, -9] + ) + + def test_slice_in_for_loop(self): + x = 2 + with strict_mode_guard(0): + self.assert_results_with_side_effects(slice_in_for_loop, x) + + def test_list_nested(self): + self.assert_results_with_side_effects(list_nested, [1, 2, 3]) + + +class TestSliceAfterChange(TestCaseBase): + def test_slice_list_after_change(self): + self.assert_results_with_side_effects( + slice_list_after_change, [1, 2, 3, 4] + ) + self.assert_results_with_side_effects( + slice_list_after_change, [7, 8, 9, 10] + ) + + +class TestAttrSideEffect(TestCaseBase): + def attr_check(self, func, attr_keys: list[str], cls, *inputs): + cus_obj1 = cls() + cus_obj2 = cls() + sym_output = symbolic_translate(func)(cus_obj1, *inputs) + paddle_output = func(cus_obj2, *inputs) + for key in attr_keys: + self.assert_nest_match( + getattr(cus_obj1, key, f"__MISS_KEY__{key}"), + getattr(cus_obj2, key, f"__MISS_KEY__{key}"), + ) + self.assert_nest_match(sym_output, paddle_output) + + def test_attr_set(self): + self.attr_check(object_attr_set, ["x"], CustomObject, 5) + self.attr_check( + CustomObject.object_attr_set2, ["outputs"], CustomObject, 6 + ) + self.attr_check( + CustomObject.object_attr_set2, + ["outputs"], + CustomObject, + paddle.to_tensor(5), + ) + self.attr_check( + object_attr_set, ["x"], CustomObject, paddle.to_tensor(5) + ) + + def test_attr_del(self): + self.attr_check(object_attr_tensor_del, ["y"], CustomObject) + self.attr_check(object_attr_int_del, ["x"], CustomObject) + + def test_attr_set_breakgraph(self): + self.attr_check(object_attr_breakgraph, ["x"], CustomObject, 100) + self.attr_check(object_attr_breakgraph, ["x"], CustomObject, 1000) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_simulate_initialize.py b/test/sot/test_simulate_initialize.py new file mode 100644 index 0000000000000..495e06ac1dbda --- /dev/null +++ b/test/sot/test_simulate_initialize.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle import nn +from paddle.jit.sot import symbolic_translate + + +class A: + def __init__(self, vals): + vals.append(1) + + +def foo(x, y): + out = nn.Softmax()(paddle.to_tensor([x, y], dtype="float32")) + return out + + +def bar(x): + a = A(x) + t = paddle.to_tensor(x) + return t.mean() + + +class TestInit(TestCaseBase): + def test_init_paddle_layer(self): + self.assert_results(foo, 1, 2) + + def test_init_python_object(self): + sot_output = symbolic_translate(bar)([1.0, 2.0]) + dyn_output = bar([1.0, 2.0]) + self.assert_nest_match(sot_output, dyn_output) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_sir_rollback.py b/test/sot/test_sir_rollback.py new file mode 100644 index 0000000000000..ddb7792651e4d --- /dev/null +++ b/test/sot/test_sir_rollback.py @@ -0,0 +1,88 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import operator +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot.opcode_translator.executor.function_graph import ( + FunctionGraph, +) +from paddle.jit.sot.opcode_translator.executor.tracker import ( + DanglingTracker, + LocalTracker, +) +from paddle.jit.sot.opcode_translator.executor.variables import ( + BuiltinVariable, + VariableFactory, +) + + +def compute(x, y): + ret = BuiltinVariable(operator.add, x.graph, DanglingTracker())(x, y) + return BuiltinVariable(operator.mul, x.graph, DanglingTracker())(ret, x) + + +def try_add(x, y): + return BuiltinVariable(operator.add, x.graph, DanglingTracker())(x, y) + + +class TestRollback(TestCaseBase): + def test_rollback(self): + frame = inspect.currentframe() + graph = FunctionGraph(frame) + a = paddle.to_tensor(1.0) + b = paddle.to_tensor(2.0) + a = VariableFactory().from_value(a, graph, LocalTracker("a")) + b = VariableFactory().from_value(b, graph, LocalTracker("b")) + out = compute(a, b) + original_length = len(graph.sir_ctx.TOS.statements) + memo = graph.save_memo() + try_add(out, out) + + assert len(graph.sir_ctx.TOS.statements) != len( + memo.stmt_ir.statements + ), "After add, we must statement IR." + graph.restore_memo(memo) + + assert len(graph.sir_ctx.TOS.statements) == original_length + + +def fn_with_side_effects_inner(x, y): + x[0] += 10 + x[1] += 20 + x[2] -= 10 + print(y) # print will cause breakgraph + + +def fn_with_side_effects(x, y): + x[0] += 1 + fn_with_side_effects_inner(x, y) + return x[0] + y + + +class TestSideEffectRollback(TestCaseBase): + def test_side_effect_rollback(self): + self.assert_results_with_side_effects( + fn_with_side_effects, [1, 2, 3], paddle.to_tensor(42) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_stack.py b/test/sot/test_stack.py new file mode 100644 index 0000000000000..e29610b2c837c --- /dev/null +++ b/test/sot/test_stack.py @@ -0,0 +1,56 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from paddle.jit.sot.opcode_translator.executor.variable_stack import ( + VariableStack, +) + + +class TestVariableStack(unittest.TestCase): + def test_basic(self): + stack = VariableStack([1, 2, 3]) + self.assertEqual(str(stack), "[1, 2, 3]") + self.assertEqual(len(stack), 3) + self.assertEqual(str(stack.copy()), str(stack)) + + def test_peek(self): + stack = VariableStack([1, 2, 3]) + self.assertEqual(stack.peek(), 3) + self.assertEqual(stack.top, 3) + self.assertEqual(stack.peek(1), 3) + stack.peek[1] = 4 + stack.peek[2] = 3 + self.assertEqual(stack.peek[1], 4) + self.assertEqual(stack.peek[:1], [4]) + self.assertEqual(stack.peek[:2], [3, 4]) + stack.top = 5 + self.assertEqual(stack.peek[:2], [3, 5]) + + def test_push_pop(self): + stack = VariableStack() + stack.push(1) + stack.push(2) + self.assertEqual(stack.pop(), 2) + self.assertEqual(stack.pop(), 1) + + def test_pop_n(self): + stack = VariableStack([1, 2, 3, 4]) + self.assertEqual(stack.pop_n(2), [3, 4]) + self.assertEqual(stack.pop_n(2), [1, 2]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_str_format.py b/test/sot/test_str_format.py new file mode 100644 index 0000000000000..34bbd6e31f3dd --- /dev/null +++ b/test/sot/test_str_format.py @@ -0,0 +1,37 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + + +# copy from python library _distutils_hack/__init__.py +def find_spec(self, fullname, path, target=None): + method_name = 'spec_for_{fullname}'.format( + **{'self': self, 'fullname': fullname} + ) + method = getattr(self, method_name, lambda: None) + return method() + + +class TestExecutor(TestCaseBase): + def test_simple(self): + self.assert_results(find_spec, "self", "fullname", "path", None) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_tensor_dtype_in_guard.py b/test/sot/test_tensor_dtype_in_guard.py new file mode 100644 index 0000000000000..d5d001b7038d0 --- /dev/null +++ b/test/sot/test_tensor_dtype_in_guard.py @@ -0,0 +1,76 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import ( + TestCaseBase, + test_instruction_translator_cache_context, +) + +import paddle +from paddle.jit import sot + + +def foo(x, y): + if x.dtype == paddle.float32: + out = x + y + else: + out = x - y + return out + + +@sot.skip_function +def dtype_in_guard(x, y): + with paddle.amp.auto_cast(level='O2'): + for i in range(10): + z = foo(x, y) + x = z + return x + + +def bar(x, y): + if x == paddle.float32: + return y + 1 + else: + return y - 1 + + +@sot.skip_function +def dtype_as_input(x, y): + with paddle.amp.auto_cast(level='O2'): + for i in range(10): + z = bar(x, y) + y = z + return y + + +class TestDtypeInGuard(TestCaseBase): + def test_dtype_in_guard(self): + with test_instruction_translator_cache_context() as ctx: + x = paddle.to_tensor([2], dtype="float32") + y = paddle.to_tensor([3], dtype="float32") + self.assert_results(dtype_in_guard, x, y) + self.assertEqual(ctx.translate_count, 1) + + def test_input_dtype_in_guard(self): + with test_instruction_translator_cache_context() as ctx: + x = paddle.float32 + y = paddle.to_tensor([3], dtype="float32") + self.assert_results(dtype_as_input, x, y) + self.assertEqual(ctx.translate_count, 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_tensor_slice.py b/test/sot/test_tensor_slice.py new file mode 100644 index 0000000000000..32c52759da438 --- /dev/null +++ b/test/sot/test_tensor_slice.py @@ -0,0 +1,33 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def foo(x: paddle.Tensor): + return x[:, 0] + + +class TestExecutor(TestCaseBase): + def test_tensor_slice(self): + x = paddle.randn((10, 10)) + self.assert_results(foo, x) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_trace_list_arg.py b/test/sot/test_trace_list_arg.py new file mode 100644 index 0000000000000..8a82406a11f75 --- /dev/null +++ b/test/sot/test_trace_list_arg.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import ( + TestCaseBase, + test_instruction_translator_cache_context, +) + +import paddle + + +def foo(x: list[paddle.Tensor], y: list[paddle.Tensor]): + return x[0] + y[0] + + +def bar(x: list[paddle.Tensor], y: int, z: int): + return x[y + z] + 1 + + +class TestTraceListArg(TestCaseBase): + def test_foo(self): + a = paddle.to_tensor(1) + b = paddle.to_tensor(2) + c = paddle.to_tensor([3, 4]) + + with test_instruction_translator_cache_context() as cache: + self.assert_results(foo, [a], [b]) + self.assertEqual(cache.translate_count, 1) + self.assert_results(foo, [b], [a]) # Cache hit + self.assertEqual(cache.translate_count, 1) + self.assert_results(foo, [a], [c]) # Cache miss + self.assertEqual(cache.translate_count, 2) + + def test_bar(self): + a = [paddle.to_tensor(1), paddle.to_tensor(2), paddle.to_tensor(3)] + b = [paddle.to_tensor([2, 3]), paddle.to_tensor(4), paddle.to_tensor(5)] + + with test_instruction_translator_cache_context() as cache: + self.assert_results(bar, a, 1, 1) + self.assertEqual(cache.translate_count, 1) + self.assert_results(bar, a, 2, 0) # Cache miss + self.assertEqual(cache.translate_count, 2) + self.assert_results(bar, b, 1, 1) # Cache hit + self.assertEqual(cache.translate_count, 2) + + +if __name__ == "__main__": + unittest.main()