From 59223df8db5a6e776d87244b7e50eb57d60e89ae Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Wed, 22 Sep 2021 20:35:24 +1000 Subject: [PATCH] add sinkhorn_divergence to runtests (#146) * add sinkhorn_divergence to runtests * fix issues with julia 1.0 compat * add more compat tag * remove usage of eachcol --- Project.toml | 2 +- src/utils.jl | 5 ++++- test/entropic/sinkhorn_divergence.jl | 15 +++++++-------- test/runtests.jl | 3 +++ 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 1919a7df..7a224926 100644 --- a/Project.toml +++ b/Project.toml @@ -27,8 +27,8 @@ PythonOT = "3c485715-4278-42b2-9b5f-8f00e43c12ef" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Distances", "ForwardDiff", "ReverseDiff", "Pkg", "PythonOT", "Random", "SafeTestsets", "Test", "StatsBase"] diff --git a/src/utils.jl b/src/utils.jl index 8110cd72..c73214fe 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -25,7 +25,10 @@ end dot_matwise(x::AbstractMatrix, y::AbstractArray) = dot_matwise(y, x) function dot_vecwise(x::AbstractMatrix, y::AbstractMatrix) - return [dot(u, v) for (u, v) in zip(eachcol(x), eachcol(y))] + return [ + dot(u, v) for (u, v) in + zip((view(x, :, i) for i in axes(x, 2)), (view(y, :, i) for i in axes(y, 2))) + ] end dot_vecwise(x::AbstractMatrix, y::AbstractVector) = x' * y diff --git a/test/entropic/sinkhorn_divergence.jl b/test/entropic/sinkhorn_divergence.jl index 16bc9d13..bebe5709 100644 --- a/test/entropic/sinkhorn_divergence.jl +++ b/test/entropic/sinkhorn_divergence.jl @@ -5,7 +5,6 @@ using ForwardDiff using ReverseDiff using LogExpFunctions using PythonOT: PythonOT -using StatsBase using LinearAlgebra using Random using Test @@ -54,18 +53,18 @@ Random.seed!(100) for reg in (true, false) loss_batch = sinkhorn_divergence(μ, ν, C, ε; regularization=reg) @test loss_batch ≈ [ - sinkhorn_divergence(x, y, C, ε; regularization=reg) for - (x, y) in zip(eachcol(μ), eachcol(ν)) + sinkhorn_divergence(μ[:, i], ν[:, i], C, ε; regularization=reg) for + i in 1:M ] loss_batch_μ = sinkhorn_divergence(μ, ν[:, 1], C, ε; regularization=reg) @test loss_batch_μ ≈ [ - sinkhorn_divergence(x, ν[:, 1], C, ε; regularization=reg) for - x in eachcol(μ) + sinkhorn_divergence(μ[:, i], ν[:, 1], C, ε; regularization=reg) for + i in 1:M ] loss_batch_ν = sinkhorn_divergence(μ[:, 1], ν, C, ε; regularization=reg) @test loss_batch_ν ≈ [ - sinkhorn_divergence(μ[:, 1], y, C, ε; regularization=reg) for - y in eachcol(ν) + sinkhorn_divergence(μ[:, 1], ν[:, i], C, ε; regularization=reg) for + i in 1:M ] end end @@ -98,7 +97,7 @@ Random.seed!(100) Cμν = pairwise(SqEuclidean(), μ_spt', ν_spt'; dims=2) Cμ = pairwise(SqEuclidean(), μ_spt'; dims=2) Cν = pairwise(SqEuclidean(), ν_spt'; dims=2) - ε = 0.1 * max(mean(Cμν), mean(Cμ), mean(Cν)) + ε = 1.0 @testset "basic" begin for reg in (true, false) diff --git a/test/runtests.jl b/test/runtests.jl index 4aba626f..66314dfa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,6 +28,9 @@ const GROUP = get(ENV, "GROUP", "All") @safetestset "Sinkhorn barycenter" begin include(joinpath("entropic", "sinkhorn_barycenter.jl")) end + @safetestset "Sinkhorn divergence" begin + include(joinpath("entropic", "sinkhorn_divergence.jl")) + end end @safetestset "Quadratically regularized OT" begin