-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
tensor.py
203 lines (159 loc) · 5.75 KB
/
tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tensor class for computation declaration."""
# pylint: disable=invalid-name
import tvm._ffi
from tvm.runtime import Object, ObjectGeneric, convert_to_object
from tvm.tir import expr as _expr, DataProducer
from . import _ffi_api
class TensorSlice(ObjectGeneric, _expr.ExprOp):
"""Auxiliary data structure for enable slicing syntax from tensor."""
def __init__(self, tensor, indices):
if not isinstance(indices, tuple):
indices = (indices,)
self.tensor = tensor
self.indices = indices
def __getitem__(self, indices):
if not isinstance(indices, tuple):
indices = (indices,)
return TensorSlice(self.tensor, self.indices + indices)
def asobject(self):
"""Convert slice to object."""
return self.tensor(*self.indices)
@property
def dtype(self):
"""Data content of the tensor."""
return self.tensor.dtype
@tvm._ffi.register_object
class TensorIntrinCall(Object):
"""Intermediate structure for calling a tensor intrinsic."""
@tvm._ffi.register_object
class Tensor(DataProducer, _expr.ExprOp):
"""Tensor object, to construct, see function.Tensor"""
def __call__(self, *indices):
ndim = self.ndim
if len(indices) != ndim:
raise ValueError("Need to provide %d index in tensor slice" % ndim)
indices = convert_to_object(indices)
args = []
for x in indices:
if isinstance(x, _expr.PrimExpr):
args.append(x)
elif isinstance(x, _expr.IterVar):
args.append(x.var)
else:
raise ValueError("The indices must be expression")
return _expr.ProducerLoad(self, args)
def __getitem__(self, indices):
return TensorSlice(self, indices)
def __hash__(self):
return _ffi_api.TensorHash(self)
def __eq__(self, other):
if not isinstance(other, Tensor):
if isinstance(other, _expr.ExprOp):
return _expr.EqualOp(self, other)
return False
if self.ndim == 0 and other.ndim == 0:
raise ValueError("Equal == comparison among rank-0 tensor is ambiguous, "
"use Tensor.equal for content expression equvalence, "
"use Tensor.same_as for exact reference comparison")
return _ffi_api.TensorEqual(self, other)
@property
def ndim(self):
"""Dimension of the tensor."""
return len(self.shape)
@property
def axis(self):
"""Axis of the tensor."""
return self.__getattr__("axis")
@property
def op(self):
"""The corressponding :py:class:`Operation`."""
return self.__getattr__("op")
@property
def value_index(self):
"""The output value index the tensor corresponds to."""
return self.__getattr__("value_index")
@property
def shape(self):
"""The output shape of the tensor."""
return self.__getattr__("shape")
@property
def name(self):
op = self.op
if op.num_outputs == 1:
return op.name
return "%s.v%d" % (op.name, self.value_index)
class Operation(Object):
"""Represent an operation that generates a tensor"""
def output(self, index):
"""Get the index-th output of the operation
Parameters
----------
index : int
The index size.
Returns
-------
out : Tensor
The i-th output.
"""
return _ffi_api.OpGetOutput(self, index)
@property
def num_outputs(self):
"""Number of outputs from this op."""
return _ffi_api.OpNumOutputs(self)
@property
def input_tensors(self):
"""List of input tensors to this op."""
return _ffi_api.OpInputTensors(self)
@tvm._ffi.register_object
class PlaceholderOp(Operation):
"""Placeholder operation."""
@tvm._ffi.register_object
class BaseComputeOp(Operation):
"""Compute operation."""
@property
def axis(self):
"""Represent the IterVar axis, defined when it is a ComputeOp"""
return self.__getattr__("axis")
@property
def reduce_axis(self):
"""Represent axis of reductions, only defined when it is a ComputeOp"""
return self.__getattr__("reduce_axis")
@tvm._ffi.register_object
class ComputeOp(BaseComputeOp):
"""Scalar operation."""
@tvm._ffi.register_object
class TensorComputeOp(BaseComputeOp):
"""Tensor operation."""
@tvm._ffi.register_object
class ScanOp(Operation):
"""Scan operation."""
@property
def scan_axis(self):
"""Represent the scan axis, only defined when it is a ScanOp"""
return self.__getattr__("scan_axis")
@tvm._ffi.register_object
class ExternOp(Operation):
"""External operation."""
@tvm._ffi.register_object
class HybridOp(Operation):
"""Hybrid operation."""
@property
def axis(self):
"""Represent the IterVar axis, also defined when it is a HybridOp"""
return self.__getattr__("axis")