From e0a539e3332eb2e4927d78d96c4c8204447ba3ef Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Fri, 18 Aug 2023 18:14:11 +0200 Subject: [PATCH] Correctly handle cost = 0. --- ot/optim.py | 2 +- test/test_gromov.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/ot/optim.py b/ot/optim.py index 9e65e8141..61bc91e7f 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -290,7 +290,7 @@ def cost(G): loop = 0 abs_delta_cost_G = abs(cost_G - old_cost_G) - relative_delta_cost_G = abs_delta_cost_G / abs(cost_G) + relative_delta_cost_G = abs_delta_cost_G / abs(cost_G) if cost_G != 0. else np.nan if relative_delta_cost_G < stopThr or abs_delta_cost_G < stopThr2: loop = 0 diff --git a/test/test_gromov.py b/test/test_gromov.py index 13ff3fe99..1aeaf46e3 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -8,10 +8,12 @@ # License: MIT License import numpy as np +import pytest +import warnings + import ot from ot.backend import NumpyBackend from ot.backend import torch, tf -import pytest def test_gromov(nx): @@ -146,8 +148,10 @@ def test_gromov_dtype_device(nx): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0, type_as=tp) - Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False) + with warnings.catch_warnings(): + warnings.filterwarnings('error') + Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False) nx.assert_same_dtype_device(C1b, Gb) nx.assert_same_dtype_device(C1b, gw_valb)