Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dicp/dicp/dynamo_bridge/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ def compile_fx_210(
def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference):
if is_inference:
# partition_fn won't be called
joint_graph_passes(model)
# joint_graph_passes(model)
pass

fixed = len(example_inputs) - num_example_inputs
return inner_compile(
Expand Down
38 changes: 38 additions & 0 deletions dicp/dicp/dynamo_bridge/decompositions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from collections import defaultdict
from typing import Callable, Dict, Sequence, Union

import torch
from torch._decomp import register_decomposition
from torch._ops import OpOverload, OpOverloadPacket

dicp_decomposition_table = {}
aten = torch.ops.aten


def register_decomposition_for_dicp(fn):
return register_decomposition(fn, registry=dicp_decomposition_table)


@register_decomposition_for_dicp(aten.count_nonzero.default)
def count_nonzero_default(x, dim=None):
cond = x != 0
dim = [] if dim is None else dim
return aten.sum.dim_IntList(cond, dim=dim, keepdim=False, dtype=torch.int64)


def get_decompositions(
aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]],
target_decomposition_table: Dict[OpOverload, Callable] = None,
) -> Dict[OpOverload, Callable]:
registry = dicp_decomposition_table
packets_to_overloads = defaultdict(list)
for opo in registry:
packets_to_overloads[opo.overloadpacket].append(opo)
decompositions = target_decomposition_table if target_decomposition_table else {}
for op in aten_ops:
if isinstance(op, OpOverloadPacket) and op in packets_to_overloads:
for op_overload in packets_to_overloads[op]:
decompositions[op_overload] = registry[op_overload]
elif isinstance(op, OpOverload) and op in registry:
decompositions[op] = registry[op]
return decompositions
95 changes: 88 additions & 7 deletions dicp/dicp/vendor/AscendGraph/ascend_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

aten = torch.ops.aten


def negative_in_shape(shape):
for elem in shape:
if elem < 0:
Expand Down Expand Up @@ -43,12 +44,12 @@ def __init__(self):

def infer_result(self, x, shape):
x, x_shape, _, x_dtype = get_fake_tensor_meta_val(x)
if isinstance(shape, torch._subclasses.fake_tensor.FakeTensor): # case1: shape is a fakeTensor, like conversion for 'scatter' and 'where'
if isinstance(shape, torch._subclasses.fake_tensor.FakeTensor): # case1: shape is a fakeTensor, like conversion for 'scatter' and 'where'
shape, shape_shape, _, _ = get_fake_tensor_meta_val(shape)
shape = shape_shape
elif isinstance(shape, Tuple): # case2: shape is tuple from 'Const' , like conversion for 'lt'
shape, _, _, _ =get_op_const_arg_kwarg(shape)
else: # other cases, unsupported yet
elif isinstance(shape, Tuple): # case2: shape is tuple from 'Const' , like conversion for 'lt'
shape, _, _, _ = get_op_const_arg_kwarg(shape)
else: # other cases, unsupported yet
assert False, self.__class__.__name__ + "unsupported 'shape' input type!"

out_shape = get_broadcast_res_two_shape(x_shape, shape)
Expand Down Expand Up @@ -97,7 +98,7 @@ def __init__(self):
class MatMul(Operator):
def __init__(self):
super().__init__("MatMul")

def infer_result(self, x1, x2, adj_x1=False, adj_x2=False):
attr = acl.op.create_attr()
check_ret("acl.op.set_attr_bool", acl.op.set_attr_bool(attr, "transpose_x1", adj_x1))
Expand Down Expand Up @@ -290,6 +291,14 @@ def infer_result(self, x, dims, keepdim):
return reduce_op_infer(x, dims, keepdim)


class ReduceSum(Operator):
def __init__(self):
super().__init__("ReduceSum")

def infer_result(self, x, dims, keepdim):
return reduce_op_infer(x, dims, keepdim)


class Unsqueeze(Operator):
def __init__(self):
super().__init__("Unsqueeze")
Expand Down Expand Up @@ -628,7 +637,7 @@ def infer_result(self, x, index, orig_index):

# assume not none index, and replace prefix x_shape dims
len_idx_shape = len(orig_index)
assert(len_idx_shape > 0)
assert (len_idx_shape > 0)
bcast_index_shape = list(orig_index[0].shape)
x_shape = bcast_index_shape + list(x_shape[len_idx_shape:])
return torch.empty(x_shape, dtype=x_dtype, memory_format=get_memory_format(x))
Expand Down Expand Up @@ -962,6 +971,14 @@ def infer_result(self, x1, x2):
return common_binary_op_infer(x1, x2, torch.bool)


class LogicalNot(Operator):
def __init__(self):
super().__init__("LogicalNot")

def infer_result(self, x):
return common_binary_op_infer(x, torch.bool)


class Tril(Operator):
def __init__(self):
super().__init__("Tril")
Expand Down Expand Up @@ -1023,7 +1040,7 @@ def infer_result(
output_batch_var = torch.empty(
[channel_size], dtype=torch.float32, memory_format=torch.contiguous_format
)
return [output_y,output_mean,output_var,output_batch_mean,output_batch_var]
return [output_y, output_mean, output_var, output_batch_mean, output_batch_var]


class TileWithAxis(Operator):
Expand All @@ -1032,6 +1049,38 @@ def __init__(self):
self.torch_op = aten.repeat_interleave.self_int


class RotaryMul(Operator):
def __init__(self):
super().__init__("RotaryMul")

def infer_result(self, x, cos, sin):
return torch.empty_like(x)


class RmsNorm(Operator):
def __init__(self):
super().__init__("RmsNorm")

def infer_result(self, x, weight, eps):
return torch.empty_like(x)


class PromptFlashAttention(Operator):
def __init__(self):
super().__init__("PromptFlashAttention")

def infer_result(self, q, k, v, num_head, seqlen, mask, head_dim):
return torch.empty_like(q)


class IncreFlashAttention(Operator):
def __init__(self):
super().__init__("IncreFlashAttention")

def infer_result(self, q, k, v, head_num):
return torch.empty_like(q)


class TensorScatterUpdate(Operator):
def __init__(self):
super().__init__("TensorScatterUpdate")
Expand All @@ -1054,6 +1103,38 @@ def infer_result(self, x, indices, updates):
return torch.empty(x_shape, dtype=x_dtype, memory_format=get_memory_format(x))


class ExpandDims(Operator):
def __init__(self):
super().__init__("ExpandDims")

def infer_result(self, x, axis):
return torch.unsqueeze(x, axis)


class MaskedScatter(Operator):
def __init__(self):
super().__init__("MaskedScatter")

def infer_result(self, x, mask, updates):
return x


class ViewCopy(Operator):
def __init__(self):
super().__init__("ViewCopy")

def infer_result(self, dst, dst_size, dst_stride, dst_storage_offset, src, src_size, src_stride, src_storage_offset):
return dst


class ScatterNdUpdate(Operator):
def __init__(self):
super().__init__("ScatterNdUpdate")

def infer_result(self, x, indices, updates):
return x


def ret_triple(a, b, c) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return a, b, c

Expand Down
Loading