Skip to content

Commit

Permalink
Merge branch 'fix_windows_x86_build' of https://github.com/Tencent/TNN
Browse files Browse the repository at this point in the history
…into fix_windows_x86_build
  • Loading branch information
Maosquerade committed Jun 28, 2022
2 parents d9edfab + e26df3f commit dac71f7
Show file tree
Hide file tree
Showing 13 changed files with 198 additions and 58 deletions.
3 changes: 2 additions & 1 deletion source/tnn/core/default_network.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,9 @@ Status DefaultNetwork::Init(NetworkConfig &net_config, ModelConfig &model_config

/*
* decode dynamic quantization model for const folder.
* coreml can not support this optimize
*/
if (runtime_model_ == RUNTIME_MODE_CONST_FOLD) {
if (runtime_model_ == RUNTIME_MODE_CONST_FOLD && net_config.network_type != NETWORK_TYPE_COREML) {
std::unique_lock<std::mutex> lck(optimize_mtx_);
auto optimizer = optimizer::NetOptimizerManager::GetNetOptimizerByName("net_optimizer_dynamic_range_dequant");
ret = optimizer->Optimize(net_structure, net_resource);
Expand Down
17 changes: 9 additions & 8 deletions source/tnn/device/cpu/acc/cpu_gathernd_layer_acc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,39 +28,40 @@ Status CpuGatherNDLayerAcc::Forward(const std::vector<Blob *> &inputs, const std
auto layer_param = dynamic_cast<GatherNDLayerParam*>(param_);
CHECK_PARAM_NULL(layer_param);
int batch_dims = layer_param->batch_dims;

if (batch_dims != 0) {
return Status(TNNERR_PARAM_ERR, "GatherNDLayerParam has invalid param batch_dims");
}

auto input_data_dims = (*(inputs.begin()))->GetBlobDesc().dims;
auto input_data_ptr = (char*)(*(inputs.begin()))->GetHandle().base + (*(inputs.begin()))->GetHandle().bytes_offset;
auto output_data_ptr = (char*)(*(outputs.begin()))->GetHandle().base + (*(outputs.begin()))->GetHandle().bytes_offset;
auto input_stride = DimsFunctionUtils::StrideOfShape(input_data_dims);

auto indices_dims = (*(inputs.rbegin()))->GetBlobDesc().dims;
int *indices_data_ptr = (int *)(*(inputs.rbegin()))->GetHandle().base;

if (indices_dims[indices_dims.size()-1] != input_data_dims.size()) {
if (indices_dims[indices_dims.size() - 1] > input_data_dims.size()) {
return Status(TNNERR_PARAM_ERR, "GatherNDLayerParam has invalid param indices_dims");
}

const int slice_index_size = indices_dims[indices_dims.size()-1];
const int ele_size = DataTypeUtils::GetBytesSize(outputs[0]->GetBlobDesc().data_type);

const int ele_count =
DimsVectorUtils::Count(input_data_dims, input_data_dims.size() - indices_dims[indices_dims.size() - 1], -1);
const int output_slice_count = DimsVectorUtils::Count(indices_dims, 0, (int)indices_dims.size()-1);
for (int i=0; i<output_slice_count; i++) {
auto output_index = i;

int *indices_ptr = indices_data_ptr + i * slice_index_size;

int input_index = 0;
for (int ii=0; ii<slice_index_size; ii++) {
input_index += indices_ptr[ii] *input_stride[ii];
}
memcpy(output_data_ptr + output_index*ele_size,
input_data_ptr + input_index*ele_size,
1 * ele_size);
ele_count * ele_size);
}
return TNN_OK;
}
Expand Down
24 changes: 23 additions & 1 deletion source/tnn/device/metal/metal_blob_converter.metal
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,17 @@ kernel void data_converter_nc4hw4_2_nchw_float_v2(
data_converter_nc4hw4_2_nchw_v2<ftype, ftype4, float, float4>(dst, src, params, scale, bias, gid);
}

kernel void data_converter_nc4hw4_2_ngray_v2(
device uchar *dst [[buffer(0)]],
const device ftype4 *src [[buffer(1)]],
constant MetalImageConverterParams& params [[buffer(2)]],
const device float *scale [[buffer(3)]],
const device float *bias [[buffer(4)]],
uint3 gid [[thread_position_in_grid]])
{
data_converter_nc4hw4_2_nchw_v2<ftype, ftype4, uchar, uchar4>(dst, src, params, scale, bias, gid);
}

kernel void data_converter_nc4hw4_2_nchw_half_v2(
device half *dst [[buffer(0)]],
const device ftype4 *src [[buffer(1)]],
Expand Down Expand Up @@ -428,6 +439,17 @@ kernel void data_converter_nchw_2_nc4hw4_float_v2(
data_converter_nchw_2_nc4hw4_v2<float, float4, ftype, ftype4>(dst, src, params, scale, bias, gid);
}

kernel void data_converter_ngray_2_nc4hw4_float_v2(
device ftype4 *dst [[buffer(0)]],
const device uchar *src [[buffer(1)]],
constant MetalImageConverterParams& params [[buffer(2)]],
const device float *scale [[buffer(3)]],
const device float *bias [[buffer(4)]],
uint3 gid [[thread_position_in_grid]])
{
data_converter_nchw_2_nc4hw4_v2<uchar, uchar4, ftype, ftype4>(dst, src, params, scale, bias, gid);
}

kernel void data_converter_nchw_2_nc4hw4_half_v2(
device ftype4 *dst [[buffer(0)]],
const device half *src [[buffer(1)]],
Expand Down Expand Up @@ -532,4 +554,4 @@ kernel void data_converter_nchw_int(device int *dst [[buffer(0)]],
uint3 gid [[thread_position_in_grid]])
{
data_converter_nchw_copy_type<int, int>(dst, src, params, gid);
}
}
26 changes: 20 additions & 6 deletions source/tnn/device/metal/metal_blob_converter.mm
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Status AllocateComputePipeline(MatConvertParam param, Mat *mat, Blob *blob, bool
bool device_supported = (device_type == DEVICE_METAL || device_type == DEVICE_ARM ||
device_type == DEVICE_X86 || device_type == DEVICE_NAIVE);

bool mat_supported = (mat_type == N8UC4 || mat_type == N8UC3 ||
bool mat_supported = (mat_type == N8UC4 || mat_type == N8UC3 || mat_type == NGRAY||
mat_type == NCHW_FLOAT || mat_type == RESERVED_BFP16_TEST || mat_type == NC_INT32);

return device_supported && mat_supported;
Expand Down Expand Up @@ -93,7 +93,7 @@ Status AllocateComputePipeline(MatConvertParam param, Mat *mat, Blob *blob, bool
bias_texture_buffer = is_mat_to_blob ? 1.0 : 1.0 / 255.0f;
}

if (mat_type == NCHW_FLOAT || mat_type == RESERVED_BFP16_TEST || mat_type == NC_INT32) {
if (mat_type == NCHW_FLOAT || mat_type == NGRAY || mat_type == RESERVED_BFP16_TEST || mat_type == NC_INT32) {
// scale and bias should at least have channel elements, so we use another buffer instead of metal_param
if (param.scale.size() < metal_param.channel || param.bias.size() < metal_param.channel) {
// invalid scale and bias
Expand Down Expand Up @@ -213,6 +213,18 @@ Status AllocateComputePipeline(MatConvertParam param, Mat *mat, Blob *blob, bool
LOGD("image_converter_buffer_nc4hw4_2_buffer_bgr\n");
}
}
} else if (mat_type == NGRAY) {
if (is_mat_to_blob) {
if (blob_data_format == DATA_FORMAT_NC4HW4) {
func_name = @"data_converter_ngray_2_nc4hw4_float_v2";
LOGD("data_converter_ngray_2_nc4hw4_float_v2\n");
}
} else {
if (blob_data_format == DATA_FORMAT_NC4HW4) {
func_name = @"data_converter_nc4hw4_2_ngray_v2";
LOGD("data_converter_nc4hw4_2_ngray_v2\n");
}
}
} else if (mat_type == NCHW_FLOAT) {
if (is_mat_to_blob) {
if (blob_data_format == DATA_FORMAT_NCHW) {
Expand Down Expand Up @@ -415,12 +427,13 @@ Status AllocateComputePipeline(MatConvertParam param, Mat *mat, Blob *blob, bool

[command_buffer waitUntilCompleted];
memcpy(output_mat.GetData(), output_mtl_buffer.contents, count * bytes_size);
} else if (mat_type == NCHW_FLOAT || mat_type == RESERVED_BFP16_TEST) {
} else if (mat_type == NGRAY ||mat_type == NCHW_FLOAT || mat_type == RESERVED_BFP16_TEST) {
auto input_buffer_blob = dynamic_cast<Blob *>(input_blob);
id<MTLBuffer> output_mtl_buffer = nil;

int count = DimsVectorUtils::Count(dims);
const auto bytes_size = (mat_type == NCHW_FLOAT) ? sizeof(float) : sizeof(fp16_t);
const auto bytes_size = (mat_type == NCHW_FLOAT) ? sizeof(float) : ((mat_type == NGRAY) ? sizeof(unsigned char) : sizeof(fp16_t));

if (output_mat_device == DEVICE_METAL) {
output_mtl_buffer = (__bridge id<MTLBuffer>)(output_mat.GetData());
} else if (output_mat_device == DEVICE_ARM || output_mat_device == DEVICE_NAIVE || mat_device_type == DEVICE_X86) {
Expand Down Expand Up @@ -710,11 +723,12 @@ Status AllocateComputePipeline(MatConvertParam param, Mat *mat, Blob *blob, bool
[command_buffer waitUntilScheduled];
}
return TNN_OK;
} else if (mat_type == NCHW_FLOAT || mat_type == RESERVED_BFP16_TEST) {
} else if (mat_type == NGRAY || mat_type == NCHW_FLOAT || mat_type == RESERVED_BFP16_TEST) {
// For Buffer input

id<MTLBuffer> input_buffer = nil;
const auto bytes_size = (mat_type == NCHW_FLOAT) ? sizeof(float) : sizeof(fp16_t);
const auto bytes_size = (mat_type == NCHW_FLOAT) ? sizeof(float) : ((mat_type == NGRAY) ? sizeof(unsigned char) : sizeof(fp16_t));

if (mat_device_type == DEVICE_METAL) {
input_buffer = (__bridge id<MTLBuffer>)(input_mat.GetData());
} else if (mat_device_type == DEVICE_NAIVE || mat_device_type == DEVICE_ARM || mat_device_type == DEVICE_X86) {
Expand Down
103 changes: 81 additions & 22 deletions source/tnn/optimizer/net_optimizer_dynamic_range_dequant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,21 @@ namespace optimizer {
auto buffer_name = layer->inputs.at(idx);
auto scale_name = buffer_name + DynamicRangeQuantScaleSuffix;
auto buffer = resource->constant_map[buffer_name];
auto scale = resource->constant_map[scale_name];
if (buffer->GetDataType() != DATA_TYPE_INT8) {
LOGD("dynamic range dequantize layer(%s) weight data type is not int8_t."
"This weight might have been dequantized before.\n", layer->name.c_str());
LOGD(
"dynamic range dequantize layer(%s) weight data type is not int8_t."
"This weight might have been dequantized before.\n",
layer->name.c_str());
return TNN_OK;
}

if (resource->constant_map.count(scale_name) == 0) {
LOGE("scale is not found in constant map, its name is %s\n", scale_name.c_str());
return Status(TNNERR_PARAM_ERR, "scale is not found in constant map");
}

auto scale = resource->constant_map[scale_name];

const int data_size = buffer->GetDataCount();
auto weight_ptr = buffer->force_to<int8_t *>();
auto scale_value = scale->force_to<float *>()[0];
Expand Down Expand Up @@ -150,29 +158,80 @@ namespace optimizer {
Status NetOptimizerDynamicRangeDequant::DequantMatMul(std::shared_ptr<LayerInfo> &layer, NetStructure *structure,
NetResource *resource) {
auto layer_name = layer->name;
auto matmul_param = std::dynamic_pointer_cast<MatMulLayerParam>(layer->param);
auto matmul_resource = std::dynamic_pointer_cast<MatMulLayerResource>(resource->resource_map[layer_name]);
auto scale_handle = matmul_resource->scale_handle;
if (matmul_resource->weight.GetDataType() != DATA_TYPE_INT8) {
LOGD("Dynamic range dequantize layer(%s) weight data type is not int8_t."
"This weight might have been dequantized before.\n", layer_name.c_str());
return TNN_OK;
}
if (matmul_param->weight_position == 1) {
auto scale_handle = matmul_resource->scale_handle;
if (matmul_resource->weight.GetDataType() != DATA_TYPE_INT8) {
LOGD(
"Dynamic range dequantize layer(%s) weight data type is not int8_t."
"This weight might have been dequantized before.\n",
layer_name.c_str());
return TNN_OK;
}

const int data_size = matmul_resource->weight.GetDataCount();
auto weight_ptr = matmul_resource->weight.force_to<int8_t *>();
auto scale_value = scale_handle.force_to<float *>()[0];
std::vector<float> weight_data(data_size, 0);
for (int i = 0; i < data_size; i++) {
weight_data[i] = scale_value * (float)(weight_ptr[i]);
}
const int data_size = matmul_resource->weight.GetDataCount();
auto weight_ptr = matmul_resource->weight.force_to<int8_t *>();
auto scale_value = scale_handle.force_to<float *>()[0];
std::vector<float> weight_data(data_size, 0);
for (int i = 0; i < data_size; i++) {
weight_data[i] = scale_value * (float)(weight_ptr[i]);
}

RawBuffer weight_buf(data_size * sizeof(float));
memcpy(weight_buf.force_to<float *>(), weight_data.data(), data_size * sizeof(float));
weight_buf.SetDataType(DATA_TYPE_FLOAT);
weight_buf.SetBufferDims(matmul_resource->weight.GetBufferDims());
RawBuffer weight_buf(data_size * sizeof(float));
memcpy(weight_buf.force_to<float *>(), weight_data.data(), data_size * sizeof(float));
weight_buf.SetDataType(DATA_TYPE_FLOAT);
weight_buf.SetBufferDims(matmul_resource->weight.GetBufferDims());

matmul_resource->weight = weight_buf;
layer->param->dynamic_range_quantized = false;
} else if (matmul_param->weight_position == -1) {
auto input0_iter = resource->constant_map.find(layer->inputs[0]);
auto input1_iter = resource->constant_map.find(layer->inputs[1]);
if (input0_iter == resource->constant_map.end() && input1_iter == resource->constant_map.end()) {
return TNN_OK;
}

auto buffer_name = input0_iter != resource->constant_map.end() ? layer->inputs[0] : layer->inputs[1];
auto scale_name = buffer_name + DynamicRangeQuantScaleSuffix;
auto buffer = resource->constant_map[buffer_name];
if (buffer->GetDataType() != DATA_TYPE_INT8) {
LOGD(
"dynamic range dequantize layer(%s) weight data type is not int8_t."
"This weight might have been dequantized before.\n",
layer->name.c_str());
return TNN_OK;
}
if(resource->constant_map.count(scale_name) == 0) {
LOGE("scale is not found in constant map, its name is %s\n", scale_name.c_str());
return Status(TNNERR_PARAM_ERR, "scale is not found in constant map");
}

auto scale = resource->constant_map[scale_name];

const int data_size = buffer->GetDataCount();
auto weight_ptr = buffer->force_to<int8_t *>();
auto scale_value = scale->force_to<float *>()[0];
std::vector<float> weight_data(data_size, 0);
for (int i = 0; i < data_size; i++) {
weight_data[i] = scale_value * (float)(weight_ptr[i]);
}

auto weight_buf = std::make_shared<RawBuffer>(data_size * sizeof(float));
memcpy(weight_buf->force_to<float *>(), weight_data.data(), data_size * sizeof(float));
weight_buf->SetDataType(DATA_TYPE_FLOAT);
weight_buf->SetBufferDims(buffer->GetBufferDims());

resource->constant_map[buffer_name] = weight_buf;

// delete scale buffer
if (resource->constant_map.count(scale_name)) {
resource->constant_map.erase(scale_name);
}

layer->param->dynamic_range_quantized = false;
}

matmul_resource->weight = weight_buf;
layer->param->dynamic_range_quantized = false;
return TNN_OK;
}

Expand Down
2 changes: 1 addition & 1 deletion source/tnn/utils/split_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ Status SplitUtils::SplitStr(const char *str, str_arr &subs_array, const char spl
return TNN_OK;
}

const int subs_length = 2048;
const int subs_length = 4096;
char *subs = (char *)calloc(subs_length, sizeof(char));

for (int i = 0, cursor = 0;; i += step) {
Expand Down
12 changes: 11 additions & 1 deletion tools/convert2tnn/utils/align_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ def get_input_from_file(path: str) -> dict:
data.append(float(f.readline().strip('\n')))
np_data = np.reshape(np.array(data).astype(np.float32), dims)
input_dict.update({input_name: np_data})
elif data_type == 2:
# bool
for j in range(count):
data.append(int(f.readline().strip('\n')))
np_data = np.array(data).astype(np.bool).reshape(dims)
input_dict.update({input_name: np_data})
elif data_type == 3:
#int32
for j in range(count):
Expand Down Expand Up @@ -106,6 +112,8 @@ def run_onnx(model_path: str, input_path: str, input_info: dict) -> str:
data_type = np.int64
elif item.type == "tensor(int32)":
data_type = np.int32
elif item.type == "tensor(bool)":
data_type = np.bool
input_data_dict[item.name] = input_data_dict[item.name].astype(data_type)

output_info = session.get_outputs()
Expand Down Expand Up @@ -245,7 +253,9 @@ def get_input_shape_from_onnx(onnx_path) -> dict:
data_type = 0
if ip.type == 'tensor(float)':
data_type = 0
elif ip.type == 'tensor(int64)':
elif ip.type == 'tensor(bool)':
data_type = 2
elif ip.type == 'tensor(int64)' or ip.type == 'tensor(int32)':
data_type = 3
else:
logging.error("Do not support input date type")
Expand Down
2 changes: 1 addition & 1 deletion tools/convert2tnn/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def gene_random_data(input_info: dict) -> str:
if data_type == 0:
data[name] = np.random.rand(*shape)
np.savetxt(data_file, data[name].reshape(-1), fmt="%0.6f")
elif data_type == 3:
elif data_type == 2 or data_type == 3:
# range [low, high)
data[name] = np.random.randint(low=0, high=2, size=shape)
np.savetxt(data_file, data[name].reshape(-1), fmt="%i")
Expand Down
4 changes: 3 additions & 1 deletion tools/convert2tnn/utils/run_onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ def get_src_model_input_information(self) -> dict:
data_type = 0
if ip.type == 'tensor(float)':
data_type = 0
elif ip.type == 'tensor(int64)':
elif ip.type == 'tensor(int64)' or ip.type == 'tensor(int32)':
data_type = 3
elif ip.type == 'tensor(bool)':
data_type = 2
else:
logging.error("Do not support input date type")
if type(shape[0]) is not int:
Expand Down
2 changes: 1 addition & 1 deletion tools/convert2tnn/utils/run_src_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def generate_input_data(self, input_information: dict, tnn_model_input_informati
if data_type == 0:
self.input_data[name] = np.random.rand(*shape).astype(np.float32)
np.savetxt(data_file, self.input_data[name].reshape(-1), fmt="%0.6f")
elif data_type == 3:
elif data_type == 2 or data_type == 3:
# range [low, high)
self.input_data[name] = np.random.randint(low=0, high=2, size=shape).astype(np.int64)
np.savetxt(data_file, self.input_data[name].reshape(-1), fmt="%i")
Expand Down
Loading

0 comments on commit dac71f7

Please sign in to comment.