diff --git a/dicp/dicp/dynamo_bridge/compile_fx.py b/dicp/dicp/dynamo_bridge/compile_fx.py index 5eaee327c..42465d27d 100644 --- a/dicp/dicp/dynamo_bridge/compile_fx.py +++ b/dicp/dicp/dynamo_bridge/compile_fx.py @@ -208,7 +208,8 @@ def compile_fx_210( def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference): if is_inference: # partition_fn won't be called - joint_graph_passes(model) + # joint_graph_passes(model) + pass fixed = len(example_inputs) - num_example_inputs return inner_compile( diff --git a/dicp/dicp/dynamo_bridge/decompositions.py b/dicp/dicp/dynamo_bridge/decompositions.py new file mode 100644 index 000000000..f821abbcc --- /dev/null +++ b/dicp/dicp/dynamo_bridge/decompositions.py @@ -0,0 +1,38 @@ +from collections import defaultdict +from typing import Callable, Dict, Sequence, Union + +import torch +from torch._decomp import register_decomposition +from torch._ops import OpOverload, OpOverloadPacket + +dicp_decomposition_table = {} +aten = torch.ops.aten + + +def register_decomposition_for_dicp(fn): + return register_decomposition(fn, registry=dicp_decomposition_table) + + +@register_decomposition_for_dicp(aten.count_nonzero.default) +def count_nonzero_default(x, dim=None): + cond = x != 0 + dim = [] if dim is None else dim + return aten.sum.dim_IntList(cond, dim=dim, keepdim=False, dtype=torch.int64) + + +def get_decompositions( + aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]], + target_decomposition_table: Dict[OpOverload, Callable] = None, +) -> Dict[OpOverload, Callable]: + registry = dicp_decomposition_table + packets_to_overloads = defaultdict(list) + for opo in registry: + packets_to_overloads[opo.overloadpacket].append(opo) + decompositions = target_decomposition_table if target_decomposition_table else {} + for op in aten_ops: + if isinstance(op, OpOverloadPacket) and op in packets_to_overloads: + for op_overload in packets_to_overloads[op]: + decompositions[op_overload] = registry[op_overload] + elif isinstance(op, OpOverload) and op in registry: + decompositions[op] = registry[op] + return decompositions diff --git a/dicp/dicp/vendor/AscendGraph/ascend_op.py b/dicp/dicp/vendor/AscendGraph/ascend_op.py index 9d129e638..201187101 100644 --- a/dicp/dicp/vendor/AscendGraph/ascend_op.py +++ b/dicp/dicp/vendor/AscendGraph/ascend_op.py @@ -14,6 +14,7 @@ aten = torch.ops.aten + def negative_in_shape(shape): for elem in shape: if elem < 0: @@ -43,12 +44,12 @@ def __init__(self): def infer_result(self, x, shape): x, x_shape, _, x_dtype = get_fake_tensor_meta_val(x) - if isinstance(shape, torch._subclasses.fake_tensor.FakeTensor): # case1: shape is a fakeTensor, like conversion for 'scatter' and 'where' + if isinstance(shape, torch._subclasses.fake_tensor.FakeTensor): # case1: shape is a fakeTensor, like conversion for 'scatter' and 'where' shape, shape_shape, _, _ = get_fake_tensor_meta_val(shape) shape = shape_shape - elif isinstance(shape, Tuple): # case2: shape is tuple from 'Const' , like conversion for 'lt' - shape, _, _, _ =get_op_const_arg_kwarg(shape) - else: # other cases, unsupported yet + elif isinstance(shape, Tuple): # case2: shape is tuple from 'Const' , like conversion for 'lt' + shape, _, _, _ = get_op_const_arg_kwarg(shape) + else: # other cases, unsupported yet assert False, self.__class__.__name__ + "unsupported 'shape' input type!" out_shape = get_broadcast_res_two_shape(x_shape, shape) @@ -97,7 +98,7 @@ def __init__(self): class MatMul(Operator): def __init__(self): super().__init__("MatMul") - + def infer_result(self, x1, x2, adj_x1=False, adj_x2=False): attr = acl.op.create_attr() check_ret("acl.op.set_attr_bool", acl.op.set_attr_bool(attr, "transpose_x1", adj_x1)) @@ -290,6 +291,14 @@ def infer_result(self, x, dims, keepdim): return reduce_op_infer(x, dims, keepdim) +class ReduceSum(Operator): + def __init__(self): + super().__init__("ReduceSum") + + def infer_result(self, x, dims, keepdim): + return reduce_op_infer(x, dims, keepdim) + + class Unsqueeze(Operator): def __init__(self): super().__init__("Unsqueeze") @@ -628,7 +637,7 @@ def infer_result(self, x, index, orig_index): # assume not none index, and replace prefix x_shape dims len_idx_shape = len(orig_index) - assert(len_idx_shape > 0) + assert (len_idx_shape > 0) bcast_index_shape = list(orig_index[0].shape) x_shape = bcast_index_shape + list(x_shape[len_idx_shape:]) return torch.empty(x_shape, dtype=x_dtype, memory_format=get_memory_format(x)) @@ -962,6 +971,14 @@ def infer_result(self, x1, x2): return common_binary_op_infer(x1, x2, torch.bool) +class LogicalNot(Operator): + def __init__(self): + super().__init__("LogicalNot") + + def infer_result(self, x): + return common_binary_op_infer(x, torch.bool) + + class Tril(Operator): def __init__(self): super().__init__("Tril") @@ -1023,7 +1040,7 @@ def infer_result( output_batch_var = torch.empty( [channel_size], dtype=torch.float32, memory_format=torch.contiguous_format ) - return [output_y,output_mean,output_var,output_batch_mean,output_batch_var] + return [output_y, output_mean, output_var, output_batch_mean, output_batch_var] class TileWithAxis(Operator): @@ -1032,6 +1049,38 @@ def __init__(self): self.torch_op = aten.repeat_interleave.self_int +class RotaryMul(Operator): + def __init__(self): + super().__init__("RotaryMul") + + def infer_result(self, x, cos, sin): + return torch.empty_like(x) + + +class RmsNorm(Operator): + def __init__(self): + super().__init__("RmsNorm") + + def infer_result(self, x, weight, eps): + return torch.empty_like(x) + + +class PromptFlashAttention(Operator): + def __init__(self): + super().__init__("PromptFlashAttention") + + def infer_result(self, q, k, v, num_head, seqlen, mask, head_dim): + return torch.empty_like(q) + + +class IncreFlashAttention(Operator): + def __init__(self): + super().__init__("IncreFlashAttention") + + def infer_result(self, q, k, v, head_num): + return torch.empty_like(q) + + class TensorScatterUpdate(Operator): def __init__(self): super().__init__("TensorScatterUpdate") @@ -1054,6 +1103,38 @@ def infer_result(self, x, indices, updates): return torch.empty(x_shape, dtype=x_dtype, memory_format=get_memory_format(x)) +class ExpandDims(Operator): + def __init__(self): + super().__init__("ExpandDims") + + def infer_result(self, x, axis): + return torch.unsqueeze(x, axis) + + +class MaskedScatter(Operator): + def __init__(self): + super().__init__("MaskedScatter") + + def infer_result(self, x, mask, updates): + return x + + +class ViewCopy(Operator): + def __init__(self): + super().__init__("ViewCopy") + + def infer_result(self, dst, dst_size, dst_stride, dst_storage_offset, src, src_size, src_stride, src_storage_offset): + return dst + + +class ScatterNdUpdate(Operator): + def __init__(self): + super().__init__("ScatterNdUpdate") + + def infer_result(self, x, indices, updates): + return x + + def ret_triple(a, b, c) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return a, b, c diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 9b9fc24f4..a7930c4b7 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -1,9 +1,10 @@ import json import os -import uuid +import math import torch from typing import Any, List from torch.fx.node import Node +from torch.utils._pytree import tree_map_only from torch._inductor.utils import IndentedBuffer from dicp.dynamo_bridge.utils import symint_in_shape from dicp.vendor.AscendGraph.codegen.utils import ( @@ -12,6 +13,7 @@ get_ascend_dtype_num ) + graph_id = 0 precision_check = bool(os.environ.get("DICP_ASCEND_PRECISION_CHECK", False)) @@ -91,7 +93,7 @@ def placeholder(self, name, target, args, kwargs): for idx, dim in enumerate(fake_tensor.shape): if isinstance(dim, torch.SymInt): st = dim.node.str() - if not st in self.sym_in_args: + if st not in self.sym_in_args: self.sym_in_args[st] = (name, idx) # deal with dynamic shape -1 @@ -278,7 +280,7 @@ def gen_call_func(self): # dynamic shape feature if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: - args = ['_' if not arg in shape_symint and not arg in self.sym_to_inputs.values() else arg for arg in self.args] + args = ['_' if arg not in shape_symint and arg not in self.sym_to_inputs.values() else arg for arg in self.args] call_body.writeline(f"({','.join(args)}) = args") # generate input dims @@ -296,14 +298,14 @@ def gen_call_func(self): dims = dims[:-1] + '}' call_body.writeline(dims) else: - call_body.writeline(f'''dims = None''') + call_body.writeline('''dims = None''') # generate output shapes # dynamic shape feature extra_stride_str = '' extra_storage_offset_str = '' if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: - shape_str = f'''output_shape = [''' + shape_str = '''output_shape = [''' for elem in self.output_args: if hasattr(elem, 'meta'): elem = elem.meta['val'] @@ -329,11 +331,11 @@ def gen_call_func(self): stride = [self.process_sym_name(str(dim)) for dim in stride] extra_stride_str += '[' + ','.join(map(str, stride)) + '],' extra_storage_offset_str += str(self.input_args[elem[1]].meta['val'].storage_offset()) + ',' - shape_str = shape_str[:-1] + f''']''' + shape_str = shape_str[:-1] + ''']''' call_body.writeline(shape_str) else: call_body.writeline('''output_shape = None''') - + # add stride & storage_offset info out_strides = [] out_storage_offsets = [] @@ -344,7 +346,7 @@ def gen_call_func(self): out_strides.append('[1]') out_storage_offsets.append('0') continue - if elem.dim()==0: # temporary solution for sum.default(a) whose result is a scalar(no dim no stride) + if elem.dim() == 0: # temporary solution for sum.default(a) whose result is a scalar(no dim no stride) out_strides.append('[1]') out_storage_offsets.append('0') continue @@ -375,11 +377,11 @@ def gen_call_func(self): if precision_check and self.aten_graph is not None: # import aten graph - call_str.append(f"import sys") + call_str.append("import sys") call_str.append(f"if '{self.folder}' not in sys.path:") call_str.append(f" sys.path.insert(0, '{self.folder}')") call_str.append(f"from {self.graph_key[:4]} import {self.graph_key} as graph_module") - call_str.append(f"aten_call = graph_module()") + call_str.append("aten_call = graph_module()") call_str.append('aten_args = list(map(lambda x: x.to("cpu"), args))') call_str.append('for idx in modified:') @@ -689,11 +691,7 @@ class AscendOverrides: def gen_args(op_var, args_dict, args): src_code = IndentedBuffer() args_str = [op_var] - for i in range(len(args)): - if isinstance(args[i], Node): - args_str.append(args_dict[args[i].name]) - else: - args_str.append(args[i]) + args_str.extend(tree_map_only(Node, lambda x: args_dict[x.name], args)) return src_code, args_str @staticmethod @@ -956,6 +954,14 @@ def SoftmaxV2(name, x, dim): op.set_attr_list_int("axes", dim) return op.to_node() + @staticmethod + def ReduceSum(name, x, axes, keep_dims): + op = OP(name, "ReduceSum") + op.set_input("x", x) + op.set_input("axes", axes) + op.set_attr_bool("keep_dims", keep_dims) + return op.to_node() + @staticmethod def ReduceSumD(name, x, axes, keep_dims): op = OP(name, "ReduceSumD") @@ -1376,7 +1382,7 @@ def Pack(name, x, axis): x_name = [] for elem in x: if elem is not None: - x_name.append(elem.name) + x_name.append(elem) op = OP(name, "Pack") op.set_dynamic_input("x", len(x_name), x_name) @@ -1522,7 +1528,7 @@ def DropOutDoMaskV3(name, x, mask, keep_prob): op.set_input("mask", mask) op.set_input("keep_prob", keep_prob) return op.to_node() - + @staticmethod def GatherElements(name, x, index, dim): op = OP(name, "GatherElements") @@ -1530,20 +1536,6 @@ def GatherElements(name, x, index, dim): op.set_input("index", index) op.set_attr_int("dim", dim) return op.to_node() - - @staticmethod - def AdaptiveAvgPool2D(name, x, output_size): - op = OP(name, "AdaptiveAvgPool2d") - op.set_input("x", x) - op.set_attr_list_int("output_size", output_size) - return op.to_node() - - @staticmethod - def AdaptiveAvgPool2DGrad(name, input_grad, orig_input_shape): - op = OP(name, "AdaptiveAvgPool2dGrad") - op.set_input("input_grad", input_grad) - op.set_attr_list_int("orig_input_shape", orig_input_shape) - return op.to_node() @staticmethod def AdaptiveAvgPool2D(name, x, output_size): @@ -1580,6 +1572,12 @@ def LogicalOr(name, x, y): op.set_input("x2", y) return op.to_node() + @staticmethod + def LogicalNot(name, x): + op = OP(name, "LogicalNot") + op.set_input("x", x) + return op.to_node() + @staticmethod def TileWithAxis(name, x, axis, tiles): op = OP(name, "TileWithAxis") @@ -1595,3 +1593,79 @@ def TensorScatterUpdate(name, x, indices, updates): op.set_input("indices", indices) op.set_input("updates", updates) return op.to_node() + + @staticmethod + def RotaryMul(name, x, cos, sin): + op = OP(name, "RotaryMul") + op.set_input("x", x) + op.set_input("r1", cos) + op.set_input("r2", sin) + return op.to_node() + + @staticmethod + def RmsNorm(name, x, weight, eps): + op = OP(name, "RmsNorm") + op.set_input("x", x) + op.set_input("gamma", weight) + op.set_attr_float("epsilon", float(eps)) + return op.to_node() + + @staticmethod + def PromptFlashAttention(name, q, k, v, head_num, seqlen, mask, head_dim): + op = OP(name, "PromptFlashAttention") + op.set_input("query", q) + op.set_input("key", k) + op.set_input("value", v) + op.set_input("atten_mask", mask) + op.set_attr_int("num_heads", head_num) + op.set_attr_float("scale_value", float(1 / math.sqrt(head_dim))) + op.set_attr_str("input_layout", "BSH") + return op.to_node() + + @staticmethod + def IncreFlashAttention(name, q, k_list, v_list, kv_input_num, head_num, kv_head_num, dim, input_layout="BSH"): + op = OP(name, "IncreFlashAttention") + op.set_input("query", q) + op.set_dynamic_input("key", kv_input_num, k_list) + op.set_dynamic_input("value", kv_input_num, v_list) + op.set_attr_int("num_heads", head_num) + op.set_attr_float("scale_value", float(1 / math.sqrt(dim))) + op.set_attr_int("num_key_value_heads", kv_head_num) + op.set_attr_str("input_layout", input_layout) + return op.to_node() + + @staticmethod + def ExpandDims(name, x, axis): + gather_op = OP(name, "ExpandDims") + gather_op.set_input("x", x) + gather_op.set_input("axis", axis) + return gather_op.to_node() + + @staticmethod + def MaskedScatter(name, x, mask, updates): + op = OP(name, "MaskedScatter") + op.set_input("x", x) + op.set_input("mask", mask) + op.set_input("updates", updates) + return op.to_node() + + @staticmethod + def ViewCopy(name, dst, dst_size, dst_stride, dst_storage_offset, src, src_size, src_stride, src_storage_offset): + op = OP(name, "ViewCopy") + op.set_input("dst", dst) + op.set_input("dst_size", dst_size) + op.set_input("dst_stride", dst_stride) + op.set_input("dst_storage_offset", dst_storage_offset) + op.set_input("src", src) + op.set_input("src_size", src_size) + op.set_input("src_stride", src_stride) + op.set_input("src_storage_offset", src_storage_offset) + return op.to_node() + + @staticmethod + def ScatterNdUpdate(name, x, indices, updates): + op = OP(name, "ScatterNdUpdate") + op.set_input("var", x) + op.set_input("indices", indices) + op.set_input("updates", updates) + return op.to_node() diff --git a/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h b/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h index 25f3a1629..8cee35353 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h +++ b/dicp/dicp/vendor/AscendGraph/codegen/graph_utils.h @@ -32,7 +32,7 @@ using json = nlohmann::json; using namespace ge; static std::unordered_set op_with_dynamic_inputs_outputs = { - "ConcatD", "IdentityN", "Pack", "SplitD"}; + "ConcatD", "IdentityN", "Pack", "SplitD", "IncreFlashAttention"}; void check_op(std::unordered_map& op_map, const std::string& op_name) { @@ -188,6 +188,41 @@ void parseDynamicInput(std::unordered_map& op_map, } } +void parseIncreFlashAttentionDynamicInput( + std::unordered_map& op_map, + op::IncreFlashAttention& op, const json& node) { + if (node.contains("dynamic_inputs")) { + int kv_inputs_num = 0; + for (const auto& i : node["dynamic_inputs"]) { + auto num = i["num"].get(); + auto name = i["name"].get(); + if (name == "key") { + kv_inputs_num = static_cast(num); + op.create_dynamic_input_byindex_key(num, 1); + for (const auto& item : i["value"]) { + auto index = item["index"].get(); + auto value = op_map[item["value"].get()]; + op.set_dynamic_input_key(index, value); + } + } else if (name == "value") { + if (kv_inputs_num == 0 && num == kv_inputs_num) { + throw std::runtime_error( + "need first set dynamic key input for IncreFlashAttention Op" + "and kv_inputs_num == num !!"); + } + op.create_dynamic_input_byindex_value(num, 1 + num); + for (const auto& item : i["value"]) { + auto index = item["index"].get(); + auto value = op_map[item["value"].get()]; + op.set_dynamic_input_value(index, value); + } + } else { + throw std::runtime_error("invalid dynamic input name"); + } + } + } +} + template void parseDynamicOutput(T& op, const json& node) { if (node.contains("dynamic_outputs")) { @@ -220,6 +255,10 @@ ge::Operator genDynamicOperator( auto op = genDynamicOp(op_name); parseDynamicInput(op_map, op, node); return op; + } else if (op_type == "IncreFlashAttention") { + auto op = genDynamicOp(op_name); + parseIncreFlashAttentionDynamicInput(op_map, op, node); + return op; } else if (op_type == "SplitD") { auto op = genDynamicOp(op_name); parseDynamicOutput(op, node); diff --git a/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py b/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py index cdf31ddf6..12c7bb193 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py @@ -65,6 +65,7 @@ ACL_HBM_MEM_P2P_HUGE = 8 ACL_HBM_MEM_P2P_NORMAL = 9 + def get_np_dtype(dtype): if dtype == ACL_FLOAT: return np.float32 @@ -121,7 +122,7 @@ def __init__(self): def init_work_weight_ptr(self): if self.work_ptr is None: - self.work_size = 18 * 1024 * 1024 * 1024 + self.work_size = 26 * 1024 * 1024 * 1024 self.work_ptr, ret = acl.rt.malloc(self.work_size, ACL_MEM_MALLOC_HUGE_FIRST) check_ret("acl.rt.malloc", ret) @@ -192,7 +193,7 @@ def load_model(self): memory_pool.release_memory() print("Adjust memory pool allocation.") memory_pool.work_ptr, ret = acl.rt.malloc(work_size, - ACL_MEM_MALLOC_HUGE_FIRST) + ACL_MEM_MALLOC_HUGE_FIRST) check_ret("acl.rt.malloc", ret) self.weight_ptr, ret = acl.rt.malloc(weight_size, @@ -264,7 +265,6 @@ def init_resource(self): _, ret = acl.mdl.add_dataset_buffer(self.output_dataset, data_buf) check_ret("acl.add_dataset_buffer", ret) - @record_function('load_and_run_prepare_input') def _prepare_input(self, images, dims): assert self.num_inputs == len(images) @@ -348,6 +348,7 @@ def run(self, images, dims=None, output_shape=None, allocated_output_tensor = {} for output_index, input_index in allocated_output.items(): allocated_output_tensor[output_index] = input[input_index] + self._prepare_input(input, dims) output = [] if output_shape: diff --git a/dicp/dicp/vendor/AscendGraph/codegen/utils.py b/dicp/dicp/vendor/AscendGraph/codegen/utils.py index 88b13357e..404d827f9 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/utils.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/utils.py @@ -56,7 +56,7 @@ def get_acl_format(x) -> int: return AclFormat.ACL_FORMAT_ND.value -def get_acl_dtype(dtype: torch.dtype) ->int: +def get_acl_dtype(dtype: torch.dtype) -> int: if dtype == torch.bool: return AclDataType.ACL_BOOL.value elif dtype == torch.int64: @@ -162,6 +162,7 @@ def get_cpp_dtype(dtype: torch.dtype) -> str: return "INT32" elif dtype == torch.float16: return "FLOAT16" + elif dtype == torch.bool: + return "BOOL" else: raise RuntimeError(f"unknow torch data type ({dtype}) in get_cpp_dtype!") - diff --git a/dicp/dicp/vendor/AscendGraph/compile_job.py b/dicp/dicp/vendor/AscendGraph/compile_job.py index 8c35fef4e..f9720bc31 100644 --- a/dicp/dicp/vendor/AscendGraph/compile_job.py +++ b/dicp/dicp/vendor/AscendGraph/compile_job.py @@ -26,7 +26,7 @@ def __init__(self, source_code) -> None: source_code.strip(), "json", extra=cpp_compile_command("i", "o", vec_isa=picked_vec_isa) + - 'local_rank' + str(self._local_rank) + code_hash(compile_file_code) + 'local_rank' + str(self._local_rank) + code_hash(compile_file_code) ) self._output_graph_path = self._input_path[:-5] + '/graph' print('output_path: ', self._output_graph_path) diff --git a/dicp/dicp/vendor/AscendGraph/config.py b/dicp/dicp/vendor/AscendGraph/config.py index d4bb64fa3..f89495ba4 100644 --- a/dicp/dicp/vendor/AscendGraph/config.py +++ b/dicp/dicp/vendor/AscendGraph/config.py @@ -1,14 +1,15 @@ import torch -from torch._decomp import get_decompositions - - -aten = torch.ops.aten -decomp_keys = [] +from dicp.dynamo_bridge.decompositions import get_decompositions def get_decomp(): - return get_decompositions(decomp_keys) + aten = torch.ops.aten + return get_decompositions( + [ + aten.count_nonzero.default, + ] + ) decomp = get_decomp() diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 11e72be2e..a56674fba 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -16,12 +16,11 @@ import dicp.vendor.AscendGraph.ascend_op as ascend_op from dicp.dynamo_bridge.utils import symint_in_shape from dicp.vendor.AscendGraph.codegen.utils import ( - get_ascend_dtype, - get_cpp_dtype + get_ascend_dtype ) from dicp.dynamo_bridge.conversion import register_conversion_impl from dicp.dynamo_bridge.op_transformer import SingleOpTransformer - +from dicp.vendor.AscendGraph import ext_ops aten = torch.ops.aten prims = torch.ops.prims @@ -51,7 +50,7 @@ def try_to_get_dtype(x): def is_dicp_cpp_support_dtype(dtype): - if dtype in [torch.float32, torch.float, torch.float16, torch.int32, torch.int64]: + if dtype in [torch.float32, torch.float, torch.float16, torch.int32, torch.int64, torch.bool]: return True return False @@ -134,14 +133,14 @@ def generate_sym_int(elem): # concat all ops return self.get_proxy(ascend_op.ConcatD, (x_names, 0)) - def get_shape_proxy(self, shape): + def get_shape_proxy(self, shape, dtype=torch.int32): if isinstance(shape, torch.fx.proxy.Proxy) or isinstance(shape, FakeTensor): return shape elif isinstance(shape, list) and symint_in_shape(shape): return self.process_dynamic_shape(shape) else: return self.get_proxy( - ascend_op.Const, (shape, torch.int32, [len(shape)])) + ascend_op.Const, (shape, dtype, [len(shape)])) def get_const_proxy(self, param, dtype, format=None, target_shape=None): if not isinstance(param, torch.fx.proxy.Proxy) and not isinstance(param, FakeTensor): @@ -240,12 +239,12 @@ def mul(self, x, y): def add(self, x, y, alpha: Optional[Number] = 1): out_dtype = fx_traceback.get_current_meta()['val'].dtype if not isinstance(y, torch.fx.proxy.Proxy): - y = y * alpha + y = y * alpha if alpha != 1 else y if out_dtype in [torch.float, torch.float16]: return self.get_proxy(ascend_op.Adds, (x, float(y)), {}) y = self.get_const_proxy(y, out_dtype) else: - y = self.mul(y, alpha) + y = self.mul(y, alpha) if alpha != 1 else y x, y = self.promote_dtype(x, y, target_dtype=out_dtype) return self.get_proxy(ascend_op.AddV2, (x, y), {}) @@ -318,7 +317,7 @@ def div(self, x, y): @register_conversion(aten.split.Tensor) def split(self, x, split_size, dim=0): - splitD_kw = { "from_view_complex": False } + splitD_kw = {"from_view_complex": False} shape = list(x.node.meta['val'].shape) if dim < 0: dim += len(shape) @@ -331,6 +330,7 @@ def slice(self, x, dim=0, start=None, end=None, step=1): # TODO(tangzhiyi): miss step parameter x_shape = list(x.node.meta['val'].shape) y_shape = list(fx_traceback.get_current_meta()['val'].shape) + # y_shape = fx_traceback.get_current_meta()['val'].shape dim = int(dim) start = int(start) if start is not None else 0 start = start if start >= 0 else x_shape[dim] + start @@ -338,6 +338,8 @@ def slice(self, x, dim=0, start=None, end=None, step=1): assert start is None or start >= 0 and start < x_shape[dim] offset = [0] * len(x_shape) offset[dim] = start + # import pdb; pdb.set_trace() + offset = self.get_shape_proxy(offset) size = self.get_shape_proxy(y_shape) return self.get_proxy(ascend_op.Slice, (x, offset, size)) @@ -358,20 +360,25 @@ def NewEmptyStrided(self, x, size, stride, dtype=torch.float32, layout=torch.str return self.empty_like(x) @register_conversion(aten.empty) - def empty(self, size, dtype=torch.int64, layout=torch.strided, device='cpu', memory_format=torch.contiguous_format): + def empty(self, size, dtype=torch.int64, layout=torch.strided, device='cpu', memory_format=torch.contiguous_format, pin_memory=False): shape_op = self.get_proxy( ascend_op.Const, (size, torch.int32, [len(size)])) return self.get_proxy(ascend_op.Empty, (shape_op, dtype, layout, device, memory_format)) @register_conversion(aten.empty_like.default) - def empty_like(self, x, dtype=torch.float32, layout=torch.strided, - device='cpu', pin_memory=False, memory_format=torch.preserve_format): - dtype = x.node.meta['val'].dtype - shape = list(x.node.meta['val'].shape) - shape_op = self.get_proxy( - ascend_op.Const, (shape, torch.int32, [len(shape)])) - new_memory_format=x.node.meta['tensor_meta'].memory_format if memory_format is torch.preserve_format else memory_format - return self.get_proxy(ascend_op.Empty, (shape_op, dtype, layout, device, new_memory_format)) + def empty_like(self, x, dtype=None, layout=None, + device=None, pin_memory=None, memory_format=None): + if dtype is None: + dtype = x.node.meta['val'].dtype + if layout is not None and (layout != torch.strided): + raise NotImplementedError("torch.ops.aten.empty_like.default is " + "only supported on dense tensor now.") + if memory_format is not None and memory_format != torch.contiguous_format \ + and memory_format != torch.preserve_format: + raise NotImplementedError("torch.ops.aten.empty_like.default is only supported " + "contiguous_format and preserve_format now.") + shape = self.get_proxy(ascend_op.Shape, (x,)) + return self.get_proxy(ascend_op.Empty, (shape, dtype)) @register_conversion(aten.select.int) def select(self, x, dim, index): @@ -455,7 +462,7 @@ def view(self, x, size): return self.get_proxy(ascend_op.IdentityN, (real_reshape, imag_reshape)) else: return self.get_proxy(ascend_op.Reshape, (x, shape)) - + @register_conversion(torch.ops.aten.where) def where(self, condition, x1, x2): # TODO(tangzhiyi): need to process scalars @@ -473,9 +480,9 @@ def arange(self, end, start=0, step=1, dtype=None, device='xpu', layout=None, pi assert isinstance(end, torch.fx.proxy.Proxy) or type(end) in [int, float] assert isinstance(step, torch.fx.proxy.Proxy) or type(step) in [int, float] - if not isinstance(start, torch.fx.proxy.Proxy): # scalar const + if not isinstance(start, torch.fx.proxy.Proxy): # scalar const start = self.get_const_proxy(start, out_dtype) - elif start.node.meta['val'] != out_dtype: # align tensor dtype + elif start.node.meta['val'] != out_dtype: # align tensor dtype start = self.get_proxy(ascend_op.Cast, (start, get_ascend_dtype(out_dtype)), {}) if not isinstance(end, torch.fx.proxy.Proxy): end = self.get_const_proxy(end, out_dtype) @@ -597,7 +604,7 @@ def view_as_complex(self, x): assert x_val.dtype == torch.float32 assert x_shape[-1] == 2 dim = len(x_shape) - 1 - splitD_kw = { "from_view_complex": True } + splitD_kw = {"from_view_complex": True} return self.get_proxy(ascend_op.SplitD, (x, dim, 2, 2), splitD_kw) @register_conversion(torch.ops.aten.full.default) @@ -868,9 +875,20 @@ def index_put_default(self, x, indices, values): # tf.tensor_scatter_nd_update param 'indices' is different from # indices in torch.ops.aten.index_put.default, we use broadcast and # stack to construct param 'indices' in tf.tensor_scatter_nd_update + x_shape = list(x.node.meta['val'].shape) + index = indices[0] + if len(indices) == 1 and ('val' in index.node.meta.keys()) and index.node.meta['val'].dtype == torch.bool: + index_shape = list(index.node.meta['val'].shape) + if len(index_shape) == len(x_shape): + return self.masked_fill(x, index, values) + reshape_shape = index_shape + [1] * (len(x_shape) - len(index_shape)) + reshape_op = self.get_const_proxy(reshape_shape, torch.int32) + index = self.get_proxy(ascend_op.Reshape, (index, reshape_op)) + return self.masked_fill(x, index, values) + stacked_indices, indices_broadcast_shape, stacked_indices_last_dim = \ self.compute_stacked_indices(indices, x.node.meta['val'].shape) - values_broadcast_shape = indices_broadcast_shape + x_shape[stacked_indices_last_dim:] # batch_shape + inner_shape + values_broadcast_shape = indices_broadcast_shape + x_shape[stacked_indices_last_dim:] # batch_shape + inner_shape values_broadcast_shape_op = self.get_const_proxy(values_broadcast_shape, torch.int32) broadcasted_values = self.get_proxy(ascend_op.BroadcastTo, (values, values_broadcast_shape_op)) return self.get_proxy(ascend_op.TensorScatterUpdate, (x, stacked_indices, broadcasted_values)) @@ -1230,7 +1248,7 @@ def repeat_interleave(self, x, repeats, dim): transpose_perm_proxy = self.get_shape_proxy(transpose_perm) transpose_proxy = self.get_proxy(ascend_op.Transpose, (reshape_proxy, transpose_perm_proxy)) - result_reshape = x_shape[:dim] + [x_shape[dim] * repeats] + x_shape[dim+1:] + result_reshape = x_shape[:dim] + [x_shape[dim] * repeats] + x_shape[dim + 1:] result_reshape_shape_proxy = self.get_const_proxy(result_reshape, torch.int32) return self.get_proxy(ascend_op.Reshape, (transpose_proxy, result_reshape_shape_proxy)) @@ -1304,14 +1322,15 @@ def sum(self, a): return self.sumdim(a) @register_conversion(torch.ops.aten.sum.dim_IntList) - def sumdim(self, x, dims=[], keepdim=False, dtype=None): + def sumdim(self, x, dim=[], keepdim=False, dtype=None): x_dtype = x.node.meta['val'].dtype - if not isinstance(dims, list): - dims = [dims] - if dtype is None or x_dtype == dtype: - return self.get_proxy(ascend_op.ReduceSumD, (x, dims, keepdim)) - sum = self.get_proxy(ascend_op.ReduceSumD, (x, dims, keepdim)) - return self.get_proxy(ascend_op.Cast, (sum, get_ascend_dtype(dtype))) + x_shape = x.node.meta['val'].shape + if len(dim) == 0: + dim = list(range(len(x_shape))) + axes = self.get_const_proxy(dim, torch.int32) + if dtype and x_dtype != dtype: + x = self.get_proxy(ascend_op.Cast, (x, get_ascend_dtype(dtype))) + return self.get_proxy(ascend_op.ReduceSum, (x, axes, keepdim)) @register_conversion(torch.ops.aten.amax) def amax(self, x, dims, keepdim=False): @@ -1420,34 +1439,17 @@ def NativeDropoutBackward(self, grad_output, mask, scale): p = 1. - scale prob_op = self.get_const_proxy(float(p), dtype) return self.get_proxy(ascend_op.DropOutDoMaskV3, (grad_output, mask, prob_op)) - - @register_conversion([torch.ops.aten._adaptive_avg_pool2d.default]) - def adaptiveavgpool2d(self, x, output_size): - assert isinstance(output_size, int) or ( len(output_size) in range(1,3) and any(output_size) ) - if not isinstance(output_size, list): - if isinstance(output_size, tuple): - output_size = list(output_size) - elif isinstance(output_size, int): - output_size = [output_size, output_size] - else: - raise RuntimeError("not supported output type!") - return self.get_proxy(ascend_op.AdaptiveAvgPool2D, (x, output_size)) - - @register_conversion([torch.ops.aten._adaptive_avg_pool2d_backward.default]) - def adaptiveavgpool2dBackward(self, grad, input): - input_shape = list(input.node.meta['val'].shape) - return self.get_proxy(ascend_op.AdaptiveAvgPool2DGrad, (grad, input_shape)) @register_conversion([torch.ops.aten._adaptive_avg_pool2d.default]) def adaptiveavgpool2d(self, x, output_size): - assert isinstance(output_size, int) or ( len(output_size) in range(1,3) and any(output_size) ) + assert isinstance(output_size, int) or (len(output_size) in range(1, 3) and any(output_size)) if not isinstance(output_size, list): if isinstance(output_size, tuple): output_size = list(output_size) elif isinstance(output_size, int): output_size = [output_size, output_size] else: - raise RuntimeError("not supported output size!") + raise RuntimeError("not supported output type!") return self.get_proxy(ascend_op.AdaptiveAvgPool2D, (x, output_size)) @register_conversion([torch.ops.aten._adaptive_avg_pool2d_backward.default]) @@ -1475,6 +1477,10 @@ def Ge(self, x, y): def LogicalOr(self, x, y): return self.get_proxy(ascend_op.LogicalOr, (x, y)) + @register_conversion(torch.ops.aten.logical_not.default) + def LogicalNot(self, x): + return self.get_proxy(ascend_op.LogicalNot, (x,)) + @register_conversion(torch.ops.aten.slice_scatter.default) def SliceScatter(self, operand, src, dim=0, start=None, end=None, step=1): # modified from torchair @@ -1514,3 +1520,123 @@ def SliceScatter(self, operand, src, dim=0, start=None, end=None, step=1): @register_conversion(torch.ops.aten.scalar_tensor.default) def scalar_tensor(self, x, dtype=None, layout=None, device=None, pin_memory=None): return self.get_const_proxy(x, dtype) + + @register_conversion(torch.ops.lightllm.rotary_emb.default) + def lightllm_rotary_emb(self, x, cos, sin): + x_shape = list(x.node.meta['val'].shape) + assert len(x_shape) == 3 + + seq_len = x_shape[0] + dim = x_shape[2] + + cos_sin_shape = self.get_const_proxy([seq_len, 1, dim // 2], torch.int32) + cos = self.get_proxy(ascend_op.Reshape, (cos, cos_sin_shape)) + sin = self.get_proxy(ascend_op.Reshape, (sin, cos_sin_shape)) + + x = self.get_proxy(ascend_op.Unsqueeze, (x, [0])) + cos = self.get_proxy(ascend_op.Tile, (cos, [1, 1, 1, 2])) + sin = self.get_proxy(ascend_op.Tile, (sin, [1, 1, 1, 2])) + + out = self.get_proxy(ascend_op.RotaryMul, (x, cos, sin)) + return self.get_proxy(ascend_op.Squeeze, (out, [0])) + + @register_conversion(torch.ops.lightllm.rms_norm.default) + def lightllm_rms_norm(self, x, weight, eps): + out = self.get_proxy(ascend_op.RmsNorm, (x, weight, eps)) + return self.get_proxy(ascend_op.Identity, (out, 0)) + + @register_conversion(torch.ops.lightllm.prompt_attention_inference.default) + def prompt_attention_inference(self, q, k, v, seqlen, num_head, head_dim): + q_shape = list(q.node.meta['val'].shape) + seq_len = q_shape[1] + shape = [seq_len, seq_len] + shape = self.get_proxy(ascend_op.Const, (shape, torch.int32, [len(shape)])) + mask = self.get_proxy(ascend_op.Empty, (shape, torch.bool)) + mask = self.get_proxy(ascend_op.OnesLike, (mask,)) + mask = self.get_proxy(ascend_op.Tril, (mask,)) + mask = self.get_proxy(ascend_op.LogicalNot, (mask,)) + fa = self.get_proxy(ascend_op.PromptFlashAttention, (q, k, v, num_head, seqlen, mask, head_dim)) + return fa + + def incre_flash_attention(self, q, k, v, head_num, kv_head_num, dim): + k_list = [] + v_list = [] + if not isinstance(k, list): + k_list.append(k) + else: + k_list = k + if not isinstance(v, list): + v_list.append(v) + else: + v_list = v + assert len(k_list) == len(v_list) + kv_input_num = len(k_list) + out = self.get_proxy(ascend_op.IncreFlashAttention, (q, k_list, v_list, kv_input_num, kv_head_num, head_num, dim, "BSH")) + return out + + @register_conversion(aten.select_scatter.default) + def select_scatter(self, x, src, dim, index): + if not isinstance(index, torch.fx.proxy.Proxy): + index = self.get_const_proxy(index, torch.int32) + input_sizes = self.get_proxy(ascend_op.Shape, (x,)) + index = self.get_proxy(ascend_op.BroadcastTo, (index, input_sizes)) + dim_op = self.get_const_proxy(dim, torch.int32) + src = self.get_proxy(ascend_op.ExpandDims, (src, dim_op)) + src = self.get_proxy(ascend_op.BroadcastTo, (src, input_sizes)) + + return self.get_proxy(ascend_op.ScatterElements, (x, index, src, dim)) + + @register_conversion(torch.ops.lightllm.copy_with_offset.default) + def copy_with_offset(self, x, src, start_dim, end_dim): + dims = [x for x in range(start_dim, end_dim)] + dims = self.get_const_proxy(dims, torch.int32, target_shape=[len(dims), 1]) + return self.get_proxy(ascend_op.ScatterNdUpdate, (x, dims, src)) + + @register_conversion(torch.ops.lightllm.flash_attention_inference.default) + def flash_attention_inference(self, q, all_k, all_v, current_len, max_len): + q_shape = list(q.node.meta['val'].shape) + batch, head, dim = q_shape[0], q_shape[1], q_shape[2] + k_shape = list(all_k.node.meta['val'].shape) + kvhead = k_shape[1] + + res = [] + compute_batch = 1 + select_axis = self.get_const_proxy(0, torch.int32) + + for i in range(batch): + current_len = current_len[i] + select_index = self.get_const_proxy(i, torch.int32) + xq = self.get_proxy(ascend_op.GatherV2, (q, select_index, select_axis)) + + kv_start_index = self.get_const_proxy([i * max_len, 0, 0], torch.int32) + kv_end_index = self.get_const_proxy([i * max_len + current_len, kvhead, dim], torch.int32) + kv_seq_len = current_len + + kv_gather_shape = self.get_shape_proxy([compute_batch, kv_seq_len, kvhead, dim]) + kv_compute_shape = self.get_shape_proxy([compute_batch, kv_seq_len, kvhead * dim]) + + # fetch k + k = self.get_proxy(ascend_op.Slice, (all_k, kv_start_index, kv_end_index)) + k = self.get_proxy(ascend_op.Reshape, (k, kv_gather_shape)) + k = self.get_proxy(ascend_op.Reshape, (k, kv_compute_shape)) + + # fetch v + v = self.get_proxy(ascend_op.Slice, (all_v, kv_start_index, kv_end_index)) + v = self.get_proxy(ascend_op.Reshape, (v, kv_gather_shape)) + v = self.get_proxy(ascend_op.Reshape, (v, kv_compute_shape)) + + # k,v shape: batch, kv_seq_len, head, dim + q_shape = self.get_shape_proxy([compute_batch, 1, head, dim]) + q_compute_shape = self.get_shape_proxy([compute_batch, 1, head * dim]) + xq = self.get_proxy(ascend_op.Reshape, (xq, q_shape)) + xq = self.get_proxy(ascend_op.Reshape, (xq, q_compute_shape)) + + out = self.incre_flash_attention(xq, k, v, kvhead, head, dim) # q shape is BSH + out_shape = self.get_shape_proxy([compute_batch, 1, head, dim]) + out_shape2 = self.get_shape_proxy([compute_batch, head, dim]) + out = self.get_proxy(ascend_op.Reshape, (out, out_shape)) + out = self.get_proxy(ascend_op.Reshape, (out, out_shape2)) + res.append(out) + + res = self.get_proxy(ascend_op.ConcatD, (res, 0)) + return res diff --git a/dicp/dicp/vendor/AscendGraph/ext_ops.py b/dicp/dicp/vendor/AscendGraph/ext_ops.py new file mode 100644 index 000000000..324d2a9b2 --- /dev/null +++ b/dicp/dicp/vendor/AscendGraph/ext_ops.py @@ -0,0 +1,147 @@ +import math + +import torch +import torch_dipu +import torch.nn.functional as F + +from torch import Tensor +from typing import Sequence + +torch._dynamo.config.suppress_errors = False + + +# rotary_emb +@torch._custom_op.impl.custom_op('lightllm::rotary_emb') +def rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + ... + + +@rotary_emb.impl_abstract() +def lightllm_rotary_emb_abstract(x, cos, sin): + return torch.empty_like(x) + + +@rotary_emb.impl(['cpu', 'cuda']) +def lightllm_rotary_emb_impl(x, cos, sin): + seq_len, h, dim = x.shape + cos = cos.view((seq_len, 1, dim // 2)) + sin = sin.view((seq_len, 1, dim // 2)) + x0 = x[:, :, 0: dim // 2] + x1 = x[:, :, dim // 2: dim] + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + return torch.cat((o0, o1), dim=-1) + + +# rms_norm +@torch._custom_op.impl.custom_op('lightllm::rms_norm') +def rms_norm(x: Tensor, weight: Tensor, eps: float) -> Tensor: + ... + + +@rms_norm.impl_abstract() +def lightllm_rms_norm_abstract(x, weight, eps): + return torch.empty_like(x) + + +@rms_norm.impl(['cpu', 'cuda']) +def lightllm_rms_norm_impl(x, weight, eps): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) * weight + + +@torch._custom_op.impl.custom_op('lightllm::prompt_attention_inference') +def prompt_attention_inference(q: Tensor, k: Tensor, v: Tensor, seqlen: Tensor, num_head: int, head_dim: int) -> Tensor: + ... + + +@prompt_attention_inference.impl_abstract() +def lightllm_prompt_attention_inference_abstract(q: Tensor, k: Tensor, v: Tensor, seqlen: Tensor, num_head: int, head_dim: int): + return torch.empty_like(q) + + +@prompt_attention_inference.impl(['cpu', 'cuda']) +def lightllm_prompt_attention_inference_impl(q, k, v, seqlen, num_head, head_dim): + assert q.shape[0] == 1, "prompt attention just support bs=1 for now." + bs = q.shape[0] + seqlen = seqlen.item() + + xq = q.view(bs, seqlen, num_head, head_dim) + xk = k.view(bs, seqlen, num_head, head_dim) + xv = v.view(bs, seqlen, num_head, head_dim) + + mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0) + mask = mask.masked_fill(mask == 0., -999999999999.0) + mask = mask.masked_fill(mask == 1., 0.0) + mask = mask.repeat(bs, num_head, 1, 1) + + keys = xk + values = xv + xq = xq.transpose(1, 2).float() + keys = xk.transpose(1, 2).float() + values = xv.transpose(1, 2).float() + + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim) + scores = F.softmax(scores + mask.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) + + return output + + +@torch._custom_op.impl.custom_op('lightllm::flash_attention_inference') +def flash_attention_inference(q: Tensor, all_k: Tensor, all_v: Tensor, currnet_lens: Sequence[int], max_len: int) -> Tensor: + ... + + +@flash_attention_inference.impl_abstract() +def lightllm_flash_attention_inference_abstract(q: Tensor, all_k: Tensor, all_v: Tensor, currnet_lens: Sequence[int], max_len: int): + return torch.empty_like(q) + + +@flash_attention_inference.impl(['cpu', 'cuda']) +def lightllm_flash_attention_inference_impl(q, all_k, all_v, current_lens, max_len): + # q: batch, head, dim + batch = q.shape[0] + head = q.shape[1] + dim = q.shape[2] + res = [] + compute_batch = 1 + for i in range(batch): + current_len = current_lens[i] + kv_seq_len = current_len + + k = all_k[:current_len].reshape(compute_batch, kv_seq_len, head, dim) + v = all_v[:current_len].reshape(compute_batch, kv_seq_len, head, dim) + + xq = q[i].view(compute_batch, 1, head, dim).transpose(1, 2).transpose(0, 1) # shape: head, batch, 1, dim + bmm_xq = xq.reshape(head * compute_batch, 1, dim) + bmm_xk = k.transpose(1, 2).transpose(0, 1).transpose(2, 3).reshape(head * compute_batch, dim, kv_seq_len) + + # q @ k + out = torch.bmm(bmm_xq, bmm_xk) / math.sqrt(dim) + out = out.reshape(head, compute_batch, 1, -1).reshape(head, compute_batch, -1) + + # softmax + out = out.softmax(-1).reshape(head, compute_batch, 1, kv_seq_len).transpose(0, 1) # shape: batch head 1 seq_len + xv = v.transpose(1, 2) # shape: batch head, seq_len, dim + out = torch.bmm(out.reshape(compute_batch * head, 1, kv_seq_len), xv.reshape(compute_batch * head, kv_seq_len, dim)) + + out = out.reshape(compute_batch, head, 1, dim).view(compute_batch, head, dim) + res.append(out) + res = torch.cat(res) + return res + + +@torch._custom_op.impl.custom_op('lightllm::copy_with_offset') +def copy_with_offset(x: Tensor, src: Tensor, start_dim: int, end_dim: int) -> Tensor: + ... + + +@copy_with_offset.impl_abstract() +def lightllm_copy_with_offset_abstract(x: Tensor, src: Tensor, start_dim: int, end_dim: int) -> Tensor: + return x + + +@copy_with_offset.impl(['cpu', 'cuda']) +def lightllm_copy_with_offset_impl(x, src, start_dim, end_dim) -> Tensor: + x[start_dim:end_dim] = src + return x diff --git a/dicp/dicp/vendor/AscendGraph/opset_convert.py b/dicp/dicp/vendor/AscendGraph/opset_convert.py index e9a44216b..1db1f8b28 100644 --- a/dicp/dicp/vendor/AscendGraph/opset_convert.py +++ b/dicp/dicp/vendor/AscendGraph/opset_convert.py @@ -1,7 +1,7 @@ import torch import torch_dipu from dicp.dynamo_bridge.compile_fx import is_torch_210 -from dicp.vendor.AscendGraph.ascend_op import MatMul, CastToCpu, IdentityInp +from dicp.vendor.AscendGraph.ascend_op import CastToCpu, IdentityInp from dicp.vendor.AscendGraph.conversion import AtenToAscendTransformer from ...dynamo_bridge.graph import GraphTransformer @@ -43,9 +43,9 @@ def transform(self, gm: torch.fx.graph_module): for n in gm.graph.nodes: if n.op != 'call_function': continue - if type(n.target) == CastToCpu: + if isinstance(n.target, CastToCpu): self.cpu_tensor.append(n.name) - elif type(n.target) == IdentityInp: + elif isinstance(n.target, IdentityInp): if len(n.args) == 2 and n.args[1] is not None and str(n.args[1]) in input_names: self.assign_args.append((n.name, input_names.index(str(n.args[1])))) else: @@ -55,10 +55,10 @@ def transform(self, gm: torch.fx.graph_module): if n.op == 'call_function': prop = {} if n.name in self.cpu_tensor: - prop.update({'cpu_tensor' : n.name}) + prop.update({'cpu_tensor': n.name}) if len(self.assign_args) > 0 and n.name in list(zip(*self.assign_args))[0]: idx = list(zip(*self.assign_args))[0].index(n.name) - prop.update({'assign_args' : (self.assign_args[idx][0], self.assign_args[idx][1])}) + prop.update({'assign_args': (self.assign_args[idx][0], self.assign_args[idx][1])}) n.meta['prop'] = prop return gm @@ -77,6 +77,7 @@ def symint_in_inputs(nodes): return True return False + def ascendgraph_opset_convert( gm: torch.fx.GraphModule, ): diff --git a/dicp/dicp/vendor/AscendGraph/pattern_replacement.py b/dicp/dicp/vendor/AscendGraph/pattern_replacement.py index 6e737f085..d2ef21be1 100644 --- a/dicp/dicp/vendor/AscendGraph/pattern_replacement.py +++ b/dicp/dicp/vendor/AscendGraph/pattern_replacement.py @@ -27,16 +27,17 @@ def replacement(input, dims): varVal = torch.ops.aten.var(input, dims, correction=1, keepdim=True) return ascend_op.ret_tuple(varVal, meanVal) + @register_aten_pattern class FusedRepeatInterleaveSelfInt(BackendPatternBase): @staticmethod def pattern(self, repeat, dim, input_shape, empty_device, view_1_shape, expand_1_shape, repeat_interleave_output_size): - empty = torch.ops.aten.empty.memory_format(input_shape, dtype = torch.int64, layout = torch.strided, device=empty_device) + empty = torch.ops.aten.empty.memory_format(input_shape, dtype=torch.int64, layout=torch.strided, device=empty_device) fill = torch.ops.aten.fill.Scalar(empty, repeat) view_1 = torch.ops.aten.view.default(fill, view_1_shape) expand_1 = torch.ops.aten.expand.default(view_1, expand_1_shape) - repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(expand_1, output_size = repeat_interleave_output_size) + repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(expand_1, output_size=repeat_interleave_output_size) index_select = torch.ops.aten.index_select.default(self, dim, repeat_interleave) return index_select @@ -44,6 +45,23 @@ def pattern(self, repeat, dim, input_shape, empty_device, view_1_shape, def replacement(self, repeat, dim): return torch.ops.aten.repeat_interleave.self_int(self, repeat, dim) + +@register_aten_pattern +class ReplaceAtenSliceScatter(BackendPatternBase): + @staticmethod + def pattern(arg0, arg1, start_index, end_index): + slice = torch.ops.aten.slice.Tensor(arg0, 0, start_index, end_index) + copy = torch.ops.aten.copy.default(slice, arg1) + slice_scatter = torch.ops.aten.slice_scatter.default(arg0, copy, 0, start_index, end_index) + copy_ = torch.ops.aten.copy_.default(slice_scatter, arg0) + return slice_scatter + + @staticmethod + def replacement(arg0, arg1, start_index, end_index): + slice_scatter = torch.ops.lightllm.copy_with_offset.default(arg0, arg1, start_index, end_index) + return slice_scatter + + Muls = torch.fx.wrap(ascend_op.Muls.get_singleton()) Shape = torch.fx.wrap(ascend_op.Shape.get_singleton()) Const = torch.fx.wrap(ascend_op.Const.get_singleton()) @@ -55,7 +73,7 @@ def replacement(self, repeat, dim): MatMul = torch.fx.wrap(ascend_op.MatMul.get_singleton()) -@register_ascend_pattern +# @register_ascend_pattern class FuseBmmTransposeRhsPattern(BackendPatternBase): @staticmethod def pattern(x1, x2, dtype):