Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dicp/dicp/dynamo_bridge/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ def compile_fx_210(
def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference):
if is_inference:
# partition_fn won't be called
joint_graph_passes(model)
# joint_graph_passes(model)
pass

fixed = len(example_inputs) - num_example_inputs
return inner_compile(
Expand Down
10 changes: 1 addition & 9 deletions dicp/dicp/vendor/AscendGraph/ascend_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,14 +1115,6 @@ def infer_result(self, x, index):
return torch.empty(idx_shape, dtype=x_dtype, memory_format=get_memory_format(x))


class InplaceCopyWithOffset(Operator):
def __init__(self):
super().__init__("InplaceCopyWithOffset")

def infer_result(self, x, src, dim, offset):
return src


class ExpandDims(Operator):
def __init__(self):
super().__init__("ExpandDims")
Expand All @@ -1144,7 +1136,7 @@ def __init__(self):
super().__init__("ViewCopy")

def infer_result(self, dst, dst_size, dst_stride, dst_storage_offset, src, src_size, src_stride, src_storage_offset):
return x
return dst


class ScatterNdUpdate(Operator):
Expand Down
23 changes: 1 addition & 22 deletions dicp/dicp/vendor/AscendGraph/codegen/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def __init__(self, graph, aten_graph=None, folder=None, graph_key=None):
# for modified args return
self.assign_args = []
self.cpu_tensor = []
self.assign_with_offset_args = []

super().__init__(graph)

Expand Down Expand Up @@ -135,8 +134,6 @@ def call_function(self, name, target, args, kwargs):
self.cpu_tensor.append(self.cur_node.meta['prop']['cpu_tensor'])
if 'prop' in self.cur_node.meta and 'assign_args' in self.cur_node.meta['prop']:
self.assign_args.append(self.cur_node.meta['prop']['assign_args'])
if 'prop' in self.cur_node.meta and 'assign_with_offset_args' in self.cur_node.meta['prop']:
self.assign_with_offset_args.append(self.cur_node.meta['prop']['assign_with_offset_args'])

_, args_list = AscendOverrides.gen_args(
self.args_dict[name], self.args_dict, args)
Expand Down Expand Up @@ -199,10 +196,6 @@ def parse_outputs(self):

if len(self.assign_args) > 0:
self.graph_output_names.extend(list(zip(*self.assign_args))[0])
current_index = len(self.graph_output_names)
for i in range(len(self.assign_with_offset_args)):
self.assign_with_offset_args[i]['output_index'] = current_index
current_index += 1

def gen_import_code(self):
self.import_code.splice(
Expand Down Expand Up @@ -380,15 +373,7 @@ def gen_call_func(self):
output_index = self.graph_output_names.index(item[0])
allocated_output[output_index] = input_index
call_body.writeline(f'allocated_output= {allocated_output}')

allocated_output_with_offset = {}
for item in self.assign_with_offset_args:
# input_index = item['input_index']
output_index = item['output_index']
# offset = item['offset']
allocated_output_with_offset[output_index] = item
call_body.writeline(f'allocated_with_offset_output= {self.assign_with_offset_args}')
call_str = ['output_tensor = kernel_cpp_0(args, dims, output_shape, out_stride, out_storage_offset, allocated_output, allocated_with_offset_output)']
call_str = ['output_tensor = kernel_cpp_0(args, dims, output_shape, out_stride, out_storage_offset, allocated_output)']

if precision_check and self.aten_graph is not None:
# import aten graph
Expand Down Expand Up @@ -1673,12 +1658,6 @@ def ExpandDims(name, x, axis):
gather_op.set_input("axis", axis)
return gather_op.to_node()

@staticmethod
def InplaceCopyWithOffset(name, x, src, dim, offset):
op = OP(name, "Identity")
op.set_input("x", src)
return op.to_node()

@staticmethod
def MaskedScatter(name, x, mask, updates):
op = OP(name, "MaskedScatter")
Expand Down
16 changes: 5 additions & 11 deletions dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def _prepare_input(self, images, dims):
assert (dataset == self.input_dataset)

@record_function('load_and_run_prepare_output')
def _prepare_output(self, output_tensor, output_shape, out_stride, out_storage_offset, allocated_output, allocated_output_with_offset_tensor):
def _prepare_output(self, output_tensor, output_shape, out_stride, out_storage_offset, allocated_output):
for i in range(self.num_outputs):
if allocated_output and i in allocated_output.keys():
item = allocated_output[i]
Expand Down Expand Up @@ -339,7 +339,7 @@ def _prepare_dynamic_output(self, output_tensor, output_shape, out_stride, out_s
@record_function('load_and_run_run')
def run(self, images, dims=None, output_shape=None,
out_stride=None, out_storage_offset=None,
allocated_output=None, allocated_with_offset_output=None):
allocated_output=None):
with record_function(f'load_and_run_run_{self.model_id}'):
assert len(images) > 0
input = [x.to(dipu_device_str) if isinstance(x, torch.Tensor)
Expand All @@ -350,20 +350,14 @@ def run(self, images, dims=None, output_shape=None,
for output_index, input_index in allocated_output.items():
allocated_output_tensor[output_index] = input[input_index]

allocated_output_with_offset_tensor = None
if allocated_with_offset_output:
allocated_output_with_offset_tensor = {}
for item in allocated_with_offset_output:
allocated_output_with_offset_tensor[item['output_index']] = {'input_tensor': input[item['input_index']], 'offset': item['offset']}

self._prepare_input(input, dims)
output = []
if output_shape:
self._prepare_dynamic_output(
output, output_shape, out_stride, out_storage_offset, allocated_output_tensor)
else:
self._prepare_output(
output, output_shape, out_stride, out_storage_offset, allocated_output_tensor, allocated_output_with_offset_tensor)
output, output_shape, out_stride, out_storage_offset, allocated_output_tensor)
self.forward()
self._destroy_databuffer()
return output
Expand All @@ -388,8 +382,8 @@ def __init__(self, device_id, model_path) -> None:
self.exe = AscendExecutor(device_id, model_path)

def run(self, images, dims=None, output_shape=None,
out_stride=None, out_storage_offset=None, allocated_output=None, allocated_with_offset_output=None):
return self.exe.run(images, dims, output_shape, out_stride, out_storage_offset, allocated_output, allocated_with_offset_output)
out_stride=None, out_storage_offset=None, allocated_output=None):
return self.exe.run(images, dims, output_shape, out_stride, out_storage_offset, allocated_output)

def cleanup(self):
if hasattr(self, 'exe'):
Expand Down
18 changes: 10 additions & 8 deletions dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1574,7 +1574,9 @@ def copy_with_offset2(self, x, src, start_dim, end_dim):
def flash_attention_inference(self, q, all_k, all_v, current_len, max_len):
q_shape = list(q.node.meta['val'].shape)
batch, head, dim = q_shape[0], q_shape[1], q_shape[2]

k_shape = list(all_k.node.meta['val'].shape)
kvhead = k_shape[1]

res = []
compute_batch = 1
select_axis = self.get_const_proxy(0, torch.int32)
Expand All @@ -1585,12 +1587,12 @@ def flash_attention_inference(self, q, all_k, all_v, current_len, max_len):
xq = self.get_proxy(ascend_op.GatherV2, (q, select_index, select_axis))

kv_start_index = self.get_const_proxy([i * max_len, 0, 0], torch.int32)
kv_end_index = self.get_const_proxy([i * max_len + current_len, head, dim], torch.int32)
kv_end_index = self.get_const_proxy([i * max_len + current_len, kvhead, dim], torch.int32)
kv_seq_len = current_len

kv_gather_shape = self.get_shape_proxy([compute_batch, kv_seq_len, head, dim])
kv_compute_shape = self.get_shape_proxy([compute_batch, kv_seq_len, head * dim])

kv_gather_shape = self.get_shape_proxy([compute_batch, kv_seq_len, kvhead, dim])
kv_compute_shape = self.get_shape_proxy([compute_batch, kv_seq_len, kvhead * dim])
# fetch k
k = self.get_proxy(ascend_op.Slice, (all_k, kv_start_index, kv_end_index))
k = self.get_proxy(ascend_op.Reshape, (k, kv_gather_shape))
Expand All @@ -1606,13 +1608,13 @@ def flash_attention_inference(self, q, all_k, all_v, current_len, max_len):
q_compute_shape = self.get_shape_proxy([compute_batch, 1, head * dim])
xq = self.get_proxy(ascend_op.Reshape, (xq, q_shape))
xq = self.get_proxy(ascend_op.Reshape, (xq, q_compute_shape))

out = self.incre_flash_attention(xq, k, v, head, head, dim) # q shape is BSH
out = self.incre_flash_attention(xq, k, v, kvhead, head, dim) # q shape is BSH
out_shape = self.get_shape_proxy([compute_batch, 1, head, dim])
out_shape2 = self.get_shape_proxy([compute_batch, head, dim])
out = self.get_proxy(ascend_op.Reshape, (out, out_shape))
out = self.get_proxy(ascend_op.Reshape, (out, out_shape2))
res.append(out)

res = self.get_proxy(ascend_op.ConcatD, (res, 0))
return res
8 changes: 1 addition & 7 deletions dicp/dicp/vendor/AscendGraph/opset_convert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch_dipu
from dicp.dynamo_bridge.compile_fx import is_torch_210
from dicp.vendor.AscendGraph.ascend_op import CastToCpu, IdentityInp, InplaceCopyWithOffset
from dicp.vendor.AscendGraph.ascend_op import CastToCpu, IdentityInp
from dicp.vendor.AscendGraph.conversion import AtenToAscendTransformer
from ...dynamo_bridge.graph import GraphTransformer

Expand Down Expand Up @@ -46,10 +46,6 @@ def transform(self, gm: torch.fx.graph_module):
continue
if isinstance(n.target, CastToCpu):
self.cpu_tensor.append(n.name)
elif isinstance(n.target, InplaceCopyWithOffset):
input_index = input_names.index(str(n.args[0]))
offset = int(n.args[-1])
self.assign_with_offset_args[n.name] = {'name': n.name, 'input_index': input_index, 'offset': offset}
elif isinstance(n.target, IdentityInp):
if len(n.args) == 2 and n.args[1] is not None and str(n.args[1]) in input_names:
self.assign_args.append((n.name, input_names.index(str(n.args[1]))))
Expand All @@ -64,8 +60,6 @@ def transform(self, gm: torch.fx.graph_module):
if len(self.assign_args) > 0 and n.name in list(zip(*self.assign_args))[0]:
idx = list(zip(*self.assign_args))[0].index(n.name)
prop.update({'assign_args': (self.assign_args[idx][0], self.assign_args[idx][1])})
if n.name in self.assign_with_offset_args.keys():
prop['assign_with_offset_args'] = self.assign_with_offset_args[n.name]
n.meta['prop'] = prop
return gm

Expand Down