Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Histograms #10

Merged
merged 4 commits into from
Sep 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
ScikitLearn = "3646fa90-6ef7-5e7e-9f22-8aca16db6324"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand Down
12 changes: 12 additions & 0 deletions src/ConvenienceFunctions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

const RPAD = 25

function name(name::AbstractString)
return rpad(name * ":", RPAD)
end

function warn(name::AbstractString)
return rpad("WARNING (" * name * "):", RPAD)
end


15 changes: 2 additions & 13 deletions src/GPR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ using Parameters # lets you have defaults for fields
using EllipsisNotation # adds '..' to refer to the rest of array
import ScikitLearn
import StatsBase
include("ConvenienceFunctions.jl")

const sklearn = ScikitLearn

sklearn.@sk_import gaussian_process : GaussianProcessRegressor
Expand Down Expand Up @@ -324,19 +326,6 @@ function plot_fit(gprw::Wrap, plt; plot_95 = false, label = nothing)
end
end

################################################################################
# convenience functions ########################################################
################################################################################
const RPAD = 25

function name(name::AbstractString)
return rpad(name * ":", RPAD)
end

function warn(name::AbstractString)
return rpad("WARNING (" * name * "):", RPAD)
end

end # module


80 changes: 80 additions & 0 deletions src/Histograms.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
module Histograms
"""
This module is mostly a convenient wrapper of Python functions (numpy, scipy).

Functions in this module:
- W1 (2 methods)

"""

import PyCall
include("ConvenienceFunctions.jl")

scsta = PyCall.pyimport("scipy.stats")

################################################################################
# distance functions ###########################################################
################################################################################
"""
Compute the Wasserstein-1 distance between two distributions from their samples

Parameters:
- u_samples: array-like; samples from the 1st distribution
- v_samples: array-like; samples from the 2nd distribution
- normalize: boolean; whether to normalize the distance by 1/(max-min)

Returns:
- w1_uv: number; the Wasserstein-1 distance
"""
function W1(u_samples::AbstractVector, v_samples::AbstractVector;
normalize = true)
L = maximum([u_samples; v_samples]) - minimum([u_samples; v_samples])
return if !normalize
scsta.wasserstein_distance(u_samples, v_samples)
else
scsta.wasserstein_distance(u_samples, v_samples) / L
end
end

"""
Compute the pairwise Wasserstein-1 distances between two sets of distributions
from their samples

Parameters:
- U_samples: matrix-like; samples from distributions (u1, u2, ...)
- V_samples: matrix-like; samples from distributions (v1, v2, ...)
- normalize: boolean; whether to normalize the distances by 1/(max-min)

`U_samples` and `V_samples` should have samples in the 2nd dimension (along
rows) and have the same 1st dimension (same number of rows). If not, the minimum
of the two (minimum number of rows) will be taken.

`normalize` induces *pairwise* normalization, i.e. it max's and min's are
computed for each pair (u_j, v_j) individually.

Returns:
- w1_UV: array-like; the pairwise Wasserstein-1 distances:
w1(u1, v1)
w1(u2, v2)
...
w1(u_K, v_K)
"""
function W1(U_samples::AbstractMatrix, V_samples::AbstractMatrix;
normalize = true)
if size(U_samples, 1) != size(V_samples, 1)
println(warn("W1"), "sizes of U_samples & V_samples don't match; ",
"will use the minimum of the two")
end
K = min(size(U_samples, 1), size(V_samples, 1))
w1_UV = zeros(K)
U_sorted = sort(U_samples[1:K, 1:end], dims = 2)
V_sorted = sort(V_samples[1:K, 1:end], dims = 2)
for k in 1:K
w1_UV[k] = W1(U_sorted[k, 1:end], V_sorted[k, 1:end]; normalize = normalize)
end
return w1_UV
end

end # module


19 changes: 19 additions & 0 deletions test/ConvenienceFunctions/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using Test

include("../../src/ConvenienceFunctions.jl")

################################################################################
# unit testing #################################################################
################################################################################
@testset "unit testing" begin
@test isdefined(Main, :RPAD)
@test length(name("a")) == RPAD
@test length(name("a" ^ RPAD)) == (RPAD + 1)
@test length(warn("a")) == RPAD
@test length(warn("a" ^ RPAD)) == (RPAD + 11)
@test isa(name("a"), String)
@test isa(warn("a"), String)
end
println("")


Binary file added test/Histograms/data/x1_bal.npy
Binary file not shown.
Binary file added test/Histograms/data/x1_dns.npy
Binary file not shown.
Binary file added test/Histograms/data/x1_onl.npy
Binary file not shown.
36 changes: 36 additions & 0 deletions test/Histograms/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using Test
import NPZ

include("../../src/Histograms.jl")
const Hgm = Histograms

const data_dir = joinpath(@__DIR__, "data")
const x1_bal = NPZ.npzread(joinpath(data_dir, "x1_bal.npy"))
const x1_dns = NPZ.npzread(joinpath(data_dir, "x1_dns.npy"))
const x1_onl = NPZ.npzread(joinpath(data_dir, "x1_onl.npy"))
const w1_dns_bal = 0.03755967829782972
const w1_dns_onl = 0.004489688974663949
const w1_bal_onl = 0.037079734072606625
const w1_dns_bal_unnorm = 0.8190688772401341

################################################################################
# unit testing #################################################################
################################################################################
@testset "unit testing" begin
arr1 = [1, 1, 1, 2, 3, 4, 4, 4]
arr2 = [1, 1, 2, 2, 3, 3, 4, 4, 4]
@test Hgm.W1(arr1, arr2, normalize = false) == 0.25
@test Hgm.W1(arr2, arr1, normalize = false) == 0.25
@test Hgm.W1(arr1, arr2) == Hgm.W1(arr2, arr1)

@test isapprox(Hgm.W1(x1_dns, x1_bal), w1_dns_bal)
@test isapprox(Hgm.W1(x1_dns, x1_onl), w1_dns_onl)
@test isapprox(Hgm.W1(x1_bal, x1_onl), w1_bal_onl)
@test isapprox(Hgm.W1(x1_dns, x1_bal, normalize = false), w1_dns_bal_unnorm)

@test size(Hgm.W1(rand(3,100), rand(3,100))) == (3,)
@test size(Hgm.W1(rand(9,100), rand(3,100))) == (3,)
end
println("")


2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ include("neki.jl")

for submodule in ["L96m",
"GPR",
"Histograms",
"ConvenienceFunctions",
]

println("Starting tests for $submodule")
Expand Down