Skip to content

Commit

Permalink
add sinkhorn_divergence to runtests (#146)
Browse files Browse the repository at this point in the history
* add sinkhorn_divergence to runtests

* fix issues with julia 1.0 compat

* add more compat tag

* remove usage of eachcol
  • Loading branch information
zsteve committed Sep 22, 2021
1 parent 803274d commit 59223df
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Expand Up @@ -27,8 +27,8 @@ PythonOT = "3c485715-4278-42b2-9b5f-8f00e43c12ef"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Distances", "ForwardDiff", "ReverseDiff", "Pkg", "PythonOT", "Random", "SafeTestsets", "Test", "StatsBase"]
5 changes: 4 additions & 1 deletion src/utils.jl
Expand Up @@ -25,7 +25,10 @@ end
dot_matwise(x::AbstractMatrix, y::AbstractArray) = dot_matwise(y, x)

function dot_vecwise(x::AbstractMatrix, y::AbstractMatrix)
return [dot(u, v) for (u, v) in zip(eachcol(x), eachcol(y))]
return [
dot(u, v) for (u, v) in
zip((view(x, :, i) for i in axes(x, 2)), (view(y, :, i) for i in axes(y, 2)))
]
end

dot_vecwise(x::AbstractMatrix, y::AbstractVector) = x' * y
Expand Down
15 changes: 7 additions & 8 deletions test/entropic/sinkhorn_divergence.jl
Expand Up @@ -5,7 +5,6 @@ using ForwardDiff
using ReverseDiff
using LogExpFunctions
using PythonOT: PythonOT
using StatsBase
using LinearAlgebra
using Random
using Test
Expand Down Expand Up @@ -54,18 +53,18 @@ Random.seed!(100)
for reg in (true, false)
loss_batch = sinkhorn_divergence(μ, ν, C, ε; regularization=reg)
@test loss_batch [
sinkhorn_divergence(x, y, C, ε; regularization=reg) for
(x, y) in zip(eachcol(μ), eachcol(ν))
sinkhorn_divergence(μ[:, i], ν[:, i], C, ε; regularization=reg) for
i in 1:M
]
loss_batch_μ = sinkhorn_divergence(μ, ν[:, 1], C, ε; regularization=reg)
@test loss_batch_μ [
sinkhorn_divergence(x, ν[:, 1], C, ε; regularization=reg) for
x in eachcol(μ)
sinkhorn_divergence(μ[:, i], ν[:, 1], C, ε; regularization=reg) for
i in 1:M
]
loss_batch_ν = sinkhorn_divergence(μ[:, 1], ν, C, ε; regularization=reg)
@test loss_batch_ν [
sinkhorn_divergence(μ[:, 1], y, C, ε; regularization=reg) for
y in eachcol(ν)
sinkhorn_divergence(μ[:, 1], ν[:, i], C, ε; regularization=reg) for
i in 1:M
]
end
end
Expand Down Expand Up @@ -98,7 +97,7 @@ Random.seed!(100)
Cμν = pairwise(SqEuclidean(), μ_spt', ν_spt'; dims=2)
= pairwise(SqEuclidean(), μ_spt'; dims=2)
= pairwise(SqEuclidean(), ν_spt'; dims=2)
ε = 0.1 * max(mean(Cμν), mean(Cμ), mean(Cν))
ε = 1.0

@testset "basic" begin
for reg in (true, false)
Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Expand Up @@ -28,6 +28,9 @@ const GROUP = get(ENV, "GROUP", "All")
@safetestset "Sinkhorn barycenter" begin
include(joinpath("entropic", "sinkhorn_barycenter.jl"))
end
@safetestset "Sinkhorn divergence" begin
include(joinpath("entropic", "sinkhorn_divergence.jl"))
end
end

@safetestset "Quadratically regularized OT" begin
Expand Down

0 comments on commit 59223df

Please sign in to comment.