diff --git a/RELEASES.md b/RELEASES.md index 9e5be471f..e5a8ac54b 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -3,6 +3,7 @@ ## 0.9.5dev #### New features +- Add feature `mass=True` for `nx.kl_div` (PR #654) #### Closed issues diff --git a/ot/backend.py b/ot/backend.py index 534c03293..819b91db5 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -944,16 +944,17 @@ def eigh(self, a): """ raise NotImplementedError() - def kl_div(self, p, q, eps=1e-16): + def kl_div(self, p, q, mass=False, eps=1e-16): r""" - Computes the Kullback-Leibler divergence. + Computes the (Generalized) Kullback-Leibler divergence. This function follows the api from :any:`scipy.stats.entropy`. Parameter eps is used to avoid numerical errors and is added in the log. .. math:: - KL(p,q) = \sum_i p(i) \log (\frac{p(i)}{q(i)}+\epsilon) + KL(p,q) = \langle \mathbf{p}, log(\mathbf{p} / \mathbf{q} + eps \rangle + + \mathbb{1}_{mass=True} \langle \mathbf{q} - \mathbf{p}, \mathbf{1} \rangle See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html """ @@ -1352,8 +1353,11 @@ def sqrtm(self, a): def eigh(self, a): return np.linalg.eigh(a) - def kl_div(self, p, q, eps=1e-16): - return np.sum(p * np.log(p / q + eps)) + def kl_div(self, p, q, mass=False, eps=1e-16): + value = np.sum(p * np.log(p / q + eps)) + if mass: + value = value + np.sum(q - p) + return value def isfinite(self, a): return np.isfinite(a) @@ -1751,8 +1755,11 @@ def sqrtm(self, a): def eigh(self, a): return jnp.linalg.eigh(a) - def kl_div(self, p, q, eps=1e-16): - return jnp.sum(p * jnp.log(p / q + eps)) + def kl_div(self, p, q, mass=False, eps=1e-16): + value = jnp.sum(p * jnp.log(p / q + eps)) + if mass: + value = value + jnp.sum(q - p) + return value def isfinite(self, a): return jnp.isfinite(a) @@ -2238,8 +2245,11 @@ def sqrtm(self, a): def eigh(self, a): return torch.linalg.eigh(a) - def kl_div(self, p, q, eps=1e-16): - return torch.sum(p * torch.log(p / q + eps)) + def kl_div(self, p, q, mass=False, eps=1e-16): + value = torch.sum(p * torch.log(p / q + eps)) + if mass: + value = value + torch.sum(q - p) + return value def isfinite(self, a): return torch.isfinite(a) @@ -2639,8 +2649,11 @@ def sqrtm(self, a): def eigh(self, a): return cp.linalg.eigh(a) - def kl_div(self, p, q, eps=1e-16): - return cp.sum(p * cp.log(p / q + eps)) + def kl_div(self, p, q, mass=False, eps=1e-16): + value = cp.sum(p * cp.log(p / q + eps)) + if mass: + value = value + cp.sum(q - p) + return value def isfinite(self, a): return cp.isfinite(a) @@ -3063,8 +3076,11 @@ def sqrtm(self, a): def eigh(self, a): return tf.linalg.eigh(a) - def kl_div(self, p, q, eps=1e-16): - return tnp.sum(p * tnp.log(p / q + eps)) + def kl_div(self, p, q, mass=False, eps=1e-16): + value = tnp.sum(p * tnp.log(p / q + eps)) + if mass: + value = value + tnp.sum(q - p) + return value def isfinite(self, a): return tnp.isfinite(a) diff --git a/ot/bregman/_barycenter.py b/ot/bregman/_barycenter.py index abef77d63..5d90782fb 100644 --- a/ot/bregman/_barycenter.py +++ b/ot/bregman/_barycenter.py @@ -364,7 +364,7 @@ def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, log = {'err': []} M = - M / reg - logA = nx.log(A + 1e-15) + logA = nx.log(A + 1e-16) log_KU, G = nx.zeros((2, *logA.shape), type_as=A) err = 1 for ii in range(numItermax): @@ -702,7 +702,7 @@ def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000, log = {'err': []} M = - M / reg - logA = nx.log(A + 1e-15) + logA = nx.log(A + 1e-16) log_KU, G = nx.zeros((2, *logA.shape), type_as=A) c = nx.zeros(dim, type_as=A) err = 1 diff --git a/ot/coot.py b/ot/coot.py index 477529f45..4134e594c 100644 --- a/ot/coot.py +++ b/ot/coot.py @@ -139,10 +139,6 @@ def co_optimal_transport(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat Advances in Neural Information Processing ny_sampstems, 33 (2020). """ - def compute_kl(p, q): - kl = nx.sum(p * nx.log(p + 1.0 * (p == 0))) - nx.sum(p * nx.log(q)) - return kl - # Main function if method_sinkhorn not in ["sinkhorn", "sinkhorn_log"]: @@ -245,9 +241,9 @@ def compute_kl(p, q): coot = coot + alpha_samp * nx.sum(M_samp * pi_samp) # Entropic part if eps_samp != 0: - coot = coot + eps_samp * compute_kl(pi_samp, wxy_samp) + coot = coot + eps_samp * nx.kl_div(pi_samp, wxy_samp) if eps_feat != 0: - coot = coot + eps_feat * compute_kl(pi_feat, wxy_feat) + coot = coot + eps_feat * nx.kl_div(pi_feat, wxy_feat) list_coot.append(coot) if err < tol_bcd or abs(list_coot[-2] - list_coot[-1]) < early_stopping_tol: diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index c1b744333..9a8111453 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -109,7 +109,7 @@ def h2(b): return 2 * b elif loss_fun == 'kl_loss': def f1(a): - return a * nx.log(a + 1e-15) - a + return a * nx.log(a + 1e-16) - a def f2(b): return b @@ -118,7 +118,7 @@ def h1(a): return a def h2(b): - return nx.log(b + 1e-15) + return nx.log(b + 1e-16) else: raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.") @@ -502,7 +502,7 @@ def h2(b): return 2 * b elif loss_fun == 'kl_loss': def f1(a): - return a * nx.log(a + 1e-15) - a + return a * nx.log(a + 1e-16) - a def f2(b): return b @@ -511,7 +511,7 @@ def h1(a): return a def h2(b): - return nx.log(b + 1e-15) + return nx.log(b + 1e-16) else: raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.") diff --git a/test/test_utils.py b/test/test_utils.py index 966cef989..5c5dab48d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -620,3 +620,19 @@ def test_label_normalization(nx): # labels are shifted but the shift if expected y_normalized_start = ot.utils.label_normalization(y, start=1) np.testing.assert_array_equal(y, y_normalized_start) + + +def test_kl_div(nx): + n = 10 + rng = np.random.RandomState(0) + # test on non-negative tensors + x = rng.randn(n) + x = x - x.min() + 1e-5 + y = rng.randn(n) + y = y - y.min() + 1e-5 + xb = nx.from_numpy(x) + yb = nx.from_numpy(y) + kl = nx.kl_div(xb, yb) + kl_mass = nx.kl_div(xb, yb, True) + recovered_kl = kl_mass - nx.sum(yb - xb) + np.testing.assert_allclose(kl, recovered_kl)