Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNN模型resizeSession之后,推理结果出现较大误差 #2871

Open
younghuvee opened this issue May 16, 2024 · 11 comments
Open

MNN模型resizeSession之后,推理结果出现较大误差 #2871

younghuvee opened this issue May 16, 2024 · 11 comments

Comments

@younghuvee
Copy link

younghuvee commented May 16, 2024

平台(如果交叉编译请再附上交叉编译目标平台):

Platform(Include target platform as well if cross-compiling):

x86 ubuntu22.04

Github版本:

Github Version:

MNN2.8.1

pytorch模型转换为onnx之后(设置了动态输入),转换为mnn进行推理,其中有三个输入,两个输入是固定不变的,一个输入的size每次增加,本模型需要循环推理,第一次推理的输出结果可以与原始模型对应上,第二次resizeSession之后,模型推理的结果不符合预期。代码如下:
`int Decoder::deInfer(const std::vector src, const std::vector src_mask, int input_h, std::vector<std::vector<int32_t>> ids, float* next_token_logits){

#if MODULE

int input_ids_size = (int)(ids[0].size());
std::vector<int32_t> idxx;
for (int i=0; i<ids.size(); i++){
    for (int j=0; j<ids[i].size(); j++){
        idxx.push_back(ids[i][j]);
    }
}

LOG(INFO) << "idxx.size(): " << idxx.size();
for (int i=0; i<idxx.size(); i++){
    std::cout<< idxx[i] <<" ";
}
std::cout<<std::endl;

LOG(INFO) << "input_h: " << input_h << " input_ids_size: " << input_ids_size;
auto input_0 = MNN::Express::_Input({2, 13, input_h, 256}, MNN::Express::NCHW, halide_type_of<float>());
auto input_1 = MNN::Express::_Input({2, 13, input_h}, MNN::Express::NCHW, halide_type_of<int>());
auto input_2 = MNN::Express::_Input({2, input_ids_size}, MNN::Express::NCHW, halide_type_of<int32_t>());

::memcpy(input_0->writeMap<float>(), src.data(), src.size() * sizeof(float));
::memcpy(input_1->writeMap<bool>(), src_mask.data(), src_mask.size() * sizeof(bool));
::memcpy(input_2->writeMap<int32_t>(), idxx.data(), idxx.size() * sizeof(int32_t));

std::vector<VARP> outputs;

try {
    std::ostringstream fileNameOs;
    outputs = module->onForward({input_0, input_1, input_2});
    std::ostringstream dimInfo;
    auto info = outputs[0]->getInfo();
    for (int d=0; d<info->dim.size(); ++d) {
        dimInfo << info->dim[d] << "_";
    }
    auto fileName = fileNameOs.str();
    MNN_PRINT("Output Name: %s, Dim: %s\n", fileName.c_str(), dimInfo.str().c_str());
    // module->traceOrOptimize(MNN::Interpreter::Session_Resize_Fix);
    auto ptr = outputs[0]->readMap<float>();

    for (int i=0; i<input_ids_size*2; i++){
        LOG(INFO) << "ptr[0 + "<<i<<"*46]: " << ptr[0 + i*46] << " ptr[1 + "<<i<<"*46]: " << ptr[1 + i*46] << " ptr[2 + "<<i<<"*46]: " << ptr[2 + i*46] << " ptr[3 + "<<i<<"*46]: " << ptr[3 + i*46] << " ptr[4 + i*46]: " << ptr[4 + i*46];
    }
    memcpy(next_token_logits, ptr+(input_ids_size-1)*46, 46*sizeof(float));
    memcpy(next_token_logits+46, ptr+(input_ids_size*2-1)*46, 46*sizeof(float));
}
catch (std::exception const &e) {
    LOG(ERROR) << "Error when run decoder onnx forword: " << (e.what());
}

#else
if (!m_mnnNet_decoder){
printf("error: CFaceDetection::FaceDetectImp(), m_mnnNet_det is null.\n");
cout<< 1 <<endl;
return -1;
}

int input_ids_size = (int)(ids[0].size());

m_mnnNet_decoder->resizeTensor(input_img, {2, 13, input_h, 256});
m_mnnNet_decoder->resizeTensor(input_mask, {2, 13, input_h});
m_mnnNet_decoder->resizeTensor(input_ids, {2, input_ids_size});
m_mnnNet_decoder->resizeSession(m_mnnSession_decoder);
m_mnnNet_decoder->resizeTensor(output_vector, {2, input_ids_size, 46});



int i_modelW2 = input_img->width();
int i_modelH2 = input_img->height();
int i_modelC2 = input_img->channel();
int i_modelB2 = input_img->batch();
int i2_modelW2 = input_mask->width();
int i2_modelH2 = input_mask->height();
int i2_modelC2 = input_mask->channel();
int i2_modelB2 = input_mask->batch();
int m_modelW2 = input_ids->width();
int m_modelH2 = input_ids->height();
int m_modelC2 = input_ids->channel();
int m_modelB2 = input_ids->batch();
int o_modelW2 = output_vector->width();
int o_modelH2 = output_vector->height();
int o_modelC2 = output_vector->channel();
int o_modelB2 = output_vector->batch();



LOG(INFO) << i_modelB2 << " " << i_modelC2 << " " << i_modelH2 << " " << i_modelW2;
LOG(INFO) << i2_modelB2 << " " << i2_modelC2 << " " << i2_modelH2 << " " << i_modelB2;
LOG(INFO) << m_modelB2 << " " << m_modelC2 << " " << m_modelH2 << " " << m_modelW2;
LOG(INFO) << o_modelB2 << " " << o_modelC2 << " " << o_modelH2 << " " << o_modelW2;


auto input_img_buffer = new Tensor(input_img, Tensor::CAFFE);
for (int i=0; i<2*13*input_h*256; i++){
    input_img_buffer->host<float>()[i] = src[i];
    // input_img_buffer->host<float>()[i] = 0.0;
}
input_img->copyFromHostTensor(input_img_buffer);
// auto input_imgxx = new Tensor(input_img, Tensor::CAFFE);
// input_img->copyToHostTensor(input_imgxx);
// auto dataid1 = input_imgxx->host<float>();
// LOG(INFO) <<"---------------------------------------------";
// for (int i=0; i<5; i++){
//     LOG(INFO) << dataid1[i];
// }

// input mask
auto input_mask_buffer = new Tensor(input_mask, Tensor::CAFFE);
for (int i=0; i<2*13*input_h; i++){
    input_mask_buffer->host<int>()[i] = src_mask[i];
    // input_mask_buffer->host<bool>()[i] = false;
}
input_mask->copyFromHostTensor(input_mask_buffer);

// auto input_maskxx = new Tensor(input_ids, Tensor::CAFFE);
// input_mask->copyToHostTensor(input_maskxx);
// auto dataid2 = input_maskxx->host<int>();
// LOG(INFO) <<"---------------------------------------------";
// for (int i=0; i<5; i++){
//     LOG(INFO) << dataid2[i];
// }


// input ids   
auto input_ids_buffer = new Tensor(input_ids, Tensor::CAFFE);
for (int i=0; i<ids[0].size(); i++){
    input_ids_buffer->host<int32_t>()[i] = ids[0][i];
}
for (int i=0; i<ids[1].size(); i++){
    input_ids_buffer->host<int32_t>()[i+ids[0].size()] = ids[1][i];
}
input_ids->copyFromHostTensor(input_ids_buffer);

// auto input_idxx = new Tensor(input_ids, Tensor::CAFFE);
// input_ids->copyToHostTensor(input_idxx);
// auto dataid = input_idxx->host<int>();

// LOG(INFO) <<"---------------------------------------------";
// for (int i=0; i<input_ids_size*2; i++){
//     LOG(INFO) << dataid[i];
// }


m_mnnNet_decoder->runSession(m_mnnSession_decoder);

auto nchwTensor_feature = new Tensor(output_vector, Tensor::CAFFE);
output_vector->copyToHostTensor(nchwTensor_feature);
auto data_feature = nchwTensor_feature->host<float>();

for (int i=0; i<input_ids_size*2; i++){
    LOG(INFO) << "data_feature[0 + "<<i<<"*46]: " << data_feature[0 + i*46] << " data_feature[1 + "<<i<<"*46]: " << data_feature[1 + i*46] << " data_feature[2 + "<<i<<"*46]: " << data_feature[2 + i*46] << " data_feature[3 + "<<i<<"*46]: " << data_feature[3 + i*46] << " data_feature[4 + i*46]: " << data_feature[4 + i*46];
}

memcpy(next_token_logits, data_feature+(input_ids_size-1)*46, 46*sizeof(float));
memcpy(next_token_logits+46, data_feature+(input_ids_size*2-1)*46, 46*sizeof(float));
delete input_img_buffer;
delete input_mask_buffer;
delete input_ids_buffer;
delete nchwTensor_feature;

#endif
return 0;
}

`

使用Session和Module两种推理方式,结果都不对
模型导出onnx代码如下:
torch.onnx.export(decoder, (src[0].to(device), src_mask[0].to(device), input_ids.to(device)), "decoder_0515.onnx", input_names=["input1","input2","input3"], output_names=["output"], dynamic_axes={"input1":{2:"input_width"},"input2":{2:"input_width"}, "input3":{1:"length"}}, verbose=True, opset_version=19)

@jxt1234
Copy link
Collaborator

jxt1234 commented May 16, 2024

输出不要 resize

m_mnnNet_decoder->resizeTensor(output_vector, {2, input_ids_size, 46});

@younghuvee
Copy link
Author

输出不要 resize

m_mnnNet_decoder->resizeTensor(output_vector, {2, input_ids_size, 46});

试过了,没什么变化

@younghuvee
Copy link
Author

直接把onnx模型按照需要的size导出,不设置dynamic_size的话,推理结果是正确的

@jxt1234
Copy link
Collaborator

jxt1234 commented May 16, 2024

设置 dyamic_size 后导出 onnx ,然后按指定输入用 testMNNFromOnnx.py 测试结果如何?

@jxt1234
Copy link
Collaborator

jxt1234 commented May 16, 2024

int i_modelW2 = input_img->width();
int i_modelH2 = input_img->height();
int i_modelC2 = input_img->channel();
int i_modelB2 = input_img->batch();
int i2_modelW2 = input_mask->width();
int i2_modelH2 = input_mask->height();
int i2_modelC2 = input_mask->channel();
int i2_modelB2 = input_mask->batch();
int m_modelW2 = input_ids->width();
int m_modelH2 = input_ids->height();
int m_modelC2 = input_ids->channel();
int m_modelB2 = input_ids->batch();
int o_modelW2 = output_vector->width();
int o_modelH2 = output_vector->height();
int o_modelC2 = output_vector->channel();
int o_modelB2 = output_vector->batch();

这一段有点问题,非四维不要用 width/height 等,用 length(0) , length(1) , length(2)

@jxt1234
Copy link
Collaborator

jxt1234 commented May 16, 2024

::memcpy(input_1->writeMap(), src_mask.data(), src_mask.size() * sizeof(bool));
这个 bool 都换成 int32_t

@younghuvee
Copy link
Author

设置 dyamic_size 后导出 onnx ,然后按指定输入用 testMNNFromOnnx.py 测试结果如何?

image

结果如图,这个误差应该是正确的,不太大,C++里面的误差非常大

@younghuvee
Copy link
Author

是不是resizeSession产生的错误呢

@younghuvee
Copy link
Author

设置 dyamic_size 后导出 onnx ,然后按指定输入用 testMNNFromOnnx.py 测试结果如何?

image

结果如图,这个误差应该是正确的,不太大,C++里面的误差非常大

但是 MNN的推理结果是和pytorch比的,这个结果是mnn和onnx比的

@younghuvee
Copy link
Author

1

@jxt1234
Copy link
Collaborator

jxt1234 commented May 22, 2024

设置 dyamic_size 后导出 onnx ,然后按指定输入用 testMNNFromOnnx.py 测试结果如何?

image

结果如图,这个误差应该是正确的,不太大,C++里面的误差非常大

这个误差挺大的。更新到 2.9.0 测试下,仍然有问题的话发一下 onnx

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants