Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[draft] Constant folding for StridedSlice op #12713

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions compiler/circle2circle/src/Circle2Circle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ int entry(int argc, char **argv)
add_switch(arser, "--fold_gather", "This will fold Gather operator");
add_switch(arser, "--fold_shape", "This will fold Shape operator");
add_switch(arser, "--fold_sparse_to_dense", "This will fold SparseToDense operator");
add_switch(arser, "--fold_strided_slice", "This will fold StridedSlice operator");
add_switch(arser, "--forward_reshape_to_unaryop",
"This will move Reshape after UnaryOp for centain condition");
add_switch(arser, "--forward_transpose_op",
Expand Down Expand Up @@ -272,6 +273,8 @@ int entry(int argc, char **argv)
options->enable(Algorithms::FoldShape);
if (arser.get<bool>("--fold_sparse_to_dense"))
options->enable(Algorithms::FoldSparseToDense);
if (arser.get<bool>("--fold_strided_slice"))
options->enable(Algorithms::FoldStridedSlice);
if (arser.get<bool>("--forward_reshape_to_unaryop"))
options->enable(Algorithms::ForwardReshapeToUnaryOp);
if (arser.get<bool>("--forward_transpose_op"))
Expand Down
88 changes: 88 additions & 0 deletions compiler/luci-compute/include/luci_compute/StridedSlice.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __LUCI_COMPUTE_STRIDED_SLICE_H__
#define __LUCI_COMPUTE_STRIDED_SLICE_H__

#include "Types.h"

#include <loco/IR/TensorShape.h>

namespace luci
{
namespace compute
{

template <typename T> class StridedSlice
{
public:
StridedSlice() = default;

public:
StridedSliceParams &params(void) { return _params; }

void input(const loco::TensorShape &shape, const T *data)
{
_input_shape = shape;
_input_data = data;
}

void begin(const loco::TensorShape &shape, const T *data)
{
_begin_shape = shape;
_begin_data = data;
}

void end(const loco::TensorShape &shape, const T *data)
{
_end_shape = shape;
_end_data = data;
}

void strides(const loco::TensorShape &shape, const T *data)
{
_strides_shape = shape;
_strides_data = data;
}

void output(T *data) { _output_data = data; }

public:
const loco::TensorShape &output_shape(void) const { return _output_shape; }
bool prepare(void);
void compute(void);

private:
// param to pass to compute kernel
StridedSliceParams _params = {};
// shape and data for inputs
loco::TensorShape _input_shape;
loco::TensorShape _begin_shape;
loco::TensorShape _end_shape;
loco::TensorShape _strides_shape;
const T *_input_data = nullptr;
const T *_begin_data = nullptr;
const T *_end_data = nullptr;
const T *_strides_data = nullptr;

// compute results
loco::TensorShape _output_shape;
T *_output_data = nullptr;
};

} // namespace compute
} // namespace luci

#endif // __LUCI_COMPUTE_STRIDED_SLICE_H__
18 changes: 18 additions & 0 deletions compiler/luci-compute/include/luci_compute/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,24 @@ struct FullyConnectedParams
FullyConnectedWeightsFormat weights_format;
};

// from tflite as-is
struct StridedSliceParams
{
int8_t start_indices_count;
int32_t start_indices[5];
int8_t stop_indices_count;
int32_t stop_indices[5];
int8_t strides_count;
int32_t strides[5];

uint16_t begin_mask;
uint16_t ellipsis_mask;
uint16_t end_mask;
uint16_t new_axis_mask;
uint16_t shrink_axis_mask;
bool offset;
};

// from luci as-is
enum class FusedActFunc
{
Expand Down
132 changes: 132 additions & 0 deletions compiler/luci-compute/src/StridedSlice.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci_compute/Types.h"
#include "luci_compute/StridedSlice.h"

#include "ConvertTypes.h"
#include "ConvertValues.h"

#include <tensorflow/lite/kernels/internal/reference/strided_slice.h>

#include <cassert>
#include <cstdint>

namespace luci
{
namespace compute
{

template bool StridedSlice<float>::prepare(void);
template bool StridedSlice<int32_t>::prepare(void);
template void StridedSlice<float>::compute(void);
template void StridedSlice<int32_t>::compute(void);

template <typename T> bool StridedSlice<T>::prepare(void)
{
assert(_begin_shape.rank() == 1);
assert(_end_shape.rank() == 1);
assert(_strides_shape.rank() == 1);
assert(_input_shape.rank() <= 4);
if (_params.ellipsis_mask != 0)
{
throw std::runtime_error("ellipsis_mask is not implemented yet.");
}
if (_params.new_axis_mask != 0)
{
throw std::runtime_error("new_axis_mask is not implemented yet.");
}

tflite::StridedSliceParams params;

// clang-format off
params.start_indices_count = _params.start_indices_count;
params.stop_indices_count = _params.stop_indices_count;
params.strides_count = _params.strides_count;
for (auto i = 0; i < _input_shape.rank(); ++i)
{
params.start_indices[i] = _params.start_indices[i];
params.stop_indices[i] = _params.stop_indices[i];
params.strides[i] = _params.strides[i];
}
params.begin_mask = _params.begin_mask;
params.ellipsis_mask = 0;
params.end_mask = _params.end_mask;
params.new_axis_mask = 0;
params.shrink_axis_mask = _params.shrink_axis_mask;
// clang-format on

std::vector<int32_t> output_shape_vector;
for (auto i = 0; i < _input_shape.rank(); ++i)
{
auto idx = _input_shape.rank() - i - 1;
auto stride = _strides_data[idx];
assert(stride != 0);
auto begin = ::tflite::strided_slice::StartForAxis(params, tflite_shape(_input_shape), idx);
auto end = ::tflite::strided_slice::StopForAxis(params, tflite_shape(_input_shape), idx, begin);

const bool shrink_axis = params.shrink_axis_mask & (1 << idx);
if (shrink_axis)
{
end = begin + 1;
}

auto dim_shape = std::ceil((end - begin) / static_cast<float>(stride));
dim_shape = dim_shape < 0 ? 0 : dim_shape;
if (!shrink_axis)
{
output_shape_vector.emplace_back(dim_shape);
}
}

_output_shape.rank(output_shape_vector.size());
for (auto i = 0; i < output_shape_vector.size(); ++i)
{
_output_shape.dim(i) = output_shape_vector[output_shape_vector.size() - i - 1];
}

return true;
}

template <typename T> void StridedSlice<T>::compute(void)
{
// NOTE if this fails, structure may have changed
static_assert(sizeof(compute::StridedSliceParams) == sizeof(tflite::StridedSliceParams));

tflite::StridedSliceParams params;

// clang-format off
params.start_indices_count = _params.start_indices_count;
params.stop_indices_count = _params.stop_indices_count;
params.strides_count = _params.strides_count;
for (int i = 0; i < _input_shape.rank(); i++)
{
params.start_indices[i] = _params.start_indices[i];
params.stop_indices[i] = _params.stop_indices[i];
params.strides[i] = _params.strides[i];
}
params.begin_mask = _params.begin_mask;
params.ellipsis_mask = _params.ellipsis_mask;
params.end_mask = _params.end_mask;
params.new_axis_mask = _params.new_axis_mask;
params.shrink_axis_mask = _params.shrink_axis_mask;
// clang-format on

tflite::reference_ops::StridedSlice(params, tflite_shape(_input_shape), _input_data,
tflite_shape(_output_shape), _output_data);
}

} // namespace compute
} // namespace luci
Loading
Loading