Skip to content

Commit

Permalink
[FxImporter] Fix constant bool tensor (llvm#3375)
Browse files Browse the repository at this point in the history
  • Loading branch information
penguin-wwy authored and Branko Trifkovic committed May 24, 2024
1 parent 3b8639a commit 313dc33
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 16 deletions.
1 change: 0 additions & 1 deletion projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,6 @@
"BoolIntTrueModule_basic",
"BroadcastDynamicDimModule_basic",
"CeilFloatModule_basic",
"ConstantBoolParameterModule_basic",
"ContainsIntList_False",
"ContainsIntList_True",
"Conv2dQInt8Module_basic",
Expand Down
68 changes: 53 additions & 15 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,10 @@ def _graph_to_function_meta(self, g: Graph) -> Tuple[FunctionType, Location]:
result_types.append(
IrType.parse("!torch.none", context=self._c)
)
elif isinstance(result_node, torch.Tensor):
result_types.append(
self._cc.tensor_to_vtensor_type(result_node)
)
else:
result_types.append(self._cc.node_val_to_type(result_node))
return (
Expand Down Expand Up @@ -1002,9 +1006,14 @@ def dtype_to_type(self, dtype: TorchDtype) -> IrType:
self._dtype_to_type[dtype] = t
return t

def create_vtensor_type(self, dtype: torch.dtype, size: torch.Size) -> IrType:
dtype_asm = str(self.dtype_to_type(dtype))
return IrType.parse(
f"!torch.vtensor<{list(size)},{dtype_asm}>", context=self._c
)

def tensor_to_vtensor_type(self, tensor: torch.Tensor) -> IrType:
dtype_asm = str(self.dtype_to_type(tensor.dtype))
return IrType.parse(f"!torch.vtensor<{list(tensor.size())},{dtype_asm}>")
return self.create_vtensor_type(tensor.dtype, tensor.size())

def get_node_location(self, node: torch_fx.Node) -> Optional[Location]:
stack_trace = node.meta.get("stack_trace")
Expand Down Expand Up @@ -1513,37 +1522,58 @@ def _import_argument(
):
# promote scalars to tensor types as appropriate
argument_value = self._import_scalar_as_tensor(loc, arg)
else:
elif LITERAL_CONVERTER_MAP.lookup(type(arg)) is not None:
with loc:
argument_value = self._import_literal(arg)
return self._convert_type(loc, argument_value, expected_jit_type)
else:
raise TypeError(f"Unsupported argument type {arg.__class__}")
with loc:
return self._convert_type(argument_value, expected_jit_type)

def _convert_type(self, loc: Location, val: Value, expected_jit_type):
def _convert_type(
self,
val: Value,
expected_type,
dtype: Optional[torch.dtype] = None,
size: Optional[torch.Size] = None,
):
"""
When the type of 'value' and the type in the schema do not match,
attempt to perform automatic type conversion.
example: test/python/fx_importer/basic_test.py::test_full
"""
if not expected_type:
return val
op_name = None
result_type = None
# TODO: If additional types require conversion in the future,
# consider implementing a table-driven approach.
operands = [val]
if val.type == self._cc.torch_bool_type:
if isinstance(expected_jit_type, torch.FloatType):
if isinstance(expected_type, torch.FloatType):
op_name = "torch.aten.Float.bool"
result_type = self._cc.torch_float_type
elif isinstance(expected_jit_type, (torch.IntType, torch.NumberType)):
elif isinstance(expected_type, (torch.IntType, torch.NumberType)):
op_name = "torch.aten.Int.bool"
result_type = self._cc.torch_int_type
elif expected_type is torch.Tensor:
op_name = "torch.prims.convert_element_type"
result_type = self._cc.create_vtensor_type(dtype, size)
operands.append(
LITERAL_CONVERTER_MAP.lookup(torch.dtype)(dtype, self, self._cc)
)
if op_name is None:
return val
with loc:
return Operation.create(
name=op_name, results=[result_type], operands=[val]
).result
return Operation.create(
name=op_name, results=[result_type], operands=operands
).result

def _import_literal(self, py_value: Any) -> Value:
orig_value = None
if isinstance(py_value, torch.Tensor) and py_value.dtype == torch.bool:
orig_value = py_value
py_value = py_value.to(torch.uint8)
# Apply the conversion callback.
user_value = self.fx_importer._hooks.resolve_literal(self, py_value)
if user_value is not None:
Expand All @@ -1556,7 +1586,12 @@ def _import_literal(self, py_value: Any) -> Value:
raise TypeError(
f"Unsupported argument -> literal conversion for {py_value.__class__}"
)
return converter(py_value, self, self._cc)
result = converter(py_value, self, self._cc)
if orig_value is not None:
result = self._convert_type(
result, torch.Tensor, orig_value.dtype, orig_value.size()
)
return result

def _import_input(self, py_value: Any, info: InputInfo) -> Value:
# Try the hook.
Expand Down Expand Up @@ -1704,16 +1739,19 @@ def _make_constant_op(
)


def create_mlir_tensor_type(tensor: torch.Tensor) -> IrType:
def _create_mlir_tensor_type(dtype: torch.dtype, size: torch.Size) -> IrType:
try:
dtype = tensor.dtype
element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]()
tensor_type = RankedTensorType.get(tuple(tensor.size()), element_type)
tensor_type = RankedTensorType.get(size, element_type)
return tensor_type
except KeyError:
raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type")


def create_mlir_tensor_type(tensor: torch.Tensor) -> IrType:
return _create_mlir_tensor_type(tensor.dtype, tensor.size())


def _make_vtensor_literal_op(
tensor: torch.Tensor, vtensor_type: IrType, py_attr_tracker: "RefTracker"
) -> Operation:
Expand Down

0 comments on commit 313dc33

Please sign in to comment.