From 53bdb08f7c9bc254e7eaf7eaaf5504fffc122586 Mon Sep 17 00:00:00 2001 From: Elias Carvalho Date: Fri, 15 Dec 2023 15:28:42 -0300 Subject: [PATCH] Initial commit --- .JuliaFormatter.toml | 10 ++++ .github/workflows/FormatPR.yml | 28 ++++++++++ .gitignore | 3 +- Project.toml | 21 +++++++- README.md | 2 +- src/GeoStatsValidation.jl | 18 ++++++- src/cverror.jl | 53 +++++++++++++++++++ src/cverrors/bcv.jl | 39 ++++++++++++++ src/cverrors/drv.jl | 64 ++++++++++++++++++++++ src/cverrors/kfv.jl | 38 +++++++++++++ src/cverrors/lbo.jl | 42 +++++++++++++++ src/cverrors/loo.jl | 32 +++++++++++ src/cverrors/wcv.jl | 97 ++++++++++++++++++++++++++++++++++ test/Manifest.toml | 34 ------------ test/Project.toml | 6 +++ test/runtests.jl | 39 +++++++++++++- 16 files changed, 487 insertions(+), 39 deletions(-) create mode 100644 .JuliaFormatter.toml create mode 100644 .github/workflows/FormatPR.yml create mode 100644 src/cverror.jl create mode 100644 src/cverrors/bcv.jl create mode 100644 src/cverrors/drv.jl create mode 100644 src/cverrors/kfv.jl create mode 100644 src/cverrors/lbo.jl create mode 100644 src/cverrors/loo.jl create mode 100644 src/cverrors/wcv.jl delete mode 100644 test/Manifest.toml diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 0000000..eb77a5b --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1,10 @@ +indent = 2 +margin = 120 +always_for_in = true +always_use_return = false +whitespace_typedefs = false +whitespace_in_kwargs = false +whitespace_ops_in_indices = true +remove_extra_newlines = true +trailing_comma = false +normalize_line_endings = "unix" diff --git a/.github/workflows/FormatPR.yml b/.github/workflows/FormatPR.yml new file mode 100644 index 0000000..2a3051d --- /dev/null +++ b/.github/workflows/FormatPR.yml @@ -0,0 +1,28 @@ +name: FormatPR +on: + push: + branches: + - master +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install JuliaFormatter and format + run: | + julia -e 'import Pkg; Pkg.add("JuliaFormatter")' + julia -e 'using JuliaFormatter; format(".")' + - name: Create Pull Request + id: cpr + uses: peter-evans/create-pull-request@v5 + with: + token: ${{ secrets.GITHUB_TOKEN }} + commit-message: ":robot: Format .jl files" + title: '[AUTO] JuliaFormatter.jl run' + branch: auto-juliaformatter-pr + delete-branch: true + labels: formatting, automated pr, no changelog + - name: Check outputs + run: | + echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}" + echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}" diff --git a/.gitignore b/.gitignore index 0f84bed..dfb1780 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *.jl.*.cov *.jl.cov *.jl.mem -/Manifest.toml +Manifest.toml +.vscode diff --git a/Project.toml b/Project.toml index be08a0d..1dfb654 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,26 @@ name = "GeoStatsValidation" uuid = "36f43c0d-3673-45fc-9557-6860e708e7aa" authors = ["Elias Carvalho and contributors"] -version = "1.0.0-DEV" +version = "0.1.0" + +[deps] +ColumnSelectors = "9cc86067-7e36-4c61-b350-1ac9833d277f" +DensityRatioEstimation = "ab46fb84-d57c-11e9-2f65-6f72e4a7229f" +GeoStatsBase = "323cb8eb-fbf6-51c0-afd0-f8fba70507b2" +GeoStatsTransforms = "725d9659-360f-4996-9c94-5f19c7e4a8a6" +GeoTables = "e502b557-6362-48c1-8219-d30d308dcdb0" +LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" +Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa" +StatsLearnModels = "c146b59d-1589-421c-8e09-a22e554fd05c" +Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" [compat] +ColumnSelectors = "0.1" +DensityRatioEstimation = "1.2" +GeoStatsBase = "0.42" +GeoStatsTransforms = "0.2" +GeoTables = "1.14" +LossFunctions = "0.11" +StatsLearnModels = "0.2" +Transducers = "0.4" julia = "1.9" diff --git a/README.md b/README.md index 80c6c47..49878bd 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# GeoStatsValidation +# GeoStatsValidation.jl [![Build Status](https://github.com/JuliaEarth/GeoStatsValidation.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/JuliaEarth/GeoStatsValidation.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/JuliaEarth/GeoStatsValidation.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/JuliaEarth/GeoStatsValidation.jl) diff --git a/src/GeoStatsValidation.jl b/src/GeoStatsValidation.jl index 0cf0787..9fc1b61 100644 --- a/src/GeoStatsValidation.jl +++ b/src/GeoStatsValidation.jl @@ -1,5 +1,21 @@ module GeoStatsValidation -# Write your package code here. +using Meshes +using GeoTables +using Transducers +using DensityRatioEstimation + +using StatsLearnModels: Learn +using GeoStatsTransforms: Interpolate, InterpolateNeighbors + +using ColumnSelectors: selector +using GeoStatsBase: WeightingMethod, DensityRatioWeighting, UniformWeighting +using GeoStatsBase: FoldingMethod, BallFolding, BlockFolding, OneFolding, UniformFolding +using GeoStatsBase: weight, folds, defaultloss, mean +using LossFunctions.Traits: SupervisedLoss + +include("cverror.jl") + +export cverror, LeaveOneOut, LeaveBallOut, KFoldValidation, BlockValidation, WeightedValidation, DensityRatioValidation end diff --git a/src/cverror.jl b/src/cverror.jl new file mode 100644 index 0000000..7bcaf40 --- /dev/null +++ b/src/cverror.jl @@ -0,0 +1,53 @@ +# ------------------------------------------------------------------ +# Licensed under the MIT License. See LICENSE in the project root. +# ------------------------------------------------------------------ + +""" + ErrorMethod + +A method for estimating cross-validatory error. +""" +abstract type ErrorMethod end + +struct LearnSetup{M} + model::M + input::Vector{Symbol} + output::Vector{Symbol} +end + +struct InterpSetup{I,M} + model::M +end + +""" + cverror(Lean, model, incols => outcols, geotable, method) + cverror(Interpolate, model, geotable, method) + cverror(InterpolateNeighbors, model, geotable, method) + +Estimate error of `model` in a given `geotable` with +error estimation `method`. +""" +function cverror end + +function cverror(::Type{Learn}, model, (incols, outcols)::Pair, geotable::AbstractGeoTable, method::ErrorMethod) + names = setdiff(propertynames(geotable), [:geometry]) + input = selector(incols)(names) + output = selector(outcols)(names) + cverror(LearnSetup(model, input, output), geotable, method) +end + +const Interp = Union{Interpolate,InterpolateNeighbors} + +cverror(::Type{I}, model::M, geotable::AbstractGeoTable, method::ErrorMethod) where {I<:Interp,M} = + cverror(InterpSetup{I,M}(model), geotable, method) + +# ---------------- +# IMPLEMENTATIONS +# ---------------- + +include("cverrors/loo.jl") +include("cverrors/lbo.jl") +include("cverrors/kfv.jl") +include("cverrors/bcv.jl") +include("cverrors/wcv.jl") +include("cverrors/drv.jl") diff --git a/src/cverrors/bcv.jl b/src/cverrors/bcv.jl new file mode 100644 index 0000000..73ffd99 --- /dev/null +++ b/src/cverrors/bcv.jl @@ -0,0 +1,39 @@ +# ------------------------------------------------------------------ +# Licensed under the MIT License. See LICENSE in the project root. +# ------------------------------------------------------------------ + +""" + BlockValidation(sides; loss=Dict()) + +Cross-validation with blocks of given `sides`. Optionally, +specify `loss` function from `LossFunctions.jl` for some +of the variables. If only one side is provided, then blocks +become cubes. + +## References + +* Roberts et al. 2017. [Cross-validation strategies for data with + temporal, spatial, hierarchical, or phylogenetic structure] + (https://onlinelibrary.wiley.com/doi/10.1111/ecog.02881) +* Pohjankukka et al. 2017. [Estimating the prediction performance + of spatial models via spatial k-fold cross-validation] + (https://www.tandfonline.com/doi/full/10.1080/13658816.2017.1346255) +""" +struct BlockValidation{S} <: ErrorMethod + sides::S + loss::Dict{Symbol,SupervisedLoss} +end + +BlockValidation(sides; loss=Dict()) = BlockValidation{typeof(sides)}(sides, loss) + +function cverror(setup, geotable, method::BlockValidation) + # uniform weights + weighting = UniformWeighting() + + # block folds + folding = BlockFolding(method.sides) + + wcv = WeightedValidation(weighting, folding, lambda=1, loss=method.loss) + + cverror(setup, geotable, wcv) +end diff --git a/src/cverrors/drv.jl b/src/cverrors/drv.jl new file mode 100644 index 0000000..5b16b2f --- /dev/null +++ b/src/cverrors/drv.jl @@ -0,0 +1,64 @@ +# ------------------------------------------------------------------ +# Licensed under the MIT License. See LICENSE in the project root. +# ------------------------------------------------------------------ + +""" + DensityRatioValidation(k; [parameters]) + +Density ratio validation where weights are first obtained with density +ratio estimation, and then used in `k`-fold weighted cross-validation. + +## Parameters + +* `shuffle` - Shuffle the data before folding (default to `true`) +* `estimator` - Density ratio estimator (default to `LSIF()`) +* `optlib` - Optimization library (default to `default_optlib(estimator)`) +* `lambda` - Power of density ratios (default to `1.0`) + +Please see [DensityRatioEstimation.jl] +(https://github.com/JuliaEarth/DensityRatioEstimation.jl) +for a list of supported estimators. + +## References + +* Hoffimann et al. 2020. [Geostatistical Learning: Challenges and Opportunities] + (https://arxiv.org/abs/2102.08791) +""" +struct DensityRatioValidation{T,E,O} <: ErrorMethod + k::Int + shuffle::Bool + lambda::T + dre::E + optlib::O + loss::Dict{Symbol,SupervisedLoss} +end + +function DensityRatioValidation( + k::Int; + shuffle=true, + lambda=1.0, + loss=Dict(), + estimator=LSIF(), + optlib=default_optlib(estimator) +) + @assert k > 0 "number of folds must be positive" + @assert 0 ≤ lambda ≤ 1 "lambda must lie in [0,1]" + T = typeof(lambda) + E = typeof(estimator) + O = typeof(optlib) + DensityRatioValidation{T,E,O}(k, shuffle, lambda, estimator, optlib, loss) +end + +function cverror(setup::LearnSetup, geotable, method::DensityRatioValidation) + vars = setup.input + + # density-ratio weights + weighting = DensityRatioWeighting(geotable, vars, estimator=method.dre, optlib=method.optlib) + + # random folds + folding = UniformFolding(method.k, method.shuffle) + + wcv = WeightedValidation(weighting, folding, lambda=method.lambda, loss=method.loss) + + cverror(setup, geotable, wcv) +end diff --git a/src/cverrors/kfv.jl b/src/cverrors/kfv.jl new file mode 100644 index 0000000..0f2e3d3 --- /dev/null +++ b/src/cverrors/kfv.jl @@ -0,0 +1,38 @@ +# ------------------------------------------------------------------ +# Licensed under the MIT License. See LICENSE in the project root. +# ------------------------------------------------------------------ + +""" + KFoldValidation(k; shuffle=true, loss=Dict()) + +`k`-fold cross-validation. Optionally, `shuffle` the +data, and specify `loss` function from `LossFunctions.jl` +for some of the variables. + +## References + +* Geisser, S. 1975. [The predictive sample reuse method with applications] + (https://www.jstor.org/stable/2285815) +* Burman, P. 1989. [A comparative study of ordinary cross-validation, v-fold + cross-validation and the repeated learning-testing methods] + (https://www.jstor.org/stable/2336116) +""" +struct KFoldValidation <: ErrorMethod + k::Int + shuffle::Bool + loss::Dict{Symbol,SupervisedLoss} +end + +KFoldValidation(k::Int; shuffle=true, loss=Dict()) = KFoldValidation(k, shuffle, loss) + +function cverror(setup, geotable, method::KFoldValidation) + # uniform weights + weighting = UniformWeighting() + + # random folds + folding = UniformFolding(method.k, method.shuffle) + + wcv = WeightedValidation(weighting, folding, lambda=1, loss=method.loss) + + cverror(setup, geotable, wcv) +end diff --git a/src/cverrors/lbo.jl b/src/cverrors/lbo.jl new file mode 100644 index 0000000..6f79cbf --- /dev/null +++ b/src/cverrors/lbo.jl @@ -0,0 +1,42 @@ +# ------------------------------------------------------------------ +# Licensed under the MIT License. See LICENSE in the project root. +# ------------------------------------------------------------------ + +""" + LeaveBallOut(ball; loss=Dict()) + +Leave-`ball`-out (a.k.a. spatial leave-one-out) validation. +Optionally, specify `loss` function from the +[LossFunctions.jl](https://github.com/JuliaML/LossFunctions.jl) +package for some of the variables. + + LeaveBallOut(radius; loss=Dict()) + +By default, use Euclidean ball of given `radius` in space. + +## References + +* Le Rest et al. 2014. [Spatial leave-one-out cross-validation + for variable selection in the presence of spatial autocorrelation] + (https://onlinelibrary.wiley.com/doi/full/10.1111/geb.12161) +""" +struct LeaveBallOut{B<:MetricBall} <: ErrorMethod + ball::B + loss::Dict{Symbol,SupervisedLoss} +end + +LeaveBallOut(ball; loss=Dict()) = LeaveBallOut{typeof(ball)}(ball, loss) + +LeaveBallOut(radius::Number; loss=Dict()) = LeaveBallOut(MetricBall(radius), loss=loss) + +function cverror(setup, geotable, method::LeaveBallOut) + # uniform weights + weighting = UniformWeighting() + + # ball folds + folding = BallFolding(method.ball) + + wcv = WeightedValidation(weighting, folding, lambda=1, loss=method.loss) + + cverror(setup, geotable, wcv) +end diff --git a/src/cverrors/loo.jl b/src/cverrors/loo.jl new file mode 100644 index 0000000..a33789b --- /dev/null +++ b/src/cverrors/loo.jl @@ -0,0 +1,32 @@ +# ------------------------------------------------------------------ +# Licensed under the MIT License. See LICENSE in the project root. +# ------------------------------------------------------------------ + +""" + LeaveOneOut(; loss=Dict()) + +Leave-one-out validation. Optionally, specify `loss` function +from `LossFunctions.jl` for some of the variables. + +## References + +* Stone. 1974. [Cross-Validatory Choice and Assessment of Statistical Predictions] + (https://rss.onlinelibrary.wiley.com/doi/abs/10.1111/j.2517-6161.1974.tb00994.x) +""" +struct LeaveOneOut <: ErrorMethod + loss::Dict{Symbol,SupervisedLoss} +end + +LeaveOneOut(; loss=Dict()) = LeaveOneOut(loss) + +function cverror(setup, geotable, method::LeaveOneOut) + # uniform weights + weighting = UniformWeighting() + + # point folds + folding = OneFolding() + + wcv = WeightedValidation(weighting, folding, lambda=1, loss=method.loss) + + cverror(setup, geotable, wcv) +end diff --git a/src/cverrors/wcv.jl b/src/cverrors/wcv.jl new file mode 100644 index 0000000..1cc7e32 --- /dev/null +++ b/src/cverrors/wcv.jl @@ -0,0 +1,97 @@ +# ------------------------------------------------------------------ +# Licensed under the MIT License. See LICENSE in the project root. +# ------------------------------------------------------------------ + +""" + WeightedValidation(weighting, folding; lambda=1.0, loss=Dict()) + +An error estimation method which samples are weighted with +`weighting` method and split into folds with `folding` method. +Weights are raised to `lambda` power in `[0,1]`. Optionally, +specify `loss` function from `LossFunctions.jl` for some of +the variables. + +## References + +* Sugiyama et al. 2006. [Importance-weighted cross-validation for + covariate shift](https://link.springer.com/chapter/10.1007/11861898_36) +* Sugiyama et al. 2007. [Covariate shift adaptation by importance weighted + cross validation](http://www.jmlr.org/papers/volume8/sugiyama07a/sugiyama07a.pdf) +""" +struct WeightedValidation{W<:WeightingMethod,F<:FoldingMethod,T<:Real} <: ErrorMethod + weighting::W + folding::F + lambda::T + loss::Dict{Symbol,SupervisedLoss} + + function WeightedValidation{W,F,T}(weighting, folding, lambda, loss) where {W,F,T} + @assert 0 ≤ lambda ≤ 1 "lambda must lie in [0,1]" + new(weighting, folding, lambda, loss) + end +end + +WeightedValidation(weighting::W, folding::F; lambda::T=one(T), loss=Dict()) where {W,F,T} = + WeightedValidation{W,F,T}(weighting, folding, lambda, loss) + +function cverror(setup, geotable, method::WeightedValidation) + # retrieve problem info + ovars = _outputvars(setup, geotable) + loss = method.loss + for var in ovars + if var ∉ keys(loss) + v = getproperty(geotable, var) + loss[var] = defaultloss(v[1]) + end + end + + # weight all samples + ws = weight(geotable, method.weighting) .^ method.lambda + + # folds for cross-validation + fs = folds(geotable, method.folding) + + # error for a fold + function ε(f) + # solve sub-problem + solution = _solution(setup, geotable, f) + + # holdout set + holdout = view(geotable, f[2]) + + # holdout weights + 𝓌 = view(ws, f[2]) + + # loss for each variable + losses = map(ovars) do var + ℒ = loss[var] + ŷ = getproperty(solution, var) + y = getproperty(holdout, var) + var => mean(ℒ, ŷ, y, 𝓌, normalize=false) + end + + Dict(losses) + end + + # compute error for each fold in parallel + εs = foldxt(vcat, Map(ε), fs) + + # combine error from different folds + Dict(var => mean(get.(εs, var, 0)) for var in ovars) +end + +# output variables of the problem +_outputvars(::InterpSetup, gtb) = setdiff(propertynames(gtb), [:geometry]) +_outputvars(s::LearnSetup, gtb) = s.output + +# solution for a given fold +function _solution(s::InterpSetup{I}, geotable, f) where {I} + sdat = view(geotable, f[1]) + sdom = view(domain(geotable), f[2]) + sdat |> I(sdom, s.model) +end + +function _solution(s::LearnSetup, geotable, f) + source = view(geotable, f[1]) + target = view(geotable, f[2]) + target |> Learn(source, s.model, s.input => s.output) +end diff --git a/test/Manifest.toml b/test/Manifest.toml deleted file mode 100644 index 4ad1915..0000000 --- a/test/Manifest.toml +++ /dev/null @@ -1,34 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.9.3" -manifest_format = "2.0" -project_hash = "71d91126b5a1fb1020e1098d9d492de2a4438fd2" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[deps.Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[deps.Random]] -deps = ["SHA", "Serialization"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" -version = "0.7.0" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[deps.Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/Project.toml b/test/Project.toml index 0c36332..5dd24af 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,2 +1,8 @@ [deps] +GeoStatsModels = "ad987403-13c5-47b5-afee-0a48f6ac4f12" +GeoStatsTransforms = "725d9659-360f-4996-9c94-5f19c7e4a8a6" +GeoTables = "e502b557-6362-48c1-8219-d30d308dcdb0" +Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StatsLearnModels = "c146b59d-1589-421c-8e09-a22e554fd05c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index 27682f4..5b6535b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,43 @@ using GeoStatsValidation +using StatsLearnModels +using GeoStatsTransforms +using GeoStatsModels +using GeoTables +using Meshes +using Random using Test @testset "GeoStatsValidation.jl" begin - # Write your tests here. + Random.seed!(123) + + @testset "Learning" begin + x = rand(1:2, 1000) + y = rand(1:2, 1000) + X = rand(2, 1000) + gtb = georef((; x, y), X) + model = DecisionTreeClassifier() + + # dummy classifier → 0.5 misclassification rate + for method in [LeaveOneOut(), LeaveBallOut(0.1), KFoldValidation(10), BlockValidation(0.1), DensityRatioValidation(10)] + e = cverror(Learn, model, :x => :y, gtb, method) + @test isapprox(e[:y], 0.5, atol=0.06) + end + end + + @testset "Interpolation" begin + gtb₁ = georef((z=rand(50, 50),)) + gtb₂ = georef((z=100rand(50, 50),)) + sgtb₁ = sample(gtb₁, UniformSampling(100, replace=false)) + sgtb₂ = sample(gtb₂, UniformSampling(100, replace=false)) + model = NN() + + # low variance + dummy (mean) estimator → low error + # high variance + dummy (mean) estimator → high error + for method in [LeaveOneOut(), LeaveBallOut(0.1), KFoldValidation(10), BlockValidation(0.1)] + e₁ = cverror(Interpolate, model, sgtb₁, method) + e₂ = cverror(Interpolate, model, sgtb₂, method) + @test e₁[:z] < 1 + @test e₂[:z] > 1 + end + end end