-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix dtype not matching bug in log_prob and probs method of Distribution class #26767
fix dtype not matching bug in log_prob and probs method of Distribution class #26767
Conversation
Thanks for your contribution! |
18029cd
to
4da51a3
Compare
python/paddle/distribution.py
Outdated
raise TypeError( | ||
"Type of input args must be float, list, numpy.ndarray or Tensor, but received type {}". | ||
format(type(arg))) | ||
|
||
arg_np = np.array(arg) | ||
arg_dtype = arg_np.dtype | ||
if str(arg_dtype) not in ['float32']: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if str(arg_dtype) not in ['float32']: | |
if str(arg_dtype) != 'float32': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And, the positive conditions is better than the negative conditions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
python/paddle/distribution.py
Outdated
"data type of argument only support float32, your argument will be convert to float32." | ||
) | ||
# "assign" op doesn't support float64. if dtype is float64, float32 variable will be generated and transformed to float64 later using "cast". | ||
if str(arg_dtype) not in ['float64']: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if str(arg_dtype) not in ['float64']: | |
if str(arg_dtype) != 'float64': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
python/paddle/distribution.py
Outdated
if isinstance(arg, float): | ||
arg = np.zeros(1) + arg | ||
arg = [arg] | ||
elif not isinstance(arg, list) and not isinstance(arg, np.ndarray): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elif not isinstance(arg, list) and not isinstance(arg, np.ndarray): | |
elif not isinstance(arg, (list, np.ndarray)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
30827ca
to
d3312c6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
…on class (PaddlePaddle#26767) * fix _to_tensor method of Distribution class * Add unittest * let dtype be consistent with value in log_prob and probs * fix format * fix dtype problem and change unittest * fix dtype of Numpy class in unittest * add formula for entropy and kl * change formula * fix kl formula format * fix kl formula format 2 * change gt to np in unittest * optimize unittest format * delete dumplicate * delete dumplicate 2 * extract common function used to convert dtype value
* fix dtype not matching bug in log_prob and probs method of Distribution class (#26767) * fix _to_tensor method of Distribution class * Add unittest * let dtype be consistent with value in log_prob and probs * fix format * fix dtype problem and change unittest * fix dtype of Numpy class in unittest * add formula for entropy and kl * change formula * fix kl formula format * fix kl formula format 2 * change gt to np in unittest * optimize unittest format * delete dumplicate * delete dumplicate 2 * extract common function used to convert dtype value * cherry pick 27046
PR types
Bug fixes
PR changes
APIs
Describe
In
_to_tensor
method of Distribution class (refer to PR #26355 and PR #26535). Even if we want to support bothfloat32
andfloat64
dtype in Distribution classes, when parameters (low
andhigh
inUniform
,loc
andscale
inNormal
) arenumpy.ndarray
and dtypes arefloat64
, we can only set dtype to befloat32
usingassign
op to get the correspoding variable. Becaseassign
op doesn't supportfloat64
when input isnumpy.ndarray
.In
log_prob
andprobs
methods in Distribution class, the inputvalue
of these methods is a tensor. In users' view, it's reasonable that the dtype ofvalue
and parameters are same.The following is an example code:
We are going to let
assign
op supportfloat64
, but it will lose precision becauseAttr
don't supportfloat64
in framework.proto (refer to #26797). That is,assign op
can only supportfloat32
.Thus, in this PR, we use
cast
operation to convert dtype afterassign
op if dtype isfloat64
.If users define a
Uniform
distribution whoselow
andhigh
arefloat64
numpy.ndarray, we firstly useassign
op to getfloat32
variable. Then usecast
to getfloat64
variable.What's more,
probs
andlog_prob
methods have a variable input namedvalues
. If dtype ofvalues
is different withlow
inUniform
orloc
inNormal
, it will cause error.To solve this dtype conflict, we
cast
dtype ofvalues
to be the same as that oflow
orloc
. (in_check_values_dtype_in_probs
function)In Doc discribtion, we add formula for
entropy
andkl-divergence
methods. Formula forlog_prob
andprobs
have been given in doc of class, that is, thepdf
(probability density function) of the distribution.By the way, we rewrite unittest to make it more readable.