Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue 1677 #1686

Merged
merged 2 commits into from
Jun 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions source/tnn/device/cpu/acc/cpu_gathernd_layer_acc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,39 +28,40 @@ Status CpuGatherNDLayerAcc::Forward(const std::vector<Blob *> &inputs, const std
auto layer_param = dynamic_cast<GatherNDLayerParam*>(param_);
CHECK_PARAM_NULL(layer_param);
int batch_dims = layer_param->batch_dims;

if (batch_dims != 0) {
return Status(TNNERR_PARAM_ERR, "GatherNDLayerParam has invalid param batch_dims");
}

auto input_data_dims = (*(inputs.begin()))->GetBlobDesc().dims;
auto input_data_ptr = (char*)(*(inputs.begin()))->GetHandle().base + (*(inputs.begin()))->GetHandle().bytes_offset;
auto output_data_ptr = (char*)(*(outputs.begin()))->GetHandle().base + (*(outputs.begin()))->GetHandle().bytes_offset;
auto input_stride = DimsFunctionUtils::StrideOfShape(input_data_dims);

auto indices_dims = (*(inputs.rbegin()))->GetBlobDesc().dims;
int *indices_data_ptr = (int *)(*(inputs.rbegin()))->GetHandle().base;

if (indices_dims[indices_dims.size()-1] != input_data_dims.size()) {
if (indices_dims[indices_dims.size() - 1] > input_data_dims.size()) {
return Status(TNNERR_PARAM_ERR, "GatherNDLayerParam has invalid param indices_dims");
}

const int slice_index_size = indices_dims[indices_dims.size()-1];
const int ele_size = DataTypeUtils::GetBytesSize(outputs[0]->GetBlobDesc().data_type);

const int ele_count =
DimsVectorUtils::Count(input_data_dims, input_data_dims.size() - indices_dims[indices_dims.size() - 1], -1);
const int output_slice_count = DimsVectorUtils::Count(indices_dims, 0, (int)indices_dims.size()-1);
for (int i=0; i<output_slice_count; i++) {
auto output_index = i;

int *indices_ptr = indices_data_ptr + i * slice_index_size;

int input_index = 0;
for (int ii=0; ii<slice_index_size; ii++) {
input_index += indices_ptr[ii] *input_stride[ii];
}
memcpy(output_data_ptr + output_index*ele_size,
input_data_ptr + input_index*ele_size,
1 * ele_size);
ele_count * ele_size);
}
return TNN_OK;
}
Expand Down
2 changes: 1 addition & 1 deletion source/tnn/utils/split_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ Status SplitUtils::SplitStr(const char *str, str_arr &subs_array, const char spl
return TNN_OK;
}

const int subs_length = 2048;
const int subs_length = 4096;
char *subs = (char *)calloc(subs_length, sizeof(char));

for (int i = 0, cursor = 0;; i += step) {
Expand Down
12 changes: 11 additions & 1 deletion tools/convert2tnn/utils/align_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ def get_input_from_file(path: str) -> dict:
data.append(float(f.readline().strip('\n')))
np_data = np.reshape(np.array(data).astype(np.float32), dims)
input_dict.update({input_name: np_data})
elif data_type == 2:
# bool
for j in range(count):
data.append(int(f.readline().strip('\n')))
np_data = np.array(data).astype(np.bool).reshape(dims)
input_dict.update({input_name: np_data})
elif data_type == 3:
#int32
for j in range(count):
Expand Down Expand Up @@ -106,6 +112,8 @@ def run_onnx(model_path: str, input_path: str, input_info: dict) -> str:
data_type = np.int64
elif item.type == "tensor(int32)":
data_type = np.int32
elif item.type == "tensor(bool)":
data_type = np.bool
input_data_dict[item.name] = input_data_dict[item.name].astype(data_type)

output_info = session.get_outputs()
Expand Down Expand Up @@ -245,7 +253,9 @@ def get_input_shape_from_onnx(onnx_path) -> dict:
data_type = 0
if ip.type == 'tensor(float)':
data_type = 0
elif ip.type == 'tensor(int64)':
elif ip.type == 'tensor(bool)':
data_type = 2
elif ip.type == 'tensor(int64)' or ip.type == 'tensor(int32)':
data_type = 3
else:
logging.error("Do not support input date type")
Expand Down
2 changes: 1 addition & 1 deletion tools/convert2tnn/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def gene_random_data(input_info: dict) -> str:
if data_type == 0:
data[name] = np.random.rand(*shape)
np.savetxt(data_file, data[name].reshape(-1), fmt="%0.6f")
elif data_type == 3:
elif data_type == 2 or data_type == 3:
# range [low, high)
data[name] = np.random.randint(low=0, high=2, size=shape)
np.savetxt(data_file, data[name].reshape(-1), fmt="%i")
Expand Down
4 changes: 3 additions & 1 deletion tools/convert2tnn/utils/run_onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ def get_src_model_input_information(self) -> dict:
data_type = 0
if ip.type == 'tensor(float)':
data_type = 0
elif ip.type == 'tensor(int64)':
elif ip.type == 'tensor(int64)' or ip.type == 'tensor(int32)':
data_type = 3
elif ip.type == 'tensor(bool)':
data_type = 2
else:
logging.error("Do not support input date type")
if type(shape[0]) is not int:
Expand Down
2 changes: 1 addition & 1 deletion tools/convert2tnn/utils/run_src_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def generate_input_data(self, input_information: dict, tnn_model_input_informati
if data_type == 0:
self.input_data[name] = np.random.rand(*shape).astype(np.float32)
np.savetxt(data_file, self.input_data[name].reshape(-1), fmt="%0.6f")
elif data_type == 3:
elif data_type == 2 or data_type == 3:
# range [low, high)
self.input_data[name] = np.random.randint(low=0, high=2, size=shape).astype(np.int64)
np.savetxt(data_file, self.input_data[name].reshape(-1), fmt="%i")
Expand Down
6 changes: 5 additions & 1 deletion tools/onnx2tnn/src/core/layer/onnx_converter_gather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ int OnnxOpConverterGather::WriteTNNModel(Serializer *net_writer, NodeProto &node
auto indices_iter = net_info.weights_map.find(node.input(1));
if (data_iter != net_info.weights_map.end()) {
net_writer->PutInt(1);
WriteTensorData(data_iter->second, net_writer, net_info.data_type);
DataType dst_data_type = net_info.data_type;
if (data_iter->second.data_type() == onnx::TensorProto_DataType_INT32) {
dst_data_type = DATA_TYPE_INT32;
}
WriteTensorData(data_iter->second, net_writer, dst_data_type);
} else {
net_writer->PutInt(0);
}
Expand Down
12 changes: 11 additions & 1 deletion tools/onnx2tnn/src/core/layer/onnx_converter_onehot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,23 @@
#include "onnx_op_converter.h"
#include "onnx_utility.h"

DECLARE_OP_CONVERTER(OneHot);
DECLARE_OP_CONVERTER_WITH_FUNC(OneHot, virtual std::vector<std::string> GetValidInputNames(NodeProto &node,
OnnxNetInfo &net_info););

string OnnxOpConverterOneHot::TNNOpType(NodeProto &node,
OnnxNetInfo &net_info) {
return "OneHot";
}

std::vector<std::string> OnnxOpConverterOneHot::GetValidInputNames(NodeProto &node, OnnxNetInfo &net_info) {
const int input_size = node.input_size();
std::vector<std::string> inputs(input_size);
for (int i = 0; i < input_size; i++) {
inputs[i] = node.input(i);
}
return inputs;
}

string OnnxOpConverterOneHot::TNNLayerParam(NodeProto &node,
OnnxNetInfo &net_info) {
const std::string &onnx_op = node.op_type();
Expand Down