Skip to content

Commit

Permalink
Merge pull request #550 from BUG1989/fix_interp
Browse files Browse the repository at this point in the history
fix the bug of interp nearest type
  • Loading branch information
daquexian committed Feb 14, 2021
2 parents 19c14d3 + 745ae59 commit 2e86073
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 68 deletions.
18 changes: 0 additions & 18 deletions src/dev/cpu/cpu_device.c
Expand Up @@ -3026,24 +3026,6 @@ static int run(struct nn_device* dev, struct subgraph* subgraph)
if (output_tensor->dim_num <= 5)
extract_feature_blob_f32("out", name, output_tensor);
}

//#define DUMP_NODE_OUTPUT
#ifdef DUMP_NODE_OUTPUT
/* dump the node output */
struct ir_node* ir_node = node->ir_node;
struct ir_graph* ir_graph = ir_node->graph;

for (int i = 0; i < ir_node->input_num; i++)
{
char fname[128];
struct ir_tensor* ir_tensor = get_ir_graph_tensor(ir_graph, ir_node->input_tensors[i]);

sprintf(fname, "/tmp/dump/node%s%d.%d", (ir_node->idx < 10 ? "0" : ""), ir_node->idx, i);

dump_float(fname, ir_tensor->data, ir_tensor->elem_num);
}

#endif
}

return 0;
Expand Down
172 changes: 123 additions & 49 deletions src/dev/cpu/op/interp/interp_ref.c
Expand Up @@ -159,43 +159,81 @@ void resize_bilinear_image(float* src, float* dst, float* alpha, int* xofs, floa

int ref_interp_fp32(struct ir_tensor* input_tensor, struct ir_tensor* output_tensor, struct interp_param* param)
{
float* input = input_tensor->data;
float* output = output_tensor->data;
if (param->resize_type == 1)
{
float* input = input_tensor->data;
float* output = output_tensor->data;

int batch = input_tensor->dims[0];
int channel = input_tensor->dims[1];
int in_h = input_tensor->dims[2];
int in_w = input_tensor->dims[3];
int out_h = output_tensor->dims[2];
int out_w = output_tensor->dims[3];
int batch = output_tensor->dims[0];
int channel = output_tensor->dims[1];
int output_h = output_tensor->dims[2];
int output_w = output_tensor->dims[3];
int input_h = input_tensor->dims[2];
int input_w = input_tensor->dims[3];

int in_channel_size = in_h * in_w;
int out_channel_size = out_h * out_w;
for (int n = 0; n < batch; ++n)
{
for (int c = 0; c < channel; c++)
{
for (int h = 0; h < output_h; h++)
{
for (int w = 0; w < output_w; w++)
{
int in_w = w / param->width_scale;
int in_h = h / param->height_scale;
int out_idx = n * channel * output_h * output_w + c * output_h * output_w + h * output_w + w;
int in_idx = n * channel * input_h * input_w + c * input_w * input_h + in_h * input_w + in_w;
output[out_idx] = input[in_idx];
}
}
}
}
}
else if (param->resize_type == 2)
{
float* input = input_tensor->data;
float* output = output_tensor->data;

int* buf = sys_malloc((param->output_width + param->output_height + param->output_width*2 + param->output_height*2)*sizeof(float));
int batch = input_tensor->dims[0];
int channel = input_tensor->dims[1];
int in_h = input_tensor->dims[2];
int in_w = input_tensor->dims[3];
int out_h = output_tensor->dims[2];
int out_w = output_tensor->dims[3];

if (buf == NULL)
{
printf("interp malloc failed!\n");
return -1;
}
int in_channel_size = in_h * in_w;
int out_channel_size = out_h * out_w;

int* xofs = buf;//new int[ow];
int* yofs = buf + param->output_width ;//new int[oh];
int* buf = sys_malloc((param->output_width + param->output_height + param->output_width*2 + param->output_height*2)*sizeof(float));

float* alpha = (float*)(buf + param->output_width + param->output_height);//new float[ow * 2];
float* beta = (float*)(buf + param->output_width + param->output_height + param->output_width*2);//new float[oh * 2];
if (buf == NULL)
{
fprintf(stderr,"interp malloc failed!\n");
return -1;
}

int* xofs = buf;//new int[ow];
int* yofs = buf + param->output_width ;//new int[oh];

float* alpha = (float*)(buf + param->output_width + param->output_height);//new float[ow * 2];
float* beta = (float*)(buf + param->output_width + param->output_height + param->output_width*2);//new float[oh * 2];

linear_coeffs(in_w, out_w, xofs, alpha);
linear_coeffs(in_h, out_h, yofs, beta);

linear_coeffs(in_w, out_w, xofs, alpha);
linear_coeffs(in_h, out_h, yofs, beta);
for (int q = 0; q < channel; ++q)
{
resize_bilinear_image(input+in_channel_size*q, output+out_channel_size*q, alpha, xofs, beta, yofs, out_h, out_w, in_h, in_w);
}

for (int q = 0; q < channel; ++q)
sys_free(buf);
}
else
{
resize_bilinear_image(input+in_channel_size*q, output+out_channel_size*q, alpha, xofs, beta, yofs, out_h, out_w, in_h, in_w);
fprintf(stderr,"interp resize type %d not support!\n", param->resize_type);
return -1;
}

sys_free(buf);

return 0;
}

Expand All @@ -221,36 +259,73 @@ int ref_interp_uint8(struct ir_tensor* input_tensor, struct ir_tensor* output_te
}

/* process */
int batch = input_tensor->dims[0];
int channel = input_tensor->dims[1];
int in_h = input_tensor->dims[2];
int in_w = input_tensor->dims[3];
int out_h = output_tensor->dims[2];
int out_w = output_tensor->dims[3];
if (param->resize_type == 1)
{
int batch = output_tensor->dims[0];
int channel = output_tensor->dims[1];
int output_h = output_tensor->dims[2];
int output_w = output_tensor->dims[3];
int input_h = input_tensor->dims[2];
int input_w = input_tensor->dims[3];

for (int n = 0; n < batch; ++n)
{
for (int c = 0; c < channel; c++)
{
for (int h = 0; h < output_h; h++)
{
for (int w = 0; w < output_w; w++)
{
int in_w = w / param->width_scale;
int in_h = h / param->height_scale;
int out_idx = n * channel * output_h * output_w + c * output_h * output_w + h * output_w + w;
int in_idx = n * channel * input_h * input_w + c * input_w * input_h + in_h * input_w + in_w;
output_fp32[out_idx] = input_fp32[in_idx];
}
}
}
}
}
else if (param->resize_type == 2)
{
int batch = input_tensor->dims[0];
int channel = input_tensor->dims[1];
int in_h = input_tensor->dims[2];
int in_w = input_tensor->dims[3];
int out_h = output_tensor->dims[2];
int out_w = output_tensor->dims[3];

int in_channel_size = in_h * in_w;
int out_channel_size = out_h * out_w;
int in_channel_size = in_h * in_w;
int out_channel_size = out_h * out_w;

int* buf = sys_malloc((param->output_width + param->output_height + param->output_width*2 + param->output_height*2)*sizeof(float));
int* buf = sys_malloc((param->output_width + param->output_height + param->output_width*2 + param->output_height*2)*sizeof(float));

if (buf == NULL)
{
printf("interp malloc failed!\n");
return -1;
}
if (buf == NULL)
{
fprintf(stderr,"interp malloc failed!\n");
return -1;
}

int* xofs = buf;//new int[ow];
int* yofs = buf + param->output_width ;//new int[oh];
int* xofs = buf;//new int[ow];
int* yofs = buf + param->output_width ;//new int[oh];

float* alpha = (float*)(buf + param->output_width + param->output_height);//new float[ow * 2];
float* beta = (float*)(buf + param->output_width + param->output_height + param->output_width*2);//new float[oh * 2];
float* alpha = (float*)(buf + param->output_width + param->output_height);//new float[ow * 2];
float* beta = (float*)(buf + param->output_width + param->output_height + param->output_width*2);//new float[oh * 2];

linear_coeffs(in_w, out_w, xofs, alpha);
linear_coeffs(in_h, out_h, yofs, beta);
linear_coeffs(in_w, out_w, xofs, alpha);
linear_coeffs(in_h, out_h, yofs, beta);

for (int q = 0; q < channel; ++q)
for (int q = 0; q < channel; ++q)
{
resize_bilinear_image(input_fp32+in_channel_size*q, output_fp32+out_channel_size*q, alpha, xofs, beta, yofs, out_h, out_w, in_h, in_w);
}

sys_free(buf);
}
else
{
resize_bilinear_image(input_fp32+in_channel_size*q, output_fp32+out_channel_size*q, alpha, xofs, beta, yofs, out_h, out_w, in_h, in_w);
fprintf(stderr,"interp resize type %d not support!\n", param->resize_type);
return -1;
}

/* quant */
Expand All @@ -264,7 +339,6 @@ int ref_interp_uint8(struct ir_tensor* input_tensor, struct ir_tensor* output_te
output_uint8[i] = udata;
}

sys_free(buf);
sys_free(input_fp32);
sys_free(output_fp32);

Expand Down
2 changes: 1 addition & 1 deletion src/op/interp_param.h
Expand Up @@ -27,7 +27,7 @@

struct interp_param
{
int resize_type;
int resize_type; // 1:nearest 2:bilinear or linear
int output_height;
int output_width;
float height_scale;
Expand Down

0 comments on commit 2e86073

Please sign in to comment.