From 38a66a1db41bec532a9b718bc731d49972c1b8da Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Thu, 3 Dec 2020 14:16:17 +0100 Subject: [PATCH] Save PyTorch frontend state in object While the functional approach is pretty neat, we ended up having global state (default frontend, dtype) and it'll be more soon (caching of inferred types, see #6900). To not have to pass around the state, this moves the op conversion into a class with instances having the state. --- python/tvm/relay/frontend/pytorch.py | 1461 ++++++++++---------------- 1 file changed, 540 insertions(+), 921 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 38478e27ff928..a9ea7dcc162e9 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -22,6 +22,7 @@ import logging import sys import math +import functools import numpy as np @@ -133,16 +134,22 @@ def _is_quantized_tensor(data, prelude): # operator implementation -def _elemwise(name): - def _impl(inputs, input_types): - data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2]) - return get_relay_op(name)(data0, data1) - return _impl +class PyTorchOpConverter: + """A helper class for holding PyTorch op converters.""" + + def __init__(self, prelude, default_dtype): + self.prelude = prelude + self.default_dtype = default_dtype -def _min_max_common(name_elemwise, name_reduce): - def _impl(inputs, input_types): + @staticmethod + def _elemwise(name, inputs, input_types): + data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2]) + return get_relay_op(name)(data0, data1) + + @staticmethod + def _min_max_common(name_elemwise, name_reduce, inputs, input_types): if len(inputs) == 1: data = _pytorch_promote_types(inputs[:1], input_types[:1]) return get_relay_op(name_reduce)(data[0]) @@ -156,38 +163,27 @@ def _impl(inputs, input_types): data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2]) return get_relay_op(name_elemwise)(data0, data1) - return _impl - - -def _max(): - return _min_max_common("maximum", "max") - + def _max(self, inputs, input_types): + return self._min_max_common("maximum", "max", inputs, input_types) -def _min(): - return _min_max_common("minimum", "min") + def _min(self, inputs, input_types): + return self._min_max_common("minimum", "min", inputs, input_types) - -def _unary(name): - def _impl(inputs, input_types): + @staticmethod + def _unary(name, inputs, input_types): # this is just to ensure tensor input (data,) = _pytorch_promote_types(inputs[:1], input_types[:1]) return get_relay_op(name)(data) - return _impl - - -def _log1p(): - def _impl(inputs, input_types): + @staticmethod + def _log1p(inputs, input_types): # 1_plus_log x = log(x + 1) (dtype,) = input_types one = _expr.const(1, dtype=dtype) return _op.log(inputs[0] + one) - return _impl - - -def _arange(): - def _impl(inputs, input_types): + @staticmethod + def _arange(inputs, input_types): def _get_value(val, dtype): # dtype is a tvm dtype if isinstance(val, _expr.Expr): @@ -235,11 +231,8 @@ def _get_type(val, inp_type): return _op.transform.arange(start=start, stop=stop, step=step, dtype=dtype) - return _impl - - -def _squeeze(): - def _impl(inputs, input_types): + @staticmethod + def _squeeze(inputs, input_types): data = inputs[0] if len(inputs) == 1: axis = None @@ -249,33 +242,28 @@ def _impl(inputs, input_types): return _op.transform.squeeze(data, axis) - return _impl - - -def _unsqueeze(): - def _impl(inputs, input_types): + @staticmethod + def _unsqueeze(inputs, input_types): data = inputs[0] axis = inputs[1] return _op.transform.expand_dims(data, int(axis), 1) - return _impl - - -def _concatenate(prelude): - def tensor_array_concat(lst, axis): - assert axis == 0, "Tensor array concat supported only for axis 0" - tensor_array, shape = _convert_to_tensor_array(lst, prelude) - concat_shape = (Any(),) + shape[1:] - concat = prelude.get_global_var_static("tensor_array_concat", "float32", shape) - concatenated = concat(tensor_array) - - static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", concat_shape) - static_tensor_array_ops.register() - get_tensor = prelude.get_global_var_static("tensor_get_data", "float32", concat_shape) - return get_tensor(concatenated) + def _concatenate(self, inputs, input_types): + def tensor_array_concat(lst, axis): + assert axis == 0, "Tensor array concat supported only for axis 0" + tensor_array, shape = _convert_to_tensor_array(lst, self.prelude) + concat_shape = (Any(),) + shape[1:] + concat = self.prelude.get_global_var_static("tensor_array_concat", "float32", shape) + concatenated = concat(tensor_array) + + static_tensor_array_ops = StaticTensorArrayOps(self.prelude, "float32", concat_shape) + static_tensor_array_ops.register() + get_tensor = self.prelude.get_global_var_static( + "tensor_get_data", "float32", concat_shape + ) + return get_tensor(concatenated) - def _impl(inputs, input_types): data = inputs[0] axis = inputs[1] @@ -287,11 +275,8 @@ def _impl(inputs, input_types): return _op.tensor.concatenate(data, int(axis)) - return _impl - - -def _slice(): - def _impl(inputs, input_types): + @staticmethod + def _slice(inputs, input_types): axis_dtype = "int64" index_size_limit = 2 ** 63 - 1 data = inputs[0] @@ -391,11 +376,8 @@ def _impl(inputs, input_types): data, begin=begin, end=end, strides=strides, slice_mode="end" ) - return _impl - - -def _split(): - def _impl(inputs, input_types): + @staticmethod + def _split(inputs, input_types): data = inputs[0] split_size = int(inputs[1]) dim = int(inputs[2]) @@ -408,11 +390,8 @@ def _impl(inputs, input_types): return _op.split(data, indices, dim) - return _impl - - -def _split_with_sizes(): - def _impl(inputs, input_types): + @staticmethod + def _split_with_sizes(inputs, input_types): data = inputs[0] sections = inputs[1] dim = int(inputs[2]) @@ -430,31 +409,22 @@ def _impl(inputs, input_types): return _op.split(data, indices, dim) - return _impl - - -def _select(): - def _impl(inputs, input_types): + @staticmethod + def _select(inputs, input_types): data = inputs[0] dim = int(inputs[1]) index = _wrap_const(inputs[2]) return _op.transform.take(data, index, axis=dim) - return _impl - - -def _take(): - def _impl(inputs, input_types): + @staticmethod + def _take(inputs, input_types): data = inputs[0] indices = _op.cast(inputs[1], "int32") return _op.transform.take(data, indices=indices) - return _impl - - -def _topk(): - def _impl(inputs, input_types): + @staticmethod + def _topk(inputs, input_types): data = inputs[0] axis = int(inputs[2]) is_ascend = not bool(inputs[3]) @@ -473,28 +443,19 @@ def _impl(inputs, input_types): return outs[0], outs[1] - return _impl - - -def _reciprocal(): - def _impl(inputs, input_types): + @staticmethod + def _reciprocal(inputs, input_types): data = inputs[0] return _expr.const(1.0, dtype=input_types[0]) / data - return _impl - - -def _repeat(): - def _impl(inputs, input_types): + @staticmethod + def _repeat(inputs, input_types): data = inputs[0] reps = inputs[1] return _op.transform.tile(data, reps=reps) - return _impl - - -def _repeat_interleave(): - def _impl(inputs, input_types): + @staticmethod + def _repeat_interleave(inputs, input_types): data = inputs[0] if isinstance(inputs[1], int): repeats = inputs[1] @@ -507,77 +468,63 @@ def _impl(inputs, input_types): axis = 0 return _op.transform.repeat(data, repeats=repeats, axis=axis) - return _impl - - -def _addcdiv(): - def _impl(inputs, input_types): + @staticmethod + def _addcdiv(inputs, input_types): data, t1, t2, c = _pytorch_promote_types(inputs[:4], input_types[:4]) return data + (c * (t1 / t2)) - return _impl - - -def _addcmul(): - def _impl(inputs, input_types): + @staticmethod + def _addcmul(inputs, input_types): data, t1, t2, c = _pytorch_promote_types(inputs[:4], input_types[:4]) return data + (c * (t1 * t2)) - return _impl - - -def _where(): - def _impl(inputs, input_types): + def _where(self, inputs, input_types): if len(inputs) == 1: - return _nonzero(False)([inputs[0], True], input_types) + return self._nonzero([inputs[0], True], input_types) cond = inputs[0] x, y = _pytorch_promote_types(inputs[1:3], input_types[1:3]) return _op.where(cond, x, y) - return _impl - - -def _full_impl(data, fill_value, dtype): - size = [] - need_reshape = False - new_shape = [] - for dim in data: - if isinstance(dim, _expr.Expr): - if isinstance(dim, _expr.Constant): - dim = int(dim.data.asnumpy()) - if isinstance(size, list): - size.append(dim) - new_shape.append(dim) - else: - dim, success = try_infer_value(dim, lambda ret: int(ret), lambda: 0) - new_shape.append(dim) - - if success: + @staticmethod + def _full_impl(data, fill_value, dtype): + size = [] + need_reshape = False + new_shape = [] + for dim in data: + if isinstance(dim, _expr.Expr): + if isinstance(dim, _expr.Constant): + dim = int(dim.data.asnumpy()) if isinstance(size, list): size.append(dim) + new_shape.append(dim) else: - size = None - need_reshape = True - else: - if isinstance(size, list): - size.append(dim) - new_shape.append(dim) + dim, success = try_infer_value(dim, lambda ret: int(ret), lambda: 0) + new_shape.append(dim) - if size is None: - tmp = [] - for dim in data: - tmp.append(_op.cast(_op.expand_dims(dim, axis=0), "int64")) - size = _op.concatenate(tmp, axis=0) + if success: + if isinstance(size, list): + size.append(dim) + else: + size = None + need_reshape = True + else: + if isinstance(size, list): + size.append(dim) + new_shape.append(dim) - out = _op.full(_expr.const(fill_value), size, dtype=dtype) - if need_reshape: - out = _op.reshape(out, new_shape) - return out + if size is None: + tmp = [] + for dim in data: + tmp.append(_op.cast(_op.expand_dims(dim, axis=0), "int64")) + size = _op.concatenate(tmp, axis=0) + out = _op.full(_expr.const(fill_value), size, dtype=dtype) + if need_reshape: + out = _op.reshape(out, new_shape) + return out -def _ones(default_dtype): - def _impl(inputs, input_types): + def _ones(self, inputs, input_types): data = inputs[0] import torch @@ -589,14 +536,10 @@ def _impl(inputs, input_types): if inputs[1] is not None: dtype = _convert_dtype_value(inputs[1]) else: - dtype = default_dtype - return _full_impl(data, 1, dtype) - - return _impl + dtype = self.default_dtype + return self._full_impl(data, 1, dtype) - -def _ones_like(default_dtype): - def _impl(inputs, input_types): + def _ones_like(self, inputs, input_types): data = inputs[0] out = _op.ones_like(data) @@ -604,17 +547,13 @@ def _impl(inputs, input_types): if inputs[1] is not None: dtype = _convert_dtype_value(inputs[1]) else: - dtype = default_dtype + dtype = self.default_dtype if input_types[0] != dtype: out = _op.cast(out, dtype) return out - return _impl - - -def _zeros(default_dtype): - def _impl(inputs, input_types): + def _zeros(self, inputs, input_types): data = inputs[0] import torch @@ -626,14 +565,10 @@ def _impl(inputs, input_types): if inputs[1] is not None: dtype = _convert_dtype_value(inputs[1]) else: - dtype = default_dtype - return _full_impl(data, 0, dtype) - - return _impl + dtype = self.default_dtype + return self._full_impl(data, 0, dtype) - -def _zeros_like(default_dtype): - def _impl(inputs, input_types): + def _zeros_like(self, inputs, input_types): data = inputs[0] out = _op.zeros_like(data) @@ -641,17 +576,13 @@ def _impl(inputs, input_types): if inputs[1] is not None: dtype = _convert_dtype_value(inputs[1]) else: - dtype = default_dtype + dtype = self.default_dtype if input_types[0] not in dtype: out = _op.cast(out, dtype) return out - return _impl - - -def _full(default_dtype): - def _impl(inputs, input_types): + def _full(self, inputs, input_types): data = inputs[0] fill_value = inputs[1] @@ -665,15 +596,11 @@ def _impl(inputs, input_types): dtype = _convert_dtype_value(inputs[2]) else: # if dtype is None, torch uses a global default set by torch.set_default_tensor_type() - dtype = default_dtype - - return _full_impl(data, fill_value, dtype) - - return _impl + dtype = self.default_dtype + return self._full_impl(data, fill_value, dtype) -def _full_like(default_dtype): - def _impl(inputs, input_types): + def _full_like(self, inputs, input_types): data = inputs[0] fill_value = inputs[1] @@ -684,17 +611,14 @@ def _impl(inputs, input_types): dtype = _convert_dtype_value(inputs[2]) else: # if dtype is None, torch uses a global default set by torch.set_default_tensor_type() - dtype = default_dtype + dtype = self.default_dtype if input_types[0] not in dtype: out = _op.cast(out, dtype) return out - return _impl - - -def _linspace(): - def _impl(inputs, input_types): + @staticmethod + def _linspace(inputs, input_types): start = inputs[0] stop = inputs[1] step = inputs[2] @@ -713,51 +637,35 @@ def _impl(inputs, input_types): return _op.transform.arange(start=start, stop=stop, step=step, dtype=dtype) - return _impl - - -def _relu(prelude): - def _impl(inputs, input_types): + def _relu(self, inputs, input_types): data = inputs[0] - if _is_quantized_tensor(data, prelude): + if _is_quantized_tensor(data, self.prelude): assert len(inputs) == 3, "Input quant param not found in op inputs" input_zero_point = _expr.const(inputs[2], dtype="int32") return qnn_torch.quantized_relu(data, input_zero_point) return _op.nn.relu(data) - return _impl - - -def _prelu(): - def _impl(inputs, input_types): + @staticmethod + def _prelu(inputs, input_types): data = inputs[0] alpha = inputs[1] return _op.nn.prelu(data, alpha) - return _impl - - -def _leaky_relu(): - def _impl(inputs, input_types): + @staticmethod + def _leaky_relu(inputs, input_types): data = inputs[0] alpha = float(inputs[1]) return _op.nn.leaky_relu(data, alpha) - return _impl - - -def _elu(): - def _impl(inputs, input_types): + @staticmethod + def _elu(inputs, input_types): data = inputs[0] dtype = input_types[0] alpha = _expr.const(float(inputs[1]), dtype=dtype) return alpha * _op.nn.relu(_expr.const(1, dtype=dtype) - _op.exp(data)) + _op.nn.relu(data) - return _impl - - -def _celu(): - def _impl(inputs, input_types): + @staticmethod + def _celu(inputs, input_types): data = inputs[0] dtype = input_types[0] alpha = _expr.const(float(inputs[1]), dtype=dtype) @@ -765,11 +673,8 @@ def _impl(inputs, input_types): _expr.const(1, dtype=dtype) - _op.exp(data / alpha) ) + _op.nn.relu(data) - return _impl - - -def _gelu(): - def _impl(inputs, input_types): + @staticmethod + def _gelu(inputs, input_types): data = inputs[0] dtype = input_types[0] # gelu is data * normcdf(data) @@ -781,11 +686,8 @@ def _impl(inputs, input_types): + _op.erf(data * _expr.const(0.5 ** 0.5, dtype=dtype)) * _expr.const(0.5, dtype=dtype) ) - return _impl - - -def _selu(): - def _impl(inputs, input_types): + @staticmethod + def _selu(inputs, input_types): data = inputs[0] # https://pytorch.org/docs/stable/nn.html#selu dtype = input_types[0] @@ -795,65 +697,46 @@ def _impl(inputs, input_types): alpha * _op.nn.relu(_expr.const(1.0, dtype=dtype) - _op.exp(data)) + _op.nn.relu(data) ) - return _impl - - -def _log_sigmoid(): - def _impl(inputs, input_types): + @staticmethod + def _log_sigmoid(inputs, input_types): data = inputs[0] return _op.log(_op.tensor.sigmoid(data)) - return _impl - - -def _adaptive_avg_pool_2d(prelude): - def _impl(inputs, input_types): + def _adaptive_avg_pool_2d(self, inputs, input_types): data = inputs[0] output_size = inputs[1] def func(x): return _op.nn.adaptive_avg_pool2d(x, output_size=output_size) - if _is_quantized_tensor(data, prelude): + if _is_quantized_tensor(data, self.prelude): return qnn_torch.apply_with_upcast(data, func) return func(data) - return _impl - - -def _adaptive_max_pool_2d(): - def _impl(inputs, input_types): + @staticmethod + def _adaptive_max_pool_2d(inputs, input_types): data = inputs[0] output_size = inputs[1] # returns dummy indices too return _op.nn.adaptive_max_pool2d(data, output_size=output_size), None - return _impl - - -def _adaptive_max_pool_3d(): - def _impl(inputs, input_types): + @staticmethod + def _adaptive_max_pool_3d(inputs, input_types): data = inputs[0] output_size = inputs[1] # returns dummy indices too return _op.nn.adaptive_max_pool3d(data, output_size=output_size), None - return _impl - - -def _adaptive_avg_pool_3d(): - def _impl(inputs, input_types): + @staticmethod + def _adaptive_avg_pool_3d(inputs, input_types): data = inputs[0] output_size = inputs[1] return _op.nn.adaptive_avg_pool3d(data, output_size=output_size) - return _impl - - -def _maxpool_2d(): - def _impl(inputs, input_types): + @staticmethod + def _maxpool_2d(inputs, input_types): data = inputs[0] pool_size = inputs[1] @@ -868,19 +751,12 @@ def _impl(inputs, input_types): return _op.nn.max_pool2d(data, pool_size, strides, padding, "NCHW", ceil_mode) - return _impl - - -def _maxpool_2d_with_indices(): - def _impl(inputs, input_types): + def _maxpool_2d_with_indices(self, inputs, input_types): # returns dummy indices too - return _maxpool_2d()(inputs, input_types), None - - return _impl + return self._maxpool_2d(inputs, input_types), None - -def _maxpool_1d(): - def _impl(inputs, input_types): + @staticmethod + def _maxpool_1d(inputs, input_types): data = inputs[0] pool_size = inputs[1] @@ -895,11 +771,8 @@ def _impl(inputs, input_types): return _op.nn.max_pool1d(data, pool_size, strides, padding, "NCW", ceil_mode) - return _impl - - -def _maxpool_3d(): - def _impl(inputs, input_types): + @staticmethod + def _maxpool_3d(inputs, input_types): data = inputs[0] pool_size = inputs[1] @@ -915,21 +788,15 @@ def _impl(inputs, input_types): data, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode ) - return _impl - - -def _hardtanh(): - def _impl(inputs, input_types): + @staticmethod + def _hardtanh(inputs, input_types): a = inputs[0] tanh_min = float(inputs[1]) tanh_max = float(inputs[2]) return _op.tensor.clip(a, tanh_min, tanh_max) - return _impl - - -def _convolution(): - def _impl(inputs, input_types): + @staticmethod + def _convolution(inputs, input_types): # Use transpose or normal use_transpose = True if inputs[6] == 1 else False @@ -1018,11 +885,8 @@ def _impl(inputs, input_types): res = _op.squeeze(res, axis=[2]) return res - return _impl - - -def _softmax(): - def _impl(inputs, input_types): + @staticmethod + def _softmax(inputs, input_types): data = inputs[0] axis = inputs[1] if isinstance(axis, str): @@ -1030,27 +894,18 @@ def _impl(inputs, input_types): return _op.nn.softmax(data, axis=axis) - return _impl - - -def _threshold(): - def _impl(inputs, input_types): + @staticmethod + def _threshold(inputs, input_types): data = inputs[0] return _op.nn.relu(data) - return _impl - - -def _contiguous(): - def _impl(inputs, input_types): + @staticmethod + def _contiguous(inputs, input_types): data = inputs[0] return _op.tensor.copy(data) - return _impl - - -def _batch_norm(): - def _impl(inputs, input_types): + @staticmethod + def _batch_norm(inputs, input_types): data = inputs[0] data_type = input_types[0] @@ -1086,11 +941,8 @@ def _impl(inputs, input_types): scale=scale, )[0] - return _impl - - -def _instance_norm(): - def _impl(inputs, input_types): + @staticmethod + def _instance_norm(inputs, input_types): data = inputs[0] data_type = input_types[0] channels = _infer_shape(data) @@ -1114,28 +966,24 @@ def _impl(inputs, input_types): data, gamma, beta, axis=1, epsilon=epsilon, center=center, scale=scale ) - return _impl - - -def _get_dims(data): - import torch - - if isinstance(data, _expr.Expr): - dims = _infer_shape(data) - elif isinstance(data, list): - dims = data - elif isinstance(data, (torch.Tensor, np.ndarray)): - dims = data.shape - else: - msg = "Data type %s could not be parsed" % type(data) - raise AssertionError(msg) - return dims + @staticmethod + def _get_dims(data): + import torch + if isinstance(data, _expr.Expr): + dims = _infer_shape(data) + elif isinstance(data, list): + dims = data + elif isinstance(data, (torch.Tensor, np.ndarray)): + dims = data.shape + else: + msg = "Data type %s could not be parsed" % type(data) + raise AssertionError(msg) + return dims -def _layer_norm(): - def _impl(inputs, input_types): + def _layer_norm(self, inputs, input_types): data = inputs[0] - ndims = len(_get_dims(inputs[1])) + ndims = len(self._get_dims(inputs[1])) assert ndims == 1, "Support only normalization over last one dimension." return _op.nn.layer_norm( @@ -1148,11 +996,8 @@ def _impl(inputs, input_types): scale=True, ) - return _impl - - -def _group_norm(): - def _impl(inputs, input_types): + @staticmethod + def _group_norm(inputs, input_types): data = inputs[0] gamma = inputs[2] beta = inputs[3] @@ -1170,17 +1015,13 @@ def _impl(inputs, input_types): scale=True, ) - return _impl - - -def _transpose(prelude): - def _impl(inputs, input_types): + def _transpose(self, inputs, input_types): data = inputs[0] import torch if isinstance(data, _expr.Expr): - ndims = len(_infer_shape(data, prelude.mod)) + ndims = len(_infer_shape(data, self.prelude.mod)) elif isinstance(data, list): ndims = data elif isinstance(data, (torch.Tensor, np.ndarray)): @@ -1211,11 +1052,8 @@ def _impl(inputs, input_types): axes = inputs[1] return _op.transform.transpose(data, axes) - return _impl - - -def _flatten(): - def _impl(inputs, input_types): + @staticmethod + def _flatten(inputs, input_types): data = inputs[0] start = int(inputs[1]) end = int(inputs[2]) @@ -1237,11 +1075,8 @@ def _impl(inputs, input_types): out = _op.squeeze(out, axis=squeeze_axes) return out - return _impl - - -def _addmm(): - def _impl(inputs, input_types): + @staticmethod + def _addmm(inputs, input_types): input_mat = inputs[0] mat1 = inputs[1] data_type = input_types[1] @@ -1265,35 +1100,25 @@ def _impl(inputs, input_types): return dense_out + input_mat - return _impl - - -def _size(prelude): - def _impl_dynamic(inp, axis): - shape_dynamic = _op.shape_of(inp, dtype="int32") - if axis is not None: - return _op.take(shape_dynamic, _expr.const(axis), 0) - return shape_dynamic - - def _impl(inputs, input_types): - shape = _infer_shape(inputs[0], prelude.mod) + def _size(self, inputs, input_types): + shape = _infer_shape(inputs[0], self.prelude.mod) axis = None if len(inputs) > 1: axis = int(inputs[1]) if any(map(lambda s: isinstance(s, tvm.tir.expr.Any), shape)): if axis is None or isinstance(shape[axis], tvm.tir.expr.Any): - return _impl_dynamic(inputs[0], axis) + shape_dynamic = _op.shape_of(inputs[0], dtype="int32") + if axis is not None: + return _op.take(shape_dynamic, _expr.const(axis), 0) + return shape_dynamic if axis is not None: return _expr.const(shape[axis]) return _expr.const(shape) - return _impl - - -def _numtotensor(): - def _impl(inputs, input_types): + @staticmethod + def _numtotensor(inputs, input_types): val = inputs[0] dtype = input_types[0] @@ -1307,18 +1132,12 @@ def _impl(inputs, input_types): arr = val * np.ones([]).astype(dtype) return arr - return _impl - - -def _tensortonum(): - def _impl(inputs, input_types): + @staticmethod + def _tensortonum(inputs, input_types): return inputs[0] - return _impl - - -def _view(): - def _impl(inputs, input_types): + @staticmethod + def _view(inputs, input_types): data = inputs[0] if len(inputs) == 3: @@ -1336,11 +1155,8 @@ def _impl(inputs, input_types): return _op.transform.reshape(data, new_shape) - return _impl - - -def _reshape(): - def _impl(inputs, input_types): + @staticmethod + def _reshape(inputs, input_types): data = inputs[0] new_shape = inputs[1] @@ -1371,11 +1187,7 @@ def _impl(inputs, input_types): new_shape = tmp_shape return _op.transform.reshape(data, new_shape) - return _impl - - -def _pixel_shuffle(prelude): - def _impl(inputs, input_types): + def _pixel_shuffle(self, inputs, input_types): data = inputs[0] upscale_factor = inputs[1] upscale_squared = upscale_factor * upscale_factor @@ -1384,7 +1196,7 @@ def _impl(inputs, input_types): c % upscale_squared == 0 ), "input channel should be divisible by square of upscale_factor" - ndims = len(_infer_shape(data, prelude.mod)) + ndims = len(_infer_shape(data, self.prelude.mod)) axes = list(range(ndims)) num_inputs = len(inputs) oc = c // upscale_squared @@ -1402,46 +1214,30 @@ def _impl(inputs, input_types): data = _op.transform.transpose(data, axes) return _op.transform.reshape(data, out_shape) - return _impl - - -def _clone(): - def _impl(inputs, input_types): + @staticmethod + def _clone(inputs, input_types): data = inputs[0] return _op.tensor.copy(data) - return _impl - - -def _log_softmax(): - def _impl(inputs, input_types): + @staticmethod + def _log_softmax(inputs, input_types): data = inputs[0] axis = int(inputs[1]) return _op.nn.log_softmax(data, axis) - return _impl - - -def _sigmoid(): - def _impl(inputs, input_types): + @staticmethod + def _sigmoid(inputs, input_types): data = inputs[0] return _op.tensor.sigmoid(data) - return _impl - - -def _softplus(): - def _impl(inputs, input_types): + @staticmethod + def _softplus(inputs, input_types): data = inputs[0] dtype = input_types[0] beta = _expr.const(float(inputs[1]), dtype=dtype) return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.0, dtype=dtype)) / beta - return _impl - - -def _avg_pool2d(prelude): - def _impl(inputs, input_types): + def _avg_pool2d(self, inputs, input_types): data = inputs[0] pool_size = inputs[1] @@ -1460,16 +1256,13 @@ def func(x): count_include_pad=count_include_pad, ) - if _is_quantized_tensor(data, prelude): + if _is_quantized_tensor(data, self.prelude): return qnn_torch.apply_with_upcast(data, func) return func(data) - return _impl - - -def _avg_pool3d(): - def _impl(inputs, input_types): + @staticmethod + def _avg_pool3d(inputs, input_types): data = inputs[0] pool_size = inputs[1] @@ -1487,21 +1280,15 @@ def _impl(inputs, input_types): count_include_pad=count_include_pad, ) - return _impl - - -def _dropout(): - def _impl(inputs, input_types): + @staticmethod + def _dropout(inputs, input_types): data = inputs[0] rate = float(inputs[1]) return _op.nn.dropout(data, rate) - return _impl - - -def _reduce(name): - def _impl(inputs, input_types): + @staticmethod + def _reduce(name, inputs, input_types): data = inputs[0] axis = None keepdims = False @@ -1517,11 +1304,8 @@ def _impl(inputs, input_types): return get_relay_op(name)(data, axis=axis, keepdims=keepdims) - return _impl - - -def _norm(): - def _impl(inputs, input_types): + @staticmethod + def _norm(inputs, input_types): data = inputs[0] dtype = input_types[0] axis = None @@ -1543,11 +1327,8 @@ def _impl(inputs, input_types): reci_order, ) - return _impl - - -def _frobenius_norm(): - def _impl(inputs, input_types): + @staticmethod + def _frobenius_norm(inputs, input_types): data = inputs[0] axis = None keepdims = False @@ -1557,11 +1338,8 @@ def _impl(inputs, input_types): return _op.sqrt(_op.reduce.sum((data * data), axis=axis, keepdims=keepdims)) - return _impl - - -def _std(): - def _impl(inputs, input_types): + @staticmethod + def _std(inputs, input_types): data = inputs[0] if len(inputs) == 2: axis = None @@ -1574,11 +1352,8 @@ def _impl(inputs, input_types): return _op.reduce.std(data, axis=axis, keepdims=keepdims, unbiased=unbiased) - return _impl - - -def _variance(): - def _impl(inputs, input_types): + @staticmethod + def _variance(inputs, input_types): data = inputs[0] if len(inputs) == 2: axis = None @@ -1591,11 +1366,7 @@ def _impl(inputs, input_types): return _op.reduce.variance(data, axis=axis, keepdims=keepdims, unbiased=unbiased) - return _impl - - -def _mean(prelude): - def _impl(inputs, input_types): + def _mean(self, inputs, input_types): data = inputs[0] if inputs[1]: @@ -1615,7 +1386,7 @@ def _impl(inputs, input_types): def func(x): return _op.mean(x, axis, keepdims, exclude) - if _is_quantized_tensor(data, prelude): + if _is_quantized_tensor(data, self.prelude): assert len(inputs) == 6, "Input quant param not found in op inputs" input_scale = _expr.const(inputs[4]) input_zero_point = _expr.const(inputs[5]) @@ -1623,18 +1394,14 @@ def func(x): return func(data) - return _impl - - -def _chunk(prelude): - def _impl(inputs, input_types): + def _chunk(self, inputs, input_types): data = inputs[0] num_chunks = int(inputs[1]) axis = int(inputs[2]) if isinstance(data, _expr.Expr): - inferred_shape = _infer_shape(data, prelude.mod) + inferred_shape = _infer_shape(data, self.prelude.mod) shape = [] for infer in inferred_shape: @@ -1670,18 +1437,14 @@ def _impl(inputs, input_types): return chunks - return _impl - - -def _matmul(prelude): - def _impl(inputs, input_types): + def _matmul(self, inputs, input_types): inputs_0 = inputs[0] inputs_1 = inputs[1] # Need to check input shape as batch matmul must be supported. - a_shape = _infer_shape(inputs_0, prelude.mod) - b_shape = _infer_shape(inputs_1, prelude.mod) + a_shape = _infer_shape(inputs_0, self.prelude.mod) + b_shape = _infer_shape(inputs_1, self.prelude.mod) # When performing a batch matmul, we need to properly handle N-dim shapes. if len(a_shape) > 2 or len(b_shape) > 2: @@ -1689,8 +1452,8 @@ def _impl(inputs, input_types): a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]]) b = _op.reshape(inputs_1, [-1, b_shape[-2], b_shape[-1]]) # Broadcast b to match batch size of a - new_b_shape = list(_infer_shape(b, prelude.mod)) - new_a_shape = _infer_shape(a, prelude.mod) + new_b_shape = list(_infer_shape(b, self.prelude.mod)) + new_a_shape = _infer_shape(a, self.prelude.mod) if new_a_shape[0] > new_b_shape[0]: new_b_shape[0] = new_a_shape[0] b = _op.broadcast_to(b, new_b_shape) @@ -1714,11 +1477,8 @@ def _impl(inputs, input_types): return out - return _impl - - -def _expand(): - def _impl(inputs, input_types): + @staticmethod + def _expand(inputs, input_types): data_in = inputs[0] shape = list(_infer_shape(data_in)) @@ -1740,34 +1500,22 @@ def _impl(inputs, input_types): return out - return _impl - - -def _int(): - def _impl(inputs, input_types): + @staticmethod + def _int(inputs, input_types): if isinstance(inputs[0], _expr.Expr): return inputs[0] return int(inputs[0]) - return _impl - - -def _identity(): - def _impl(inputs, input_types): + @staticmethod + def _identity(inputs, input_types): return inputs[0] - return _impl - - -def _none(): - def _impl(inputs, input_types): + @staticmethod + def _none(inputs, input_types): return None - return _impl - - -def _pad(mode): - def _impl(inputs, input_types): + @staticmethod + def _pad(mode, inputs, input_types): data = inputs[0] if isinstance(inputs[1], list): pad_list = inputs[1] @@ -1804,21 +1552,15 @@ def _impl(inputs, input_types): else: return _op.nn.pad(data, const_paddings, pad_mode=mode) - return _impl - - -def _clamp(): - def _impl(inputs, input_types): + @staticmethod + def _clamp(inputs, input_types): data = inputs[0] amin = inputs[1] if inputs[1] else np.finfo(np.float32).min amax = inputs[2] if inputs[2] else np.finfo(np.float32).max return _op.clip(data, amin, amax) - return _impl - - -def _to(): - def _impl(inputs, input_types): + @staticmethod + def _to(inputs, input_types): data = inputs[0] dtype = inputs[1] if inputs[1] is not None and not isinstance(inputs[1], str) else inputs[2] # special handling for aten::to(data, 6, _, _, _) case @@ -1844,34 +1586,30 @@ def _impl(inputs, input_types): return ret - return _impl - - -def _get_upsample_out_size(inputs, method): - # This assumes a static shape - out_size = [] - if inputs[1] is not None: - for size in inputs[1]: - if not isinstance(size, int): - out_size.append(int(_infer_value(size, {}).asnumpy())) - else: - out_size.append(size) - else: - scale_index = 3 if method in ["bilinear", "trilinear"] else 2 - scales = inputs[scale_index] - assert scales is not None, "neither out size nor scale provided" - assert isinstance(scales, list) - ishape = _infer_shape(inputs[0]) - for i, scale in enumerate(scales): - out_size.append(int(math.floor(float(ishape[2 + i]) * scale))) - - return out_size + @staticmethod + def _get_upsample_out_size(inputs, method): + # This assumes a static shape + out_size = [] + if inputs[1] is not None: + for size in inputs[1]: + if not isinstance(size, int): + out_size.append(int(_infer_value(size, {}).asnumpy())) + else: + out_size.append(size) + else: + scale_index = 3 if method in ["bilinear", "trilinear"] else 2 + scales = inputs[scale_index] + assert scales is not None, "neither out size nor scale provided" + assert isinstance(scales, list) + ishape = _infer_shape(inputs[0]) + for i, scale in enumerate(scales): + out_size.append(int(math.floor(float(ishape[2 + i]) * scale))) + return out_size -def _upsample(method, prelude): - def _impl(inputs, input_types): + def _upsample(self, method, inputs, input_types): data = inputs[0] - out_size = _get_upsample_out_size(inputs, method) + out_size = self._get_upsample_out_size(inputs, method) if len(inputs) > 2 and method == "bilinear": align_corners = inputs[2] @@ -1888,7 +1626,7 @@ def _impl(inputs, input_types): def func(x): return _op.image.resize(x, out_size, "NCHW", method, coord_trans) - if _is_quantized_tensor(data, prelude): + if _is_quantized_tensor(data, self.prelude): # input qparams are manually appended by us assert isinstance(inputs[-2], float) assert isinstance(inputs[-1], int) @@ -1898,13 +1636,9 @@ def func(x): return func(data) - return _impl - - -def _upsample3d(method): - def _impl(inputs, input_types): + def _upsample3d(self, method, inputs, input_types): data = inputs[0] - out_size = _get_upsample_out_size(inputs, method) + out_size = self._get_upsample_out_size(inputs, method) if len(inputs) > 2 and method == "trilinear": align_corners = inputs[2] @@ -1920,11 +1654,8 @@ def _impl(inputs, input_types): return _op.image.resize3d(data, out_size, "NCDHW", method, coord_trans) - return _impl - - -def _expand_as(): - def _impl(inputs, input_types): + @staticmethod + def _expand_as(inputs, input_types): target = inputs[1] t0 = _infer_type(inputs[0]).checked_type.dtype t1 = _infer_type(inputs[1]).checked_type.dtype @@ -1932,34 +1663,22 @@ def _impl(inputs, input_types): target = _op.cast(target, t0) return _op.broadcast_to_like(inputs[0], target) - return _impl - - -def _Bool(): - def _impl(inputs, input_types): + @staticmethod + def _Bool(inputs, input_types): assert len(inputs) == 1 return inputs[0] - return _impl - - -def _Float(): - def _impl(inputs, input_types): + @staticmethod + def _Float(inputs, input_types): assert len(inputs) == 1 return _op.cast(inputs[0], "float32") - return _impl - - -def _mm(): - def _impl(inputs, input_types): + @staticmethod + def _mm(inputs, input_types): return _op.nn.dense(inputs[0], inputs[1]) - return _impl - - -def _bitwise_not(): - def _impl(inputs, input_types): + @staticmethod + def _bitwise_not(inputs, input_types): data = inputs[0] # The input tensor must be of integral or Boolean types. # For bool tensors, it computes the logical NOT @@ -1970,11 +1689,8 @@ def _impl(inputs, input_types): return out - return _impl - - -def _bitwise_xor(): - def _impl(inputs, input_types): + @staticmethod + def _bitwise_xor(inputs, input_types): lhs = inputs[0] rhs = inputs[1] lhs = _op.cast(lhs, "bool") if input_types[0] == "bool" else _op.cast(lhs, "int") @@ -1982,91 +1698,59 @@ def _impl(inputs, input_types): return _op.bitwise_xor(lhs, rhs) - return _impl - - -def _logical_not(): - def _impl(inputs, input_types): + @staticmethod + def _logical_not(inputs, input_types): data = _wrap_const(inputs[0]) return _op.logical_not(_op.cast(data, "bool")) - return _impl - - -def _logical_xor(): - def _impl(inputs, input_types): + @staticmethod + def _logical_xor(inputs, input_types): lhs = _op.cast(inputs[0], "bool") rhs = _op.cast(inputs[1], "bool") return _op.logical_xor(lhs, rhs) - return _impl - - -def _list_getitem(prelude): - def _impl(inputs, input_types): - return prelude.nth(inputs[0], _wrap_const(inputs[1])) - - return _impl - - -def _list_len(prelude): - def _impl(inputs, input_types): - return prelude.length(inputs[0]) + def _list_getitem(self, inputs, input_types): + return self.prelude.nth(inputs[0], _wrap_const(inputs[1])) - return _impl + def _list_len(self, inputs, input_types): + return self.prelude.length(inputs[0]) - -def _type_as(): - def _impl(inputs, input_types): + @staticmethod + def _type_as(inputs, input_types): assert len(inputs) == 2 assert len(input_types) == 2 return _op.cast(inputs[0], input_types[1]) - return _impl - - -def _gather(): - def _impl(inputs, input_types): + @staticmethod + def _gather(inputs, input_types): data = inputs[0] axis = inputs[1] indices = inputs[2] return _op.gather(data, axis, indices) - return _impl - - -def _add(prelude): - # add_ is overloaded for tensor add and list concat - def _impl(inputs, input_types): + def _add(self, inputs, input_types): + # add_ is overloaded for tensor add and list concat if input_types[0] == "ListType": - return prelude.concat(inputs[0], inputs[1]) - return _elemwise("add")(inputs, input_types) + return self.prelude.concat(inputs[0], inputs[1]) + return self._elemwise("add", inputs, input_types) - return _impl - - -def _tensor_array_stack(prelude): - def _impl(inputs, input_types): + def _tensor_array_stack(self, inputs, input_types): dim = inputs[1] assert dim == 0, "stacking on a dynamic tensor list only supported on a first axis" - tensor_array, shape = _convert_to_tensor_array(inputs[0], prelude) + tensor_array, shape = _convert_to_tensor_array(inputs[0], self.prelude) stacked_shape = (Any(),) + shape - stack = prelude.get_global_var_static("tensor_array_stack", "float32", shape) + stack = self.prelude.get_global_var_static("tensor_array_stack", "float32", shape) stacked = stack(tensor_array) - static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", stacked_shape) + static_tensor_array_ops = StaticTensorArrayOps(self.prelude, "float32", stacked_shape) static_tensor_array_ops.register() - get_tensor = prelude.get_global_var_static("tensor_get_data", "float32", stacked_shape) + get_tensor = self.prelude.get_global_var_static("tensor_get_data", "float32", stacked_shape) return get_tensor(stacked) - return _impl - - -def _stack(prelude): - def _impl(inputs, input_types): + def _stack(self, inputs, input_types): if isinstance(inputs[0], list): # a static python list of tensors dim = inputs[1] @@ -2074,17 +1758,14 @@ def _impl(inputs, input_types): else: # List ADT case assert isinstance(inputs[0], _expr.Expr) - ty = _infer_type_with_prelude(inputs[0], prelude) - list_ty = prelude.mod.get_global_type_var("List") + ty = _infer_type_with_prelude(inputs[0], self.prelude) + list_ty = self.prelude.mod.get_global_type_var("List") msg = "The input list is expected to be List ADT" assert isinstance(ty, tvm.ir.TypeCall) and ty.func == list_ty, msg - return _tensor_array_stack(prelude)(inputs, input_types) + return self._tensor_array_stack(inputs, input_types) - return _impl - - -def _rsub(): - def _impl(inputs, input_types): + @staticmethod + def _rsub(inputs, input_types): data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2]) # TODO (t-vi): should this also be part of the type promotion? @@ -2093,21 +1774,15 @@ def _impl(inputs, input_types): # note: rsub means data0 and data1 swap places return get_relay_op("subtract")(data1, alpha * data0) - return _impl - - -def _embedding(): - def _impl(inputs, input_types): + @staticmethod + def _embedding(inputs, input_types): weight = inputs[0] indices = inputs[1] return _op.take(weight, indices.astype("int32"), axis=0) - return _impl - - -def _one_hot(): - def _impl(inputs, input_types): + @staticmethod + def _one_hot(inputs, input_types): indices = inputs[0].astype("int32") num_classes = inputs[1] if num_classes == -1: @@ -2120,28 +1795,18 @@ def _impl(inputs, input_types): return _op.one_hot(indices, on_value, off_value, num_classes, -1, dtype) - return _impl - - -def _index(): - def _impl(inputs, input_types): + @staticmethod + def _index(inputs, input_types): data = inputs[0] indices = inputs[1] return _op.adv_index([data] + indices) - return _impl - - -def _meshgrid(): - def _impl(inputs, input_types): + @staticmethod + def _meshgrid(inputs, input_types): data = inputs[0] return _op.meshgrid(data, indexing="ij") - return _impl - - -def _nms(prelude): - def _impl(inputs, input_types): + def _nms(self, inputs, input_types): boxes = inputs[0] scores = inputs[1] iou_threshold = inputs[2] @@ -2187,11 +1852,8 @@ def _impl(inputs, input_types): # in torchvision, indices from nms are int64 return _op.cast(ret, "int64") - return _impl - - -def _logsumexp(): - def _impl(inputs, input_types): + @staticmethod + def _logsumexp(inputs, input_types): data = _pytorch_promote_types(inputs[:1], input_types[:1]) dim_list = inputs[1] keepdim = inputs[2] if len(inputs) > 2 else False @@ -2199,11 +1861,7 @@ def _impl(inputs, input_types): assert isinstance(dim_list, list), "dim is expected to be a list" return _op.logsumexp(data[0], axis=dim_list, keepdims=keepdim) - return _impl - - -def _roi_align(prelude): - def _impl(inputs, input_types): + def _roi_align(self, inputs, input_types): data = inputs[0] boxes = inputs[1] @@ -2217,16 +1875,13 @@ def _impl(inputs, input_types): return _op.vision.roi_align(data, boxes, output_size, spatial_scale, sample_ratio) - return _impl - - -def _unbind(): - def _impl(inputs, input_types): + @staticmethod + def _unbind(inputs, input_types): data = inputs[0] dim = int(inputs[1]) ishapes = _infer_shape(data) if dim >= len(ishapes): - msg = "Please check input dim, it shouldn't" "be greater than or equal to rank." + msg = "Please check input dim, it shouldn't be greater than or equal to rank." raise AttributeError(msg) selections = ishapes[dim] @@ -2239,13 +1894,9 @@ def _impl(inputs, input_types): ret = _expr.TupleWrapper(_expr.Tuple(ret), selections) return ret - return _impl - - -def _shape_as_tensor(prelude): - def _impl(inputs, input_types): + def _shape_as_tensor(self, inputs, input_types): is_symbolic_shape = False - input_shape = _infer_shape(inputs[0], prelude.mod) + input_shape = _infer_shape(inputs[0], self.prelude.mod) for axis in input_shape: if not isinstance(axis, (int, tvm.tir.IntImm)): is_symbolic_shape = True @@ -2258,45 +1909,30 @@ def _impl(inputs, input_types): return ret - return _impl - - -def _logical_and(): - def _impl(inputs, input_types): + @staticmethod + def _logical_and(inputs, input_types): lhs = _op.cast(inputs[0], "bool") rhs = _op.cast(inputs[1], "bool") return _op.logical_and(lhs, rhs) - return _impl - - -def _nonzero(is_numpy_style): - def _impl(inputs, input_types): + def _nonzero(self, inputs, input_types, is_numpy_style=False): data = inputs[0] ret = _op.transform.argwhere(data) - if is_numpy_style or (len(inputs) > 1 and inputs[1]): - return _unbind()([ret, 1], None) - + return self._unbind([ret, 1], None) return ret - return _impl - - -def _scatter(): - def _impl(inputs, input_types): + @staticmethod + def _scatter(inputs, input_types): data = inputs[0] axis = int(inputs[1]) index = inputs[2] src = inputs[3] return _op.transform.scatter(data, index, src, axis) - return _impl - - -def _scalar_tensor(): - def _impl(inputs, input_types): + @staticmethod + def _scalar_tensor(inputs, input_types): data = inputs[0] cast_map = { 6: "float32", @@ -2309,11 +1945,8 @@ def _impl(inputs, input_types): data = data.data.asnumpy().tolist() return _expr.const(data, cast_map[type_key]) - return _impl - - -def _interpolate(): - def _impl(inputs, input_types): + @staticmethod + def _interpolate(inputs, input_types): if isinstance(inputs[1], _expr.Expr): out_size = inputs[1] elif isinstance(inputs[1], list): @@ -2342,26 +1975,17 @@ def _impl(inputs, input_types): return _op.image.resize(data, out_size, "NCHW", method, coord_trans) - return _impl - - -def _numel(): - def _impl(inputs, input_types): + @staticmethod + def _numel(inputs, input_types): return _op.ndarray_size(inputs[0]) - return _impl - - -def _empty(): - def _impl(inputs, input_types): + @staticmethod + def _empty(inputs, input_types): shape = inputs[0] return _op.zeros(shape, _convert_dtype_value(inputs[1])) - return _impl - - -def _bincount(): - def _impl(inputs, input_types): + @staticmethod + def _bincount(inputs, input_types): data = inputs[0] weights = inputs[1] maximum = _op.max(data) @@ -2377,18 +2001,208 @@ def _impl(inputs, input_types): counts = _op.zeros(_op.reshape(dim, [1]), out_dtype) return _op.scatter_add(counts, data, updates, axis=0) - return _impl - - -def _scatter_add(): - def _impl(inputs, input_types): + @staticmethod + def _scatter_add(inputs, input_types): data = inputs[0] axis = inputs[1] index = inputs[2] src = inputs[3] return _op.scatter_add(data, index, src, axis=axis) - return _impl + # Operator mappings + def _get_convert_map(self): + convert_map = { + "aten::pixel_shuffle": self._pixel_shuffle, + "aten::device": self._none, + "prim::device": self._none, + "aten::sub": functools.partial(self._elemwise, "subtract"), + "aten::sub_": functools.partial(self._elemwise, "subtract"), + "aten::max": self._max, + "aten::min": self._min, + "aten::mul": functools.partial(self._elemwise, "multiply"), + "aten::mul_": functools.partial(self._elemwise, "multiply"), + "aten::pow": functools.partial(self._elemwise, "power"), + "aten::arange": self._arange, + "aten::meshgrid": self._meshgrid, + "aten::div": functools.partial(self._elemwise, "divide"), + "aten::div_": functools.partial(self._elemwise, "divide"), + "aten::floor_divide": functools.partial(self._elemwise, "floor_divide"), + "aten::true_divide": functools.partial(self._elemwise, "divide"), + "aten::addcdiv": self._addcdiv, + "aten::addcmul": self._addcmul, + "aten::ones": self._ones, + "aten::ones_like": self._ones_like, + "aten::zeros": self._zeros, + "aten::zeros_like": self._zeros_like, + "aten::full": self._full, + "aten::full_like": self._full_like, + "aten::linspace": self._linspace, + "aten::reciprocal": self._reciprocal, + "aten::repeat": self._repeat, + "aten::repeat_interleave": self._repeat_interleave, + "aten::to": self._to, + "aten::squeeze": self._squeeze, + "aten::unsqueeze": self._unsqueeze, + "aten::cat": self._concatenate, + "aten::slice": self._slice, + "aten::split": self._split, + "aten::split_with_sizes": self._split_with_sizes, + "aten::select": self._select, + "aten::take": self._take, + "aten::where": self._where, + "aten::topk": self._topk, + "aten::relu": self._relu, + "aten::relu_": self._relu, + "aten::prelu": self._prelu, + "aten::leaky_relu": self._leaky_relu, + "aten::leaky_relu_": self._leaky_relu, + "aten::elu": self._elu, + "aten::elu_": self._elu, + "aten::celu": self._celu, + "aten::gelu": self._gelu, + "aten::selu": self._selu, + "aten::log_sigmoid": self._log_sigmoid, + "aten::adaptive_avg_pool2d": self._adaptive_avg_pool_2d, + "aten::adaptive_max_pool2d": self._adaptive_max_pool_2d, + "aten::max_pool2d": self._maxpool_2d, + "aten::max_pool2d_with_indices": self._maxpool_2d_with_indices, + "aten::max_pool1d": self._maxpool_1d, + "aten::max_pool3d": self._maxpool_3d, + "aten::hardtanh": self._hardtanh, + "aten::hardtanh_": self._hardtanh, + "aten::_convolution": self._convolution, + "aten::softmax": self._softmax, + "aten::threshold": self._threshold, + "aten::threshold_": self._threshold, + "aten::contiguous": self._contiguous, + "aten::batch_norm": self._batch_norm, + "aten::instance_norm": self._instance_norm, + "aten::layer_norm": self._layer_norm, + "aten::group_norm": self._group_norm, + "aten::transpose": self._transpose, + "aten::transpose_": self._transpose, + "aten::t": self._transpose, + "aten::flatten": self._flatten, + "aten::addmm": self._addmm, + "aten::size": self._size, + "aten::view": self._view, + "aten::reshape": self._reshape, + "aten::clone": self._clone, + "aten::log_softmax": self._log_softmax, + "aten::sigmoid": self._sigmoid, + "aten::softplus": self._softplus, + "aten::avg_pool2d": self._avg_pool2d, + "aten::avg_pool3d": self._avg_pool3d, + "aten::dropout": self._dropout, + "aten::dropout_": self._dropout, + "aten::feature_dropout": self._dropout, + "aten::alpha_dropout": self._dropout, + "aten::mean": self._mean, + "aten::chunk": self._chunk, + "aten::matmul": self._matmul, + "aten::bmm": self._matmul, + "aten::expand": self._expand, + "aten::Int": self._int, + "prim::NumToTensor": self._numtotensor, + "prim::ImplicitTensorToNum": self._tensortonum, + "aten::ScalarImplicit": self._tensortonum, + "aten::constant_pad_nd": functools.partial(self._pad, "constant"), + "aten::reflection_pad1d": functools.partial(self._pad, "reflect"), + "aten::reflection_pad2d": functools.partial(self._pad, "reflect"), + "aten::replication_pad1d": functools.partial(self._pad, "edge"), + "aten::replication_pad2d": functools.partial(self._pad, "edge"), + "aten::replication_pad3d": functools.partial(self._pad, "edge"), + "aten::permute": self._transpose, + "aten::sum": functools.partial(self._reduce, "sum"), + "aten::prod": functools.partial(self._reduce, "prod"), + "aten::argmin": functools.partial(self._reduce, "argmin"), + "aten::argmax": functools.partial(self._reduce, "argmax"), + "aten::norm": self._norm, + "aten::frobenius_norm": self._frobenius_norm, + "aten::std": self._std, + "aten::var": self._variance, + "aten::abs": functools.partial(self._unary, "abs"), + "aten::neg": functools.partial(self._unary, "negative"), + "aten::cos": functools.partial(self._unary, "cos"), + "aten::cosh": functools.partial(self._unary, "cosh"), + "aten::sin": functools.partial(self._unary, "sin"), + "aten::sinh": functools.partial(self._unary, "sinh"), + "aten::tan": functools.partial(self._unary, "tan"), + "aten::tanh": functools.partial(self._unary, "tanh"), + "aten::acos": functools.partial(self._unary, "acos"), + "aten::asin": functools.partial(self._unary, "asin"), + "aten::atan": functools.partial(self._unary, "atan"), + "aten::log": functools.partial(self._unary, "log"), + "aten::log2": functools.partial(self._unary, "log2"), + "aten::log10": functools.partial(self._unary, "log10"), + "aten::log1p": self._log1p, + "aten::exp": functools.partial(self._unary, "exp"), + "aten::erf": functools.partial(self._unary, "erf"), + "aten::trunc": functools.partial(self._unary, "trunc"), + "aten::sign": functools.partial(self._unary, "sign"), + "aten::sqrt": functools.partial(self._unary, "sqrt"), + "aten::rsqrt": functools.partial(self._unary, "rsqrt"), + "aten::ceil": functools.partial(self._unary, "ceil"), + "aten::floor": functools.partial(self._unary, "floor"), + "aten::round": functools.partial(self._unary, "round"), + "aten::isfinite": functools.partial(self._unary, "isfinite"), + "aten::isinf": functools.partial(self._unary, "isinf"), + "aten::isnan": functools.partial(self._unary, "isnan"), + "aten::clamp": self._clamp, + "aten::clamp_": self._clamp, + "aten::detach": self._identity, + "aten::upsample_bilinear2d": functools.partial(self._upsample, "bilinear"), + "aten::upsample_nearest2d": functools.partial(self._upsample, "nearest_neighbor"), + "aten::upsample_trilinear3d": functools.partial(self._upsample3d, "trilinear"), + "aten::upsample_nearest3d": functools.partial(self._upsample3d, "nearest_neighbor"), + "aten::expand_as": self._expand_as, + "aten::lt": functools.partial(self._elemwise, "less"), + "aten::gt": functools.partial(self._elemwise, "greater"), + "aten::le": functools.partial(self._elemwise, "less_equal"), + "aten::ge": functools.partial(self._elemwise, "greater_equal"), + "aten::ne": functools.partial(self._elemwise, "not_equal"), + "aten::eq": functools.partial(self._elemwise, "equal"), + "aten::logical_not": self._logical_not, + "aten::logical_xor": self._logical_xor, + "aten::bitwise_not": self._bitwise_not, + "aten::bitwise_xor": self._bitwise_xor, + "aten::Bool": self._Bool, + "aten::Float": self._Float, + "aten::adaptive_avg_pool3d": self._adaptive_avg_pool_3d, + "aten::adaptive_max_pool3d": self._adaptive_max_pool_3d, + "aten::rsub": self._rsub, + "aten::embedding": self._embedding, + "aten::one_hot": self._one_hot, + "aten::mm": self._matmul, + "aten::add": self._add, + "aten::add_": self._add, + "aten::stack": self._stack, + "aten::__getitem__": self._list_getitem, + "aten::len": self._list_len, + "aten::type_as": self._type_as, + "aten::gather": self._gather, + "aten::index_select": self._select, + "aten::index": self._index, + "torchvision::nms": self._nms, + "aten::logsumexp": self._logsumexp, + "torchvision::roi_align": self._roi_align, + "aten::unbind": self._unbind, + "aten::__and__": self._logical_and, + "aten::_shape_as_tensor": self._shape_as_tensor, + "aten::nonzero": self._nonzero, + "aten::nonzero_numpy": functools.partial(self._nonzero, is_numpy_style=True), + "aten::scatter": self._scatter, + "aten::scalar_tensor": self._scalar_tensor, + "aten::__interpolate": self._interpolate, + "aten::IntImplicit": self._identity, + "aten::tensor": self._identity, # used for example in tensor(1.0) + "aten::numel": self._numel, + "aten::empty": self._empty, + "aten::bincount": self._bincount, + "aten::scatter_add": self._scatter_add, + "aten::__not__": self._logical_not, + } + return convert_map def _pytorch_result_type(dtypes, non_tensor_inputs): @@ -2544,202 +2358,6 @@ def _wrap_const(c): return c -# Operator mappings -def _get_convert_map(prelude, default_dtype): - convert_map = { - "aten::pixel_shuffle": _pixel_shuffle(prelude), - "aten::device": _none(), - "prim::device": _none(), - "aten::sub": _elemwise("subtract"), - "aten::sub_": _elemwise("subtract"), - "aten::max": _max(), - "aten::min": _min(), - "aten::mul": _elemwise("multiply"), - "aten::mul_": _elemwise("multiply"), - "aten::pow": _elemwise("power"), - "aten::arange": _arange(), - "aten::meshgrid": _meshgrid(), - "aten::div": _elemwise("divide"), - "aten::div_": _elemwise("divide"), - "aten::floor_divide": _elemwise("floor_divide"), - "aten::true_divide": _elemwise("divide"), - "aten::addcdiv": _addcdiv(), - "aten::addcmul": _addcmul(), - "aten::ones": _ones(default_dtype), - "aten::ones_like": _ones_like(default_dtype), - "aten::zeros": _zeros(default_dtype), - "aten::zeros_like": _zeros_like(default_dtype), - "aten::full": _full(default_dtype), - "aten::full_like": _full_like(default_dtype), - "aten::linspace": _linspace(), - "aten::reciprocal": _reciprocal(), - "aten::repeat": _repeat(), - "aten::repeat_interleave": _repeat_interleave(), - "aten::to": _to(), - "aten::squeeze": _squeeze(), - "aten::unsqueeze": _unsqueeze(), - "aten::cat": _concatenate(prelude), - "aten::slice": _slice(), - "aten::split": _split(), - "aten::split_with_sizes": _split_with_sizes(), - "aten::select": _select(), - "aten::take": _take(), - "aten::where": _where(), - "aten::topk": _topk(), - "aten::relu": _relu(prelude), - "aten::relu_": _relu(prelude), - "aten::prelu": _prelu(), - "aten::leaky_relu": _leaky_relu(), - "aten::leaky_relu_": _leaky_relu(), - "aten::elu": _elu(), - "aten::elu_": _elu(), - "aten::celu": _celu(), - "aten::gelu": _gelu(), - "aten::selu": _selu(), - "aten::log_sigmoid": _log_sigmoid(), - "aten::adaptive_avg_pool2d": _adaptive_avg_pool_2d(prelude), - "aten::adaptive_max_pool2d": _adaptive_max_pool_2d(), - "aten::max_pool2d": _maxpool_2d(), - "aten::max_pool2d_with_indices": _maxpool_2d_with_indices(), - "aten::max_pool1d": _maxpool_1d(), - "aten::max_pool3d": _maxpool_3d(), - "aten::hardtanh": _hardtanh(), - "aten::hardtanh_": _hardtanh(), - "aten::_convolution": _convolution(), - "aten::softmax": _softmax(), - "aten::threshold": _threshold(), - "aten::threshold_": _threshold(), - "aten::contiguous": _contiguous(), - "aten::batch_norm": _batch_norm(), - "aten::instance_norm": _instance_norm(), - "aten::layer_norm": _layer_norm(), - "aten::group_norm": _group_norm(), - "aten::transpose": _transpose(prelude), - "aten::transpose_": _transpose(prelude), - "aten::t": _transpose(prelude), - "aten::flatten": _flatten(), - "aten::addmm": _addmm(), - "aten::size": _size(prelude), - "aten::view": _view(), - "aten::reshape": _reshape(), - "aten::clone": _clone(), - "aten::log_softmax": _log_softmax(), - "aten::sigmoid": _sigmoid(), - "aten::softplus": _softplus(), - "aten::avg_pool2d": _avg_pool2d(prelude), - "aten::avg_pool3d": _avg_pool3d(), - "aten::dropout": _dropout(), - "aten::dropout_": _dropout(), - "aten::feature_dropout": _dropout(), - "aten::alpha_dropout": _dropout(), - "aten::mean": _mean(prelude), - "aten::chunk": _chunk(prelude), - "aten::matmul": _matmul(prelude), - "aten::bmm": _matmul(prelude), - "aten::expand": _expand(), - "aten::Int": _int(), - "prim::NumToTensor": _numtotensor(), - "prim::ImplicitTensorToNum": _tensortonum(), - "aten::ScalarImplicit": _tensortonum(), - "aten::constant_pad_nd": _pad("constant"), - "aten::reflection_pad1d": _pad("reflect"), - "aten::reflection_pad2d": _pad("reflect"), - "aten::replication_pad1d": _pad("edge"), - "aten::replication_pad2d": _pad("edge"), - "aten::replication_pad3d": _pad("edge"), - "aten::permute": _transpose(prelude), - "aten::sum": _reduce("sum"), - "aten::prod": _reduce("prod"), - "aten::argmin": _reduce("argmin"), - "aten::argmax": _reduce("argmax"), - "aten::norm": _norm(), - "aten::frobenius_norm": _frobenius_norm(), - "aten::std": _std(), - "aten::var": _variance(), - "aten::abs": _unary("abs"), - "aten::neg": _unary("negative"), - "aten::cos": _unary("cos"), - "aten::cosh": _unary("cosh"), - "aten::sin": _unary("sin"), - "aten::sinh": _unary("sinh"), - "aten::tan": _unary("tan"), - "aten::tanh": _unary("tanh"), - "aten::acos": _unary("acos"), - "aten::asin": _unary("asin"), - "aten::atan": _unary("atan"), - "aten::log": _unary("log"), - "aten::log2": _unary("log2"), - "aten::log10": _unary("log10"), - "aten::log1p": _log1p(), - "aten::exp": _unary("exp"), - "aten::erf": _unary("erf"), - "aten::trunc": _unary("trunc"), - "aten::sign": _unary("sign"), - "aten::sqrt": _unary("sqrt"), - "aten::rsqrt": _unary("rsqrt"), - "aten::ceil": _unary("ceil"), - "aten::floor": _unary("floor"), - "aten::round": _unary("round"), - "aten::isfinite": _unary("isfinite"), - "aten::isinf": _unary("isinf"), - "aten::isnan": _unary("isnan"), - "aten::clamp": _clamp(), - "aten::clamp_": _clamp(), - "aten::detach": _identity(), - "aten::upsample_bilinear2d": _upsample("bilinear", prelude), - "aten::upsample_nearest2d": _upsample("nearest_neighbor", prelude), - "aten::upsample_trilinear3d": _upsample3d("trilinear"), - "aten::upsample_nearest3d": _upsample3d("nearest_neighbor"), - "aten::expand_as": _expand_as(), - "aten::lt": _elemwise("less"), - "aten::gt": _elemwise("greater"), - "aten::le": _elemwise("less_equal"), - "aten::ge": _elemwise("greater_equal"), - "aten::ne": _elemwise("not_equal"), - "aten::eq": _elemwise("equal"), - "aten::logical_not": _logical_not(), - "aten::logical_xor": _logical_xor(), - "aten::bitwise_not": _bitwise_not(), - "aten::bitwise_xor": _bitwise_xor(), - "aten::Bool": _Bool(), - "aten::Float": _Float(), - "aten::adaptive_avg_pool3d": _adaptive_avg_pool_3d(), - "aten::adaptive_max_pool3d": _adaptive_max_pool_3d(), - "aten::rsub": _rsub(), - "aten::embedding": _embedding(), - "aten::one_hot": _one_hot(), - "aten::mm": _matmul(prelude), - "aten::add": _add(prelude), - "aten::add_": _add(prelude), - "aten::stack": _stack(prelude), - "aten::__getitem__": _list_getitem(prelude), - "aten::len": _list_len(prelude), - "aten::type_as": _type_as(), - "aten::gather": _gather(), - "aten::index_select": _select(), - "aten::index": _index(), - "torchvision::nms": _nms(prelude), - "aten::logsumexp": _logsumexp(), - "torchvision::roi_align": _roi_align(prelude), - "aten::unbind": _unbind(), - "aten::__and__": _logical_and(), - "aten::_shape_as_tensor": _shape_as_tensor(prelude), - "aten::nonzero": _nonzero(False), - "aten::nonzero_numpy": _nonzero(True), - "aten::scatter": _scatter(), - "aten::scalar_tensor": _scalar_tensor(), - "aten::__interpolate": _interpolate(), - "aten::IntImplicit": _identity(), - "aten::tensor": _identity(), # used for example in tensor(1.0) - "aten::numel": _numel(), - "aten::empty": _empty(), - "aten::bincount": _bincount(), - "aten::scatter_add": _scatter_add(), - "aten::__not__": _logical_not(), - } - return convert_map - - def _run_jit_passes(graph): """ The inline pass is necessary to unwrap prim::CallMethod """ # pylint: disable=c-extension-no-member @@ -3370,7 +2988,8 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt mod = tvm.IRModule() prelude = Prelude(mod) - convert_map = _get_convert_map(prelude, default_dtype) + converter = PyTorchOpConverter(prelude, default_dtype) + convert_map = converter._get_convert_map() graph = script_module.graph.copy() _run_jit_passes(graph)