From 19ccd7a8c944500e4f8ed953f6c6a37cc01f3b32 Mon Sep 17 00:00:00 2001 From: Reinerzhou <1768552509@qq.com> Date: Fri, 8 Mar 2024 06:57:17 +0000 Subject: [PATCH 01/17] support ascendgraph. --- dicp/dicp/dynamo_bridge/compile_fx.py | 5 +- dicp/dicp/dynamo_bridge/decompositions.py | 38 +++++ dicp/dicp/dynamo_bridge/graph.py | 6 +- dicp/dicp/vendor/AscendGraph/ascend_op.py | 8 + .../dicp/vendor/AscendGraph/codegen/ascend.py | 19 ++- .../AscendGraph/codegen/fusion_switch.cfg | 6 +- .../vendor/AscendGraph/codegen/graph_utils.h | 36 ++++- .../AscendGraph/codegen/load_and_run.py | 150 +++++++++--------- dicp/dicp/vendor/AscendGraph/compile_job.py | 6 +- dicp/dicp/vendor/AscendGraph/config.py | 16 +- dicp/dicp/vendor/AscendGraph/conversion.py | 95 ++++++----- .../vendor/AscendGraph/pattern_replacement.py | 34 +++- dicp/dicp/vendor/AscendGraph/torch_ext.py | 27 ++++ 13 files changed, 315 insertions(+), 131 deletions(-) create mode 100644 dicp/dicp/dynamo_bridge/decompositions.py create mode 100644 dicp/dicp/vendor/AscendGraph/torch_ext.py diff --git a/dicp/dicp/dynamo_bridge/compile_fx.py b/dicp/dicp/dynamo_bridge/compile_fx.py index 5eaee327c..118244517 100644 --- a/dicp/dicp/dynamo_bridge/compile_fx.py +++ b/dicp/dicp/dynamo_bridge/compile_fx.py @@ -208,7 +208,8 @@ def compile_fx_210( def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference): if is_inference: # partition_fn won't be called - joint_graph_passes(model) + # joint_graph_passes(model) + pass fixed = len(example_inputs) - num_example_inputs return inner_compile( @@ -223,7 +224,7 @@ def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference): inference_compiler = functools.partial(fw_compiler_base, is_inference=True) def partition_fn(graph, joint_inputs, **kwargs): - joint_graph_passes(graph) + # joint_graph_passes(graph) return min_cut_rematerialization_partition( graph, joint_inputs, **kwargs, compiler="inductor" ) diff --git a/dicp/dicp/dynamo_bridge/decompositions.py b/dicp/dicp/dynamo_bridge/decompositions.py new file mode 100644 index 000000000..f821abbcc --- /dev/null +++ b/dicp/dicp/dynamo_bridge/decompositions.py @@ -0,0 +1,38 @@ +from collections import defaultdict +from typing import Callable, Dict, Sequence, Union + +import torch +from torch._decomp import register_decomposition +from torch._ops import OpOverload, OpOverloadPacket + +dicp_decomposition_table = {} +aten = torch.ops.aten + + +def register_decomposition_for_dicp(fn): + return register_decomposition(fn, registry=dicp_decomposition_table) + + +@register_decomposition_for_dicp(aten.count_nonzero.default) +def count_nonzero_default(x, dim=None): + cond = x != 0 + dim = [] if dim is None else dim + return aten.sum.dim_IntList(cond, dim=dim, keepdim=False, dtype=torch.int64) + + +def get_decompositions( + aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]], + target_decomposition_table: Dict[OpOverload, Callable] = None, +) -> Dict[OpOverload, Callable]: + registry = dicp_decomposition_table + packets_to_overloads = defaultdict(list) + for opo in registry: + packets_to_overloads[opo.overloadpacket].append(opo) + decompositions = target_decomposition_table if target_decomposition_table else {} + for op in aten_ops: + if isinstance(op, OpOverloadPacket) and op in packets_to_overloads: + for op_overload in packets_to_overloads[op]: + decompositions[op_overload] = registry[op_overload] + elif isinstance(op, OpOverload) and op in registry: + decompositions[op] = registry[op] + return decompositions diff --git a/dicp/dicp/dynamo_bridge/graph.py b/dicp/dicp/dynamo_bridge/graph.py index 2bfdc0b2b..852657a83 100644 --- a/dicp/dicp/dynamo_bridge/graph.py +++ b/dicp/dicp/dynamo_bridge/graph.py @@ -46,7 +46,11 @@ def make_tensor_meta(x) -> Optional[TensorMetadata]: for n in self.gm.graph.nodes: fake_value = None if n.op == 'call_function': - fake_value = (n.target(*n.args, **n.kwargs)) + try: + fake_value = (n.target(*n.args, **n.kwargs)) + except Exception as e: + import pdb;pdb.set_trace() + pass elif n.op == 'get_attr': target_atoms = n.target.split('.') attr_itr = self.gm diff --git a/dicp/dicp/vendor/AscendGraph/ascend_op.py b/dicp/dicp/vendor/AscendGraph/ascend_op.py index 9d129e638..c8af38e48 100644 --- a/dicp/dicp/vendor/AscendGraph/ascend_op.py +++ b/dicp/dicp/vendor/AscendGraph/ascend_op.py @@ -290,6 +290,14 @@ def infer_result(self, x, dims, keepdim): return reduce_op_infer(x, dims, keepdim) +class ReduceSum(Operator): + def __init__(self): + super().__init__("ReduceSum") + + def infer_result(self, x, dims, keepdim): + return reduce_op_infer(x, dims, keepdim) + + class Unsqueeze(Operator): def __init__(self): super().__init__("Unsqueeze") diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 9b9fc24f4..db82369d6 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -63,6 +63,9 @@ def __init__(self, graph, aten_graph=None, folder=None, graph_key=None): self.folder = folder self.graph_key = graph_key + # aten_graph.print_readable() + # graph.print_readable() + self.sym_to_inputs = {} self.sym_in_args = {} @@ -77,7 +80,7 @@ def placeholder(self, name, target, args, kwargs): self.input_args.append(self.cur_node) fake_tensor = self.cur_node.meta['val'] - format = "NCHW" + format = "ND" index = -1 if isinstance(fake_tensor, torch.SymInt): @@ -551,6 +554,7 @@ def set_input(self, name, value): "name": name, "value": value, }) + def set_output_desc(self, name, shape, format, data_type): self.outputs.append({ @@ -956,6 +960,14 @@ def SoftmaxV2(name, x, dim): op.set_attr_list_int("axes", dim) return op.to_node() + @staticmethod + def ReduceSum(name, x, axes, keep_dims): + op = OP(name, "ReduceSum") + op.set_input("x", x) + op.set_input("axes", axes) + op.set_attr_bool("keep_dims", keep_dims) + return op.to_node() + @staticmethod def ReduceSumD(name, x, axes, keep_dims): op = OP(name, "ReduceSumD") @@ -1047,7 +1059,10 @@ def CastToCpu(name, x, ascend_dtype, device=None): def Const(name, x, dtype, dims=None, format="ND"): if not isinstance(x, list): x = [x] - assert len(x) > 0 + if len(x) <= 0: + import pdb;pdb.set_trace() + pass + # assert len(x) > 0 ascend_dtype = get_ascend_dtype(dtype) cpp_dtype = get_cpp_dtype(dtype) const_op = OP(name, "Const") diff --git a/dicp/dicp/vendor/AscendGraph/codegen/fusion_switch.cfg b/dicp/dicp/vendor/AscendGraph/codegen/fusion_switch.cfg index 71834659c..4a699c252 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/fusion_switch.cfg +++ b/dicp/dicp/vendor/AscendGraph/codegen/fusion_switch.cfg @@ -1,10 +1,12 @@ { "Switch":{ "GraphFusion":{ - "ALL":"on" + "IncreFlashAttentionQuantDeployPass": "on", + "RefreshInt64ToInt32FusionPass": "on", + "ALL":"off" }, "UBFusion":{ - "ALL":"on" + "ALL":"off" } } } diff --git a/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h b/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h index 25f3a1629..4f45b3b88 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h +++ b/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h @@ -32,7 +32,7 @@ using json = nlohmann::json; using namespace ge; static std::unordered_set op_with_dynamic_inputs_outputs = { - "ConcatD", "IdentityN", "Pack", "SplitD"}; + "ConcatD", "IdentityN", "Pack", "SplitD", "IncreFlashAttention"}; void check_op(std::unordered_map& op_map, const std::string& op_name) { @@ -91,7 +91,7 @@ class AclgraphBuilder { {AscendString(ge::ir_option::SOC_VERSION), AscendString(kSocVersion)}, {AscendString(ge::ir_option::FUSION_SWITCH_FILE), AscendString(_fusion_switch_file.c_str())}, - {AscendString(ge::ir_option::PRECISION_MODE), "allow_fp32_to_fp16"}, + // {AscendString(ge::ir_option::PRECISION_MODE), "allow_fp32_to_fp16"}, }; auto status = aclgrphBuildInitialize(global_options); if (status != GRAPH_SUCCESS) { @@ -188,6 +188,34 @@ void parseDynamicInput(std::unordered_map& op_map, } } +void parseIncreFlashAttentionDynamicInput(std::unordered_map& op_map, + op::IncreFlashAttention& op, const json& node) { + if (node.contains("dynamic_inputs")) { + for (const auto& i : node["dynamic_inputs"]) { + auto num = i["num"].get(); + auto name = i["name"].get(); + if (name == "key" || name == "value") { + if (name == "key") { + op.create_dynamic_input_key(num); + } else { + op.create_dynamic_input_value(num); + } + for (const auto& item : i["value"]) { + auto index = item["index"].get(); + auto value = op_map[item["value"].get()]; + if (name == "key") { + op.set_dynamic_input_key(index, value); + } else { + op.set_dynamic_input_value(index, value); + } + } + } else { + throw std::runtime_error("invalid dynamic input name"); + } + } + } +} + template void parseDynamicOutput(T& op, const json& node) { if (node.contains("dynamic_outputs")) { @@ -220,6 +248,10 @@ ge::Operator genDynamicOperator( auto op = genDynamicOp(op_name); parseDynamicInput(op_map, op, node); return op; + } else if (op_type == "IncreFlashAttention") { + auto op = genDynamicOp(op_name); + parseIncreFlashAttentionDynamicInput(op_map, op, node); + return op; } else if (op_type == "SplitD") { auto op = genDynamicOp(op_name); parseDynamicOutput(op, node); diff --git a/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py b/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py index cdf31ddf6..4c2e7f00a 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py @@ -121,7 +121,7 @@ def __init__(self): def init_work_weight_ptr(self): if self.work_ptr is None: - self.work_size = 18 * 1024 * 1024 * 1024 + self.work_size = 26 * 1024 * 1024 * 1024 self.work_ptr, ret = acl.rt.malloc(self.work_size, ACL_MEM_MALLOC_HUGE_FIRST) check_ret("acl.rt.malloc", ret) @@ -178,56 +178,62 @@ def release_resource(self): self.weight_ptr = None def load_model(self): - work_size, weight_size, ret = acl.mdl.query_size(self.model_path) - check_ret("acl.mdl.query_size", ret) - if work_size == 0: - work_size = memory_pool.work_size - elif work_size > memory_pool.work_size: - free, _, ret = acl.rt.get_mem_info(ACL_HBM_MEM) - check_ret("acl.rt.get_mem_info", ret) - # If free < work_size, means that memory is insufficient. - # Just ignore and continue, it may be work. - if free > work_size: - memory_pool.work_size = work_size - memory_pool.release_memory() - print("Adjust memory pool allocation.") - memory_pool.work_ptr, ret = acl.rt.malloc(work_size, - ACL_MEM_MALLOC_HUGE_FIRST) - check_ret("acl.rt.malloc", ret) - - self.weight_ptr, ret = acl.rt.malloc(weight_size, - ACL_MEM_MALLOC_HUGE_FIRST) - check_ret("acl.rt.malloc", ret) - config_handle = acl.mdl.create_config_handle() - ret = acl.mdl.set_config_opt(config_handle, ACL_MDL_LOAD_TYPE_SIZET, 2) - check_ret("set_config_opt", ret) - - ret = acl.mdl.set_config_opt( - config_handle, ACL_MDL_PATH_PTR, self.model_path) - check_ret("set_config_opt", ret) - - ret = acl.mdl.set_config_opt( - config_handle, ACL_MDL_WEIGHT_ADDR_PTR, self.weight_ptr) - check_ret("set_config_opt", ret) - - ret = acl.mdl.set_config_opt( - config_handle, ACL_MDL_WEIGHT_SIZET, weight_size) - check_ret("set_config_opt", ret) - - ret = acl.mdl.set_config_opt( - config_handle, ACL_MDL_WORKSPACE_ADDR_PTR, memory_pool.work_ptr) - check_ret("set_config_opt", ret) - - ret = acl.mdl.set_config_opt( - config_handle, ACL_MDL_WORKSPACE_SIZET, memory_pool.work_size) - check_ret("set_config_opt", ret) - - ret = acl.mdl.set_config_opt( - config_handle, ACL_MDL_WORKSPACE_MEM_OPTIMIZE, 1) - check_ret("set_config_opt", ret) - - self.model_id, ret = acl.mdl.load_with_config(config_handle) - check_ret("acl.mdl.load_with_config", ret) + # work_size, weight_size, ret = acl.mdl.query_size(self.model_path) + # check_ret("acl.mdl.query_size", ret) + # if work_size == 0: + # work_size = memory_pool.work_size + # elif work_size > memory_pool.work_size: + # free, _, ret = acl.rt.get_mem_info(ACL_HBM_MEM) + # check_ret("acl.rt.get_mem_info", ret) + # # If free < work_size, means that memory is insufficient. + # # Just ignore and continue, it may be work. + # if free > work_size: + # memory_pool.work_size = work_size + # memory_pool.release_memory() + # import pdb;pdb.set_trace() + # print("Adjust memory pool allocation.") + # memory_pool.work_ptr, ret = acl.rt.malloc(work_size, + # ACL_MEM_MALLOC_HUGE_FIRST) + # check_ret("acl.rt.malloc", ret) + + # self.weight_ptr, ret = acl.rt.malloc(weight_size, + # ACL_MEM_MALLOC_HUGE_FIRST) + # check_ret("acl.rt.malloc", ret) + # config_handle = acl.mdl.create_config_handle() + # ret = acl.mdl.set_config_opt(config_handle, ACL_MDL_LOAD_TYPE_SIZET, 2) + # check_ret("set_config_opt", ret) + + # ret = acl.mdl.set_config_opt( + # config_handle, ACL_MDL_PATH_PTR, self.model_path) + # check_ret("set_config_opt", ret) + + # ret = acl.mdl.set_config_opt( + # config_handle, ACL_MDL_WEIGHT_ADDR_PTR, self.weight_ptr) + # check_ret("set_config_opt", ret) + + # ret = acl.mdl.set_config_opt( + # config_handle, ACL_MDL_WEIGHT_SIZET, weight_size) + # check_ret("set_config_opt", ret) + + # ret = acl.mdl.set_config_opt( + # config_handle, ACL_MDL_WORKSPACE_ADDR_PTR, memory_pool.work_ptr) + # check_ret("set_config_opt", ret) + + # ret = acl.mdl.set_config_opt( + # config_handle, ACL_MDL_WORKSPACE_SIZET, memory_pool.work_size) + # check_ret("set_config_opt", ret) + + # ret = acl.mdl.set_config_opt( + # config_handle, ACL_MDL_WORKSPACE_MEM_OPTIMIZE, 1) + # check_ret("set_config_opt", ret) + + # self.model_id, ret = acl.mdl.load_with_config(config_handle) + # check_ret("acl.mdl.load_with_config", ret) + # print("model_id:{}".format(self.model_id)) + + + self.model_id, ret = acl.mdl.load_from_file(self.model_path) + check_ret("acl.mdl.load_from_file", ret) print("model_id:{}".format(self.model_id)) self.model_desc = acl.mdl.create_desc() @@ -336,29 +342,31 @@ def _prepare_dynamic_output(self, output_tensor, output_shape, out_stride, out_s self.output_data_buffers[i], item.data_ptr(), self.output_size[i]) check_ret("acl.update_data_buffer", ret) - @record_function('load_and_run_run') + # @record_function(f'load_and_run_run') def run(self, images, dims=None, output_shape=None, out_stride=None, out_storage_offset=None, allocated_output=None): - assert len(images) > 0 - input = [x.to(dipu_device_str) if isinstance(x, torch.Tensor) - and x.device.type != dipu_device_str else x for x in images] - allocated_output_tensor = None - if allocated_output: - allocated_output_tensor = {} - for output_index, input_index in allocated_output.items(): - allocated_output_tensor[output_index] = input[input_index] - self._prepare_input(input, dims) - output = [] - if output_shape: - self._prepare_dynamic_output( - output, output_shape, out_stride, out_storage_offset, allocated_output_tensor) - else: - self._prepare_output( - output, output_shape, out_stride, out_storage_offset, allocated_output_tensor) - self.forward() - self._destroy_databuffer() - return output + # print('### load_and_run: model_id:', self.model_id) + with record_function(f'load_and_run_run_{self.model_id}'): + assert len(images) > 0 + input = [x.to(dipu_device_str) if isinstance(x, torch.Tensor) + and x.device.type != dipu_device_str else x for x in images] + allocated_output_tensor = None + if allocated_output: + allocated_output_tensor = {} + for output_index, input_index in allocated_output.items(): + allocated_output_tensor[output_index] = input[input_index] + self._prepare_input(input, dims) + output = [] + if output_shape: + self._prepare_dynamic_output( + output, output_shape, out_stride, out_storage_offset, allocated_output_tensor) + else: + self._prepare_output( + output, output_shape, out_stride, out_storage_offset, allocated_output_tensor) + self.forward() + self._destroy_databuffer() + return output @record_function('load_and_run_forward') def forward(self): diff --git a/dicp/dicp/vendor/AscendGraph/compile_job.py b/dicp/dicp/vendor/AscendGraph/compile_job.py index 8c35fef4e..4eb1379fd 100644 --- a/dicp/dicp/vendor/AscendGraph/compile_job.py +++ b/dicp/dicp/vendor/AscendGraph/compile_job.py @@ -29,7 +29,6 @@ def __init__(self, source_code) -> None: 'local_rank' + str(self._local_rank) + code_hash(compile_file_code) ) self._output_graph_path = self._input_path[:-5] + '/graph' - print('output_path: ', self._output_graph_path) self._model_path = [f'{self._output_graph_path}.om', f'{self._output_graph_path}_linux_x86_64.om'] self._lib_path = "/tmp/dicp_ascend/graph_compile" @@ -66,6 +65,7 @@ def _compile(self): os.system("mkdir -p /tmp/dicp_ascend") start = time.time() try: + print(' '.join(self._cmd)) subprocess.check_output(self._cmd, stderr=subprocess.STDOUT) except subprocess.CalledProcessError as e: raise exc.CppCompileError(self._cmd, e.output) from e @@ -78,7 +78,9 @@ def build_graph(self, output_path, graph_path): self._compile() cmd = [self._lib_path, output_path, graph_path, self.fusion_switch_file] try: - subprocess.check_output(cmd, stderr=subprocess.STDOUT) + print(' '.join(cmd)) + out = subprocess.check_output(cmd, stderr=subprocess.STDOUT) + print(out.decode('utf-8')) except subprocess.CalledProcessError as e: raise exc.CppCompileError(cmd, e.output) from e diff --git a/dicp/dicp/vendor/AscendGraph/config.py b/dicp/dicp/vendor/AscendGraph/config.py index d4bb64fa3..528fb4cab 100644 --- a/dicp/dicp/vendor/AscendGraph/config.py +++ b/dicp/dicp/vendor/AscendGraph/config.py @@ -1,14 +1,16 @@ -import torch - -from torch._decomp import get_decompositions +import math +import torch -aten = torch.ops.aten -decomp_keys = [] - +from dicp.dynamo_bridge.decompositions import register_decomposition_for_dicp, get_decompositions def get_decomp(): - return get_decompositions(decomp_keys) + aten = torch.ops.aten + return get_decompositions( + [ + aten.count_nonzero.default, + ] + ) decomp = get_decomp() diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 11e72be2e..173dcf08b 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -22,6 +22,8 @@ from dicp.dynamo_bridge.conversion import register_conversion_impl from dicp.dynamo_bridge.op_transformer import SingleOpTransformer +# from dicp_ext_ops import lightllm + aten = torch.ops.aten prims = torch.ops.prims @@ -364,43 +366,26 @@ def empty(self, size, dtype=torch.int64, layout=torch.strided, device='cpu', mem return self.get_proxy(ascend_op.Empty, (shape_op, dtype, layout, device, memory_format)) @register_conversion(aten.empty_like.default) - def empty_like(self, x, dtype=torch.float32, layout=torch.strided, - device='cpu', pin_memory=False, memory_format=torch.preserve_format): - dtype = x.node.meta['val'].dtype - shape = list(x.node.meta['val'].shape) - shape_op = self.get_proxy( - ascend_op.Const, (shape, torch.int32, [len(shape)])) - new_memory_format=x.node.meta['tensor_meta'].memory_format if memory_format is torch.preserve_format else memory_format - return self.get_proxy(ascend_op.Empty, (shape_op, dtype, layout, device, new_memory_format)) + def empty_like(self, x, dtype=None, layout=None, + device=None, pin_memory=None, memory_format=None): + if dtype is None: + dtype = x.node.meta['val'].dtype + if layout is not None and (layout != torch.strided): + raise NotImplementedError("torch.ops.aten.empty_like.default is " + "only supported on dense tensor now.") + if memory_format is not None and memory_format != torch.contiguous_format \ + and memory_format != torch.preserve_format: + raise NotImplementedError("torch.ops.aten.empty_like.default is only supported " + "contiguous_format and preserve_format now.") + shape = self.get_proxy(ascend_op.Shape, (x,)) + return self.get_proxy(ascend_op.Empty, (shape, dtype)) @register_conversion(aten.select.int) def select(self, x, dim, index): - x_shape = list(x.node.meta['val'].shape) - y_shape = list(fx_traceback.get_current_meta()['val'].shape) - dim = int(dim) - index = int(index) - assert dim >= 0 and dim < len(x_shape) - start = index if index >= 0 else index + x_shape[dim] - end = start + 1 - offset = [0] * len(x_shape) - offset[dim] = start - size = [] - for i, v in enumerate(x_shape): - if i != dim: - size.append(v - offset[i]) - else: - size.append(end - offset[i]) - offset = self.get_shape_proxy(offset) - size = self.get_shape_proxy(size) - slice = self.get_proxy(ascend_op.Slice, (x, offset, size)) - y_shape = self.get_shape_proxy(y_shape) - Reshape_kw = { - "ori_op": "Select", - "params_passed": { - "sel_dim": dim, - }, - } - return self.get_proxy(ascend_op.Reshape, (slice, y_shape), Reshape_kw) + axis = self.get_const_proxy(dim, torch.int32) + if not isinstance(index, torch.fx.proxy.Proxy): + index = self.get_const_proxy(index, torch.int32) + return self.get_proxy(ascend_op.GatherV2, (x, index, axis)) @register_conversion(_operator.add) def inadd(self, x, y): @@ -517,6 +502,21 @@ def lt(self, x, y): y = self.get_proxy(ascend_op.BroadcastTo, (y, out_shape)) return self.get_proxy(ascend_op.Less, (x, y)) + + # y_shape = [1] + # if isinstance(y, torch.fx.proxy.Proxy): + # y_shape = list(y.node.meta['val'].shape) + # x_shape = list(x.node.meta['val'].shape) + # out = list(fx_traceback.get_current_meta()['val'].shape) + # out_shape = self.get_shape_proxy(out) + # x, y = self.binary_cmp_cast_input(x, y) + + # # if self.shape_prod(x_shape) < self.shape_prod(out): + # # x = self.get_proxy(ascend_op.BroadcastTo, (x, out_shape)) + # # if self.shape_prod(y_shape) < self.shape_prod(out): + # # y = self.get_proxy(ascend_op.BroadcastTo, (y, out_shape)) + # return self.get_proxy(ascend_op.Less, (x, y)) + @register_conversion(aten.masked_fill.Scalar) def masked_fill(self, x, mask, value): if str(value) == "-inf": @@ -868,6 +868,17 @@ def index_put_default(self, x, indices, values): # tf.tensor_scatter_nd_update param 'indices' is different from # indices in torch.ops.aten.index_put.default, we use broadcast and # stack to construct param 'indices' in tf.tensor_scatter_nd_update + x_shape = list(x.node.meta['val'].shape) + index = indices[0] + if len(indices) == 1 and index.node.meta['val'].dtype == torch.bool: + index_shape = list(index.node.meta['val'].shape) + if len(index_shape) == len(x_shape): + return self.masked_fill(x, index, values) + reshape_shape = index_shape + [1] * (len(x_shape) - len(index_shape)) + reshape_op = self.get_const_proxy(reshape_shape, torch.int32) + index = self.get_proxy(ascend_op.Reshape, (index, reshape_op)) + return self.masked_fill(x, index, values) + stacked_indices, indices_broadcast_shape, stacked_indices_last_dim = \ self.compute_stacked_indices(indices, x.node.meta['val'].shape) values_broadcast_shape = indices_broadcast_shape + x_shape[stacked_indices_last_dim:] # batch_shape + inner_shape @@ -1304,14 +1315,15 @@ def sum(self, a): return self.sumdim(a) @register_conversion(torch.ops.aten.sum.dim_IntList) - def sumdim(self, x, dims=[], keepdim=False, dtype=None): + def sumdim(self, x, dim=[], keepdim=False, dtype=None): x_dtype = x.node.meta['val'].dtype - if not isinstance(dims, list): - dims = [dims] - if dtype is None or x_dtype == dtype: - return self.get_proxy(ascend_op.ReduceSumD, (x, dims, keepdim)) - sum = self.get_proxy(ascend_op.ReduceSumD, (x, dims, keepdim)) - return self.get_proxy(ascend_op.Cast, (sum, get_ascend_dtype(dtype))) + x_shape = x.node.meta['val'].shape + if len(dim) == 0: + dim = list(range(len(x_shape))) + axes = self.get_const_proxy(dim, torch.int32) + if dtype and x_dtype != dtype: + x = self.get_proxy(ascend_op.Cast, (x, get_ascend_dtype(dtype))) + return self.get_proxy(ascend_op.ReduceSum, (x, axes, keepdim)) @register_conversion(torch.ops.aten.amax) def amax(self, x, dims, keepdim=False): @@ -1514,3 +1526,4 @@ def SliceScatter(self, operand, src, dim=0, start=None, end=None, step=1): @register_conversion(torch.ops.aten.scalar_tensor.default) def scalar_tensor(self, x, dtype=None, layout=None, device=None, pin_memory=None): return self.get_const_proxy(x, dtype) + diff --git a/dicp/dicp/vendor/AscendGraph/pattern_replacement.py b/dicp/dicp/vendor/AscendGraph/pattern_replacement.py index 6e737f085..d568e475d 100644 --- a/dicp/dicp/vendor/AscendGraph/pattern_replacement.py +++ b/dicp/dicp/vendor/AscendGraph/pattern_replacement.py @@ -54,8 +54,18 @@ def replacement(self, repeat, dim): Permute = torch.fx.wrap(ascend_op.Permute.get_singleton()) MatMul = torch.fx.wrap(ascend_op.MatMul.get_singleton()) +Pow = torch.fx.wrap(ascend_op.Pow.get_singleton()) +ReduceMeanD = torch.fx.wrap(ascend_op.ReduceMeanD.get_singleton()) +Adds = torch.fx.wrap(ascend_op.Adds.get_singleton()) +Rsqrt = torch.fx.wrap(ascend_op.Rsqrt.get_singleton()) +ZerosLike = torch.fx.wrap(ascend_op.ZerosLike.get_singleton()) +Less = torch.fx.wrap(ascend_op.Less.get_singleton()) +Select = torch.fx.wrap(ascend_op.Select.get_singleton()) +Mul = torch.fx.wrap(ascend_op.Mul.get_singleton()) +Div = torch.fx.wrap(ascend_op.Div.get_singleton()) +RmsNorm = torch.fx.wrap(ascend_op.RmsNorm.get_singleton()) -@register_ascend_pattern +# @register_ascend_pattern class FuseBmmTransposeRhsPattern(BackendPatternBase): @staticmethod def pattern(x1, x2, dtype): @@ -93,6 +103,28 @@ def replacement(x1, x2, c1, c2): muls = Muls(reshape, 0.3535533905932738) return BatchMatMul(x1, muls, adj_x1=False, adj_x2=True, keep_dtype=0) +@register_ascend_pattern +class FuseLightLLMRmsNorm(BackendPatternBase): + @staticmethod + def pattern(arg0_1, arg1_1): + const = Const([2], torch.float32) + pow_1 = Pow(arg0_1, const) + reduce_mean_d = ReduceMeanD(pow_1, [-1], True, False) + adds = Adds(reduce_mean_d, 0.001) + rsqrt = Rsqrt(adds) + zeros_like = ZerosLike(adds) + div = Div(zeros_like, zeros_like) + less = Less(adds, zeros_like) + select = Select(less, div, rsqrt) + mul = Mul(arg0_1, select) + mul_1 = Mul(mul, arg1_1) + return mul_1 + + @staticmethod + def replacement(arg0_1, arg1_1): + rms_norm = RmsNorm(arg0_1, arg1_1, 0.001) + return Identity(rms_norm, 0) + # @pandaoxin negotiate with @tangzhiyi # another submit would implement diff --git a/dicp/dicp/vendor/AscendGraph/torch_ext.py b/dicp/dicp/vendor/AscendGraph/torch_ext.py new file mode 100644 index 000000000..bc38f8577 --- /dev/null +++ b/dicp/dicp/vendor/AscendGraph/torch_ext.py @@ -0,0 +1,27 @@ +import torch +import torch._dynamo as dynamo + +from torch import Tensor + +from dicp.dynamo_bridge.decompositions import register_decomposition_for_dicp, get_decompositions + +# for lightllm rotary_emb +@torch._custom_op.impl.custom_op('ascend::lightllm_rotary_emb') +def lightllm_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + ... + +@lightllm_rotary_emb.impl_abstract() +def lightllm_rotary_emb_abstract(x, cos, sin): + return torch.empty_like(x) + +@lightllm_rotary_emb.impl(['cpu', 'cuda']) +def lightllm_rotary_emb_impl(x, cos, sin): + seq_len, h, dim = x.shape + x0 = x[:, :, 0: dim // 2] + x1 = x[:, :, dim // 2: dim] + cos = cos.view((seq_len, 1, dim // 2)) + sin = sin.view((seq_len, 1, dim // 2)) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + return torch.cat((o0, o1), dim=-1) + From 020afa4471a0529fbfd2a3716c447091d0cab5db Mon Sep 17 00:00:00 2001 From: Reinerzhou <1768552509@qq.com> Date: Fri, 8 Mar 2024 06:59:05 +0000 Subject: [PATCH 02/17] add prompt_attention op. --- dicp/dicp/vendor/AscendGraph/ascend_op.py | 24 ++++++++++ .../dicp/vendor/AscendGraph/codegen/ascend.py | 45 ++++++++++++++++++- dicp/dicp/vendor/AscendGraph/conversion.py | 30 +++++++++++++ 3 files changed, 98 insertions(+), 1 deletion(-) diff --git a/dicp/dicp/vendor/AscendGraph/ascend_op.py b/dicp/dicp/vendor/AscendGraph/ascend_op.py index c8af38e48..4876db5ea 100644 --- a/dicp/dicp/vendor/AscendGraph/ascend_op.py +++ b/dicp/dicp/vendor/AscendGraph/ascend_op.py @@ -1040,6 +1040,30 @@ def __init__(self): self.torch_op = aten.repeat_interleave.self_int +class RotaryMul(Operator): + def __init__(self): + super().__init__("RotaryMul") + + def infer_result(self, x, cos, sin): + return torch.empty_like(x) + + +class RmsNorm(Operator): + def __init__(self): + super().__init__("RmsNorm") + + def infer_result(self, x, weight, eps): + return torch.empty_like(x) + + +class PromptFlashAttention(Operator): + def __init__(self): + super().__init__("PromptFlashAttention") + + def infer_result(self, q, k, v, num_head, seqlen): + return torch.empty_like(q) + + class TensorScatterUpdate(Operator): def __init__(self): super().__init__("TensorScatterUpdate") diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index db82369d6..db15a6e9f 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -527,6 +527,7 @@ def __init__(self, op_name: str, op_type: str): self.op_name = op_name self.op_type = op_type self.inputs = [] + self.optional_inputs = [] self.outputs = [] self.attrs = [] self.dynamic_inputs = [] @@ -539,6 +540,8 @@ def to_node(self): } if len(self.inputs) > 0: node["inputs"] = self.inputs + if len(self.optional_inputs) > 0: + node["optional_inputs"] = self.optional_inputs if len(self.outputs) > 0: node["outputs"] = self.outputs if len(self.attrs) > 0: @@ -554,7 +557,12 @@ def set_input(self, name, value): "name": name, "value": value, }) - + + def set_optional_input(self, name, value): + self.optional_inputs.append({ + "name": name, + "value": value, + }) def set_output_desc(self, name, shape, format, data_type): self.outputs.append({ @@ -1610,3 +1618,38 @@ def TensorScatterUpdate(name, x, indices, updates): op.set_input("indices", indices) op.set_input("updates", updates) return op.to_node() + + @staticmethod + def RotaryMul(name, x, cos, sin): + op = OP(name, "RotaryMul") + op.set_input("x", x) + op.set_input("r1", cos) + op.set_input("r2", sin) + return op.to_node() + + @staticmethod + def RmsNorm(name, x, weight, eps): + op = OP(name, "RmsNorm") + op.set_input("x", x) + op.set_input("gamma", weight) + op.set_attr_float("epsilon", float(eps)) + return op.to_node() + + + @staticmethod + def PromptFlashAttention(name, q, k, v, head_num, seqlen): + op = OP(name, "PromptFlashAttention") + op.set_input("query", q) + op.set_input("key", k) + op.set_input("value", v) + op.set_optional_input("actual_seq_lengths", seqlen) + # op.set_optional_input("actual_seq_lengths_kv", seqlen) + op.set_attr_int("num_heads", head_num) + # op.set_attr_float("scale_value", float(1 / 0.0078125))】 + # op.set_attr_float("scale_value", float(1 / 2)) + op.set_attr_float("scale_value", float(1 / 11.313708498984761)) + # op.set_attr_float("scale_value", float(1 / 5.656854249492381)) + op.set_attr_str("input_layout", "BSH") + # op.set_attr_int("num_key_value_heads", head_num) + + return op.to_node() diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 173dcf08b..87d89d5ce 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -1527,3 +1527,33 @@ def SliceScatter(self, operand, src, dim=0, start=None, end=None, step=1): def scalar_tensor(self, x, dtype=None, layout=None, device=None, pin_memory=None): return self.get_const_proxy(x, dtype) + @register_conversion(torch.ops.lightllm.rotary_emb.default) + def lightllm_rotary_emb(self, x, cos, sin): + x_shape = list(x.node.meta['val'].shape) + assert len(x_shape) == 3 + + seq_len = x_shape[0] + dim = x_shape[2] + + cos_sin_shape = self.get_const_proxy([seq_len, 1, dim // 2], torch.int32) + cos = self.get_proxy(ascend_op.Reshape, (cos, cos_sin_shape)) + sin = self.get_proxy(ascend_op.Reshape, (sin, cos_sin_shape)) + + x = self.get_proxy(ascend_op.Unsqueeze, (x, [0])) + cos = self.get_proxy(ascend_op.Tile, (cos, [1, 1, 1, 2])) + sin = self.get_proxy(ascend_op.Tile, (sin, [1, 1, 1, 2])) + + out = self.get_proxy(ascend_op.RotaryMul, (x, cos, sin)) + return self.get_proxy(ascend_op.Squeeze, (out, [0])) + + @register_conversion(torch.ops.lightllm.rms_norm.default) + def lightllm_rms_norm(self, x, weight, eps): + out = self.get_proxy(ascend_op.RmsNorm, (x, weight, eps)) + return self.get_proxy(ascend_op.Identity, (out, 0)) + + + @register_conversion(torch.ops.lightllm.prompt_attention_inference.default) + def prompt_attention_inference(self, q, k, v, num_head, seqlen): + fa = self.get_proxy(ascend_op.PromptFlashAttention, (q, k, v, num_head, seqlen)) + # fa = self.get_proxy(ascend_op.Identity, (fa, 3)) + return fa From 0917803a50c0c8429b0bbf5fe5222d7866abf5f5 Mon Sep 17 00:00:00 2001 From: Reinerzhou <1768552509@qq.com> Date: Fri, 19 Apr 2024 07:25:30 +0000 Subject: [PATCH 03/17] add op pattern match and ext_ops. --- dicp/dicp/vendor/AscendGraph/ascend_op.py | 69 ++++++++- .../dicp/vendor/AscendGraph/codegen/ascend.py | 138 +++++++++++++++--- .../vendor/AscendGraph/codegen/graph_utils.h | 28 ++-- .../AscendGraph/codegen/load_and_run.py | 80 +++++----- dicp/dicp/vendor/AscendGraph/codegen/utils.py | 2 + dicp/dicp/vendor/AscendGraph/conversion.py | 126 ++++++++++++++-- dicp/dicp/vendor/AscendGraph/ext_ops.py | 136 +++++++++++++++++ dicp/dicp/vendor/AscendGraph/opset_convert.py | 9 +- .../vendor/AscendGraph/pattern_replacement.py | 17 +++ 9 files changed, 515 insertions(+), 90 deletions(-) create mode 100644 dicp/dicp/vendor/AscendGraph/ext_ops.py diff --git a/dicp/dicp/vendor/AscendGraph/ascend_op.py b/dicp/dicp/vendor/AscendGraph/ascend_op.py index 4876db5ea..e234f80ac 100644 --- a/dicp/dicp/vendor/AscendGraph/ascend_op.py +++ b/dicp/dicp/vendor/AscendGraph/ascend_op.py @@ -970,6 +970,14 @@ def infer_result(self, x1, x2): return common_binary_op_infer(x1, x2, torch.bool) +class LogicalNot(Operator): + def __init__(self): + super().__init__("LogicalNot") + + def infer_result(self, x): + return common_binary_op_infer(x, torch.bool) + + class Tril(Operator): def __init__(self): super().__init__("Tril") @@ -1060,7 +1068,15 @@ class PromptFlashAttention(Operator): def __init__(self): super().__init__("PromptFlashAttention") - def infer_result(self, q, k, v, num_head, seqlen): + def infer_result(self, q, k, v, num_head, seqlen, mask, head_dim): + return torch.empty_like(q) + + +class IncreFlashAttention(Operator): + def __init__(self): + super().__init__("IncreFlashAttention") + + def infer_result(self, q, k, v, head_num): return torch.empty_like(q) @@ -1086,6 +1102,57 @@ def infer_result(self, x, indices, updates): return torch.empty(x_shape, dtype=x_dtype, memory_format=get_memory_format(x)) +class Gather(Operator): + def __init__(self): + super().__init__("Gather") + + def infer_result(self, x, index): + x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x) + idx, idx_shape, idx_dim, idx_dtype = get_fake_tensor_meta_val(index) + idx_shape = list(idx_shape) + idx_shape.append(x_shape[-1]) + return torch.empty(idx_shape, dtype=x_dtype, memory_format=get_memory_format(x)) + + +class InplaceCopyWithOffset(Operator): + def __init__(self): + super().__init__("InplaceCopyWithOffset") + + def infer_result(self, x, src, dim, offset): + return src + + +class ExpandDims(Operator): + def __init__(self): + super().__init__("ExpandDims") + + def infer_result(self, x, axis): + return torch.unsqueeze(x, axis) + + +class MaskedScatter(Operator): + def __init__(self): + super().__init__("MaskedScatter") + + def infer_result(self, x, mask, updates): + return x + +class ViewCopy(Operator): + def __init__(self): + super().__init__("ViewCopy") + + def infer_result(self, dst, dst_size, dst_stride, dst_storage_offset, src, src_size, src_stride, src_storage_offset): + return x + + +class ScatterNdUpdate(Operator): + def __init__(self): + super().__init__("ScatterNdUpdate") + + def infer_result(self, x, indices, updates): + return x + + def ret_triple(a, b, c) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return a, b, c diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index db15a6e9f..1899ef75f 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -1,9 +1,11 @@ import json import os +import math import uuid import torch from typing import Any, List from torch.fx.node import Node +from torch.utils._pytree import tree_map_only from torch._inductor.utils import IndentedBuffer from dicp.dynamo_bridge.utils import symint_in_shape from dicp.vendor.AscendGraph.codegen.utils import ( @@ -12,6 +14,8 @@ get_ascend_dtype_num ) +need_profile = False + graph_id = 0 precision_check = bool(os.environ.get("DICP_ASCEND_PRECISION_CHECK", False)) @@ -72,6 +76,7 @@ def __init__(self, graph, aten_graph=None, folder=None, graph_key=None): # for modified args return self.assign_args = [] self.cpu_tensor = [] + self.assign_with_offset_args = [] super().__init__(graph) @@ -135,6 +140,8 @@ def call_function(self, name, target, args, kwargs): self.cpu_tensor.append(self.cur_node.meta['prop']['cpu_tensor']) if 'prop' in self.cur_node.meta and 'assign_args' in self.cur_node.meta['prop']: self.assign_args.append(self.cur_node.meta['prop']['assign_args']) + if 'prop' in self.cur_node.meta and 'assign_with_offset_args' in self.cur_node.meta['prop']: + self.assign_with_offset_args.append(self.cur_node.meta['prop']['assign_with_offset_args']) _, args_list = AscendOverrides.gen_args( self.args_dict[name], self.args_dict, args) @@ -197,6 +204,10 @@ def parse_outputs(self): if len(self.assign_args) > 0: self.graph_output_names.extend(list(zip(*self.assign_args))[0]) + current_index = len(self.graph_output_names) + for i in range(len(self.assign_with_offset_args)): + self.assign_with_offset_args[i]['output_index'] = current_index + current_index += 1 def gen_import_code(self): self.import_code.splice( @@ -374,7 +385,15 @@ def gen_call_func(self): output_index = self.graph_output_names.index(item[0]) allocated_output[output_index] = input_index call_body.writeline(f'allocated_output= {allocated_output}') - call_str = ['output_tensor = kernel_cpp_0(args, dims, output_shape, out_stride, out_storage_offset, allocated_output)'] + + allocated_output_with_offset = {} + for item in self.assign_with_offset_args: + # input_index = item['input_index'] + output_index = item['output_index'] + # offset = item['offset'] + allocated_output_with_offset[output_index] = item + call_body.writeline(f'allocated_with_offset_output= {self.assign_with_offset_args}') + call_str = ['output_tensor = kernel_cpp_0(args, dims, output_shape, out_stride, out_storage_offset, allocated_output, allocated_with_offset_output)'] if precision_check and self.aten_graph is not None: # import aten graph @@ -492,10 +511,14 @@ def gen_graph_json(self): self.parse_outputs() self.gen_build_options() has_dynamic_shape = False if len(self.sym_in_args) == 0 and len(self.sym_to_inputs) == 0 else True + with_copy_inplace = self.graph_output_names + for i in self.assign_with_offset_args: + with_copy_inplace.append(i['name']) graph = { "name": "graph", "input_names": self.graph_input_names, - "output_names": self.graph_output_names, + # "output_names": self.graph_output_names, + "output_names": with_copy_inplace, "has_dynamic_shape": has_dynamic_shape, "build_options": self.build_options, "data_nodes": self.data_nodes, @@ -527,8 +550,8 @@ def __init__(self, op_name: str, op_type: str): self.op_name = op_name self.op_type = op_type self.inputs = [] - self.optional_inputs = [] self.outputs = [] + self.optional_inputs = [] self.attrs = [] self.dynamic_inputs = [] self.dynamic_outputs = [] @@ -540,10 +563,10 @@ def to_node(self): } if len(self.inputs) > 0: node["inputs"] = self.inputs - if len(self.optional_inputs) > 0: - node["optional_inputs"] = self.optional_inputs if len(self.outputs) > 0: node["outputs"] = self.outputs + if len(self.optional_inputs) > 0: + node["optional_inputs"] = self.optional_inputs if len(self.attrs) > 0: node["attrs"] = self.attrs if len(self.dynamic_inputs) > 0: @@ -701,11 +724,13 @@ class AscendOverrides: def gen_args(op_var, args_dict, args): src_code = IndentedBuffer() args_str = [op_var] - for i in range(len(args)): - if isinstance(args[i], Node): - args_str.append(args_dict[args[i].name]) - else: - args_str.append(args[i]) + args_str.extend(tree_map_only(Node, lambda x: args_dict[x.name], args)) + + # for i in range(len(args)): + # if isinstance(args[i], Node): + # args_str.append(args_dict[args[i].name]) + # else: + # args_str.append(args[i]) return src_code, args_str @staticmethod @@ -1067,9 +1092,6 @@ def CastToCpu(name, x, ascend_dtype, device=None): def Const(name, x, dtype, dims=None, format="ND"): if not isinstance(x, list): x = [x] - if len(x) <= 0: - import pdb;pdb.set_trace() - pass # assert len(x) > 0 ascend_dtype = get_ascend_dtype(dtype) cpp_dtype = get_cpp_dtype(dtype) @@ -1399,7 +1421,8 @@ def Pack(name, x, axis): x_name = [] for elem in x: if elem is not None: - x_name.append(elem.name) + # x_name.append(elem.name) + x_name.append(elem) op = OP(name, "Pack") op.set_dynamic_input("x", len(x_name), x_name) @@ -1603,6 +1626,12 @@ def LogicalOr(name, x, y): op.set_input("x2", y) return op.to_node() + @staticmethod + def LogicalNot(name, x): + op = OP(name, "LogicalNot") + op.set_input("x", x) + return op.to_node() + @staticmethod def TileWithAxis(name, x, axis, tiles): op = OP(name, "TileWithAxis") @@ -1635,21 +1664,86 @@ def RmsNorm(name, x, weight, eps): op.set_attr_float("epsilon", float(eps)) return op.to_node() - @staticmethod - def PromptFlashAttention(name, q, k, v, head_num, seqlen): + def PromptFlashAttention(name, q, k, v, head_num, seqlen, mask, head_dim): + # import pdb; pdb.set_trace() op = OP(name, "PromptFlashAttention") op.set_input("query", q) op.set_input("key", k) op.set_input("value", v) - op.set_optional_input("actual_seq_lengths", seqlen) + op.set_input("atten_mask", mask) + # op.set_optional_input("atten_mask", mask) + # op.set_optional_input("padding_mask", mask) + # op.set_input("actual_seq_lengths", seqlen) # op.set_optional_input("actual_seq_lengths_kv", seqlen) op.set_attr_int("num_heads", head_num) - # op.set_attr_float("scale_value", float(1 / 0.0078125))】 - # op.set_attr_float("scale_value", float(1 / 2)) - op.set_attr_float("scale_value", float(1 / 11.313708498984761)) - # op.set_attr_float("scale_value", float(1 / 5.656854249492381)) - op.set_attr_str("input_layout", "BSH") + # op.set_attr_float("scale_value", float(1 / 11.313708498984761)) + op.set_attr_float("scale_value", float(1 / math.sqrt(head_dim))) + # op.set_attr_int("pre_tokens", 214748647) + # op.set_attr_int("next_tokens", 0) # op.set_attr_int("num_key_value_heads", head_num) + op.set_attr_str("input_layout", "BSH") + # op.set_attr_int("num_key_value_heads", 0) + return op.to_node() + + + @staticmethod + def IncreFlashAttention(name, q, k_list, v_list, kv_input_num, head_num, kv_head_num, dim, input_layout="BSH"): + op = OP(name, "IncreFlashAttention") + op.set_input("query", q) + op.set_dynamic_input("key", kv_input_num, k_list) + op.set_dynamic_input("value", kv_input_num, v_list) + op.set_attr_int("num_heads", head_num) + op.set_attr_float("scale_value", float(1 / math.sqrt(dim))) + op.set_attr_int("num_key_value_heads", kv_head_num) + op.set_attr_str("input_layout", input_layout) + return op.to_node() + + @staticmethod + def Gather(name, x, indices): + gather_op = OP(name, "Gather") + gather_op.set_input("x", x) + gather_op.set_input("indices", indices) + return gather_op.to_node() + + @staticmethod + def ExpandDims(name, x, axis): + gather_op = OP(name, "ExpandDims") + gather_op.set_input("x", x) + gather_op.set_input("axis", axis) + return gather_op.to_node() + + @staticmethod + def InplaceCopyWithOffset(name, x, src, dim, offset): + op = OP(name, "Identity") + op.set_input("x", src) + return op.to_node() + @staticmethod + def MaskedScatter(name, x, mask, updates): + op = OP(name, "MaskedScatter") + op.set_input("x", x) + op.set_input("mask", mask) + op.set_input("updates", updates) + return op.to_node() + + @staticmethod + def ViewCopy(name, dst, dst_size, dst_stride, dst_storage_offset, src, src_size, src_stride, src_storage_offset): + op = OP(name, "ViewCopy") + op.set_input("dst", dst) + op.set_input("dst_size", dst_size) + op.set_input("dst_stride", dst_stride) + op.set_input("dst_storage_offset", dst_storage_offset) + op.set_input("src", src) + op.set_input("src_size", src_size) + op.set_input("src_stride", src_stride) + op.set_input("src_storage_offset", src_storage_offset) + return op.to_node() + + @staticmethod + def ScatterNdUpdate(name, x, indices, updates): + op = OP(name, "ScatterNdUpdate") + op.set_input("var", x) + op.set_input("indices", indices) + op.set_input("updates", updates) return op.to_node() diff --git a/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h b/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h index 4f45b3b88..6dbedafb6 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h +++ b/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h @@ -91,7 +91,8 @@ class AclgraphBuilder { {AscendString(ge::ir_option::SOC_VERSION), AscendString(kSocVersion)}, {AscendString(ge::ir_option::FUSION_SWITCH_FILE), AscendString(_fusion_switch_file.c_str())}, - // {AscendString(ge::ir_option::PRECISION_MODE), "allow_fp32_to_fp16"}, + // {AscendString(ge::ir_option::PRECISION_MODE_V2), "mixed_float16"}, + // {AscendString(ge::ir_option::PRECISION_MODE_V2), "fp16"}, }; auto status = aclgrphBuildInitialize(global_options); if (status != GRAPH_SUCCESS) { @@ -191,23 +192,28 @@ void parseDynamicInput(std::unordered_map& op_map, void parseIncreFlashAttentionDynamicInput(std::unordered_map& op_map, op::IncreFlashAttention& op, const json& node) { if (node.contains("dynamic_inputs")) { + int kv_inputs_num = 0; for (const auto& i : node["dynamic_inputs"]) { auto num = i["num"].get(); auto name = i["name"].get(); - if (name == "key" || name == "value") { - if (name == "key") { - op.create_dynamic_input_key(num); - } else { - op.create_dynamic_input_value(num); + if (name == "key") { + kv_inputs_num = static_cast(num); + op.create_dynamic_input_byindex_key(num, 1); + for (const auto& item : i["value"]) { + auto index = item["index"].get(); + auto value = op_map[item["value"].get()]; + op.set_dynamic_input_key(index, value); } + } else if (name == "value") { + if (kv_inputs_num == 0 && num == kv_inputs_num) { + throw std::runtime_error("need first set dynamic key input for IncreFlashAttention Op" + "and kv_inputs_num == num !!"); + } + op.create_dynamic_input_byindex_value(num, 1 + num); for (const auto& item : i["value"]) { auto index = item["index"].get(); auto value = op_map[item["value"].get()]; - if (name == "key") { - op.set_dynamic_input_key(index, value); - } else { - op.set_dynamic_input_value(index, value); - } + op.set_dynamic_input_value(index, value); } } else { throw std::runtime_error("invalid dynamic input name"); diff --git a/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py b/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py index 4c2e7f00a..c4cd27883 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py @@ -178,8 +178,9 @@ def release_resource(self): self.weight_ptr = None def load_model(self): - # work_size, weight_size, ret = acl.mdl.query_size(self.model_path) - # check_ret("acl.mdl.query_size", ret) + work_size, weight_size, ret = acl.mdl.query_size(self.model_path) + check_ret("acl.mdl.query_size", ret) + # print('### query_size:', work_size) # if work_size == 0: # work_size = memory_pool.work_size # elif work_size > memory_pool.work_size: @@ -196,45 +197,40 @@ def load_model(self): # ACL_MEM_MALLOC_HUGE_FIRST) # check_ret("acl.rt.malloc", ret) - # self.weight_ptr, ret = acl.rt.malloc(weight_size, - # ACL_MEM_MALLOC_HUGE_FIRST) - # check_ret("acl.rt.malloc", ret) - # config_handle = acl.mdl.create_config_handle() - # ret = acl.mdl.set_config_opt(config_handle, ACL_MDL_LOAD_TYPE_SIZET, 2) - # check_ret("set_config_opt", ret) + self.weight_ptr, ret = acl.rt.malloc(weight_size, + ACL_MEM_MALLOC_HUGE_FIRST) + check_ret("acl.rt.malloc", ret) + config_handle = acl.mdl.create_config_handle() + ret = acl.mdl.set_config_opt(config_handle, ACL_MDL_LOAD_TYPE_SIZET, 2) + check_ret("set_config_opt", ret) - # ret = acl.mdl.set_config_opt( - # config_handle, ACL_MDL_PATH_PTR, self.model_path) - # check_ret("set_config_opt", ret) + ret = acl.mdl.set_config_opt( + config_handle, ACL_MDL_PATH_PTR, self.model_path) + check_ret("set_config_opt", ret) - # ret = acl.mdl.set_config_opt( - # config_handle, ACL_MDL_WEIGHT_ADDR_PTR, self.weight_ptr) - # check_ret("set_config_opt", ret) + ret = acl.mdl.set_config_opt( + config_handle, ACL_MDL_WEIGHT_ADDR_PTR, self.weight_ptr) + check_ret("set_config_opt", ret) - # ret = acl.mdl.set_config_opt( - # config_handle, ACL_MDL_WEIGHT_SIZET, weight_size) - # check_ret("set_config_opt", ret) + ret = acl.mdl.set_config_opt( + config_handle, ACL_MDL_WEIGHT_SIZET, weight_size) + check_ret("set_config_opt", ret) - # ret = acl.mdl.set_config_opt( - # config_handle, ACL_MDL_WORKSPACE_ADDR_PTR, memory_pool.work_ptr) - # check_ret("set_config_opt", ret) + ret = acl.mdl.set_config_opt( + config_handle, ACL_MDL_WORKSPACE_ADDR_PTR, memory_pool.work_ptr) + check_ret("set_config_opt", ret) - # ret = acl.mdl.set_config_opt( - # config_handle, ACL_MDL_WORKSPACE_SIZET, memory_pool.work_size) - # check_ret("set_config_opt", ret) + ret = acl.mdl.set_config_opt( + config_handle, ACL_MDL_WORKSPACE_SIZET, memory_pool.work_size) + check_ret("set_config_opt", ret) - # ret = acl.mdl.set_config_opt( - # config_handle, ACL_MDL_WORKSPACE_MEM_OPTIMIZE, 1) - # check_ret("set_config_opt", ret) + ret = acl.mdl.set_config_opt( + config_handle, ACL_MDL_WORKSPACE_MEM_OPTIMIZE, 1) + check_ret("set_config_opt", ret) - # self.model_id, ret = acl.mdl.load_with_config(config_handle) - # check_ret("acl.mdl.load_with_config", ret) + self.model_id, ret = acl.mdl.load_with_config(config_handle) + check_ret("acl.mdl.load_with_config", ret) # print("model_id:{}".format(self.model_id)) - - - self.model_id, ret = acl.mdl.load_from_file(self.model_path) - check_ret("acl.mdl.load_from_file", ret) - print("model_id:{}".format(self.model_id)) self.model_desc = acl.mdl.create_desc() ret = acl.mdl.get_desc(self.model_desc, self.model_id) @@ -303,7 +299,7 @@ def _prepare_input(self, images, dims): assert (dataset == self.input_dataset) @record_function('load_and_run_prepare_output') - def _prepare_output(self, output_tensor, output_shape, out_stride, out_storage_offset, allocated_output): + def _prepare_output(self, output_tensor, output_shape, out_stride, out_storage_offset, allocated_output, allocated_output_with_offset_tensor): for i in range(self.num_outputs): if allocated_output and i in allocated_output.keys(): item = allocated_output[i] @@ -345,8 +341,7 @@ def _prepare_dynamic_output(self, output_tensor, output_shape, out_stride, out_s # @record_function(f'load_and_run_run') def run(self, images, dims=None, output_shape=None, out_stride=None, out_storage_offset=None, - allocated_output=None): - # print('### load_and_run: model_id:', self.model_id) + allocated_output=None, allocated_with_offset_output=None): with record_function(f'load_and_run_run_{self.model_id}'): assert len(images) > 0 input = [x.to(dipu_device_str) if isinstance(x, torch.Tensor) @@ -356,6 +351,13 @@ def run(self, images, dims=None, output_shape=None, allocated_output_tensor = {} for output_index, input_index in allocated_output.items(): allocated_output_tensor[output_index] = input[input_index] + + allocated_output_with_offset_tensor = None + if allocated_with_offset_output: + allocated_output_with_offset_tensor = {} + for item in allocated_with_offset_output: + allocated_output_with_offset_tensor[item['output_index']] = {'input_tensor': input[item['input_index']], 'offset': item['offset']} + self._prepare_input(input, dims) output = [] if output_shape: @@ -363,7 +365,7 @@ def run(self, images, dims=None, output_shape=None, output, output_shape, out_stride, out_storage_offset, allocated_output_tensor) else: self._prepare_output( - output, output_shape, out_stride, out_storage_offset, allocated_output_tensor) + output, output_shape, out_stride, out_storage_offset, allocated_output_tensor, allocated_output_with_offset_tensor) self.forward() self._destroy_databuffer() return output @@ -388,8 +390,8 @@ def __init__(self, device_id, model_path) -> None: self.exe = AscendExecutor(device_id, model_path) def run(self, images, dims=None, output_shape=None, - out_stride=None, out_storage_offset=None, allocated_output=None): - return self.exe.run(images, dims, output_shape, out_stride, out_storage_offset, allocated_output) + out_stride=None, out_storage_offset=None, allocated_output=None, allocated_with_offset_output=None): + return self.exe.run(images, dims, output_shape, out_stride, out_storage_offset, allocated_output, allocated_with_offset_output) def cleanup(self): if hasattr(self, 'exe'): diff --git a/dicp/dicp/vendor/AscendGraph/codegen/utils.py b/dicp/dicp/vendor/AscendGraph/codegen/utils.py index 88b13357e..a871fd37b 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/utils.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/utils.py @@ -162,6 +162,8 @@ def get_cpp_dtype(dtype: torch.dtype) -> str: return "INT32" elif dtype == torch.float16: return "FLOAT16" + elif dtype == torch.bool: + return "BOOL" else: raise RuntimeError(f"unknow torch data type ({dtype}) in get_cpp_dtype!") diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 87d89d5ce..fd820779f 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -21,9 +21,7 @@ ) from dicp.dynamo_bridge.conversion import register_conversion_impl from dicp.dynamo_bridge.op_transformer import SingleOpTransformer - -# from dicp_ext_ops import lightllm - +from dicp.vendor.AscendGraph import ext_ops aten = torch.ops.aten prims = torch.ops.prims @@ -53,7 +51,7 @@ def try_to_get_dtype(x): def is_dicp_cpp_support_dtype(dtype): - if dtype in [torch.float32, torch.float, torch.float16, torch.int32, torch.int64]: + if dtype in [torch.float32, torch.float, torch.float16, torch.int32, torch.int64, torch.bool]: return True return False @@ -136,14 +134,14 @@ def generate_sym_int(elem): # concat all ops return self.get_proxy(ascend_op.ConcatD, (x_names, 0)) - def get_shape_proxy(self, shape): + def get_shape_proxy(self, shape, dtype= torch.int32): if isinstance(shape, torch.fx.proxy.Proxy) or isinstance(shape, FakeTensor): return shape elif isinstance(shape, list) and symint_in_shape(shape): return self.process_dynamic_shape(shape) else: return self.get_proxy( - ascend_op.Const, (shape, torch.int32, [len(shape)])) + ascend_op.Const, (shape, dtype, [len(shape)])) def get_const_proxy(self, param, dtype, format=None, target_shape=None): if not isinstance(param, torch.fx.proxy.Proxy) and not isinstance(param, FakeTensor): @@ -154,6 +152,8 @@ def get_const_proxy(self, param, dtype, format=None, target_shape=None): shape = target_shape param = param if isinstance(param, list) else [param] if is_dicp_cpp_support_dtype(dtype): + if isinstance(param, list) and len(param) == 1: + param = param[0] param = self.get_proxy( ascend_op.Const, (param, dtype, shape, format)) else: @@ -242,12 +242,12 @@ def mul(self, x, y): def add(self, x, y, alpha: Optional[Number] = 1): out_dtype = fx_traceback.get_current_meta()['val'].dtype if not isinstance(y, torch.fx.proxy.Proxy): - y = y * alpha + y = y * alpha if alpha != 1 else y if out_dtype in [torch.float, torch.float16]: return self.get_proxy(ascend_op.Adds, (x, float(y)), {}) y = self.get_const_proxy(y, out_dtype) else: - y = self.mul(y, alpha) + y = self.mul(y, alpha) if alpha != 1 else y x, y = self.promote_dtype(x, y, target_dtype=out_dtype) return self.get_proxy(ascend_op.AddV2, (x, y), {}) @@ -333,6 +333,7 @@ def slice(self, x, dim=0, start=None, end=None, step=1): # TODO(tangzhiyi): miss step parameter x_shape = list(x.node.meta['val'].shape) y_shape = list(fx_traceback.get_current_meta()['val'].shape) + # y_shape = fx_traceback.get_current_meta()['val'].shape dim = int(dim) start = int(start) if start is not None else 0 start = start if start >= 0 else x_shape[dim] + start @@ -340,6 +341,8 @@ def slice(self, x, dim=0, start=None, end=None, step=1): assert start is None or start >= 0 and start < x_shape[dim] offset = [0] * len(x_shape) offset[dim] = start + # import pdb; pdb.set_trace() + offset = self.get_shape_proxy(offset) size = self.get_shape_proxy(y_shape) return self.get_proxy(ascend_op.Slice, (x, offset, size)) @@ -360,7 +363,7 @@ def NewEmptyStrided(self, x, size, stride, dtype=torch.float32, layout=torch.str return self.empty_like(x) @register_conversion(aten.empty) - def empty(self, size, dtype=torch.int64, layout=torch.strided, device='cpu', memory_format=torch.contiguous_format): + def empty(self, size, dtype=torch.int64, layout=torch.strided, device='cpu', memory_format=torch.contiguous_format, pin_memory=False): shape_op = self.get_proxy( ascend_op.Const, (size, torch.int32, [len(size)])) return self.get_proxy(ascend_op.Empty, (shape_op, dtype, layout, device, memory_format)) @@ -870,7 +873,7 @@ def index_put_default(self, x, indices, values): # stack to construct param 'indices' in tf.tensor_scatter_nd_update x_shape = list(x.node.meta['val'].shape) index = indices[0] - if len(indices) == 1 and index.node.meta['val'].dtype == torch.bool: + if len(indices) == 1 and ('val' in index.node.meta.keys()) and index.node.meta['val'].dtype == torch.bool: index_shape = list(index.node.meta['val'].shape) if len(index_shape) == len(x_shape): return self.masked_fill(x, index, values) @@ -1487,6 +1490,10 @@ def Ge(self, x, y): def LogicalOr(self, x, y): return self.get_proxy(ascend_op.LogicalOr, (x, y)) + @register_conversion(torch.ops.aten.logical_not.default) + def LogicalNot(self, x): + return self.get_proxy(ascend_op.LogicalNot, (x,)) + @register_conversion(torch.ops.aten.slice_scatter.default) def SliceScatter(self, operand, src, dim=0, start=None, end=None, step=1): # modified from torchair @@ -1531,10 +1538,10 @@ def scalar_tensor(self, x, dtype=None, layout=None, device=None, pin_memory=None def lightllm_rotary_emb(self, x, cos, sin): x_shape = list(x.node.meta['val'].shape) assert len(x_shape) == 3 - + seq_len = x_shape[0] dim = x_shape[2] - + cos_sin_shape = self.get_const_proxy([seq_len, 1, dim // 2], torch.int32) cos = self.get_proxy(ascend_op.Reshape, (cos, cos_sin_shape)) sin = self.get_proxy(ascend_op.Reshape, (sin, cos_sin_shape)) @@ -1551,9 +1558,96 @@ def lightllm_rms_norm(self, x, weight, eps): out = self.get_proxy(ascend_op.RmsNorm, (x, weight, eps)) return self.get_proxy(ascend_op.Identity, (out, 0)) - @register_conversion(torch.ops.lightllm.prompt_attention_inference.default) - def prompt_attention_inference(self, q, k, v, num_head, seqlen): - fa = self.get_proxy(ascend_op.PromptFlashAttention, (q, k, v, num_head, seqlen)) - # fa = self.get_proxy(ascend_op.Identity, (fa, 3)) + def prompt_attention_inference(self, q, k, v, seqlen, num_head, head_dim): + q_shape = list(q.node.meta['val'].shape) + seq_len = q_shape[1] + shape = [seq_len, seq_len] + shape = self.get_proxy(ascend_op.Const, (shape, torch.int32, [len(shape)])) + mask = self.get_proxy(ascend_op.Empty, (shape, torch.bool)) + mask = self.get_proxy(ascend_op.OnesLike, (mask,)) + mask = self.get_proxy(ascend_op.Tril, (mask,)) + mask = self.get_proxy(ascend_op.LogicalNot, (mask,)) + fa = self.get_proxy(ascend_op.PromptFlashAttention, (q, k, v, num_head, seqlen, mask, head_dim)) return fa + + def incre_flash_attention(self, q, k, v, head_num, kv_head_num, dim): + k_list = [] + v_list = [] + if not isinstance(k, list): + k_list.append(k) + else: + k_list = k + if not isinstance(v, list): + v_list.append(v) + else: + v_list = v + assert len(k_list) == len(v_list) + kv_input_num = len(k_list) + out = self.get_proxy(ascend_op.IncreFlashAttention, (q, k_list, v_list, kv_input_num, kv_head_num, head_num, dim, "BSH")) + return out + + @register_conversion(aten.select_scatter.default) + def select_scatter(self, x, src, dim, index): + if not isinstance(index, torch.fx.proxy.Proxy): + index = self.get_const_proxy(index, torch.int32) + input_sizes = self.get_proxy(ascend_op.Shape, (x,)) + index = self.get_proxy(ascend_op.BroadcastTo, (index, input_sizes)) + dim_op = self.get_const_proxy(dim, torch.int32) + src = self.get_proxy(ascend_op.ExpandDims, (src, dim_op)) + src = self.get_proxy(ascend_op.BroadcastTo, (src, input_sizes)) + + return self.get_proxy(ascend_op.ScatterElements, (x, index, src, dim)) + + @register_conversion(torch.ops.lightllm.copy_with_offset.default) + def copy_with_offset2(self, x, src, start_dim, end_dim): + dims = [x for x in range(start_dim, end_dim)] + dims = self.get_const_proxy(dims, torch.int32, target_shape=[len(dims), 1]) + return self.get_proxy(ascend_op.ScatterNdUpdate, (x, dims, src)) + + @register_conversion(torch.ops.lightllm.flash_attention_inference.default) + def flash_attention_inference2(self, q, all_k, all_v, current_len, max_len): + q_shape = list(q.node.meta['val'].shape) + batch, head, dim = q_shape[0], q_shape[1], q_shape[2] + + res = [] + compute_batch = 1 + select_axis = self.get_const_proxy(0, torch.int32) + + for i in range(batch): + current_len = current_len[i] + select_index = self.get_const_proxy(i, torch.int32) + xq = self.get_proxy(ascend_op.GatherV2, (q, select_index, select_axis)) + + kv_start_index = self.get_const_proxy([i * max_len, 0, 0], torch.int32) + kv_end_index = self.get_const_proxy([i * max_len + current_len, head, dim], torch.int32) + kv_seq_len = current_len + + kv_gather_shape = self.get_shape_proxy([compute_batch, kv_seq_len, head, dim]) + kv_compute_shape = self.get_shape_proxy([compute_batch, kv_seq_len, head * dim]) + + # fetch k + k = self.get_proxy(ascend_op.Slice, (all_k, kv_start_index, kv_end_index)) + k = self.get_proxy(ascend_op.Reshape, (k, kv_gather_shape)) + k = self.get_proxy(ascend_op.Reshape, (k, kv_compute_shape)) + + # fetch v + v = self.get_proxy(ascend_op.Slice, (all_v, kv_start_index, kv_end_index)) + v = self.get_proxy(ascend_op.Reshape, (v, kv_gather_shape)) + v = self.get_proxy(ascend_op.Reshape, (v, kv_compute_shape)) + + # k,v shape: batch, kv_seq_len, head, dim + q_shape = self.get_shape_proxy([compute_batch, 1, head, dim]) + q_compute_shape = self.get_shape_proxy([compute_batch, 1, head * dim]) + xq = self.get_proxy(ascend_op.Reshape, (xq, q_shape)) + xq = self.get_proxy(ascend_op.Reshape, (xq, q_compute_shape)) + + out = self.incre_flash_attention(xq, k, v, head, head, dim) # q shape is BSH + out_shape = self.get_shape_proxy([compute_batch, 1, head, dim]) + out_shape2 = self.get_shape_proxy([compute_batch, head, dim]) + out = self.get_proxy(ascend_op.Reshape, (out, out_shape)) + out = self.get_proxy(ascend_op.Reshape, (out, out_shape2)) + res.append(out) + + res = self.get_proxy(ascend_op.ConcatD, (res, 0)) + return res diff --git a/dicp/dicp/vendor/AscendGraph/ext_ops.py b/dicp/dicp/vendor/AscendGraph/ext_ops.py new file mode 100644 index 000000000..9a16a1539 --- /dev/null +++ b/dicp/dicp/vendor/AscendGraph/ext_ops.py @@ -0,0 +1,136 @@ +import math + +import torch +import torch_dipu +import torch._dynamo as dynamo +import torch.nn.functional as F + +from torch import Tensor +from typing import Sequence + +torch._dynamo.config.suppress_errors = False + +# rotary_emb +@torch._custom_op.impl.custom_op('lightllm::rotary_emb') +def rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + ... + +@rotary_emb.impl_abstract() +def lightllm_rotary_emb_abstract(x, cos, sin): + return torch.empty_like(x) + +@rotary_emb.impl(['cpu', 'cuda']) +def lightllm_rotary_emb_impl(x, cos, sin): + seq_len, h, dim = x.shape + cos = cos.view((seq_len, 1, dim // 2)) + sin = sin.view((seq_len, 1, dim // 2)) + x0 = x[:, :, 0: dim // 2] + x1 = x[:, :, dim // 2: dim] + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + return torch.cat((o0, o1), dim=-1) + +# rms_norm +@torch._custom_op.impl.custom_op('lightllm::rms_norm') +def rms_norm(x: Tensor, weight: Tensor, eps: float) -> Tensor: + ... + +@rms_norm.impl_abstract() +def lightllm_rms_norm_abstract(x, weight, eps): + return torch.empty_like(x) + +@rms_norm.impl(['cpu', 'cuda']) +def lightllm_rms_norm_impl(x, weight, eps): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) * weight + + +@torch._custom_op.impl.custom_op('lightllm::prompt_attention_inference') +def prompt_attention_inference(q: Tensor, k: Tensor, v: Tensor, seqlen: Tensor, num_head: int, head_dim: int) -> Tensor: + ... + +@prompt_attention_inference.impl_abstract() +def lightllm_prompt_attention_inference_abstract(q: Tensor, k: Tensor, v: Tensor, seqlen: Tensor, num_head: int, head_dim: int): + return torch.empty_like(q) + +@prompt_attention_inference.impl(['cpu', 'cuda']) +def lightllm_prompt_attention_inference_impl(q, k, v, seqlen, num_head, head_dim): + # prompt attention just support bs=1 for now. + assert q.shape[0] == 1 + bs = q.shape[0] + seqlen = seqlen.item() + + xq = q.view(bs, seqlen, num_head, head_dim) + xk = k.view(bs, seqlen, num_head, head_dim) + xv = v.view(bs, seqlen, num_head, head_dim) + + mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0) + mask = mask.masked_fill(mask == 0., -999999999999.0) + mask = mask.masked_fill(mask == 1., 0.0) + mask = mask.repeat(bs, num_head, 1, 1) + + keys = xk + values = xv + xq = xq.transpose(1, 2).float() + keys = xk.transpose(1, 2).float() + values = xv.transpose(1, 2).float() + + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim) + scores = F.softmax(scores + mask.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) + + return output + +@torch._custom_op.impl.custom_op('lightllm::flash_attention_inference') +def flash_attention_inference(q: Tensor, all_k: Tensor, all_v: Tensor, currnet_lens: Sequence[int], max_len: int) -> Tensor: + ... + +@flash_attention_inference.impl_abstract() +def lightllm_flash_attention_inference_abstract(q: Tensor, all_k: Tensor, all_v: Tensor, currnet_lens: Sequence[int], max_len: int): + return torch.empty_like(q) + +@flash_attention_inference.impl(['cpu', 'cuda']) +def lightllm_flash_attention_inference_impl(q, all_k, all_v, current_lens, max_len): + # q: batch, head, dim + batch = q.shape[0] + head = q.shape[1] + dim = q.shape[2] + res = [] + compute_batch = 1 + for i in range(batch): + current_len = current_lens[i] + kv_seq_len = current_len + + k = all_k[:current_len].reshape(compute_batch, kv_seq_len, head, dim) + v = all_v[:current_len].reshape(compute_batch, kv_seq_len, head, dim) + + xq = q[i].view(compute_batch, 1, head, dim).transpose(1, 2).transpose(0, 1) # shape: head, batch, 1, dim + bmm_xq = xq.reshape(head * compute_batch, 1, dim) + bmm_xk = k.transpose(1, 2).transpose(0, 1).transpose(2, 3).reshape(head * compute_batch, dim, kv_seq_len) + + + # q @ k + out = torch.bmm(bmm_xq, bmm_xk) / math.sqrt(dim) + out = out.reshape(head, compute_batch, 1, -1).reshape(head, compute_batch, -1) + + # softmax + out = out.softmax(-1).reshape(head, compute_batch, 1, kv_seq_len).transpose(0, 1) # shape: batch head 1 seq_len + xv = v.transpose(1, 2) # shape: batch head, seq_len, dim + out = torch.bmm(out.reshape(compute_batch * head, 1, kv_seq_len), xv.reshape(compute_batch * head, kv_seq_len, dim)) + + out = out.reshape(compute_batch, head, 1, dim).view(compute_batch, head, dim) + res.append(out) + res = torch.cat(res) + return res + +@torch._custom_op.impl.custom_op('lightllm::copy_with_offset') +def copy_with_offset(x: Tensor, src: Tensor, start_dim: int, end_dim: int) -> Tensor: + ... + +@copy_with_offset.impl_abstract() +def lightllm_copy_with_offset_abstract(x: Tensor, src: Tensor, start_dim: int, end_dim: int) -> Tensor: + return x + +@copy_with_offset.impl(['cpu', 'cuda']) +def lightllm_copy_with_offset_impl(x, src, start_dim, end_dim) -> Tensor: + x[start_dim:end_dim] = src + return x diff --git a/dicp/dicp/vendor/AscendGraph/opset_convert.py b/dicp/dicp/vendor/AscendGraph/opset_convert.py index e9a44216b..1843b7ac9 100644 --- a/dicp/dicp/vendor/AscendGraph/opset_convert.py +++ b/dicp/dicp/vendor/AscendGraph/opset_convert.py @@ -1,7 +1,7 @@ import torch import torch_dipu from dicp.dynamo_bridge.compile_fx import is_torch_210 -from dicp.vendor.AscendGraph.ascend_op import MatMul, CastToCpu, IdentityInp +from dicp.vendor.AscendGraph.ascend_op import MatMul, CastToCpu, IdentityInp, InplaceCopyWithOffset from dicp.vendor.AscendGraph.conversion import AtenToAscendTransformer from ...dynamo_bridge.graph import GraphTransformer @@ -32,6 +32,7 @@ class OutputMarkPass: def __init__(self): self.assign_args = [] self.cpu_tensor = [] + self.assign_with_offset_args = {} def transform(self, gm: torch.fx.graph_module): # dynamic shape feature @@ -45,6 +46,10 @@ def transform(self, gm: torch.fx.graph_module): continue if type(n.target) == CastToCpu: self.cpu_tensor.append(n.name) + elif type(n.target) == InplaceCopyWithOffset: + input_index = input_names.index(str(n.args[0])) + offset = int(n.args[-1]) + self.assign_with_offset_args[n.name] = {'name': n.name, 'input_index': input_index, 'offset': offset} elif type(n.target) == IdentityInp: if len(n.args) == 2 and n.args[1] is not None and str(n.args[1]) in input_names: self.assign_args.append((n.name, input_names.index(str(n.args[1])))) @@ -59,6 +64,8 @@ def transform(self, gm: torch.fx.graph_module): if len(self.assign_args) > 0 and n.name in list(zip(*self.assign_args))[0]: idx = list(zip(*self.assign_args))[0].index(n.name) prop.update({'assign_args' : (self.assign_args[idx][0], self.assign_args[idx][1])}) + if n.name in self.assign_with_offset_args.keys(): + prop['assign_with_offset_args'] = self.assign_with_offset_args[n.name] n.meta['prop'] = prop return gm diff --git a/dicp/dicp/vendor/AscendGraph/pattern_replacement.py b/dicp/dicp/vendor/AscendGraph/pattern_replacement.py index d568e475d..aebf79a12 100644 --- a/dicp/dicp/vendor/AscendGraph/pattern_replacement.py +++ b/dicp/dicp/vendor/AscendGraph/pattern_replacement.py @@ -27,6 +27,7 @@ def replacement(input, dims): varVal = torch.ops.aten.var(input, dims, correction=1, keepdim=True) return ascend_op.ret_tuple(varVal, meanVal) + @register_aten_pattern class FusedRepeatInterleaveSelfInt(BackendPatternBase): @staticmethod @@ -43,6 +44,22 @@ def pattern(self, repeat, dim, input_shape, empty_device, view_1_shape, @staticmethod def replacement(self, repeat, dim): return torch.ops.aten.repeat_interleave.self_int(self, repeat, dim) + + +@register_aten_pattern +class ReplaceAtenSliceScatter(BackendPatternBase): + @staticmethod + def pattern(arg0, arg1, start_index, end_index): + slice = torch.ops.aten.slice.Tensor(arg0, 0, start_index, end_index) + copy = torch.ops.aten.copy.default(slice, arg1) + slice_scatter = torch.ops.aten.slice_scatter.default(arg0, copy, 0, start_index, end_index) + copy_ = torch.ops.aten.copy_.default(slice_scatter, arg0) + return slice_scatter + + @staticmethod + def replacement(arg0, arg1, start_index, end_index): + slice_scatter = torch.ops.lightllm.copy_with_offset.default(arg0, arg1, start_index, end_index) + return slice_scatter Muls = torch.fx.wrap(ascend_op.Muls.get_singleton()) Shape = torch.fx.wrap(ascend_op.Shape.get_singleton()) From 5ee8611696552ad91fc956c4f018e8aa3bfdd9a6 Mon Sep 17 00:00:00 2001 From: Reinerzhou <1768552509@qq.com> Date: Sun, 21 Apr 2024 12:17:24 +0000 Subject: [PATCH 04/17] clean code. --- dicp/dicp/dynamo_bridge/compile_fx.py | 5 ++- dicp/dicp/dynamo_bridge/graph.py | 6 +--- .../dicp/vendor/AscendGraph/codegen/ascend.py | 25 +------------ .../AscendGraph/codegen/load_and_run.py | 35 +++++++++---------- dicp/dicp/vendor/AscendGraph/compile_job.py | 6 ++-- dicp/dicp/vendor/AscendGraph/conversion.py | 19 +--------- dicp/dicp/vendor/AscendGraph/ext_ops.py | 5 ++- dicp/dicp/vendor/AscendGraph/torch_ext.py | 27 -------------- 8 files changed, 26 insertions(+), 102 deletions(-) delete mode 100644 dicp/dicp/vendor/AscendGraph/torch_ext.py diff --git a/dicp/dicp/dynamo_bridge/compile_fx.py b/dicp/dicp/dynamo_bridge/compile_fx.py index 118244517..5eaee327c 100644 --- a/dicp/dicp/dynamo_bridge/compile_fx.py +++ b/dicp/dicp/dynamo_bridge/compile_fx.py @@ -208,8 +208,7 @@ def compile_fx_210( def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference): if is_inference: # partition_fn won't be called - # joint_graph_passes(model) - pass + joint_graph_passes(model) fixed = len(example_inputs) - num_example_inputs return inner_compile( @@ -224,7 +223,7 @@ def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference): inference_compiler = functools.partial(fw_compiler_base, is_inference=True) def partition_fn(graph, joint_inputs, **kwargs): - # joint_graph_passes(graph) + joint_graph_passes(graph) return min_cut_rematerialization_partition( graph, joint_inputs, **kwargs, compiler="inductor" ) diff --git a/dicp/dicp/dynamo_bridge/graph.py b/dicp/dicp/dynamo_bridge/graph.py index 852657a83..2bfdc0b2b 100644 --- a/dicp/dicp/dynamo_bridge/graph.py +++ b/dicp/dicp/dynamo_bridge/graph.py @@ -46,11 +46,7 @@ def make_tensor_meta(x) -> Optional[TensorMetadata]: for n in self.gm.graph.nodes: fake_value = None if n.op == 'call_function': - try: - fake_value = (n.target(*n.args, **n.kwargs)) - except Exception as e: - import pdb;pdb.set_trace() - pass + fake_value = (n.target(*n.args, **n.kwargs)) elif n.op == 'get_attr': target_atoms = n.target.split('.') attr_itr = self.gm diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 1899ef75f..e78327bf7 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -14,7 +14,6 @@ get_ascend_dtype_num ) -need_profile = False graph_id = 0 @@ -67,8 +66,6 @@ def __init__(self, graph, aten_graph=None, folder=None, graph_key=None): self.folder = folder self.graph_key = graph_key - # aten_graph.print_readable() - # graph.print_readable() self.sym_to_inputs = {} self.sym_in_args = {} @@ -551,7 +548,6 @@ def __init__(self, op_name: str, op_type: str): self.op_type = op_type self.inputs = [] self.outputs = [] - self.optional_inputs = [] self.attrs = [] self.dynamic_inputs = [] self.dynamic_outputs = [] @@ -565,8 +561,6 @@ def to_node(self): node["inputs"] = self.inputs if len(self.outputs) > 0: node["outputs"] = self.outputs - if len(self.optional_inputs) > 0: - node["optional_inputs"] = self.optional_inputs if len(self.attrs) > 0: node["attrs"] = self.attrs if len(self.dynamic_inputs) > 0: @@ -581,12 +575,6 @@ def set_input(self, name, value): "value": value, }) - def set_optional_input(self, name, value): - self.optional_inputs.append({ - "name": name, - "value": value, - }) - def set_output_desc(self, name, shape, format, data_type): self.outputs.append({ "output_name": name, @@ -1092,7 +1080,7 @@ def CastToCpu(name, x, ascend_dtype, device=None): def Const(name, x, dtype, dims=None, format="ND"): if not isinstance(x, list): x = [x] - # assert len(x) > 0 + assert len(x) > 0 ascend_dtype = get_ascend_dtype(dtype) cpp_dtype = get_cpp_dtype(dtype) const_op = OP(name, "Const") @@ -1421,7 +1409,6 @@ def Pack(name, x, axis): x_name = [] for elem in x: if elem is not None: - # x_name.append(elem.name) x_name.append(elem) op = OP(name, "Pack") @@ -1666,24 +1653,14 @@ def RmsNorm(name, x, weight, eps): @staticmethod def PromptFlashAttention(name, q, k, v, head_num, seqlen, mask, head_dim): - # import pdb; pdb.set_trace() op = OP(name, "PromptFlashAttention") op.set_input("query", q) op.set_input("key", k) op.set_input("value", v) op.set_input("atten_mask", mask) - # op.set_optional_input("atten_mask", mask) - # op.set_optional_input("padding_mask", mask) - # op.set_input("actual_seq_lengths", seqlen) - # op.set_optional_input("actual_seq_lengths_kv", seqlen) op.set_attr_int("num_heads", head_num) - # op.set_attr_float("scale_value", float(1 / 11.313708498984761)) op.set_attr_float("scale_value", float(1 / math.sqrt(head_dim))) - # op.set_attr_int("pre_tokens", 214748647) - # op.set_attr_int("next_tokens", 0) - # op.set_attr_int("num_key_value_heads", head_num) op.set_attr_str("input_layout", "BSH") - # op.set_attr_int("num_key_value_heads", 0) return op.to_node() diff --git a/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py b/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py index c4cd27883..270b7d9a8 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py @@ -180,22 +180,21 @@ def release_resource(self): def load_model(self): work_size, weight_size, ret = acl.mdl.query_size(self.model_path) check_ret("acl.mdl.query_size", ret) - # print('### query_size:', work_size) - # if work_size == 0: - # work_size = memory_pool.work_size - # elif work_size > memory_pool.work_size: - # free, _, ret = acl.rt.get_mem_info(ACL_HBM_MEM) - # check_ret("acl.rt.get_mem_info", ret) - # # If free < work_size, means that memory is insufficient. - # # Just ignore and continue, it may be work. - # if free > work_size: - # memory_pool.work_size = work_size - # memory_pool.release_memory() - # import pdb;pdb.set_trace() - # print("Adjust memory pool allocation.") - # memory_pool.work_ptr, ret = acl.rt.malloc(work_size, - # ACL_MEM_MALLOC_HUGE_FIRST) - # check_ret("acl.rt.malloc", ret) + if work_size == 0: + work_size = memory_pool.work_size + elif work_size > memory_pool.work_size: + free, _, ret = acl.rt.get_mem_info(ACL_HBM_MEM) + check_ret("acl.rt.get_mem_info", ret) + # If free < work_size, means that memory is insufficient. + # Just ignore and continue, it may be work. + if free > work_size: + memory_pool.work_size = work_size + memory_pool.release_memory() + import pdb;pdb.set_trace() + print("Adjust memory pool allocation.") + memory_pool.work_ptr, ret = acl.rt.malloc(work_size, + ACL_MEM_MALLOC_HUGE_FIRST) + check_ret("acl.rt.malloc", ret) self.weight_ptr, ret = acl.rt.malloc(weight_size, ACL_MEM_MALLOC_HUGE_FIRST) @@ -230,7 +229,7 @@ def load_model(self): self.model_id, ret = acl.mdl.load_with_config(config_handle) check_ret("acl.mdl.load_with_config", ret) - # print("model_id:{}".format(self.model_id)) + print("model_id:{}".format(self.model_id)) self.model_desc = acl.mdl.create_desc() ret = acl.mdl.get_desc(self.model_desc, self.model_id) @@ -338,7 +337,7 @@ def _prepare_dynamic_output(self, output_tensor, output_shape, out_stride, out_s self.output_data_buffers[i], item.data_ptr(), self.output_size[i]) check_ret("acl.update_data_buffer", ret) - # @record_function(f'load_and_run_run') + @record_function(f'load_and_run_run') def run(self, images, dims=None, output_shape=None, out_stride=None, out_storage_offset=None, allocated_output=None, allocated_with_offset_output=None): diff --git a/dicp/dicp/vendor/AscendGraph/compile_job.py b/dicp/dicp/vendor/AscendGraph/compile_job.py index 4eb1379fd..8c35fef4e 100644 --- a/dicp/dicp/vendor/AscendGraph/compile_job.py +++ b/dicp/dicp/vendor/AscendGraph/compile_job.py @@ -29,6 +29,7 @@ def __init__(self, source_code) -> None: 'local_rank' + str(self._local_rank) + code_hash(compile_file_code) ) self._output_graph_path = self._input_path[:-5] + '/graph' + print('output_path: ', self._output_graph_path) self._model_path = [f'{self._output_graph_path}.om', f'{self._output_graph_path}_linux_x86_64.om'] self._lib_path = "/tmp/dicp_ascend/graph_compile" @@ -65,7 +66,6 @@ def _compile(self): os.system("mkdir -p /tmp/dicp_ascend") start = time.time() try: - print(' '.join(self._cmd)) subprocess.check_output(self._cmd, stderr=subprocess.STDOUT) except subprocess.CalledProcessError as e: raise exc.CppCompileError(self._cmd, e.output) from e @@ -78,9 +78,7 @@ def build_graph(self, output_path, graph_path): self._compile() cmd = [self._lib_path, output_path, graph_path, self.fusion_switch_file] try: - print(' '.join(cmd)) - out = subprocess.check_output(cmd, stderr=subprocess.STDOUT) - print(out.decode('utf-8')) + subprocess.check_output(cmd, stderr=subprocess.STDOUT) except subprocess.CalledProcessError as e: raise exc.CppCompileError(cmd, e.output) from e diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index fd820779f..aa93127dc 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -152,8 +152,6 @@ def get_const_proxy(self, param, dtype, format=None, target_shape=None): shape = target_shape param = param if isinstance(param, list) else [param] if is_dicp_cpp_support_dtype(dtype): - if isinstance(param, list) and len(param) == 1: - param = param[0] param = self.get_proxy( ascend_op.Const, (param, dtype, shape, format)) else: @@ -505,21 +503,6 @@ def lt(self, x, y): y = self.get_proxy(ascend_op.BroadcastTo, (y, out_shape)) return self.get_proxy(ascend_op.Less, (x, y)) - - # y_shape = [1] - # if isinstance(y, torch.fx.proxy.Proxy): - # y_shape = list(y.node.meta['val'].shape) - # x_shape = list(x.node.meta['val'].shape) - # out = list(fx_traceback.get_current_meta()['val'].shape) - # out_shape = self.get_shape_proxy(out) - # x, y = self.binary_cmp_cast_input(x, y) - - # # if self.shape_prod(x_shape) < self.shape_prod(out): - # # x = self.get_proxy(ascend_op.BroadcastTo, (x, out_shape)) - # # if self.shape_prod(y_shape) < self.shape_prod(out): - # # y = self.get_proxy(ascend_op.BroadcastTo, (y, out_shape)) - # return self.get_proxy(ascend_op.Less, (x, y)) - @register_conversion(aten.masked_fill.Scalar) def masked_fill(self, x, mask, value): if str(value) == "-inf": @@ -1606,7 +1589,7 @@ def copy_with_offset2(self, x, src, start_dim, end_dim): return self.get_proxy(ascend_op.ScatterNdUpdate, (x, dims, src)) @register_conversion(torch.ops.lightllm.flash_attention_inference.default) - def flash_attention_inference2(self, q, all_k, all_v, current_len, max_len): + def flash_attention_inference(self, q, all_k, all_v, current_len, max_len): q_shape = list(q.node.meta['val'].shape) batch, head, dim = q_shape[0], q_shape[1], q_shape[2] diff --git a/dicp/dicp/vendor/AscendGraph/ext_ops.py b/dicp/dicp/vendor/AscendGraph/ext_ops.py index 9a16a1539..b4e592f63 100644 --- a/dicp/dicp/vendor/AscendGraph/ext_ops.py +++ b/dicp/dicp/vendor/AscendGraph/ext_ops.py @@ -54,11 +54,10 @@ def lightllm_prompt_attention_inference_abstract(q: Tensor, k: Tensor, v: Tensor @prompt_attention_inference.impl(['cpu', 'cuda']) def lightllm_prompt_attention_inference_impl(q, k, v, seqlen, num_head, head_dim): - # prompt attention just support bs=1 for now. - assert q.shape[0] == 1 + assert q.shape[0] == 1, "prompt attention just support bs=1 for now." bs = q.shape[0] seqlen = seqlen.item() - + xq = q.view(bs, seqlen, num_head, head_dim) xk = k.view(bs, seqlen, num_head, head_dim) xv = v.view(bs, seqlen, num_head, head_dim) diff --git a/dicp/dicp/vendor/AscendGraph/torch_ext.py b/dicp/dicp/vendor/AscendGraph/torch_ext.py deleted file mode 100644 index bc38f8577..000000000 --- a/dicp/dicp/vendor/AscendGraph/torch_ext.py +++ /dev/null @@ -1,27 +0,0 @@ -import torch -import torch._dynamo as dynamo - -from torch import Tensor - -from dicp.dynamo_bridge.decompositions import register_decomposition_for_dicp, get_decompositions - -# for lightllm rotary_emb -@torch._custom_op.impl.custom_op('ascend::lightllm_rotary_emb') -def lightllm_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: - ... - -@lightllm_rotary_emb.impl_abstract() -def lightllm_rotary_emb_abstract(x, cos, sin): - return torch.empty_like(x) - -@lightllm_rotary_emb.impl(['cpu', 'cuda']) -def lightllm_rotary_emb_impl(x, cos, sin): - seq_len, h, dim = x.shape - x0 = x[:, :, 0: dim // 2] - x1 = x[:, :, dim // 2: dim] - cos = cos.view((seq_len, 1, dim // 2)) - sin = sin.view((seq_len, 1, dim // 2)) - o0 = x0 * cos - x1 * sin - o1 = x0 * sin + x1 * cos - return torch.cat((o0, o1), dim=-1) - From 2459fceb2c28c06074027654c1ba4a788d5e4d5a Mon Sep 17 00:00:00 2001 From: Reinerzhou <1768552509@qq.com> Date: Mon, 22 Apr 2024 03:07:35 +0000 Subject: [PATCH 05/17] format code for ci check. --- dicp/dicp/vendor/AscendGraph/ascend_op.py | 16 ++--- .../dicp/vendor/AscendGraph/codegen/ascend.py | 41 ++++--------- .../vendor/AscendGraph/codegen/graph_utils.h | 10 ++-- .../AscendGraph/codegen/load_and_run.py | 9 ++- dicp/dicp/vendor/AscendGraph/codegen/utils.py | 3 +- dicp/dicp/vendor/AscendGraph/compile_job.py | 2 +- dicp/dicp/vendor/AscendGraph/config.py | 5 +- dicp/dicp/vendor/AscendGraph/conversion.py | 60 +++++++------------ dicp/dicp/vendor/AscendGraph/ext_ops.py | 26 +++++--- dicp/dicp/vendor/AscendGraph/opset_convert.py | 13 ++-- .../vendor/AscendGraph/pattern_replacement.py | 11 ++-- 11 files changed, 89 insertions(+), 107 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/ascend_op.py b/dicp/dicp/vendor/AscendGraph/ascend_op.py index e234f80ac..cadc8798d 100644 --- a/dicp/dicp/vendor/AscendGraph/ascend_op.py +++ b/dicp/dicp/vendor/AscendGraph/ascend_op.py @@ -14,6 +14,7 @@ aten = torch.ops.aten + def negative_in_shape(shape): for elem in shape: if elem < 0: @@ -43,12 +44,12 @@ def __init__(self): def infer_result(self, x, shape): x, x_shape, _, x_dtype = get_fake_tensor_meta_val(x) - if isinstance(shape, torch._subclasses.fake_tensor.FakeTensor): # case1: shape is a fakeTensor, like conversion for 'scatter' and 'where' + if isinstance(shape, torch._subclasses.fake_tensor.FakeTensor): # case1: shape is a fakeTensor, like conversion for 'scatter' and 'where' shape, shape_shape, _, _ = get_fake_tensor_meta_val(shape) shape = shape_shape - elif isinstance(shape, Tuple): # case2: shape is tuple from 'Const' , like conversion for 'lt' - shape, _, _, _ =get_op_const_arg_kwarg(shape) - else: # other cases, unsupported yet + elif isinstance(shape, Tuple): # case2: shape is tuple from 'Const' , like conversion for 'lt' + shape, _, _, _ = get_op_const_arg_kwarg(shape) + else: # other cases, unsupported yet assert False, self.__class__.__name__ + "unsupported 'shape' input type!" out_shape = get_broadcast_res_two_shape(x_shape, shape) @@ -97,7 +98,7 @@ def __init__(self): class MatMul(Operator): def __init__(self): super().__init__("MatMul") - + def infer_result(self, x1, x2, adj_x1=False, adj_x2=False): attr = acl.op.create_attr() check_ret("acl.op.set_attr_bool", acl.op.set_attr_bool(attr, "transpose_x1", adj_x1)) @@ -636,7 +637,7 @@ def infer_result(self, x, index, orig_index): # assume not none index, and replace prefix x_shape dims len_idx_shape = len(orig_index) - assert(len_idx_shape > 0) + assert (len_idx_shape > 0) bcast_index_shape = list(orig_index[0].shape) x_shape = bcast_index_shape + list(x_shape[len_idx_shape:]) return torch.empty(x_shape, dtype=x_dtype, memory_format=get_memory_format(x)) @@ -1039,7 +1040,7 @@ def infer_result( output_batch_var = torch.empty( [channel_size], dtype=torch.float32, memory_format=torch.contiguous_format ) - return [output_y,output_mean,output_var,output_batch_mean,output_batch_var] + return [output_y, output_mean, output_var, output_batch_mean, output_batch_var] class TileWithAxis(Operator): @@ -1137,6 +1138,7 @@ def __init__(self): def infer_result(self, x, mask, updates): return x + class ViewCopy(Operator): def __init__(self): super().__init__("ViewCopy") diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index e78327bf7..80fcac26d 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -1,7 +1,6 @@ import json import os import math -import uuid import torch from typing import Any, List from torch.fx.node import Node @@ -66,7 +65,6 @@ def __init__(self, graph, aten_graph=None, folder=None, graph_key=None): self.folder = folder self.graph_key = graph_key - self.sym_to_inputs = {} self.sym_in_args = {} @@ -96,7 +94,7 @@ def placeholder(self, name, target, args, kwargs): for idx, dim in enumerate(fake_tensor.shape): if isinstance(dim, torch.SymInt): st = dim.node.str() - if not st in self.sym_in_args: + if st not in self.sym_in_args: self.sym_in_args[st] = (name, idx) # deal with dynamic shape -1 @@ -289,8 +287,8 @@ def gen_call_func(self): # dynamic shape feature if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: - args = ['_' if not arg in shape_symint and not arg in self.sym_to_inputs.values() else arg for arg in self.args] - call_body.writeline(f"({','.join(args)}) = args") + args = ['_' if arg not in shape_symint and arg not in self.sym_to_inputs.values() else arg for arg in self.args] + call_body.writeline("({','.join(args)}) = args") # generate input dims if len(self.dynamic_inputs) > 0: @@ -307,14 +305,14 @@ def gen_call_func(self): dims = dims[:-1] + '}' call_body.writeline(dims) else: - call_body.writeline(f'''dims = None''') + call_body.writeline('''dims = None''') # generate output shapes # dynamic shape feature extra_stride_str = '' extra_storage_offset_str = '' if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: - shape_str = f'''output_shape = [''' + shape_str = '''output_shape = [''' for elem in self.output_args: if hasattr(elem, 'meta'): elem = elem.meta['val'] @@ -340,11 +338,11 @@ def gen_call_func(self): stride = [self.process_sym_name(str(dim)) for dim in stride] extra_stride_str += '[' + ','.join(map(str, stride)) + '],' extra_storage_offset_str += str(self.input_args[elem[1]].meta['val'].storage_offset()) + ',' - shape_str = shape_str[:-1] + f''']''' + shape_str = shape_str[:-1] + ''']''' call_body.writeline(shape_str) else: call_body.writeline('''output_shape = None''') - + # add stride & storage_offset info out_strides = [] out_storage_offsets = [] @@ -355,7 +353,7 @@ def gen_call_func(self): out_strides.append('[1]') out_storage_offsets.append('0') continue - if elem.dim()==0: # temporary solution for sum.default(a) whose result is a scalar(no dim no stride) + if elem.dim() == 0: # temporary solution for sum.default(a) whose result is a scalar(no dim no stride) out_strides.append('[1]') out_storage_offsets.append('0') continue @@ -382,7 +380,7 @@ def gen_call_func(self): output_index = self.graph_output_names.index(item[0]) allocated_output[output_index] = input_index call_body.writeline(f'allocated_output= {allocated_output}') - + allocated_output_with_offset = {} for item in self.assign_with_offset_args: # input_index = item['input_index'] @@ -394,11 +392,11 @@ def gen_call_func(self): if precision_check and self.aten_graph is not None: # import aten graph - call_str.append(f"import sys") + call_str.append("import sys") call_str.append(f"if '{self.folder}' not in sys.path:") call_str.append(f" sys.path.insert(0, '{self.folder}')") call_str.append(f"from {self.graph_key[:4]} import {self.graph_key} as graph_module") - call_str.append(f"aten_call = graph_module()") + call_str.append("aten_call = graph_module()") call_str.append('aten_args = list(map(lambda x: x.to("cpu"), args))') call_str.append('for idx in modified:') @@ -1555,7 +1553,7 @@ def DropOutDoMaskV3(name, x, mask, keep_prob): op.set_input("mask", mask) op.set_input("keep_prob", keep_prob) return op.to_node() - + @staticmethod def GatherElements(name, x, index, dim): op = OP(name, "GatherElements") @@ -1563,20 +1561,6 @@ def GatherElements(name, x, index, dim): op.set_input("index", index) op.set_attr_int("dim", dim) return op.to_node() - - @staticmethod - def AdaptiveAvgPool2D(name, x, output_size): - op = OP(name, "AdaptiveAvgPool2d") - op.set_input("x", x) - op.set_attr_list_int("output_size", output_size) - return op.to_node() - - @staticmethod - def AdaptiveAvgPool2DGrad(name, input_grad, orig_input_shape): - op = OP(name, "AdaptiveAvgPool2dGrad") - op.set_input("input_grad", input_grad) - op.set_attr_list_int("orig_input_shape", orig_input_shape) - return op.to_node() @staticmethod def AdaptiveAvgPool2D(name, x, output_size): @@ -1663,7 +1647,6 @@ def PromptFlashAttention(name, q, k, v, head_num, seqlen, mask, head_dim): op.set_attr_str("input_layout", "BSH") return op.to_node() - @staticmethod def IncreFlashAttention(name, q, k_list, v_list, kv_input_num, head_num, kv_head_num, dim, input_layout="BSH"): op = OP(name, "IncreFlashAttention") diff --git a/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h b/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h index 6dbedafb6..1b9e1bc8e 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h +++ b/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h @@ -189,8 +189,9 @@ void parseDynamicInput(std::unordered_map& op_map, } } -void parseIncreFlashAttentionDynamicInput(std::unordered_map& op_map, - op::IncreFlashAttention& op, const json& node) { +void parseIncreFlashAttentionDynamicInput( + std::unordered_map& op_map, + op::IncreFlashAttention& op, const json& node) { if (node.contains("dynamic_inputs")) { int kv_inputs_num = 0; for (const auto& i : node["dynamic_inputs"]) { @@ -206,8 +207,9 @@ void parseIncreFlashAttentionDynamicInput(std::unordered_map work_size: memory_pool.work_size = work_size memory_pool.release_memory() - import pdb;pdb.set_trace() print("Adjust memory pool allocation.") memory_pool.work_ptr, ret = acl.rt.malloc(work_size, - ACL_MEM_MALLOC_HUGE_FIRST) + ACL_MEM_MALLOC_HUGE_FIRST) check_ret("acl.rt.malloc", ret) self.weight_ptr, ret = acl.rt.malloc(weight_size, @@ -265,7 +265,6 @@ def init_resource(self): _, ret = acl.mdl.add_dataset_buffer(self.output_dataset, data_buf) check_ret("acl.add_dataset_buffer", ret) - @record_function('load_and_run_prepare_input') def _prepare_input(self, images, dims): assert self.num_inputs == len(images) @@ -337,14 +336,14 @@ def _prepare_dynamic_output(self, output_tensor, output_shape, out_stride, out_s self.output_data_buffers[i], item.data_ptr(), self.output_size[i]) check_ret("acl.update_data_buffer", ret) - @record_function(f'load_and_run_run') + @record_function('load_and_run_run') def run(self, images, dims=None, output_shape=None, out_stride=None, out_storage_offset=None, allocated_output=None, allocated_with_offset_output=None): with record_function(f'load_and_run_run_{self.model_id}'): assert len(images) > 0 input = [x.to(dipu_device_str) if isinstance(x, torch.Tensor) - and x.device.type != dipu_device_str else x for x in images] + and x.device.type != dipu_device_str else x for x in images] allocated_output_tensor = None if allocated_output: allocated_output_tensor = {} diff --git a/dicp/dicp/vendor/AscendGraph/codegen/utils.py b/dicp/dicp/vendor/AscendGraph/codegen/utils.py index a871fd37b..404d827f9 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/utils.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/utils.py @@ -56,7 +56,7 @@ def get_acl_format(x) -> int: return AclFormat.ACL_FORMAT_ND.value -def get_acl_dtype(dtype: torch.dtype) ->int: +def get_acl_dtype(dtype: torch.dtype) -> int: if dtype == torch.bool: return AclDataType.ACL_BOOL.value elif dtype == torch.int64: @@ -166,4 +166,3 @@ def get_cpp_dtype(dtype: torch.dtype) -> str: return "BOOL" else: raise RuntimeError(f"unknow torch data type ({dtype}) in get_cpp_dtype!") - diff --git a/dicp/dicp/vendor/AscendGraph/compile_job.py b/dicp/dicp/vendor/AscendGraph/compile_job.py index 8c35fef4e..f9720bc31 100644 --- a/dicp/dicp/vendor/AscendGraph/compile_job.py +++ b/dicp/dicp/vendor/AscendGraph/compile_job.py @@ -26,7 +26,7 @@ def __init__(self, source_code) -> None: source_code.strip(), "json", extra=cpp_compile_command("i", "o", vec_isa=picked_vec_isa) + - 'local_rank' + str(self._local_rank) + code_hash(compile_file_code) + 'local_rank' + str(self._local_rank) + code_hash(compile_file_code) ) self._output_graph_path = self._input_path[:-5] + '/graph' print('output_path: ', self._output_graph_path) diff --git a/dicp/dicp/vendor/AscendGraph/config.py b/dicp/dicp/vendor/AscendGraph/config.py index 528fb4cab..f89495ba4 100644 --- a/dicp/dicp/vendor/AscendGraph/config.py +++ b/dicp/dicp/vendor/AscendGraph/config.py @@ -1,8 +1,7 @@ -import math - import torch -from dicp.dynamo_bridge.decompositions import register_decomposition_for_dicp, get_decompositions +from dicp.dynamo_bridge.decompositions import get_decompositions + def get_decomp(): aten = torch.ops.aten diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index aa93127dc..5d957da16 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -16,8 +16,7 @@ import dicp.vendor.AscendGraph.ascend_op as ascend_op from dicp.dynamo_bridge.utils import symint_in_shape from dicp.vendor.AscendGraph.codegen.utils import ( - get_ascend_dtype, - get_cpp_dtype + get_ascend_dtype ) from dicp.dynamo_bridge.conversion import register_conversion_impl from dicp.dynamo_bridge.op_transformer import SingleOpTransformer @@ -134,7 +133,7 @@ def generate_sym_int(elem): # concat all ops return self.get_proxy(ascend_op.ConcatD, (x_names, 0)) - def get_shape_proxy(self, shape, dtype= torch.int32): + def get_shape_proxy(self, shape, dtype=torch.int32): if isinstance(shape, torch.fx.proxy.Proxy) or isinstance(shape, FakeTensor): return shape elif isinstance(shape, list) and symint_in_shape(shape): @@ -318,7 +317,7 @@ def div(self, x, y): @register_conversion(aten.split.Tensor) def split(self, x, split_size, dim=0): - splitD_kw = { "from_view_complex": False } + splitD_kw = {"from_view_complex": False} shape = list(x.node.meta['val'].shape) if dim < 0: dim += len(shape) @@ -340,7 +339,7 @@ def slice(self, x, dim=0, start=None, end=None, step=1): offset = [0] * len(x_shape) offset[dim] = start # import pdb; pdb.set_trace() - + offset = self.get_shape_proxy(offset) size = self.get_shape_proxy(y_shape) return self.get_proxy(ascend_op.Slice, (x, offset, size)) @@ -373,11 +372,11 @@ def empty_like(self, x, dtype=None, layout=None, dtype = x.node.meta['val'].dtype if layout is not None and (layout != torch.strided): raise NotImplementedError("torch.ops.aten.empty_like.default is " - "only supported on dense tensor now.") + "only supported on dense tensor now.") if memory_format is not None and memory_format != torch.contiguous_format \ and memory_format != torch.preserve_format: raise NotImplementedError("torch.ops.aten.empty_like.default is only supported " - "contiguous_format and preserve_format now.") + "contiguous_format and preserve_format now.") shape = self.get_proxy(ascend_op.Shape, (x,)) return self.get_proxy(ascend_op.Empty, (shape, dtype)) @@ -441,7 +440,7 @@ def view(self, x, size): return self.get_proxy(ascend_op.IdentityN, (real_reshape, imag_reshape)) else: return self.get_proxy(ascend_op.Reshape, (x, shape)) - + @register_conversion(torch.ops.aten.where) def where(self, condition, x1, x2): # TODO(tangzhiyi): need to process scalars @@ -459,9 +458,9 @@ def arange(self, end, start=0, step=1, dtype=None, device='xpu', layout=None, pi assert isinstance(end, torch.fx.proxy.Proxy) or type(end) in [int, float] assert isinstance(step, torch.fx.proxy.Proxy) or type(step) in [int, float] - if not isinstance(start, torch.fx.proxy.Proxy): # scalar const + if not isinstance(start, torch.fx.proxy.Proxy): # scalar const start = self.get_const_proxy(start, out_dtype) - elif start.node.meta['val'] != out_dtype: # align tensor dtype + elif start.node.meta['val'] != out_dtype: # align tensor dtype start = self.get_proxy(ascend_op.Cast, (start, get_ascend_dtype(out_dtype)), {}) if not isinstance(end, torch.fx.proxy.Proxy): end = self.get_const_proxy(end, out_dtype) @@ -583,7 +582,7 @@ def view_as_complex(self, x): assert x_val.dtype == torch.float32 assert x_shape[-1] == 2 dim = len(x_shape) - 1 - splitD_kw = { "from_view_complex": True } + splitD_kw = {"from_view_complex": True} return self.get_proxy(ascend_op.SplitD, (x, dim, 2, 2), splitD_kw) @register_conversion(torch.ops.aten.full.default) @@ -867,7 +866,7 @@ def index_put_default(self, x, indices, values): stacked_indices, indices_broadcast_shape, stacked_indices_last_dim = \ self.compute_stacked_indices(indices, x.node.meta['val'].shape) - values_broadcast_shape = indices_broadcast_shape + x_shape[stacked_indices_last_dim:] # batch_shape + inner_shape + values_broadcast_shape = indices_broadcast_shape + x_shape[stacked_indices_last_dim:] # batch_shape + inner_shape values_broadcast_shape_op = self.get_const_proxy(values_broadcast_shape, torch.int32) broadcasted_values = self.get_proxy(ascend_op.BroadcastTo, (values, values_broadcast_shape_op)) return self.get_proxy(ascend_op.TensorScatterUpdate, (x, stacked_indices, broadcasted_values)) @@ -1227,7 +1226,7 @@ def repeat_interleave(self, x, repeats, dim): transpose_perm_proxy = self.get_shape_proxy(transpose_perm) transpose_proxy = self.get_proxy(ascend_op.Transpose, (reshape_proxy, transpose_perm_proxy)) - result_reshape = x_shape[:dim] + [x_shape[dim] * repeats] + x_shape[dim+1:] + result_reshape = x_shape[:dim] + [x_shape[dim] * repeats] + x_shape[dim + 1:] result_reshape_shape_proxy = self.get_const_proxy(result_reshape, torch.int32) return self.get_proxy(ascend_op.Reshape, (transpose_proxy, result_reshape_shape_proxy)) @@ -1418,34 +1417,17 @@ def NativeDropoutBackward(self, grad_output, mask, scale): p = 1. - scale prob_op = self.get_const_proxy(float(p), dtype) return self.get_proxy(ascend_op.DropOutDoMaskV3, (grad_output, mask, prob_op)) - - @register_conversion([torch.ops.aten._adaptive_avg_pool2d.default]) - def adaptiveavgpool2d(self, x, output_size): - assert isinstance(output_size, int) or ( len(output_size) in range(1,3) and any(output_size) ) - if not isinstance(output_size, list): - if isinstance(output_size, tuple): - output_size = list(output_size) - elif isinstance(output_size, int): - output_size = [output_size, output_size] - else: - raise RuntimeError("not supported output type!") - return self.get_proxy(ascend_op.AdaptiveAvgPool2D, (x, output_size)) - - @register_conversion([torch.ops.aten._adaptive_avg_pool2d_backward.default]) - def adaptiveavgpool2dBackward(self, grad, input): - input_shape = list(input.node.meta['val'].shape) - return self.get_proxy(ascend_op.AdaptiveAvgPool2DGrad, (grad, input_shape)) @register_conversion([torch.ops.aten._adaptive_avg_pool2d.default]) def adaptiveavgpool2d(self, x, output_size): - assert isinstance(output_size, int) or ( len(output_size) in range(1,3) and any(output_size) ) + assert isinstance(output_size, int) or (len(output_size) in range(1, 3) and any(output_size)) if not isinstance(output_size, list): if isinstance(output_size, tuple): output_size = list(output_size) elif isinstance(output_size, int): output_size = [output_size, output_size] else: - raise RuntimeError("not supported output size!") + raise RuntimeError("not supported output type!") return self.get_proxy(ascend_op.AdaptiveAvgPool2D, (x, output_size)) @register_conversion([torch.ops.aten._adaptive_avg_pool2d_backward.default]) @@ -1521,10 +1503,10 @@ def scalar_tensor(self, x, dtype=None, layout=None, device=None, pin_memory=None def lightllm_rotary_emb(self, x, cos, sin): x_shape = list(x.node.meta['val'].shape) assert len(x_shape) == 3 - + seq_len = x_shape[0] dim = x_shape[2] - + cos_sin_shape = self.get_const_proxy([seq_len, 1, dim // 2], torch.int32) cos = self.get_proxy(ascend_op.Reshape, (cos, cos_sin_shape)) sin = self.get_proxy(ascend_op.Reshape, (sin, cos_sin_shape)) @@ -1579,7 +1561,7 @@ def select_scatter(self, x, src, dim, index): dim_op = self.get_const_proxy(dim, torch.int32) src = self.get_proxy(ascend_op.ExpandDims, (src, dim_op)) src = self.get_proxy(ascend_op.BroadcastTo, (src, input_sizes)) - + return self.get_proxy(ascend_op.ScatterElements, (x, index, src, dim)) @register_conversion(torch.ops.lightllm.copy_with_offset.default) @@ -1592,7 +1574,7 @@ def copy_with_offset2(self, x, src, start_dim, end_dim): def flash_attention_inference(self, q, all_k, all_v, current_len, max_len): q_shape = list(q.node.meta['val'].shape) batch, head, dim = q_shape[0], q_shape[1], q_shape[2] - + res = [] compute_batch = 1 select_axis = self.get_const_proxy(0, torch.int32) @@ -1608,7 +1590,7 @@ def flash_attention_inference(self, q, all_k, all_v, current_len, max_len): kv_gather_shape = self.get_shape_proxy([compute_batch, kv_seq_len, head, dim]) kv_compute_shape = self.get_shape_proxy([compute_batch, kv_seq_len, head * dim]) - + # fetch k k = self.get_proxy(ascend_op.Slice, (all_k, kv_start_index, kv_end_index)) k = self.get_proxy(ascend_op.Reshape, (k, kv_gather_shape)) @@ -1624,13 +1606,13 @@ def flash_attention_inference(self, q, all_k, all_v, current_len, max_len): q_compute_shape = self.get_shape_proxy([compute_batch, 1, head * dim]) xq = self.get_proxy(ascend_op.Reshape, (xq, q_shape)) xq = self.get_proxy(ascend_op.Reshape, (xq, q_compute_shape)) - + out = self.incre_flash_attention(xq, k, v, head, head, dim) # q shape is BSH out_shape = self.get_shape_proxy([compute_batch, 1, head, dim]) out_shape2 = self.get_shape_proxy([compute_batch, head, dim]) out = self.get_proxy(ascend_op.Reshape, (out, out_shape)) out = self.get_proxy(ascend_op.Reshape, (out, out_shape2)) res.append(out) - + res = self.get_proxy(ascend_op.ConcatD, (res, 0)) return res diff --git a/dicp/dicp/vendor/AscendGraph/ext_ops.py b/dicp/dicp/vendor/AscendGraph/ext_ops.py index b4e592f63..324d2a9b2 100644 --- a/dicp/dicp/vendor/AscendGraph/ext_ops.py +++ b/dicp/dicp/vendor/AscendGraph/ext_ops.py @@ -2,7 +2,6 @@ import torch import torch_dipu -import torch._dynamo as dynamo import torch.nn.functional as F from torch import Tensor @@ -10,15 +9,18 @@ torch._dynamo.config.suppress_errors = False + # rotary_emb @torch._custom_op.impl.custom_op('lightllm::rotary_emb') def rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: ... + @rotary_emb.impl_abstract() def lightllm_rotary_emb_abstract(x, cos, sin): return torch.empty_like(x) + @rotary_emb.impl(['cpu', 'cuda']) def lightllm_rotary_emb_impl(x, cos, sin): seq_len, h, dim = x.shape @@ -30,15 +32,18 @@ def lightllm_rotary_emb_impl(x, cos, sin): o1 = x0 * sin + x1 * cos return torch.cat((o0, o1), dim=-1) + # rms_norm @torch._custom_op.impl.custom_op('lightllm::rms_norm') def rms_norm(x: Tensor, weight: Tensor, eps: float) -> Tensor: ... + @rms_norm.impl_abstract() def lightllm_rms_norm_abstract(x, weight, eps): return torch.empty_like(x) + @rms_norm.impl(['cpu', 'cuda']) def lightllm_rms_norm_impl(x, weight, eps): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) * weight @@ -48,16 +53,18 @@ def lightllm_rms_norm_impl(x, weight, eps): def prompt_attention_inference(q: Tensor, k: Tensor, v: Tensor, seqlen: Tensor, num_head: int, head_dim: int) -> Tensor: ... + @prompt_attention_inference.impl_abstract() def lightllm_prompt_attention_inference_abstract(q: Tensor, k: Tensor, v: Tensor, seqlen: Tensor, num_head: int, head_dim: int): return torch.empty_like(q) + @prompt_attention_inference.impl(['cpu', 'cuda']) def lightllm_prompt_attention_inference_impl(q, k, v, seqlen, num_head, head_dim): assert q.shape[0] == 1, "prompt attention just support bs=1 for now." bs = q.shape[0] seqlen = seqlen.item() - + xq = q.view(bs, seqlen, num_head, head_dim) xk = k.view(bs, seqlen, num_head, head_dim) xv = v.view(bs, seqlen, num_head, head_dim) @@ -79,14 +86,17 @@ def lightllm_prompt_attention_inference_impl(q, k, v, seqlen, num_head, head_dim return output + @torch._custom_op.impl.custom_op('lightllm::flash_attention_inference') def flash_attention_inference(q: Tensor, all_k: Tensor, all_v: Tensor, currnet_lens: Sequence[int], max_len: int) -> Tensor: ... + @flash_attention_inference.impl_abstract() def lightllm_flash_attention_inference_abstract(q: Tensor, all_k: Tensor, all_v: Tensor, currnet_lens: Sequence[int], max_len: int): return torch.empty_like(q) + @flash_attention_inference.impl(['cpu', 'cuda']) def lightllm_flash_attention_inference_impl(q, all_k, all_v, current_lens, max_len): # q: batch, head, dim @@ -98,37 +108,39 @@ def lightllm_flash_attention_inference_impl(q, all_k, all_v, current_lens, max_l for i in range(batch): current_len = current_lens[i] kv_seq_len = current_len - + k = all_k[:current_len].reshape(compute_batch, kv_seq_len, head, dim) v = all_v[:current_len].reshape(compute_batch, kv_seq_len, head, dim) xq = q[i].view(compute_batch, 1, head, dim).transpose(1, 2).transpose(0, 1) # shape: head, batch, 1, dim bmm_xq = xq.reshape(head * compute_batch, 1, dim) bmm_xk = k.transpose(1, 2).transpose(0, 1).transpose(2, 3).reshape(head * compute_batch, dim, kv_seq_len) - # q @ k out = torch.bmm(bmm_xq, bmm_xk) / math.sqrt(dim) out = out.reshape(head, compute_batch, 1, -1).reshape(head, compute_batch, -1) # softmax - out = out.softmax(-1).reshape(head, compute_batch, 1, kv_seq_len).transpose(0, 1) # shape: batch head 1 seq_len - xv = v.transpose(1, 2) # shape: batch head, seq_len, dim + out = out.softmax(-1).reshape(head, compute_batch, 1, kv_seq_len).transpose(0, 1) # shape: batch head 1 seq_len + xv = v.transpose(1, 2) # shape: batch head, seq_len, dim out = torch.bmm(out.reshape(compute_batch * head, 1, kv_seq_len), xv.reshape(compute_batch * head, kv_seq_len, dim)) - + out = out.reshape(compute_batch, head, 1, dim).view(compute_batch, head, dim) res.append(out) res = torch.cat(res) return res + @torch._custom_op.impl.custom_op('lightllm::copy_with_offset') def copy_with_offset(x: Tensor, src: Tensor, start_dim: int, end_dim: int) -> Tensor: ... + @copy_with_offset.impl_abstract() def lightllm_copy_with_offset_abstract(x: Tensor, src: Tensor, start_dim: int, end_dim: int) -> Tensor: return x + @copy_with_offset.impl(['cpu', 'cuda']) def lightllm_copy_with_offset_impl(x, src, start_dim, end_dim) -> Tensor: x[start_dim:end_dim] = src diff --git a/dicp/dicp/vendor/AscendGraph/opset_convert.py b/dicp/dicp/vendor/AscendGraph/opset_convert.py index 1843b7ac9..040efb0e0 100644 --- a/dicp/dicp/vendor/AscendGraph/opset_convert.py +++ b/dicp/dicp/vendor/AscendGraph/opset_convert.py @@ -1,7 +1,7 @@ import torch import torch_dipu from dicp.dynamo_bridge.compile_fx import is_torch_210 -from dicp.vendor.AscendGraph.ascend_op import MatMul, CastToCpu, IdentityInp, InplaceCopyWithOffset +from dicp.vendor.AscendGraph.ascend_op import CastToCpu, IdentityInp, InplaceCopyWithOffset from dicp.vendor.AscendGraph.conversion import AtenToAscendTransformer from ...dynamo_bridge.graph import GraphTransformer @@ -44,13 +44,13 @@ def transform(self, gm: torch.fx.graph_module): for n in gm.graph.nodes: if n.op != 'call_function': continue - if type(n.target) == CastToCpu: + if isinstance(n.target, CastToCpu): self.cpu_tensor.append(n.name) - elif type(n.target) == InplaceCopyWithOffset: + elif isinstance(n.target, InplaceCopyWithOffset): input_index = input_names.index(str(n.args[0])) offset = int(n.args[-1]) self.assign_with_offset_args[n.name] = {'name': n.name, 'input_index': input_index, 'offset': offset} - elif type(n.target) == IdentityInp: + elif isinstance(n.target, IdentityInp): if len(n.args) == 2 and n.args[1] is not None and str(n.args[1]) in input_names: self.assign_args.append((n.name, input_names.index(str(n.args[1])))) else: @@ -60,10 +60,10 @@ def transform(self, gm: torch.fx.graph_module): if n.op == 'call_function': prop = {} if n.name in self.cpu_tensor: - prop.update({'cpu_tensor' : n.name}) + prop.update({'cpu_tensor': n.name}) if len(self.assign_args) > 0 and n.name in list(zip(*self.assign_args))[0]: idx = list(zip(*self.assign_args))[0].index(n.name) - prop.update({'assign_args' : (self.assign_args[idx][0], self.assign_args[idx][1])}) + prop.update({'assign_args': (self.assign_args[idx][0], self.assign_args[idx][1])}) if n.name in self.assign_with_offset_args.keys(): prop['assign_with_offset_args'] = self.assign_with_offset_args[n.name] n.meta['prop'] = prop @@ -84,6 +84,7 @@ def symint_in_inputs(nodes): return True return False + def ascendgraph_opset_convert( gm: torch.fx.GraphModule, ): diff --git a/dicp/dicp/vendor/AscendGraph/pattern_replacement.py b/dicp/dicp/vendor/AscendGraph/pattern_replacement.py index aebf79a12..ec89bfb6c 100644 --- a/dicp/dicp/vendor/AscendGraph/pattern_replacement.py +++ b/dicp/dicp/vendor/AscendGraph/pattern_replacement.py @@ -33,18 +33,18 @@ class FusedRepeatInterleaveSelfInt(BackendPatternBase): @staticmethod def pattern(self, repeat, dim, input_shape, empty_device, view_1_shape, expand_1_shape, repeat_interleave_output_size): - empty = torch.ops.aten.empty.memory_format(input_shape, dtype = torch.int64, layout = torch.strided, device=empty_device) + empty = torch.ops.aten.empty.memory_format(input_shape, dtype=torch.int64, layout=torch.strided, device=empty_device) fill = torch.ops.aten.fill.Scalar(empty, repeat) view_1 = torch.ops.aten.view.default(fill, view_1_shape) expand_1 = torch.ops.aten.expand.default(view_1, expand_1_shape) - repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(expand_1, output_size = repeat_interleave_output_size) + repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(expand_1, output_size=repeat_interleave_output_size) index_select = torch.ops.aten.index_select.default(self, dim, repeat_interleave) return index_select @staticmethod def replacement(self, repeat, dim): return torch.ops.aten.repeat_interleave.self_int(self, repeat, dim) - + @register_aten_pattern class ReplaceAtenSliceScatter(BackendPatternBase): @@ -61,6 +61,7 @@ def replacement(arg0, arg1, start_index, end_index): slice_scatter = torch.ops.lightllm.copy_with_offset.default(arg0, arg1, start_index, end_index) return slice_scatter + Muls = torch.fx.wrap(ascend_op.Muls.get_singleton()) Shape = torch.fx.wrap(ascend_op.Shape.get_singleton()) Const = torch.fx.wrap(ascend_op.Const.get_singleton()) @@ -82,6 +83,7 @@ def replacement(arg0, arg1, start_index, end_index): Div = torch.fx.wrap(ascend_op.Div.get_singleton()) RmsNorm = torch.fx.wrap(ascend_op.RmsNorm.get_singleton()) + # @register_ascend_pattern class FuseBmmTransposeRhsPattern(BackendPatternBase): @staticmethod @@ -120,10 +122,11 @@ def replacement(x1, x2, c1, c2): muls = Muls(reshape, 0.3535533905932738) return BatchMatMul(x1, muls, adj_x1=False, adj_x2=True, keep_dtype=0) + @register_ascend_pattern class FuseLightLLMRmsNorm(BackendPatternBase): @staticmethod - def pattern(arg0_1, arg1_1): + def pattern(arg0_1, arg1_1): const = Const([2], torch.float32) pow_1 = Pow(arg0_1, const) reduce_mean_d = ReduceMeanD(pow_1, [-1], True, False) From 8cfcdad7baf33491b8a237b798cda522b2297e33 Mon Sep 17 00:00:00 2001 From: Reinerzhou <1768552509@qq.com> Date: Mon, 22 Apr 2024 05:46:08 +0000 Subject: [PATCH 06/17] adjust code. --- dicp/dicp/dynamo_bridge/compile_fx.py | 3 ++- dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dicp/dicp/dynamo_bridge/compile_fx.py b/dicp/dicp/dynamo_bridge/compile_fx.py index 5eaee327c..42465d27d 100644 --- a/dicp/dicp/dynamo_bridge/compile_fx.py +++ b/dicp/dicp/dynamo_bridge/compile_fx.py @@ -208,7 +208,8 @@ def compile_fx_210( def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference): if is_inference: # partition_fn won't be called - joint_graph_passes(model) + # joint_graph_passes(model) + pass fixed = len(example_inputs) - num_example_inputs return inner_compile( diff --git a/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h b/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h index 1b9e1bc8e..1f6bc0519 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h +++ b/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h @@ -91,8 +91,7 @@ class AclgraphBuilder { {AscendString(ge::ir_option::SOC_VERSION), AscendString(kSocVersion)}, {AscendString(ge::ir_option::FUSION_SWITCH_FILE), AscendString(_fusion_switch_file.c_str())}, - // {AscendString(ge::ir_option::PRECISION_MODE_V2), "mixed_float16"}, - // {AscendString(ge::ir_option::PRECISION_MODE_V2), "fp16"}, + {AscendString(ge::ir_option::PRECISION_MODE_V2), "fp16"}, }; auto status = aclgrphBuildInitialize(global_options); if (status != GRAPH_SUCCESS) { From f3f7d8337d3a28dde70385830921c96e5f7dcd9e Mon Sep 17 00:00:00 2001 From: Reinerzhou <1768552509@qq.com> Date: Mon, 22 Apr 2024 08:11:08 +0000 Subject: [PATCH 07/17] adjust code. --- dicp/dicp/vendor/AscendGraph/codegen/fusion_switch.cfg | 6 ++---- dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h | 3 +-- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/codegen/fusion_switch.cfg b/dicp/dicp/vendor/AscendGraph/codegen/fusion_switch.cfg index 4a699c252..71834659c 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/fusion_switch.cfg +++ b/dicp/dicp/vendor/AscendGraph/codegen/fusion_switch.cfg @@ -1,12 +1,10 @@ { "Switch":{ "GraphFusion":{ - "IncreFlashAttentionQuantDeployPass": "on", - "RefreshInt64ToInt32FusionPass": "on", - "ALL":"off" + "ALL":"on" }, "UBFusion":{ - "ALL":"off" + "ALL":"on" } } } diff --git a/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h b/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h index 1b9e1bc8e..8cee35353 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h +++ b/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h @@ -91,8 +91,7 @@ class AclgraphBuilder { {AscendString(ge::ir_option::SOC_VERSION), AscendString(kSocVersion)}, {AscendString(ge::ir_option::FUSION_SWITCH_FILE), AscendString(_fusion_switch_file.c_str())}, - // {AscendString(ge::ir_option::PRECISION_MODE_V2), "mixed_float16"}, - // {AscendString(ge::ir_option::PRECISION_MODE_V2), "fp16"}, + {AscendString(ge::ir_option::PRECISION_MODE), "allow_fp32_to_fp16"}, }; auto status = aclgrphBuildInitialize(global_options); if (status != GRAPH_SUCCESS) { From c9d7822b6f5e5995ed4da2b1d94972bda8afcbf7 Mon Sep 17 00:00:00 2001 From: tangzhiyi Date: Mon, 22 Apr 2024 08:19:59 +0000 Subject: [PATCH 08/17] update code --- dicp/dicp/vendor/AscendGraph/ascend_op.py | 10 +------- .../dicp/vendor/AscendGraph/codegen/ascend.py | 23 +------------------ .../AscendGraph/codegen/load_and_run.py | 16 ++++--------- dicp/dicp/vendor/AscendGraph/conversion.py | 18 ++++++++------- dicp/dicp/vendor/AscendGraph/opset_convert.py | 8 +------ 5 files changed, 18 insertions(+), 57 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/ascend_op.py b/dicp/dicp/vendor/AscendGraph/ascend_op.py index cadc8798d..45a2010bc 100644 --- a/dicp/dicp/vendor/AscendGraph/ascend_op.py +++ b/dicp/dicp/vendor/AscendGraph/ascend_op.py @@ -1115,14 +1115,6 @@ def infer_result(self, x, index): return torch.empty(idx_shape, dtype=x_dtype, memory_format=get_memory_format(x)) -class InplaceCopyWithOffset(Operator): - def __init__(self): - super().__init__("InplaceCopyWithOffset") - - def infer_result(self, x, src, dim, offset): - return src - - class ExpandDims(Operator): def __init__(self): super().__init__("ExpandDims") @@ -1144,7 +1136,7 @@ def __init__(self): super().__init__("ViewCopy") def infer_result(self, dst, dst_size, dst_stride, dst_storage_offset, src, src_size, src_stride, src_storage_offset): - return x + return dst class ScatterNdUpdate(Operator): diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 80fcac26d..112446f93 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -71,7 +71,6 @@ def __init__(self, graph, aten_graph=None, folder=None, graph_key=None): # for modified args return self.assign_args = [] self.cpu_tensor = [] - self.assign_with_offset_args = [] super().__init__(graph) @@ -135,8 +134,6 @@ def call_function(self, name, target, args, kwargs): self.cpu_tensor.append(self.cur_node.meta['prop']['cpu_tensor']) if 'prop' in self.cur_node.meta and 'assign_args' in self.cur_node.meta['prop']: self.assign_args.append(self.cur_node.meta['prop']['assign_args']) - if 'prop' in self.cur_node.meta and 'assign_with_offset_args' in self.cur_node.meta['prop']: - self.assign_with_offset_args.append(self.cur_node.meta['prop']['assign_with_offset_args']) _, args_list = AscendOverrides.gen_args( self.args_dict[name], self.args_dict, args) @@ -199,10 +196,6 @@ def parse_outputs(self): if len(self.assign_args) > 0: self.graph_output_names.extend(list(zip(*self.assign_args))[0]) - current_index = len(self.graph_output_names) - for i in range(len(self.assign_with_offset_args)): - self.assign_with_offset_args[i]['output_index'] = current_index - current_index += 1 def gen_import_code(self): self.import_code.splice( @@ -380,15 +373,7 @@ def gen_call_func(self): output_index = self.graph_output_names.index(item[0]) allocated_output[output_index] = input_index call_body.writeline(f'allocated_output= {allocated_output}') - - allocated_output_with_offset = {} - for item in self.assign_with_offset_args: - # input_index = item['input_index'] - output_index = item['output_index'] - # offset = item['offset'] - allocated_output_with_offset[output_index] = item - call_body.writeline(f'allocated_with_offset_output= {self.assign_with_offset_args}') - call_str = ['output_tensor = kernel_cpp_0(args, dims, output_shape, out_stride, out_storage_offset, allocated_output, allocated_with_offset_output)'] + call_str = ['output_tensor = kernel_cpp_0(args, dims, output_shape, out_stride, out_storage_offset, allocated_output)'] if precision_check and self.aten_graph is not None: # import aten graph @@ -1673,12 +1658,6 @@ def ExpandDims(name, x, axis): gather_op.set_input("axis", axis) return gather_op.to_node() - @staticmethod - def InplaceCopyWithOffset(name, x, src, dim, offset): - op = OP(name, "Identity") - op.set_input("x", src) - return op.to_node() - @staticmethod def MaskedScatter(name, x, mask, updates): op = OP(name, "MaskedScatter") diff --git a/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py b/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py index d8e188c20..8c14e1a44 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py @@ -297,7 +297,7 @@ def _prepare_input(self, images, dims): assert (dataset == self.input_dataset) @record_function('load_and_run_prepare_output') - def _prepare_output(self, output_tensor, output_shape, out_stride, out_storage_offset, allocated_output, allocated_output_with_offset_tensor): + def _prepare_output(self, output_tensor, output_shape, out_stride, out_storage_offset, allocated_output): for i in range(self.num_outputs): if allocated_output and i in allocated_output.keys(): item = allocated_output[i] @@ -339,7 +339,7 @@ def _prepare_dynamic_output(self, output_tensor, output_shape, out_stride, out_s @record_function('load_and_run_run') def run(self, images, dims=None, output_shape=None, out_stride=None, out_storage_offset=None, - allocated_output=None, allocated_with_offset_output=None): + allocated_output=None): with record_function(f'load_and_run_run_{self.model_id}'): assert len(images) > 0 input = [x.to(dipu_device_str) if isinstance(x, torch.Tensor) @@ -350,12 +350,6 @@ def run(self, images, dims=None, output_shape=None, for output_index, input_index in allocated_output.items(): allocated_output_tensor[output_index] = input[input_index] - allocated_output_with_offset_tensor = None - if allocated_with_offset_output: - allocated_output_with_offset_tensor = {} - for item in allocated_with_offset_output: - allocated_output_with_offset_tensor[item['output_index']] = {'input_tensor': input[item['input_index']], 'offset': item['offset']} - self._prepare_input(input, dims) output = [] if output_shape: @@ -363,7 +357,7 @@ def run(self, images, dims=None, output_shape=None, output, output_shape, out_stride, out_storage_offset, allocated_output_tensor) else: self._prepare_output( - output, output_shape, out_stride, out_storage_offset, allocated_output_tensor, allocated_output_with_offset_tensor) + output, output_shape, out_stride, out_storage_offset, allocated_output_tensor) self.forward() self._destroy_databuffer() return output @@ -388,8 +382,8 @@ def __init__(self, device_id, model_path) -> None: self.exe = AscendExecutor(device_id, model_path) def run(self, images, dims=None, output_shape=None, - out_stride=None, out_storage_offset=None, allocated_output=None, allocated_with_offset_output=None): - return self.exe.run(images, dims, output_shape, out_stride, out_storage_offset, allocated_output, allocated_with_offset_output) + out_stride=None, out_storage_offset=None, allocated_output=None): + return self.exe.run(images, dims, output_shape, out_stride, out_storage_offset, allocated_output) def cleanup(self): if hasattr(self, 'exe'): diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 5d957da16..761436a63 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -1574,7 +1574,9 @@ def copy_with_offset2(self, x, src, start_dim, end_dim): def flash_attention_inference(self, q, all_k, all_v, current_len, max_len): q_shape = list(q.node.meta['val'].shape) batch, head, dim = q_shape[0], q_shape[1], q_shape[2] - + k_shape = list(all_k.node.meta['val'].shape) + kvhead = k_shape[1] + res = [] compute_batch = 1 select_axis = self.get_const_proxy(0, torch.int32) @@ -1585,12 +1587,12 @@ def flash_attention_inference(self, q, all_k, all_v, current_len, max_len): xq = self.get_proxy(ascend_op.GatherV2, (q, select_index, select_axis)) kv_start_index = self.get_const_proxy([i * max_len, 0, 0], torch.int32) - kv_end_index = self.get_const_proxy([i * max_len + current_len, head, dim], torch.int32) + kv_end_index = self.get_const_proxy([i * max_len + current_len, kvhead, dim], torch.int32) kv_seq_len = current_len - kv_gather_shape = self.get_shape_proxy([compute_batch, kv_seq_len, head, dim]) - kv_compute_shape = self.get_shape_proxy([compute_batch, kv_seq_len, head * dim]) - + kv_gather_shape = self.get_shape_proxy([compute_batch, kv_seq_len, kvhead, dim]) + kv_compute_shape = self.get_shape_proxy([compute_batch, kv_seq_len, kvhead * dim]) + # fetch k k = self.get_proxy(ascend_op.Slice, (all_k, kv_start_index, kv_end_index)) k = self.get_proxy(ascend_op.Reshape, (k, kv_gather_shape)) @@ -1606,13 +1608,13 @@ def flash_attention_inference(self, q, all_k, all_v, current_len, max_len): q_compute_shape = self.get_shape_proxy([compute_batch, 1, head * dim]) xq = self.get_proxy(ascend_op.Reshape, (xq, q_shape)) xq = self.get_proxy(ascend_op.Reshape, (xq, q_compute_shape)) - - out = self.incre_flash_attention(xq, k, v, head, head, dim) # q shape is BSH + + out = self.incre_flash_attention(xq, k, v, kvhead, head, dim) # q shape is BSH out_shape = self.get_shape_proxy([compute_batch, 1, head, dim]) out_shape2 = self.get_shape_proxy([compute_batch, head, dim]) out = self.get_proxy(ascend_op.Reshape, (out, out_shape)) out = self.get_proxy(ascend_op.Reshape, (out, out_shape2)) res.append(out) - + res = self.get_proxy(ascend_op.ConcatD, (res, 0)) return res diff --git a/dicp/dicp/vendor/AscendGraph/opset_convert.py b/dicp/dicp/vendor/AscendGraph/opset_convert.py index 040efb0e0..39534710a 100644 --- a/dicp/dicp/vendor/AscendGraph/opset_convert.py +++ b/dicp/dicp/vendor/AscendGraph/opset_convert.py @@ -1,7 +1,7 @@ import torch import torch_dipu from dicp.dynamo_bridge.compile_fx import is_torch_210 -from dicp.vendor.AscendGraph.ascend_op import CastToCpu, IdentityInp, InplaceCopyWithOffset +from dicp.vendor.AscendGraph.ascend_op import CastToCpu, IdentityInp from dicp.vendor.AscendGraph.conversion import AtenToAscendTransformer from ...dynamo_bridge.graph import GraphTransformer @@ -46,10 +46,6 @@ def transform(self, gm: torch.fx.graph_module): continue if isinstance(n.target, CastToCpu): self.cpu_tensor.append(n.name) - elif isinstance(n.target, InplaceCopyWithOffset): - input_index = input_names.index(str(n.args[0])) - offset = int(n.args[-1]) - self.assign_with_offset_args[n.name] = {'name': n.name, 'input_index': input_index, 'offset': offset} elif isinstance(n.target, IdentityInp): if len(n.args) == 2 and n.args[1] is not None and str(n.args[1]) in input_names: self.assign_args.append((n.name, input_names.index(str(n.args[1])))) @@ -64,8 +60,6 @@ def transform(self, gm: torch.fx.graph_module): if len(self.assign_args) > 0 and n.name in list(zip(*self.assign_args))[0]: idx = list(zip(*self.assign_args))[0].index(n.name) prop.update({'assign_args': (self.assign_args[idx][0], self.assign_args[idx][1])}) - if n.name in self.assign_with_offset_args.keys(): - prop['assign_with_offset_args'] = self.assign_with_offset_args[n.name] n.meta['prop'] = prop return gm From dfd778ded25163572a06daab0193e178cfe8eb96 Mon Sep 17 00:00:00 2001 From: Reinerzhou <1768552509@qq.com> Date: Mon, 22 Apr 2024 08:40:45 +0000 Subject: [PATCH 09/17] remove unused record_function. --- .../AscendGraph/codegen/load_and_run.py | 41 +++++++++---------- dicp/dicp/vendor/AscendGraph/conversion.py | 8 ++-- 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py b/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py index 8c14e1a44..51588a9e5 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py @@ -340,27 +340,26 @@ def _prepare_dynamic_output(self, output_tensor, output_shape, out_stride, out_s def run(self, images, dims=None, output_shape=None, out_stride=None, out_storage_offset=None, allocated_output=None): - with record_function(f'load_and_run_run_{self.model_id}'): - assert len(images) > 0 - input = [x.to(dipu_device_str) if isinstance(x, torch.Tensor) - and x.device.type != dipu_device_str else x for x in images] - allocated_output_tensor = None - if allocated_output: - allocated_output_tensor = {} - for output_index, input_index in allocated_output.items(): - allocated_output_tensor[output_index] = input[input_index] - - self._prepare_input(input, dims) - output = [] - if output_shape: - self._prepare_dynamic_output( - output, output_shape, out_stride, out_storage_offset, allocated_output_tensor) - else: - self._prepare_output( - output, output_shape, out_stride, out_storage_offset, allocated_output_tensor) - self.forward() - self._destroy_databuffer() - return output + assert len(images) > 0 + input = [x.to(dipu_device_str) if isinstance(x, torch.Tensor) + and x.device.type != dipu_device_str else x for x in images] + allocated_output_tensor = None + if allocated_output: + allocated_output_tensor = {} + for output_index, input_index in allocated_output.items(): + allocated_output_tensor[output_index] = input[input_index] + + self._prepare_input(input, dims) + output = [] + if output_shape: + self._prepare_dynamic_output( + output, output_shape, out_stride, out_storage_offset, allocated_output_tensor) + else: + self._prepare_output( + output, output_shape, out_stride, out_storage_offset, allocated_output_tensor) + self.forward() + self._destroy_databuffer() + return output @record_function('load_and_run_forward') def forward(self): diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 761436a63..d0a1a7321 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -1576,7 +1576,7 @@ def flash_attention_inference(self, q, all_k, all_v, current_len, max_len): batch, head, dim = q_shape[0], q_shape[1], q_shape[2] k_shape = list(all_k.node.meta['val'].shape) kvhead = k_shape[1] - + res = [] compute_batch = 1 select_axis = self.get_const_proxy(0, torch.int32) @@ -1592,7 +1592,7 @@ def flash_attention_inference(self, q, all_k, all_v, current_len, max_len): kv_gather_shape = self.get_shape_proxy([compute_batch, kv_seq_len, kvhead, dim]) kv_compute_shape = self.get_shape_proxy([compute_batch, kv_seq_len, kvhead * dim]) - + # fetch k k = self.get_proxy(ascend_op.Slice, (all_k, kv_start_index, kv_end_index)) k = self.get_proxy(ascend_op.Reshape, (k, kv_gather_shape)) @@ -1608,13 +1608,13 @@ def flash_attention_inference(self, q, all_k, all_v, current_len, max_len): q_compute_shape = self.get_shape_proxy([compute_batch, 1, head * dim]) xq = self.get_proxy(ascend_op.Reshape, (xq, q_shape)) xq = self.get_proxy(ascend_op.Reshape, (xq, q_compute_shape)) - + out = self.incre_flash_attention(xq, k, v, kvhead, head, dim) # q shape is BSH out_shape = self.get_shape_proxy([compute_batch, 1, head, dim]) out_shape2 = self.get_shape_proxy([compute_batch, head, dim]) out = self.get_proxy(ascend_op.Reshape, (out, out_shape)) out = self.get_proxy(ascend_op.Reshape, (out, out_shape2)) res.append(out) - + res = self.get_proxy(ascend_op.ConcatD, (res, 0)) return res From 910a4108d81815db2e1c731028b59eef031f16ce Mon Sep 17 00:00:00 2001 From: Reinerzhou <1768552509@qq.com> Date: Mon, 22 Apr 2024 09:36:58 +0000 Subject: [PATCH 10/17] clean code. --- dicp/dicp/vendor/AscendGraph/ascend_op.py | 1 + dicp/dicp/vendor/AscendGraph/codegen/ascend.py | 12 +----------- dicp/dicp/vendor/AscendGraph/conversion.py | 2 +- dicp/dicp/vendor/AscendGraph/opset_convert.py | 1 - 4 files changed, 3 insertions(+), 13 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/ascend_op.py b/dicp/dicp/vendor/AscendGraph/ascend_op.py index 45a2010bc..e078a071c 100644 --- a/dicp/dicp/vendor/AscendGraph/ascend_op.py +++ b/dicp/dicp/vendor/AscendGraph/ascend_op.py @@ -1110,6 +1110,7 @@ def __init__(self): def infer_result(self, x, index): x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x) idx, idx_shape, idx_dim, idx_dtype = get_fake_tensor_meta_val(index) + # compute idx_shape for some special cases. idx_shape = list(idx_shape) idx_shape.append(x_shape[-1]) return torch.empty(idx_shape, dtype=x_dtype, memory_format=get_memory_format(x)) diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 112446f93..45678accc 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -491,14 +491,10 @@ def gen_graph_json(self): self.parse_outputs() self.gen_build_options() has_dynamic_shape = False if len(self.sym_in_args) == 0 and len(self.sym_to_inputs) == 0 else True - with_copy_inplace = self.graph_output_names - for i in self.assign_with_offset_args: - with_copy_inplace.append(i['name']) graph = { "name": "graph", "input_names": self.graph_input_names, - # "output_names": self.graph_output_names, - "output_names": with_copy_inplace, + "output_names": self.graph_output_names, "has_dynamic_shape": has_dynamic_shape, "build_options": self.build_options, "data_nodes": self.data_nodes, @@ -696,12 +692,6 @@ def gen_args(op_var, args_dict, args): src_code = IndentedBuffer() args_str = [op_var] args_str.extend(tree_map_only(Node, lambda x: args_dict[x.name], args)) - - # for i in range(len(args)): - # if isinstance(args[i], Node): - # args_str.append(args_dict[args[i].name]) - # else: - # args_str.append(args[i]) return src_code, args_str @staticmethod diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index d0a1a7321..ef238b85a 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -1565,7 +1565,7 @@ def select_scatter(self, x, src, dim, index): return self.get_proxy(ascend_op.ScatterElements, (x, index, src, dim)) @register_conversion(torch.ops.lightllm.copy_with_offset.default) - def copy_with_offset2(self, x, src, start_dim, end_dim): + def copy_with_offset(self, x, src, start_dim, end_dim): dims = [x for x in range(start_dim, end_dim)] dims = self.get_const_proxy(dims, torch.int32, target_shape=[len(dims), 1]) return self.get_proxy(ascend_op.ScatterNdUpdate, (x, dims, src)) diff --git a/dicp/dicp/vendor/AscendGraph/opset_convert.py b/dicp/dicp/vendor/AscendGraph/opset_convert.py index 39534710a..1db1f8b28 100644 --- a/dicp/dicp/vendor/AscendGraph/opset_convert.py +++ b/dicp/dicp/vendor/AscendGraph/opset_convert.py @@ -32,7 +32,6 @@ class OutputMarkPass: def __init__(self): self.assign_args = [] self.cpu_tensor = [] - self.assign_with_offset_args = {} def transform(self, gm: torch.fx.graph_module): # dynamic shape feature From 4c660bf64aeca90a0e42bd2546017e3ba7124b6d Mon Sep 17 00:00:00 2001 From: Reinerzhou <1768552509@qq.com> Date: Mon, 22 Apr 2024 09:38:19 +0000 Subject: [PATCH 11/17] revert select conversion. --- dicp/dicp/vendor/AscendGraph/conversion.py | 34 +++++++++++++++++++--- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index ef238b85a..4f22f3834 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -382,10 +382,36 @@ def empty_like(self, x, dtype=None, layout=None, @register_conversion(aten.select.int) def select(self, x, dim, index): - axis = self.get_const_proxy(dim, torch.int32) - if not isinstance(index, torch.fx.proxy.Proxy): - index = self.get_const_proxy(index, torch.int32) - return self.get_proxy(ascend_op.GatherV2, (x, index, axis)) + x_shape = list(x.node.meta['val'].shape) + y_shape = list(fx_traceback.get_current_meta()['val'].shape) + dim = int(dim) + index = int(index) + assert dim >= 0 and dim < len(x_shape) + start = index if index >= 0 else index + x_shape[dim] + end = start + 1 + offset = [0] * len(x_shape) + offset[dim] = start + size = [] + for i, v in enumerate(x_shape): + if i != dim: + size.append(v - offset[i]) + else: + size.append(end - offset[i]) + offset = self.get_shape_proxy(offset) + size = self.get_shape_proxy(size) + slice = self.get_proxy(ascend_op.Slice, (x, offset, size)) + y_shape = self.get_shape_proxy(y_shape) + Reshape_kw = { + "ori_op": "Select", + "params_passed": { + "sel_dim": dim, + }, + } + return self.get_proxy(ascend_op.Reshape, (slice, y_shape), Reshape_kw) + # axis = self.get_const_proxy(dim, torch.int32) + # if not isinstance(index, torch.fx.proxy.Proxy): + # index = self.get_const_proxy(index, torch.int32) + # return self.get_proxy(ascend_op.GatherV2, (x, index, axis)) @register_conversion(_operator.add) def inadd(self, x, y): From 78c9fa73bd3d5624bc5c3e13790ff52c4e455c85 Mon Sep 17 00:00:00 2001 From: Reinerzhou <1768552509@qq.com> Date: Mon, 22 Apr 2024 10:35:38 +0000 Subject: [PATCH 12/17] revert default format and remove joint_graph_passes. --- dicp/dicp/dynamo_bridge/compile_fx.py | 2 +- dicp/dicp/vendor/AscendGraph/codegen/ascend.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dicp/dicp/dynamo_bridge/compile_fx.py b/dicp/dicp/dynamo_bridge/compile_fx.py index 42465d27d..118244517 100644 --- a/dicp/dicp/dynamo_bridge/compile_fx.py +++ b/dicp/dicp/dynamo_bridge/compile_fx.py @@ -224,7 +224,7 @@ def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference): inference_compiler = functools.partial(fw_compiler_base, is_inference=True) def partition_fn(graph, joint_inputs, **kwargs): - joint_graph_passes(graph) + # joint_graph_passes(graph) return min_cut_rematerialization_partition( graph, joint_inputs, **kwargs, compiler="inductor" ) diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 45678accc..fc2647b05 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -79,7 +79,7 @@ def placeholder(self, name, target, args, kwargs): self.input_args.append(self.cur_node) fake_tensor = self.cur_node.meta['val'] - format = "ND" + format = "NCHW" index = -1 if isinstance(fake_tensor, torch.SymInt): From 9151cfee60617de28649760ac37d5f8d555835e5 Mon Sep 17 00:00:00 2001 From: Reinerzhou <1768552509@qq.com> Date: Mon, 22 Apr 2024 11:42:09 +0000 Subject: [PATCH 13/17] uncomment a joint_graph_passes. --- dicp/dicp/dynamo_bridge/compile_fx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dicp/dicp/dynamo_bridge/compile_fx.py b/dicp/dicp/dynamo_bridge/compile_fx.py index 118244517..42465d27d 100644 --- a/dicp/dicp/dynamo_bridge/compile_fx.py +++ b/dicp/dicp/dynamo_bridge/compile_fx.py @@ -224,7 +224,7 @@ def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference): inference_compiler = functools.partial(fw_compiler_base, is_inference=True) def partition_fn(graph, joint_inputs, **kwargs): - # joint_graph_passes(graph) + joint_graph_passes(graph) return min_cut_rematerialization_partition( graph, joint_inputs, **kwargs, compiler="inductor" ) From ffc239c214c4dc5b8c9fbad2d004d153203d8137 Mon Sep 17 00:00:00 2001 From: Reinerzhou <1768552509@qq.com> Date: Mon, 22 Apr 2024 11:49:24 +0000 Subject: [PATCH 14/17] cancel test_hf temporarily. --- dicp/test/ascend_scripts/models/run_test_models.sh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/dicp/test/ascend_scripts/models/run_test_models.sh b/dicp/test/ascend_scripts/models/run_test_models.sh index c480044af..3cae94d35 100755 --- a/dicp/test/ascend_scripts/models/run_test_models.sh +++ b/dicp/test/ascend_scripts/models/run_test_models.sh @@ -34,9 +34,9 @@ else exit 1 fi -PYTESTCODE=$? -if [ "$PYTESTCODE" -eq 0 ]; then - python ${TEST_MODEL_DIR}/test_hf.py -else - exit 1 -fi +# PYTESTCODE=$? +# if [ "$PYTESTCODE" -eq 0 ]; then +# python ${TEST_MODEL_DIR}/test_hf.py +# else +# exit 1 +# fi From 289734cd14d0003d1d6a80d75bb36f23c502f580 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Mon, 22 Apr 2024 13:06:06 +0000 Subject: [PATCH 15/17] Fix test_hf case. --- dicp/dicp/vendor/AscendGraph/codegen/ascend.py | 2 +- dicp/test/ascend_scripts/models/run_test_models.sh | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index fc2647b05..918a558b9 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -281,7 +281,7 @@ def gen_call_func(self): # dynamic shape feature if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: args = ['_' if arg not in shape_symint and arg not in self.sym_to_inputs.values() else arg for arg in self.args] - call_body.writeline("({','.join(args)}) = args") + call_body.writeline(f"({','.join(args)}) = args") # generate input dims if len(self.dynamic_inputs) > 0: diff --git a/dicp/test/ascend_scripts/models/run_test_models.sh b/dicp/test/ascend_scripts/models/run_test_models.sh index 3cae94d35..c480044af 100755 --- a/dicp/test/ascend_scripts/models/run_test_models.sh +++ b/dicp/test/ascend_scripts/models/run_test_models.sh @@ -34,9 +34,9 @@ else exit 1 fi -# PYTESTCODE=$? -# if [ "$PYTESTCODE" -eq 0 ]; then -# python ${TEST_MODEL_DIR}/test_hf.py -# else -# exit 1 -# fi +PYTESTCODE=$? +if [ "$PYTESTCODE" -eq 0 ]; then + python ${TEST_MODEL_DIR}/test_hf.py +else + exit 1 +fi From 61f17406f199646f4fac8ccdc57f7a2f79ff234d Mon Sep 17 00:00:00 2001 From: Reinerzhou <1768552509@qq.com> Date: Tue, 23 Apr 2024 08:33:41 +0000 Subject: [PATCH 16/17] clean code. --- .../AscendGraph/codegen/load_and_run.py | 2 +- dicp/dicp/vendor/AscendGraph/conversion.py | 4 --- .../vendor/AscendGraph/pattern_replacement.py | 34 ------------------- 3 files changed, 1 insertion(+), 39 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py b/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py index 51588a9e5..12c7bb193 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py @@ -342,7 +342,7 @@ def run(self, images, dims=None, output_shape=None, allocated_output=None): assert len(images) > 0 input = [x.to(dipu_device_str) if isinstance(x, torch.Tensor) - and x.device.type != dipu_device_str else x for x in images] + and x.device.type != dipu_device_str else x for x in images] allocated_output_tensor = None if allocated_output: allocated_output_tensor = {} diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 4f22f3834..a56674fba 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -408,10 +408,6 @@ def select(self, x, dim, index): }, } return self.get_proxy(ascend_op.Reshape, (slice, y_shape), Reshape_kw) - # axis = self.get_const_proxy(dim, torch.int32) - # if not isinstance(index, torch.fx.proxy.Proxy): - # index = self.get_const_proxy(index, torch.int32) - # return self.get_proxy(ascend_op.GatherV2, (x, index, axis)) @register_conversion(_operator.add) def inadd(self, x, y): diff --git a/dicp/dicp/vendor/AscendGraph/pattern_replacement.py b/dicp/dicp/vendor/AscendGraph/pattern_replacement.py index ec89bfb6c..d2ef21be1 100644 --- a/dicp/dicp/vendor/AscendGraph/pattern_replacement.py +++ b/dicp/dicp/vendor/AscendGraph/pattern_replacement.py @@ -72,17 +72,6 @@ def replacement(arg0, arg1, start_index, end_index): Permute = torch.fx.wrap(ascend_op.Permute.get_singleton()) MatMul = torch.fx.wrap(ascend_op.MatMul.get_singleton()) -Pow = torch.fx.wrap(ascend_op.Pow.get_singleton()) -ReduceMeanD = torch.fx.wrap(ascend_op.ReduceMeanD.get_singleton()) -Adds = torch.fx.wrap(ascend_op.Adds.get_singleton()) -Rsqrt = torch.fx.wrap(ascend_op.Rsqrt.get_singleton()) -ZerosLike = torch.fx.wrap(ascend_op.ZerosLike.get_singleton()) -Less = torch.fx.wrap(ascend_op.Less.get_singleton()) -Select = torch.fx.wrap(ascend_op.Select.get_singleton()) -Mul = torch.fx.wrap(ascend_op.Mul.get_singleton()) -Div = torch.fx.wrap(ascend_op.Div.get_singleton()) -RmsNorm = torch.fx.wrap(ascend_op.RmsNorm.get_singleton()) - # @register_ascend_pattern class FuseBmmTransposeRhsPattern(BackendPatternBase): @@ -123,29 +112,6 @@ def replacement(x1, x2, c1, c2): return BatchMatMul(x1, muls, adj_x1=False, adj_x2=True, keep_dtype=0) -@register_ascend_pattern -class FuseLightLLMRmsNorm(BackendPatternBase): - @staticmethod - def pattern(arg0_1, arg1_1): - const = Const([2], torch.float32) - pow_1 = Pow(arg0_1, const) - reduce_mean_d = ReduceMeanD(pow_1, [-1], True, False) - adds = Adds(reduce_mean_d, 0.001) - rsqrt = Rsqrt(adds) - zeros_like = ZerosLike(adds) - div = Div(zeros_like, zeros_like) - less = Less(adds, zeros_like) - select = Select(less, div, rsqrt) - mul = Mul(arg0_1, select) - mul_1 = Mul(mul, arg1_1) - return mul_1 - - @staticmethod - def replacement(arg0_1, arg1_1): - rms_norm = RmsNorm(arg0_1, arg1_1, 0.001) - return Identity(rms_norm, 0) - - # @pandaoxin negotiate with @tangzhiyi # another submit would implement # @register_ascend_pattern From bc35a28f388ab949a632749f898e2655d6ea18e6 Mon Sep 17 00:00:00 2001 From: Reinerzhou <1768552509@qq.com> Date: Tue, 23 Apr 2024 10:47:19 +0000 Subject: [PATCH 17/17] remove gather temporarily. --- dicp/dicp/vendor/AscendGraph/ascend_op.py | 13 ------------- dicp/dicp/vendor/AscendGraph/codegen/ascend.py | 7 ------- 2 files changed, 20 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/ascend_op.py b/dicp/dicp/vendor/AscendGraph/ascend_op.py index e078a071c..201187101 100644 --- a/dicp/dicp/vendor/AscendGraph/ascend_op.py +++ b/dicp/dicp/vendor/AscendGraph/ascend_op.py @@ -1103,19 +1103,6 @@ def infer_result(self, x, indices, updates): return torch.empty(x_shape, dtype=x_dtype, memory_format=get_memory_format(x)) -class Gather(Operator): - def __init__(self): - super().__init__("Gather") - - def infer_result(self, x, index): - x, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x) - idx, idx_shape, idx_dim, idx_dtype = get_fake_tensor_meta_val(index) - # compute idx_shape for some special cases. - idx_shape = list(idx_shape) - idx_shape.append(x_shape[-1]) - return torch.empty(idx_shape, dtype=x_dtype, memory_format=get_memory_format(x)) - - class ExpandDims(Operator): def __init__(self): super().__init__("ExpandDims") diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 918a558b9..a7930c4b7 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -1634,13 +1634,6 @@ def IncreFlashAttention(name, q, k_list, v_list, kv_input_num, head_num, kv_head op.set_attr_str("input_layout", input_layout) return op.to_node() - @staticmethod - def Gather(name, x, indices): - gather_op = OP(name, "Gather") - gather_op.set_input("x", x) - gather_op.set_input("indices", indices) - return gather_op.to_node() - @staticmethod def ExpandDims(name, x, axis): gather_op = OP(name, "ExpandDims")