Skip to content

Commit

Permalink
add kl-div registrition for binomial and poisson
Browse files Browse the repository at this point in the history
  • Loading branch information
NKNaN committed Oct 24, 2023
1 parent 4c1bfd4 commit 8a1f336
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions python/paddle/distribution/kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import paddle
from paddle.distribution.bernoulli import Bernoulli
from paddle.distribution.beta import Beta
from paddle.distribution.binomial import Binomial
from paddle.distribution.categorical import Categorical
from paddle.distribution.cauchy import Cauchy
from paddle.distribution.dirichlet import Dirichlet
Expand All @@ -26,6 +27,7 @@
from paddle.distribution.laplace import Laplace
from paddle.distribution.lognormal import LogNormal
from paddle.distribution.normal import Normal
from paddle.distribution.poisson import Poisson
from paddle.distribution.uniform import Uniform
from paddle.framework import in_dynamic_mode

Expand Down Expand Up @@ -151,6 +153,11 @@ def _kl_bernoulli_bernoulli(p, q):
return p.kl_divergence(q)


@register_kl(Binomial, Binomial)
def _kl_binomial_binomial(p, q):
return p.kl_divergence(q)


@register_kl(Beta, Beta)
def _kl_beta_beta(p, q):
return (
Expand Down Expand Up @@ -197,6 +204,11 @@ def _kl_normal_normal(p, q):
return p.kl_divergence(q)


@register_kl(Poisson, Poisson)
def _kl_poisson_poisson(p, q):
return p.kl_divergence(q)


@register_kl(Uniform, Uniform)
def _kl_uniform_uniform(p, q):
return p.kl_divergence(q)
Expand Down

0 comments on commit 8a1f336

Please sign in to comment.