Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add multinomial probability distribution #38820

Merged
merged 3 commits into from
Jan 30, 2022

Conversation

cxxly
Copy link
Contributor

@cxxly cxxly commented Jan 9, 2022

PR types

New features

PR changes

APIs

Describe

  • add multinomial distribution with mean, variance,sample,entropy,prob,log_prob method.
  • update beta,dirichet,exponential family docs.
  • fix categorical entropy,sample bugs.

@paddle-bot-old
Copy link

paddle-bot-old bot commented Jan 9, 2022

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot-old
Copy link

Sorry to inform you that 7e37d2b's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@cxxly cxxly force-pushed the distribution-multinomial branch 2 times, most recently from 3f921de to 154a928 Compare January 20, 2022 10:51
@@ -68,6 +68,11 @@ def kl_divergence(p, q):
def register_kl(cls_p, cls_q):
"""Decorator for register a KL divergence implemention function.

when call ``kl_divergence(p, q)`` , will search concrete implemention

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意语法。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

@@ -37,8 +44,15 @@ class Beta(ExponentialFamily):


Args:
alpha (float|Tensor): alpha parameter of beta distribution, positive(>0).
alpha (float|Tensor): alpha parameter of beta distribution,
positive(>0), support broadcast semantic. when the parameter is

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意英文句法,首字母大写等。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

@@ -200,7 +208,8 @@ def kl_divergence(self, other):
if not in_dygraph_mode():
check_type(other, 'other', Categorical, 'kl_divergence')

logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True)
logits = self.logits - \
nn.reduce_max(self.logits, dim=-1, keep_dim=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些推荐使用 paddle.max ,而不是用 nn 里面的函数。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是之前同学遗留代码,我先更新下这部分代码,其余遗留在后续规划中,会统一更新


Args:
concentration (Tensor): concentration parameter of dirichlet
distribution
distribution, also called :math:`\alpha`. when concentration over
Copy link

@iclementine iclementine Jan 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意句子首字母大写。英文句子用英文逗号,且后附空格。


Args:
total_count (int): Number of trials.
probs (Tensor): Probability of a trail falling into each category. Last
Copy link

@iclementine iclementine Jan 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意拼写。 trial trail

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

samples = self._dist.sample(sample_shape)
sample_mean = samples.mean(axis=0)
np.testing.assert_allclose(
sample_mean, self._dist.mean, atol=0, rtol=0.20)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样子的 tolerance 可能会被 CI 系统认为不合理。是否写明一下这么做的原因。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

('value-int', 10, np.array([0.2, 0.3, 0.5]), np.array([2, 3, 5])),
('value-multi-dim', 10, np.array([[0.3, 0.7], [0.5, 0.5]]),
np.array([[4., 6], [8, 2]])),
# ('value-sum-non-n', 10, np.array([0.5, 0.2, 0.3]), np.array([4,5,2])),
Copy link

@iclementine iclementine Jan 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否添加采多个样的 case. 比如 Batch shape 是 (), 而 sample shape 是 (2,) 类似这样的。

def setUp(self):
self.prog = paddle.static.Program()
self.exe = paddle.static.Executor()
with paddle.static.program_guard(prog):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最好写两个 program, 虽然这里并不创建参数。

@cxxly cxxly force-pushed the distribution-multinomial branch 2 times, most recently from 1547cc0 to efdea2e Compare January 25, 2022 03:05
@cxxly cxxly force-pushed the distribution-multinomial branch 4 times, most recently from 359cdd6 to 8654cf4 Compare January 27, 2022 05:58
function registered by ``register_kl``, according to multi-dispatch pattern.
If find the implemention function, it will return the result, or not will
raise ``NotImplementError`` exception. User can register implemention
funciton by the decorator.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. implemention functions (plural forum);
  2. If an implemention function is found;
  3. ortherwise, it will raise a NotImplementError exception;
  4. Users
  5. functions

@@ -167,7 +170,7 @@ def _kl_uniform_uniform(p, q):

@register_kl(ExponentialFamily, ExponentialFamily)
def _kl_expfamily_expfamily(p, q):
"""compute kl-divergence using `Bregman divergences`
"""Compute kl-divergence using `Bregman divergences`
https://www.lix.polytechnique.fr/~nielsen/EntropyEF-ICIP2010.pdf

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use correct hyperlink format of rst.

@@ -205,5 +208,5 @@ def _kl_expfamily_expfamily(p, q):


def _sum_rightmost(value, n):
"""sum value along rightmost n dim"""
"""Sum value along rightmost n dim"""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elements ...(plural form)
dimensions.

@@ -37,8 +44,14 @@ class Beta(ExponentialFamily):


Args:
alpha (float|Tensor): alpha parameter of beta distribution, positive(>0).
beta (float|Tensor): beta parameter of beta distribution, positive(>0).
alpha (float|Tensor): Alpha parameter. It support broadcast semantic.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It supports
Sementics.
is a tensor
represents
distributions
a

concentration (Tensor): concentration parameter of dirichlet
distribution
concentration (Tensor): "Concentration" parameter of dirichlet
distribution, also called :math:`\alpha`. When It's over one

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When it's.

concentration (Tensor): "Concentration" parameter of dirichlet
distribution, also called :math:`\alpha`. When It's over one
dimension, the last axis is parameter of distribution,
``event_shape=concentration.shape[-1:]`` , other axes is batch

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

axes other than the last are condsiderd batch dimensions.

Args:
total_count (int): Number of trials.
probs (Tensor): Probability of a trial falling into each category. Last
axis of probs indexes over categories, other axes index over batches.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last axis

iclementine
iclementine previously approved these changes Jan 27, 2022
Copy link

@iclementine iclementine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

TODO: refine documentation in next PR!

Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approve

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@iclementine iclementine merged commit 01f606b into PaddlePaddle:develop Jan 30, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants