Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor interface #1

Merged
merged 3 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.1.0"
ColumnSelectors = "9cc86067-7e36-4c61-b350-1ac9833d277f"
DensityRatioEstimation = "ab46fb84-d57c-11e9-2f65-6f72e4a7229f"
GeoStatsBase = "323cb8eb-fbf6-51c0-afd0-f8fba70507b2"
GeoStatsModels = "ad987403-13c5-47b5-afee-0a48f6ac4f12"
GeoStatsTransforms = "725d9659-360f-4996-9c94-5f19c7e4a8a6"
GeoTables = "e502b557-6362-48c1-8219-d30d308dcdb0"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
Expand All @@ -18,9 +19,11 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
ColumnSelectors = "0.1"
DensityRatioEstimation = "1.2"
GeoStatsBase = "0.42"
GeoStatsModels = "0.2"
GeoStatsTransforms = "0.2"
GeoTables = "1.14"
LossFunctions = "0.11"
StatsLearnModels = "0.2"
Meshes = "0.37"
StatsLearnModels = "0.3"
Transducers = "0.4"
julia = "1.9"
4 changes: 3 additions & 1 deletion src/GeoStatsValidation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ using GeoTables
using Transducers
using DensityRatioEstimation

using StatsLearnModels: Learn
using GeoStatsModels: GeoStatsModel
using StatsLearnModels: StatsLearnModel
using StatsLearnModels: Learn, input, output
using GeoStatsTransforms: Interpolate, InterpolateNeighbors

using ColumnSelectors: selector
Expand Down
41 changes: 28 additions & 13 deletions src/cverror.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,37 +9,52 @@ A method for estimating cross-validatory error.
"""
abstract type ErrorMethod end

struct LearnSetup{M}
abstract type ErrorSetup end

struct LearnSetup{M} <: ErrorSetup
model::M
input::Vector{Symbol}
output::Vector{Symbol}
end

struct InterpSetup{I,M}
struct InterpSetup{I,M,K} <: ErrorSetup
model::M
kwargs::K
end

"""
cverror(Learn, model, incols => outcols, geotable, method)
cverror(Interpolate, model, geotable, method)
cverror(InterpolateNeighbors, model, geotable, method)
cverror(model::GeoStatsModel, geotable, method; kwargs...)

Estimate error of `model` in a given `geotable` with
error estimation `method` using `Interpolate` or `InterpolateNeighbors`
depending on the passed `kwargs`.
juliohm marked this conversation as resolved.
Show resolved Hide resolved

cverror(model::StatsLearnModel, geotable, method)
cverror((model, invars => outvars), geotable, method)

Estimate error of `model` in a given `geotable` with
error estimation `method`.
error estimation `method` using the `Learn` transform.
"""
function cverror end

function cverror(::Type{Learn}, model, (incols, outcols)::Pair, geotable::AbstractGeoTable, method::ErrorMethod)
cverror((model, cols)::Tuple{Any,Pair}, geotable::AbstractGeoTable, method::ErrorMethod) =
cverror(StatsLearnModel(model, first(cols), last(cols)), geotable, method)

function cverror(model::StatsLearnModel, geotable::AbstractGeoTable, method::ErrorMethod)
names = setdiff(propertynames(geotable), [:geometry])
input = selector(incols)(names)
output = selector(outcols)(names)
cverror(LearnSetup(model, input, output), geotable, method)
invars = input(model)(names)
outvars = output(model)(names)
setup = LearnSetup(model, invars, outvars)
cverror(setup, geotable, method)
end

const Interp = Union{Interpolate,InterpolateNeighbors}
const INTERPNEIGHBORS = (:minneighbors, :maxneighbors, :neighborhood, :distance)

cverror(::Type{I}, model::M, geotable::AbstractGeoTable, method::ErrorMethod) where {I<:Interp,M} =
cverror(InterpSetup{I,M}(model), geotable, method)
function cverror(model::M, geotable::AbstractGeoTable, method::ErrorMethod; kwargs...) where {M<:GeoStatsModel}
I = any(∈(INTERPNEIGHBORS), keys(kwargs)) ? InterpolateNeighbors : Interpolate
setup = InterpSetup{I,M,typeof(kwargs)}(model, kwargs)
cverror(setup, geotable, method)
end

# ----------------
# IMPLEMENTATIONS
Expand Down
2 changes: 1 addition & 1 deletion src/cverrors/bcv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ end

BlockValidation(sides; loss=Dict()) = BlockValidation{typeof(sides)}(sides, loss)

function cverror(setup, geotable, method::BlockValidation)
function cverror(setup::ErrorSetup, geotable::AbstractGeoTable, method::BlockValidation)
# uniform weights
weighting = UniformWeighting()

Expand Down
2 changes: 1 addition & 1 deletion src/cverrors/drv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function DensityRatioValidation(
DensityRatioValidation{T,E,O}(k, shuffle, lambda, estimator, optlib, loss)
end

function cverror(setup::LearnSetup, geotable, method::DensityRatioValidation)
function cverror(setup::LearnSetup, geotable::AbstractGeoTable, method::DensityRatioValidation)
vars = setup.input

# density-ratio weights
Expand Down
2 changes: 1 addition & 1 deletion src/cverrors/kfv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ end

KFoldValidation(k::Int; shuffle=true, loss=Dict()) = KFoldValidation(k, shuffle, loss)

function cverror(setup, geotable, method::KFoldValidation)
function cverror(setup::ErrorSetup, geotable::AbstractGeoTable, method::KFoldValidation)
# uniform weights
weighting = UniformWeighting()

Expand Down
2 changes: 1 addition & 1 deletion src/cverrors/lbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ LeaveBallOut(ball; loss=Dict()) = LeaveBallOut{typeof(ball)}(ball, loss)

LeaveBallOut(radius::Number; loss=Dict()) = LeaveBallOut(MetricBall(radius), loss=loss)

function cverror(setup, geotable, method::LeaveBallOut)
function cverror(setup::ErrorSetup, geotable::AbstractGeoTable, method::LeaveBallOut)
# uniform weights
weighting = UniformWeighting()

Expand Down
2 changes: 1 addition & 1 deletion src/cverrors/loo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ end

LeaveOneOut(; loss=Dict()) = LeaveOneOut(loss)

function cverror(setup, geotable, method::LeaveOneOut)
function cverror(setup::ErrorSetup, geotable::AbstractGeoTable, method::LeaveOneOut)
# uniform weights
weighting = UniformWeighting()

Expand Down
6 changes: 3 additions & 3 deletions src/cverrors/wcv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ end
WeightedValidation(weighting::W, folding::F; lambda::T=one(T), loss=Dict()) where {W,F,T} =
WeightedValidation{W,F,T}(weighting, folding, lambda, loss)

function cverror(setup, geotable, method::WeightedValidation)
function cverror(setup::ErrorSetup, geotable::AbstractGeoTable, method::WeightedValidation)
ovars = _outputs(setup, geotable)
loss = method.loss
for var in ovars
Expand Down Expand Up @@ -86,11 +86,11 @@ _outputs(s::LearnSetup, gtb) = s.output
function _prediction(s::InterpSetup{I}, geotable, f) where {I}
sdat = view(geotable, f[1])
sdom = view(domain(geotable), f[2])
sdat |> I(sdom, s.model)
sdat |> I(sdom, s.model; s.kwargs...)
end

function _prediction(s::LearnSetup, geotable, f)
source = view(geotable, f[1])
target = view(geotable, f[2])
target |> Learn(source, s.model, s.input => s.output)
target |> Learn(source, s.model)
end
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ using Test

# dummy classifier → 0.5 misclassification rate
for method in [LeaveOneOut(), LeaveBallOut(0.1), KFoldValidation(10), BlockValidation(0.1), DensityRatioValidation(10)]
e = cverror(Learn, model, :x => :y, gtb, method)
e = cverror((model, :x => :y), gtb, method)
@test isapprox(e[:y], 0.5, atol=0.06)
end
end
Expand All @@ -34,8 +34,8 @@ using Test
# low variance + dummy (mean) estimator → low error
# high variance + dummy (mean) estimator → high error
for method in [LeaveOneOut(), LeaveBallOut(0.1), KFoldValidation(10), BlockValidation(0.1)]
e₁ = cverror(Interpolate, model, sgtb₁, method)
e₂ = cverror(Interpolate, model, sgtb₂, method)
e₁ = cverror(model, sgtb₁, method)
e₂ = cverror(model, sgtb₂, method)
@test e₁[:z] < 1
@test e₂[:z] > 1
end
Expand Down