Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fake_quantize_op. #11359

Merged
merged 15 commits into from
Jul 11, 2018
Merged

Add fake_quantize_op. #11359

merged 15 commits into from
Jul 11, 2018

Conversation

achao2013
Copy link
Contributor

add quant code for test

@achao2013 achao2013 changed the title Quant quantization code Jun 11, 2018
@achao2013 achao2013 changed the title quantization code add quantization code Jun 11, 2018
@qingqing01 qingqing01 self-requested a review June 11, 2018 09:20
@qingqing01 qingqing01 changed the title add quantization code Add fake_quantize_op. Jun 11, 2018
// PADDLE_ENFORCE_EQ(ctx->Inputs("InScales")[0],
// ctx->Outputs("OutScales")[0],
// "Mean and MeanOut should share the same memory");
//}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the commented lines.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the comment is for test of python , the commented lines is used for train

"Input(X) of FakeQuantizeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FakeQuantizeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("OutMovingScale"), "");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the error message.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

ctx->SetOutputDim("OutMovingScale", ctx->GetInputDim("InMovingScale"));
//}
// if (ctx->HasInput("InScales")) {
PADDLE_ENFORCE(ctx->HasOutput("OutScales"), "");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the error message.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

public:
void Make() override {
AddInput("X", "(Tensor) Input tensor of scale operator.");
AddInput("InScales", "(Tensor) scale buffer").AsDispensable();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add more comments for why this argument is optional. When need it and when don't need it. The same is the following.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

namespace operators {

template <typename T>
__global__ void find_abs_max_kernel(const int n, const T* in, T* out) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

find_abs_max_kernel -> FindAbsMaxKernel

Please follow Google C++ code style: https://google.github.io/styleguide/cppguide.html#Function_Names

Please modify other code with the same problem.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

float find_abs_max_gpu(const platform::CUDADeviceContext& ctx,
const float* array, int length) {
float host_max;
int NUM_THREADS = 1024;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NUM_THREADS -> kNumTheads Please follow Goolge code style.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

cudaMemcpy(&host_max, device_max, sizeof(float), cudaMemcpyDeviceToHost),
cudaSuccess, "cudaMemcpy failed");
return host_max;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe can use thrust::reduce + thrust::max_element to find the maximum value for more simply.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will be slow

int window_size = context.Attr<int>("window_size");
int bit_length = context.Attr<int>("bit_length");
int bin_cnt = std::pow(2, bit_length - 1) - 1;
LOG(ERROR) << "bin_cnt:" << bin_cnt;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

auto* scale_list = context.Output<framework::Tensor>("OutScales");
auto* saving_scale =
context.Output<framework::Tensor>("OutMovingScale");
scale = find_abs_max(const_cast<framework::Tensor*>(in), in->numel());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this cwiseMax is an elemwise max operation, i need a reduce max op .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

@qingqing01 qingqing01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need unit testing.

}
}

apply_saturate(const_cast<framework::Tensor*>(in), tensor, -scale, scale);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

AddComment(R"DOC(
FakeQuantize operator

$$Out = scale*X$$
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need comments for how to calculate scale.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


$$Out = scale*X$$
)DOC");
AddAttr<std::string>("quantize_type",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quantize_type -> scale_type for more accurate ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the quantization method is non-uniform, scale is not need, so i think this should not be scale_type

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我理解quantize_type一般指: Abs-Max,或者Min-Max等不同的量化方式。
而这里,这个attr是想来标示,计算scale的方式吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果是非均匀量化, 那浮点输入和定点输出可能是个函数,或者离散数值映射,就没有scale操作了

@qingqing01
Copy link
Contributor

@achao2013

[07:05:06]W: [Step 1/1] /paddle/paddle/fluid/operators/distributed/rpc_client.h:61:26: error: 'FLAGS_grpc_deadline' was not declared in this scope
[07:05:06]W: [Step 1/1] int64_t time_out = FLAGS_grpc_deadline) = 0;

CI没有通过,需要更新至最新develop代码。

@qingqing01
Copy link
Contributor

[07:00:33]	294/359 Test #290: test_fake_quantize_op ................................***Failed    5.52 sec
[07:00:33]	test_fake_quantize_op failed
[07:00:33]	E
[07:00:33]	======================================================================
[07:00:33]	ERROR: test_check_output (test_fake_quantize_op.TestFakeQuantizeOp)
[07:00:33]	----------------------------------------------------------------------
[07:00:33]	Traceback (most recent call last):
[07:00:33]	  File "/paddle/build/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py", line 47, in test_check_output
[07:00:33]	    self.check_output()
[07:00:33]	  File "/paddle/build/python/paddle/fluid/tests/unittests/op_test.py", line 325, in check_output
[07:00:33]	    self.check_output_with_place(place, atol)
[07:00:33]	  File "/paddle/build/python/paddle/fluid/tests/unittests/op_test.py", line 263, in check_output_with_place
[07:00:33]	    outs, fetch_list = self._calc_output(place)
[07:00:33]	  File "/paddle/build/python/paddle/fluid/tests/unittests/op_test.py", line 226, in _calc_output
[07:00:33]	    inputs = self._get_inputs(block)
[07:00:33]	  File "/paddle/build/python/paddle/fluid/tests/unittests/op_test.py", line 211, in _get_inputs
[07:00:33]	    return self._get_io_vars(block, self.inputs)
[07:00:33]	  File "/paddle/build/python/paddle/fluid/tests/unittests/op_test.py", line 207, in _get_io_vars
[07:00:33]	    inputs[name] = block.var(name)
[07:00:33]	  File "/paddle/build/python/paddle/fluid/framework.py", line 900, in var
[07:00:33]	    raise ValueError("var %s not in this block" % name)
[07:00:33]	ValueError: var InCurrentScale not in this block
[07:00:33]	
[07:00:33]	----------------------------------------------------------------------
[07:00:33]	Ran 1 test in 0.002s
[07:00:33]	
[07:00:33]	FAILED (errors=1)

The unit testing did not pass.

Copy link
Contributor

@qingqing01 qingqing01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved. @dangqingqing will refine and add more unit testing.

@qingqing01 qingqing01 merged commit 8e4b225 into PaddlePaddle:develop Jul 11, 2018
kuke pushed a commit to kuke/Paddle that referenced this pull request Aug 25, 2018
* Add a fake_quantize_op, which quantize an input tensor to a tensor with lower bits.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
No open projects
Development

Successfully merging this pull request may close these issues.

None yet

2 participants