diff --git a/dicp/dicp/dynamo_bridge/compile_fx.py b/dicp/dicp/dynamo_bridge/compile_fx.py index 5eaee327c..42465d27d 100644 --- a/dicp/dicp/dynamo_bridge/compile_fx.py +++ b/dicp/dicp/dynamo_bridge/compile_fx.py @@ -208,7 +208,8 @@ def compile_fx_210( def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference): if is_inference: # partition_fn won't be called - joint_graph_passes(model) + # joint_graph_passes(model) + pass fixed = len(example_inputs) - num_example_inputs return inner_compile( diff --git a/dicp/dicp/vendor/AscendGraph/ascend_op.py b/dicp/dicp/vendor/AscendGraph/ascend_op.py index cadc8798d..45a2010bc 100644 --- a/dicp/dicp/vendor/AscendGraph/ascend_op.py +++ b/dicp/dicp/vendor/AscendGraph/ascend_op.py @@ -1115,14 +1115,6 @@ def infer_result(self, x, index): return torch.empty(idx_shape, dtype=x_dtype, memory_format=get_memory_format(x)) -class InplaceCopyWithOffset(Operator): - def __init__(self): - super().__init__("InplaceCopyWithOffset") - - def infer_result(self, x, src, dim, offset): - return src - - class ExpandDims(Operator): def __init__(self): super().__init__("ExpandDims") @@ -1144,7 +1136,7 @@ def __init__(self): super().__init__("ViewCopy") def infer_result(self, dst, dst_size, dst_stride, dst_storage_offset, src, src_size, src_stride, src_storage_offset): - return x + return dst class ScatterNdUpdate(Operator): diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 80fcac26d..112446f93 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -71,7 +71,6 @@ def __init__(self, graph, aten_graph=None, folder=None, graph_key=None): # for modified args return self.assign_args = [] self.cpu_tensor = [] - self.assign_with_offset_args = [] super().__init__(graph) @@ -135,8 +134,6 @@ def call_function(self, name, target, args, kwargs): self.cpu_tensor.append(self.cur_node.meta['prop']['cpu_tensor']) if 'prop' in self.cur_node.meta and 'assign_args' in self.cur_node.meta['prop']: self.assign_args.append(self.cur_node.meta['prop']['assign_args']) - if 'prop' in self.cur_node.meta and 'assign_with_offset_args' in self.cur_node.meta['prop']: - self.assign_with_offset_args.append(self.cur_node.meta['prop']['assign_with_offset_args']) _, args_list = AscendOverrides.gen_args( self.args_dict[name], self.args_dict, args) @@ -199,10 +196,6 @@ def parse_outputs(self): if len(self.assign_args) > 0: self.graph_output_names.extend(list(zip(*self.assign_args))[0]) - current_index = len(self.graph_output_names) - for i in range(len(self.assign_with_offset_args)): - self.assign_with_offset_args[i]['output_index'] = current_index - current_index += 1 def gen_import_code(self): self.import_code.splice( @@ -380,15 +373,7 @@ def gen_call_func(self): output_index = self.graph_output_names.index(item[0]) allocated_output[output_index] = input_index call_body.writeline(f'allocated_output= {allocated_output}') - - allocated_output_with_offset = {} - for item in self.assign_with_offset_args: - # input_index = item['input_index'] - output_index = item['output_index'] - # offset = item['offset'] - allocated_output_with_offset[output_index] = item - call_body.writeline(f'allocated_with_offset_output= {self.assign_with_offset_args}') - call_str = ['output_tensor = kernel_cpp_0(args, dims, output_shape, out_stride, out_storage_offset, allocated_output, allocated_with_offset_output)'] + call_str = ['output_tensor = kernel_cpp_0(args, dims, output_shape, out_stride, out_storage_offset, allocated_output)'] if precision_check and self.aten_graph is not None: # import aten graph @@ -1673,12 +1658,6 @@ def ExpandDims(name, x, axis): gather_op.set_input("axis", axis) return gather_op.to_node() - @staticmethod - def InplaceCopyWithOffset(name, x, src, dim, offset): - op = OP(name, "Identity") - op.set_input("x", src) - return op.to_node() - @staticmethod def MaskedScatter(name, x, mask, updates): op = OP(name, "MaskedScatter") diff --git a/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py b/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py index d8e188c20..8c14e1a44 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py @@ -297,7 +297,7 @@ def _prepare_input(self, images, dims): assert (dataset == self.input_dataset) @record_function('load_and_run_prepare_output') - def _prepare_output(self, output_tensor, output_shape, out_stride, out_storage_offset, allocated_output, allocated_output_with_offset_tensor): + def _prepare_output(self, output_tensor, output_shape, out_stride, out_storage_offset, allocated_output): for i in range(self.num_outputs): if allocated_output and i in allocated_output.keys(): item = allocated_output[i] @@ -339,7 +339,7 @@ def _prepare_dynamic_output(self, output_tensor, output_shape, out_stride, out_s @record_function('load_and_run_run') def run(self, images, dims=None, output_shape=None, out_stride=None, out_storage_offset=None, - allocated_output=None, allocated_with_offset_output=None): + allocated_output=None): with record_function(f'load_and_run_run_{self.model_id}'): assert len(images) > 0 input = [x.to(dipu_device_str) if isinstance(x, torch.Tensor) @@ -350,12 +350,6 @@ def run(self, images, dims=None, output_shape=None, for output_index, input_index in allocated_output.items(): allocated_output_tensor[output_index] = input[input_index] - allocated_output_with_offset_tensor = None - if allocated_with_offset_output: - allocated_output_with_offset_tensor = {} - for item in allocated_with_offset_output: - allocated_output_with_offset_tensor[item['output_index']] = {'input_tensor': input[item['input_index']], 'offset': item['offset']} - self._prepare_input(input, dims) output = [] if output_shape: @@ -363,7 +357,7 @@ def run(self, images, dims=None, output_shape=None, output, output_shape, out_stride, out_storage_offset, allocated_output_tensor) else: self._prepare_output( - output, output_shape, out_stride, out_storage_offset, allocated_output_tensor, allocated_output_with_offset_tensor) + output, output_shape, out_stride, out_storage_offset, allocated_output_tensor) self.forward() self._destroy_databuffer() return output @@ -388,8 +382,8 @@ def __init__(self, device_id, model_path) -> None: self.exe = AscendExecutor(device_id, model_path) def run(self, images, dims=None, output_shape=None, - out_stride=None, out_storage_offset=None, allocated_output=None, allocated_with_offset_output=None): - return self.exe.run(images, dims, output_shape, out_stride, out_storage_offset, allocated_output, allocated_with_offset_output) + out_stride=None, out_storage_offset=None, allocated_output=None): + return self.exe.run(images, dims, output_shape, out_stride, out_storage_offset, allocated_output) def cleanup(self): if hasattr(self, 'exe'): diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 5d957da16..761436a63 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -1574,7 +1574,9 @@ def copy_with_offset2(self, x, src, start_dim, end_dim): def flash_attention_inference(self, q, all_k, all_v, current_len, max_len): q_shape = list(q.node.meta['val'].shape) batch, head, dim = q_shape[0], q_shape[1], q_shape[2] - + k_shape = list(all_k.node.meta['val'].shape) + kvhead = k_shape[1] + res = [] compute_batch = 1 select_axis = self.get_const_proxy(0, torch.int32) @@ -1585,12 +1587,12 @@ def flash_attention_inference(self, q, all_k, all_v, current_len, max_len): xq = self.get_proxy(ascend_op.GatherV2, (q, select_index, select_axis)) kv_start_index = self.get_const_proxy([i * max_len, 0, 0], torch.int32) - kv_end_index = self.get_const_proxy([i * max_len + current_len, head, dim], torch.int32) + kv_end_index = self.get_const_proxy([i * max_len + current_len, kvhead, dim], torch.int32) kv_seq_len = current_len - kv_gather_shape = self.get_shape_proxy([compute_batch, kv_seq_len, head, dim]) - kv_compute_shape = self.get_shape_proxy([compute_batch, kv_seq_len, head * dim]) - + kv_gather_shape = self.get_shape_proxy([compute_batch, kv_seq_len, kvhead, dim]) + kv_compute_shape = self.get_shape_proxy([compute_batch, kv_seq_len, kvhead * dim]) + # fetch k k = self.get_proxy(ascend_op.Slice, (all_k, kv_start_index, kv_end_index)) k = self.get_proxy(ascend_op.Reshape, (k, kv_gather_shape)) @@ -1606,13 +1608,13 @@ def flash_attention_inference(self, q, all_k, all_v, current_len, max_len): q_compute_shape = self.get_shape_proxy([compute_batch, 1, head * dim]) xq = self.get_proxy(ascend_op.Reshape, (xq, q_shape)) xq = self.get_proxy(ascend_op.Reshape, (xq, q_compute_shape)) - - out = self.incre_flash_attention(xq, k, v, head, head, dim) # q shape is BSH + + out = self.incre_flash_attention(xq, k, v, kvhead, head, dim) # q shape is BSH out_shape = self.get_shape_proxy([compute_batch, 1, head, dim]) out_shape2 = self.get_shape_proxy([compute_batch, head, dim]) out = self.get_proxy(ascend_op.Reshape, (out, out_shape)) out = self.get_proxy(ascend_op.Reshape, (out, out_shape2)) res.append(out) - + res = self.get_proxy(ascend_op.ConcatD, (res, 0)) return res diff --git a/dicp/dicp/vendor/AscendGraph/opset_convert.py b/dicp/dicp/vendor/AscendGraph/opset_convert.py index 040efb0e0..39534710a 100644 --- a/dicp/dicp/vendor/AscendGraph/opset_convert.py +++ b/dicp/dicp/vendor/AscendGraph/opset_convert.py @@ -1,7 +1,7 @@ import torch import torch_dipu from dicp.dynamo_bridge.compile_fx import is_torch_210 -from dicp.vendor.AscendGraph.ascend_op import CastToCpu, IdentityInp, InplaceCopyWithOffset +from dicp.vendor.AscendGraph.ascend_op import CastToCpu, IdentityInp from dicp.vendor.AscendGraph.conversion import AtenToAscendTransformer from ...dynamo_bridge.graph import GraphTransformer @@ -46,10 +46,6 @@ def transform(self, gm: torch.fx.graph_module): continue if isinstance(n.target, CastToCpu): self.cpu_tensor.append(n.name) - elif isinstance(n.target, InplaceCopyWithOffset): - input_index = input_names.index(str(n.args[0])) - offset = int(n.args[-1]) - self.assign_with_offset_args[n.name] = {'name': n.name, 'input_index': input_index, 'offset': offset} elif isinstance(n.target, IdentityInp): if len(n.args) == 2 and n.args[1] is not None and str(n.args[1]) in input_names: self.assign_args.append((n.name, input_names.index(str(n.args[1])))) @@ -64,8 +60,6 @@ def transform(self, gm: torch.fx.graph_module): if len(self.assign_args) > 0 and n.name in list(zip(*self.assign_args))[0]: idx = list(zip(*self.assign_args))[0].index(n.name) prop.update({'assign_args': (self.assign_args[idx][0], self.assign_args[idx][1])}) - if n.name in self.assign_with_offset_args.keys(): - prop['assign_with_offset_args'] = self.assign_with_offset_args[n.name] n.meta['prop'] = prop return gm