Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix KMM coefficient and add tests #19

Merged
merged 17 commits into from
Jan 12, 2020
Merged
2 changes: 1 addition & 1 deletion src/kmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ Kernel Mean Matching (KMM).
λ::T=0.0
end

default_optlib(dre::Type{<:KMM}) = JuMPLib
default_optlib(dre::Type{<:KMM}) = JuliaLib
xukai92 marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion src/kmm/julia.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ function _densratio(x_nu, x_de, dre::KMM, optlib::Type{JuliaLib})
Kdenu = gaussian_gramian(x_de, x_nu, σ=σ)

# closed-form solution (without constraints)
(n_de / n_nu) * (Kdede + λ*I) \ vec(sum(Kdenu, dims=2))
(n_de / n_nu) * ((Kdede + λ*I) \ vec(sum(Kdenu, dims=2)))
xukai92 marked this conversation as resolved.
Show resolved Hide resolved
end
21 changes: 0 additions & 21 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,5 @@
G = DensityRatioEstimation.gaussian_gramian(x_nu, x_nu, σ=2.0)
@test issymmetric(G)
@test all(G .> 0)

# features can be any indexable
xukai92 marked this conversation as resolved.
Show resolved Hide resolved
x_nu = [(a=1.,b=2.),(a=3.,b=4.)]
x_de = [(a=1.,b=2.),(a=3.,b=4.),(a=5.,b=6.)]
G = DensityRatioEstimation.gaussian_gramian(x_nu, x_de)
@test size(G) == (2, 3)
@test all(G .> 0)
end

for (d_nu, d_de) in [pair₁, pair₂]
Random.seed!(123)
x_nu, x_de = rand(d_nu, 100), rand(d_de, 200)
@testset "$dre -- $optlib" for (dre, optlib) in [(KMM(), JuMPLib),
(KLIEP(), OptimLib),
(KLIEP(), ConvexLib)]

r = densratio(x_nu, x_de, dre, optlib=optlib)

# density ratios must be positive
@test all(r .> 0)
xukai92 marked this conversation as resolved.
Show resolved Hide resolved
end
xukai92 marked this conversation as resolved.
Show resolved Hide resolved
end
end
Binary file removed test/data/KLIEP-1.png
Binary file not shown.
Binary file removed test/data/KLIEP-2.png
Binary file not shown.
Binary file added test/data/KLIEP-ConvexLib-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/data/KLIEP-ConvexLib-2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/data/KLIEP-OptimLib-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/data/KLIEP-OptimLib-2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/data/KMM-JuMPLib-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/data/KMM-JuMPLib-2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/data/KMM-JuliaLib-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/data/KMM-JuliaLib-2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed test/data/KMM.png
Binary file not shown.
23 changes: 17 additions & 6 deletions test/kliep.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
@testset "KLIEP" begin
for (i, pair) in enumerate([pair₁, pair₂])
@testset "KLIEP -- $optlib" for optlib in [OptimLib, ConvexLib]
for (i, pair, rtol_correctness) in zip(1:2, [pair₁, pair₂], (2e-1, 4e-1))
xukai92 marked this conversation as resolved.
Show resolved Hide resolved
d_nu, d_de = pair
Random.seed!(123)
x_nu, x_de = rand(d_nu, 1000), rand(d_de, 500)
x_nu, x_de = rand(d_nu, 1_000), rand(d_de, 500)
xukai92 marked this conversation as resolved.
Show resolved Hide resolved

# estimated density ratio
r̂ = densratio(x_nu, x_de, KLIEP(σ=1.0, b=100))
σ, b = 1.0, 100
r̂ = densratio(x_nu, x_de, KLIEP(σ=σ, b=b), optlib=optlib)

# density ratios must be positive
@test all(r̂ .> 0)
xukai92 marked this conversation as resolved.
Show resolved Hide resolved

# simplex constraints
@test abs(mean(r̂) - 1) ≤ 1e-2
@test all(r̂ .≤ Inf)

r = pdf.(d_nu, x_de) ./ pdf.(d_de, x_de)
@test r ≈ r̂ rtol=rtol_correctness

if visualtests
gr(size=(800,800))
@plottest plot_d_nu(pair,x_de,r̂) joinpath(datadir,"KLIEP-$i.png") !istravis
gr(size=(800, 800))
@plottest plot_d_nu(pair, x_de, r̂) joinpath(datadir, "KLIEP-$optlib-$i.png") !istravis
end
end
end
36 changes: 23 additions & 13 deletions test/kmm.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
@testset "KMM" begin
d_nu, d_de = pair₁
@testset "KMM -- $optlib" for optlib in [JuliaLib, JuMPLib]
for (i, pair) in enumerate([pair₁, pair₂])
juliohm marked this conversation as resolved.
Show resolved Hide resolved
d_nu, d_de = pair
Random.seed!(123)
x_nu, x_de = rand(d_nu, 2_000), rand(d_de, 1_000)

Random.seed!(123)
x_nu, x_de = rand(d_nu, 2000), rand(d_de, 1000)
# estimated density ratio
σ, B, ϵ, λ = 1.5, Inf, 0.01, 0.001
r̂ = densratio(x_nu, x_de, KMM(σ=σ, B=B, ϵ=ϵ, λ=λ), optlib=optlib)

# estimated density ratio
σ, B, ϵ = 1.0, Inf, 0.01
r̂ = densratio(x_nu, x_de, KMM(σ=σ, B=B, ϵ=ϵ))
# density ratios must be positive
optlib == JuMPLib && @test all(r̂ .> 0)
xukai92 marked this conversation as resolved.
Show resolved Hide resolved
xukai92 marked this conversation as resolved.
Show resolved Hide resolved

# simplex constraints
@test abs(mean(r̂) - 1) ≤ ϵ
@test all(r̂ .≤ B)
# simplex constraints
@test abs(mean(r̂) - 1) ≤ 1e-2
@test all(r̂ .≤ B)

if visualtests
gr(size=(800,800))
@plottest plot_d_nu(pair₁,x_de,r̂) joinpath(datadir,"KMM.png") !istravis
if i == 1 # FIXME: only check correctness for the Gaussian case now
juliohm marked this conversation as resolved.
Show resolved Hide resolved
juliohm marked this conversation as resolved.
Show resolved Hide resolved
# compare against true ratio
r = pdf.(d_nu, x_de) ./ pdf.(d_de, x_de)
@test r ≈ r̂ rtol=2e-1
end

if visualtests
gr(size=(800, 800))
@plottest plot_d_nu(pair, x_de, r̂) joinpath(datadir, "KMM-$optlib-$i.png") !istravis
end
end
end