-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
487 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
indent = 2 | ||
margin = 120 | ||
always_for_in = true | ||
always_use_return = false | ||
whitespace_typedefs = false | ||
whitespace_in_kwargs = false | ||
whitespace_ops_in_indices = true | ||
remove_extra_newlines = true | ||
trailing_comma = false | ||
normalize_line_endings = "unix" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
name: FormatPR | ||
on: | ||
push: | ||
branches: | ||
- master | ||
jobs: | ||
build: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- name: Install JuliaFormatter and format | ||
run: | | ||
julia -e 'import Pkg; Pkg.add("JuliaFormatter")' | ||
julia -e 'using JuliaFormatter; format(".")' | ||
- name: Create Pull Request | ||
id: cpr | ||
uses: peter-evans/create-pull-request@v5 | ||
with: | ||
token: ${{ secrets.GITHUB_TOKEN }} | ||
commit-message: ":robot: Format .jl files" | ||
title: '[AUTO] JuliaFormatter.jl run' | ||
branch: auto-juliaformatter-pr | ||
delete-branch: true | ||
labels: formatting, automated pr, no changelog | ||
- name: Check outputs | ||
run: | | ||
echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}" | ||
echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
*.jl.*.cov | ||
*.jl.cov | ||
*.jl.mem | ||
/Manifest.toml | ||
Manifest.toml | ||
.vscode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,26 @@ | ||
name = "GeoStatsValidation" | ||
uuid = "36f43c0d-3673-45fc-9557-6860e708e7aa" | ||
authors = ["Elias Carvalho <eliascarvdev@gmail.com> and contributors"] | ||
version = "1.0.0-DEV" | ||
version = "0.1.0" | ||
|
||
[deps] | ||
ColumnSelectors = "9cc86067-7e36-4c61-b350-1ac9833d277f" | ||
DensityRatioEstimation = "ab46fb84-d57c-11e9-2f65-6f72e4a7229f" | ||
GeoStatsBase = "323cb8eb-fbf6-51c0-afd0-f8fba70507b2" | ||
GeoStatsTransforms = "725d9659-360f-4996-9c94-5f19c7e4a8a6" | ||
GeoTables = "e502b557-6362-48c1-8219-d30d308dcdb0" | ||
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" | ||
Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa" | ||
StatsLearnModels = "c146b59d-1589-421c-8e09-a22e554fd05c" | ||
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" | ||
|
||
[compat] | ||
ColumnSelectors = "0.1" | ||
DensityRatioEstimation = "1.2" | ||
GeoStatsBase = "0.42" | ||
GeoStatsTransforms = "0.2" | ||
GeoTables = "1.14" | ||
LossFunctions = "0.11" | ||
StatsLearnModels = "0.2" | ||
Transducers = "0.4" | ||
julia = "1.9" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
# GeoStatsValidation | ||
# GeoStatsValidation.jl | ||
|
||
[![Build Status](https://github.com/JuliaEarth/GeoStatsValidation.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/JuliaEarth/GeoStatsValidation.jl/actions/workflows/CI.yml?query=branch%3Amain) | ||
[![Coverage](https://codecov.io/gh/JuliaEarth/GeoStatsValidation.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/JuliaEarth/GeoStatsValidation.jl) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,21 @@ | ||
module GeoStatsValidation | ||
|
||
# Write your package code here. | ||
using Meshes | ||
using GeoTables | ||
using Transducers | ||
using DensityRatioEstimation | ||
|
||
using StatsLearnModels: Learn | ||
using GeoStatsTransforms: Interpolate, InterpolateNeighbors | ||
|
||
using ColumnSelectors: selector | ||
using GeoStatsBase: WeightingMethod, DensityRatioWeighting, UniformWeighting | ||
using GeoStatsBase: FoldingMethod, BallFolding, BlockFolding, OneFolding, UniformFolding | ||
using GeoStatsBase: weight, folds, defaultloss, mean | ||
using LossFunctions.Traits: SupervisedLoss | ||
|
||
include("cverror.jl") | ||
|
||
export cverror, LeaveOneOut, LeaveBallOut, KFoldValidation, BlockValidation, WeightedValidation, DensityRatioValidation | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# ------------------------------------------------------------------ | ||
# Licensed under the MIT License. See LICENSE in the project root. | ||
# ------------------------------------------------------------------ | ||
|
||
""" | ||
ErrorMethod | ||
A method for estimating cross-validatory error. | ||
""" | ||
abstract type ErrorMethod end | ||
|
||
struct LearnSetup{M} | ||
model::M | ||
input::Vector{Symbol} | ||
output::Vector{Symbol} | ||
end | ||
|
||
struct InterpSetup{I,M} | ||
model::M | ||
end | ||
|
||
""" | ||
cverror(Lean, model, incols => outcols, geotable, method) | ||
cverror(Interpolate, model, geotable, method) | ||
cverror(InterpolateNeighbors, model, geotable, method) | ||
Estimate error of `model` in a given `geotable` with | ||
error estimation `method`. | ||
""" | ||
function cverror end | ||
|
||
function cverror(::Type{Learn}, model, (incols, outcols)::Pair, geotable::AbstractGeoTable, method::ErrorMethod) | ||
names = setdiff(propertynames(geotable), [:geometry]) | ||
input = selector(incols)(names) | ||
output = selector(outcols)(names) | ||
cverror(LearnSetup(model, input, output), geotable, method) | ||
end | ||
|
||
const Interp = Union{Interpolate,InterpolateNeighbors} | ||
|
||
cverror(::Type{I}, model::M, geotable::AbstractGeoTable, method::ErrorMethod) where {I<:Interp,M} = | ||
cverror(InterpSetup{I,M}(model), geotable, method) | ||
|
||
# ---------------- | ||
# IMPLEMENTATIONS | ||
# ---------------- | ||
|
||
include("cverrors/loo.jl") | ||
include("cverrors/lbo.jl") | ||
include("cverrors/kfv.jl") | ||
include("cverrors/bcv.jl") | ||
include("cverrors/wcv.jl") | ||
include("cverrors/drv.jl") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# ------------------------------------------------------------------ | ||
# Licensed under the MIT License. See LICENSE in the project root. | ||
# ------------------------------------------------------------------ | ||
|
||
""" | ||
BlockValidation(sides; loss=Dict()) | ||
Cross-validation with blocks of given `sides`. Optionally, | ||
specify `loss` function from `LossFunctions.jl` for some | ||
of the variables. If only one side is provided, then blocks | ||
become cubes. | ||
## References | ||
* Roberts et al. 2017. [Cross-validation strategies for data with | ||
temporal, spatial, hierarchical, or phylogenetic structure] | ||
(https://onlinelibrary.wiley.com/doi/10.1111/ecog.02881) | ||
* Pohjankukka et al. 2017. [Estimating the prediction performance | ||
of spatial models via spatial k-fold cross-validation] | ||
(https://www.tandfonline.com/doi/full/10.1080/13658816.2017.1346255) | ||
""" | ||
struct BlockValidation{S} <: ErrorMethod | ||
sides::S | ||
loss::Dict{Symbol,SupervisedLoss} | ||
end | ||
|
||
BlockValidation(sides; loss=Dict()) = BlockValidation{typeof(sides)}(sides, loss) | ||
|
||
function cverror(setup, geotable, method::BlockValidation) | ||
# uniform weights | ||
weighting = UniformWeighting() | ||
|
||
# block folds | ||
folding = BlockFolding(method.sides) | ||
|
||
wcv = WeightedValidation(weighting, folding, lambda=1, loss=method.loss) | ||
|
||
cverror(setup, geotable, wcv) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# ------------------------------------------------------------------ | ||
# Licensed under the MIT License. See LICENSE in the project root. | ||
# ------------------------------------------------------------------ | ||
|
||
""" | ||
DensityRatioValidation(k; [parameters]) | ||
Density ratio validation where weights are first obtained with density | ||
ratio estimation, and then used in `k`-fold weighted cross-validation. | ||
## Parameters | ||
* `shuffle` - Shuffle the data before folding (default to `true`) | ||
* `estimator` - Density ratio estimator (default to `LSIF()`) | ||
* `optlib` - Optimization library (default to `default_optlib(estimator)`) | ||
* `lambda` - Power of density ratios (default to `1.0`) | ||
Please see [DensityRatioEstimation.jl] | ||
(https://github.com/JuliaEarth/DensityRatioEstimation.jl) | ||
for a list of supported estimators. | ||
## References | ||
* Hoffimann et al. 2020. [Geostatistical Learning: Challenges and Opportunities] | ||
(https://arxiv.org/abs/2102.08791) | ||
""" | ||
struct DensityRatioValidation{T,E,O} <: ErrorMethod | ||
k::Int | ||
shuffle::Bool | ||
lambda::T | ||
dre::E | ||
optlib::O | ||
loss::Dict{Symbol,SupervisedLoss} | ||
end | ||
|
||
function DensityRatioValidation( | ||
k::Int; | ||
shuffle=true, | ||
lambda=1.0, | ||
loss=Dict(), | ||
estimator=LSIF(), | ||
optlib=default_optlib(estimator) | ||
) | ||
@assert k > 0 "number of folds must be positive" | ||
@assert 0 ≤ lambda ≤ 1 "lambda must lie in [0,1]" | ||
T = typeof(lambda) | ||
E = typeof(estimator) | ||
O = typeof(optlib) | ||
DensityRatioValidation{T,E,O}(k, shuffle, lambda, estimator, optlib, loss) | ||
end | ||
|
||
function cverror(setup::LearnSetup, geotable, method::DensityRatioValidation) | ||
vars = setup.input | ||
|
||
# density-ratio weights | ||
weighting = DensityRatioWeighting(geotable, vars, estimator=method.dre, optlib=method.optlib) | ||
|
||
# random folds | ||
folding = UniformFolding(method.k, method.shuffle) | ||
|
||
wcv = WeightedValidation(weighting, folding, lambda=method.lambda, loss=method.loss) | ||
|
||
cverror(setup, geotable, wcv) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# ------------------------------------------------------------------ | ||
# Licensed under the MIT License. See LICENSE in the project root. | ||
# ------------------------------------------------------------------ | ||
|
||
""" | ||
KFoldValidation(k; shuffle=true, loss=Dict()) | ||
`k`-fold cross-validation. Optionally, `shuffle` the | ||
data, and specify `loss` function from `LossFunctions.jl` | ||
for some of the variables. | ||
## References | ||
* Geisser, S. 1975. [The predictive sample reuse method with applications] | ||
(https://www.jstor.org/stable/2285815) | ||
* Burman, P. 1989. [A comparative study of ordinary cross-validation, v-fold | ||
cross-validation and the repeated learning-testing methods] | ||
(https://www.jstor.org/stable/2336116) | ||
""" | ||
struct KFoldValidation <: ErrorMethod | ||
k::Int | ||
shuffle::Bool | ||
loss::Dict{Symbol,SupervisedLoss} | ||
end | ||
|
||
KFoldValidation(k::Int; shuffle=true, loss=Dict()) = KFoldValidation(k, shuffle, loss) | ||
|
||
function cverror(setup, geotable, method::KFoldValidation) | ||
# uniform weights | ||
weighting = UniformWeighting() | ||
|
||
# random folds | ||
folding = UniformFolding(method.k, method.shuffle) | ||
|
||
wcv = WeightedValidation(weighting, folding, lambda=1, loss=method.loss) | ||
|
||
cverror(setup, geotable, wcv) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# ------------------------------------------------------------------ | ||
# Licensed under the MIT License. See LICENSE in the project root. | ||
# ------------------------------------------------------------------ | ||
|
||
""" | ||
LeaveBallOut(ball; loss=Dict()) | ||
Leave-`ball`-out (a.k.a. spatial leave-one-out) validation. | ||
Optionally, specify `loss` function from the | ||
[LossFunctions.jl](https://github.com/JuliaML/LossFunctions.jl) | ||
package for some of the variables. | ||
LeaveBallOut(radius; loss=Dict()) | ||
By default, use Euclidean ball of given `radius` in space. | ||
## References | ||
* Le Rest et al. 2014. [Spatial leave-one-out cross-validation | ||
for variable selection in the presence of spatial autocorrelation] | ||
(https://onlinelibrary.wiley.com/doi/full/10.1111/geb.12161) | ||
""" | ||
struct LeaveBallOut{B<:MetricBall} <: ErrorMethod | ||
ball::B | ||
loss::Dict{Symbol,SupervisedLoss} | ||
end | ||
|
||
LeaveBallOut(ball; loss=Dict()) = LeaveBallOut{typeof(ball)}(ball, loss) | ||
|
||
LeaveBallOut(radius::Number; loss=Dict()) = LeaveBallOut(MetricBall(radius), loss=loss) | ||
|
||
function cverror(setup, geotable, method::LeaveBallOut) | ||
# uniform weights | ||
weighting = UniformWeighting() | ||
|
||
# ball folds | ||
folding = BallFolding(method.ball) | ||
|
||
wcv = WeightedValidation(weighting, folding, lambda=1, loss=method.loss) | ||
|
||
cverror(setup, geotable, wcv) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# ------------------------------------------------------------------ | ||
# Licensed under the MIT License. See LICENSE in the project root. | ||
# ------------------------------------------------------------------ | ||
|
||
""" | ||
LeaveOneOut(; loss=Dict()) | ||
Leave-one-out validation. Optionally, specify `loss` function | ||
from `LossFunctions.jl` for some of the variables. | ||
## References | ||
* Stone. 1974. [Cross-Validatory Choice and Assessment of Statistical Predictions] | ||
(https://rss.onlinelibrary.wiley.com/doi/abs/10.1111/j.2517-6161.1974.tb00994.x) | ||
""" | ||
struct LeaveOneOut <: ErrorMethod | ||
loss::Dict{Symbol,SupervisedLoss} | ||
end | ||
|
||
LeaveOneOut(; loss=Dict()) = LeaveOneOut(loss) | ||
|
||
function cverror(setup, geotable, method::LeaveOneOut) | ||
# uniform weights | ||
weighting = UniformWeighting() | ||
|
||
# point folds | ||
folding = OneFolding() | ||
|
||
wcv = WeightedValidation(weighting, folding, lambda=1, loss=method.loss) | ||
|
||
cverror(setup, geotable, wcv) | ||
end |
Oops, something went wrong.