Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
eliascarv committed Dec 15, 2023
1 parent aadd22c commit 53bdb08
Show file tree
Hide file tree
Showing 16 changed files with 487 additions and 39 deletions.
10 changes: 10 additions & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
indent = 2
margin = 120
always_for_in = true
always_use_return = false
whitespace_typedefs = false
whitespace_in_kwargs = false
whitespace_ops_in_indices = true
remove_extra_newlines = true
trailing_comma = false
normalize_line_endings = "unix"
28 changes: 28 additions & 0 deletions .github/workflows/FormatPR.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: FormatPR
on:
push:
branches:
- master
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install JuliaFormatter and format
run: |
julia -e 'import Pkg; Pkg.add("JuliaFormatter")'
julia -e 'using JuliaFormatter; format(".")'
- name: Create Pull Request
id: cpr
uses: peter-evans/create-pull-request@v5
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: ":robot: Format .jl files"
title: '[AUTO] JuliaFormatter.jl run'
branch: auto-juliaformatter-pr
delete-branch: true
labels: formatting, automated pr, no changelog
- name: Check outputs
run: |
echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}"
echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}"
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
*.jl.*.cov
*.jl.cov
*.jl.mem
/Manifest.toml
Manifest.toml
.vscode
21 changes: 20 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,26 @@
name = "GeoStatsValidation"
uuid = "36f43c0d-3673-45fc-9557-6860e708e7aa"
authors = ["Elias Carvalho <eliascarvdev@gmail.com> and contributors"]
version = "1.0.0-DEV"
version = "0.1.0"

[deps]
ColumnSelectors = "9cc86067-7e36-4c61-b350-1ac9833d277f"
DensityRatioEstimation = "ab46fb84-d57c-11e9-2f65-6f72e4a7229f"
GeoStatsBase = "323cb8eb-fbf6-51c0-afd0-f8fba70507b2"
GeoStatsTransforms = "725d9659-360f-4996-9c94-5f19c7e4a8a6"
GeoTables = "e502b557-6362-48c1-8219-d30d308dcdb0"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa"
StatsLearnModels = "c146b59d-1589-421c-8e09-a22e554fd05c"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"

[compat]
ColumnSelectors = "0.1"
DensityRatioEstimation = "1.2"
GeoStatsBase = "0.42"
GeoStatsTransforms = "0.2"
GeoTables = "1.14"
LossFunctions = "0.11"
StatsLearnModels = "0.2"
Transducers = "0.4"
julia = "1.9"
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# GeoStatsValidation
# GeoStatsValidation.jl

[![Build Status](https://github.com/JuliaEarth/GeoStatsValidation.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/JuliaEarth/GeoStatsValidation.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/JuliaEarth/GeoStatsValidation.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/JuliaEarth/GeoStatsValidation.jl)
18 changes: 17 additions & 1 deletion src/GeoStatsValidation.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
module GeoStatsValidation

# Write your package code here.
using Meshes
using GeoTables
using Transducers
using DensityRatioEstimation

using StatsLearnModels: Learn
using GeoStatsTransforms: Interpolate, InterpolateNeighbors

using ColumnSelectors: selector
using GeoStatsBase: WeightingMethod, DensityRatioWeighting, UniformWeighting
using GeoStatsBase: FoldingMethod, BallFolding, BlockFolding, OneFolding, UniformFolding
using GeoStatsBase: weight, folds, defaultloss, mean
using LossFunctions.Traits: SupervisedLoss

include("cverror.jl")

export cverror, LeaveOneOut, LeaveBallOut, KFoldValidation, BlockValidation, WeightedValidation, DensityRatioValidation

end
53 changes: 53 additions & 0 deletions src/cverror.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

"""
ErrorMethod
A method for estimating cross-validatory error.
"""
abstract type ErrorMethod end

struct LearnSetup{M}
model::M
input::Vector{Symbol}
output::Vector{Symbol}
end

struct InterpSetup{I,M}
model::M
end

"""
cverror(Lean, model, incols => outcols, geotable, method)
cverror(Interpolate, model, geotable, method)
cverror(InterpolateNeighbors, model, geotable, method)
Estimate error of `model` in a given `geotable` with
error estimation `method`.
"""
function cverror end

function cverror(::Type{Learn}, model, (incols, outcols)::Pair, geotable::AbstractGeoTable, method::ErrorMethod)
names = setdiff(propertynames(geotable), [:geometry])
input = selector(incols)(names)
output = selector(outcols)(names)
cverror(LearnSetup(model, input, output), geotable, method)
end

const Interp = Union{Interpolate,InterpolateNeighbors}

cverror(::Type{I}, model::M, geotable::AbstractGeoTable, method::ErrorMethod) where {I<:Interp,M} =
cverror(InterpSetup{I,M}(model), geotable, method)

# ----------------
# IMPLEMENTATIONS
# ----------------

include("cverrors/loo.jl")
include("cverrors/lbo.jl")
include("cverrors/kfv.jl")
include("cverrors/bcv.jl")
include("cverrors/wcv.jl")
include("cverrors/drv.jl")
39 changes: 39 additions & 0 deletions src/cverrors/bcv.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

"""
BlockValidation(sides; loss=Dict())
Cross-validation with blocks of given `sides`. Optionally,
specify `loss` function from `LossFunctions.jl` for some
of the variables. If only one side is provided, then blocks
become cubes.
## References
* Roberts et al. 2017. [Cross-validation strategies for data with
temporal, spatial, hierarchical, or phylogenetic structure]
(https://onlinelibrary.wiley.com/doi/10.1111/ecog.02881)
* Pohjankukka et al. 2017. [Estimating the prediction performance
of spatial models via spatial k-fold cross-validation]
(https://www.tandfonline.com/doi/full/10.1080/13658816.2017.1346255)
"""
struct BlockValidation{S} <: ErrorMethod
sides::S
loss::Dict{Symbol,SupervisedLoss}
end

BlockValidation(sides; loss=Dict()) = BlockValidation{typeof(sides)}(sides, loss)

function cverror(setup, geotable, method::BlockValidation)
# uniform weights
weighting = UniformWeighting()

# block folds
folding = BlockFolding(method.sides)

wcv = WeightedValidation(weighting, folding, lambda=1, loss=method.loss)

cverror(setup, geotable, wcv)
end
64 changes: 64 additions & 0 deletions src/cverrors/drv.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

"""
DensityRatioValidation(k; [parameters])
Density ratio validation where weights are first obtained with density
ratio estimation, and then used in `k`-fold weighted cross-validation.
## Parameters
* `shuffle` - Shuffle the data before folding (default to `true`)
* `estimator` - Density ratio estimator (default to `LSIF()`)
* `optlib` - Optimization library (default to `default_optlib(estimator)`)
* `lambda` - Power of density ratios (default to `1.0`)
Please see [DensityRatioEstimation.jl]
(https://github.com/JuliaEarth/DensityRatioEstimation.jl)
for a list of supported estimators.
## References
* Hoffimann et al. 2020. [Geostatistical Learning: Challenges and Opportunities]
(https://arxiv.org/abs/2102.08791)
"""
struct DensityRatioValidation{T,E,O} <: ErrorMethod
k::Int
shuffle::Bool
lambda::T
dre::E
optlib::O
loss::Dict{Symbol,SupervisedLoss}
end

function DensityRatioValidation(
k::Int;
shuffle=true,
lambda=1.0,
loss=Dict(),
estimator=LSIF(),
optlib=default_optlib(estimator)
)
@assert k > 0 "number of folds must be positive"
@assert 0 lambda 1 "lambda must lie in [0,1]"
T = typeof(lambda)
E = typeof(estimator)
O = typeof(optlib)
DensityRatioValidation{T,E,O}(k, shuffle, lambda, estimator, optlib, loss)
end

function cverror(setup::LearnSetup, geotable, method::DensityRatioValidation)
vars = setup.input

# density-ratio weights
weighting = DensityRatioWeighting(geotable, vars, estimator=method.dre, optlib=method.optlib)

# random folds
folding = UniformFolding(method.k, method.shuffle)

wcv = WeightedValidation(weighting, folding, lambda=method.lambda, loss=method.loss)

cverror(setup, geotable, wcv)
end
38 changes: 38 additions & 0 deletions src/cverrors/kfv.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

"""
KFoldValidation(k; shuffle=true, loss=Dict())
`k`-fold cross-validation. Optionally, `shuffle` the
data, and specify `loss` function from `LossFunctions.jl`
for some of the variables.
## References
* Geisser, S. 1975. [The predictive sample reuse method with applications]
(https://www.jstor.org/stable/2285815)
* Burman, P. 1989. [A comparative study of ordinary cross-validation, v-fold
cross-validation and the repeated learning-testing methods]
(https://www.jstor.org/stable/2336116)
"""
struct KFoldValidation <: ErrorMethod
k::Int
shuffle::Bool
loss::Dict{Symbol,SupervisedLoss}
end

KFoldValidation(k::Int; shuffle=true, loss=Dict()) = KFoldValidation(k, shuffle, loss)

function cverror(setup, geotable, method::KFoldValidation)
# uniform weights
weighting = UniformWeighting()

# random folds
folding = UniformFolding(method.k, method.shuffle)

wcv = WeightedValidation(weighting, folding, lambda=1, loss=method.loss)

cverror(setup, geotable, wcv)
end
42 changes: 42 additions & 0 deletions src/cverrors/lbo.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

"""
LeaveBallOut(ball; loss=Dict())
Leave-`ball`-out (a.k.a. spatial leave-one-out) validation.
Optionally, specify `loss` function from the
[LossFunctions.jl](https://github.com/JuliaML/LossFunctions.jl)
package for some of the variables.
LeaveBallOut(radius; loss=Dict())
By default, use Euclidean ball of given `radius` in space.
## References
* Le Rest et al. 2014. [Spatial leave-one-out cross-validation
for variable selection in the presence of spatial autocorrelation]
(https://onlinelibrary.wiley.com/doi/full/10.1111/geb.12161)
"""
struct LeaveBallOut{B<:MetricBall} <: ErrorMethod
ball::B
loss::Dict{Symbol,SupervisedLoss}
end

LeaveBallOut(ball; loss=Dict()) = LeaveBallOut{typeof(ball)}(ball, loss)

LeaveBallOut(radius::Number; loss=Dict()) = LeaveBallOut(MetricBall(radius), loss=loss)

function cverror(setup, geotable, method::LeaveBallOut)
# uniform weights
weighting = UniformWeighting()

# ball folds
folding = BallFolding(method.ball)

wcv = WeightedValidation(weighting, folding, lambda=1, loss=method.loss)

cverror(setup, geotable, wcv)
end
32 changes: 32 additions & 0 deletions src/cverrors/loo.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

"""
LeaveOneOut(; loss=Dict())
Leave-one-out validation. Optionally, specify `loss` function
from `LossFunctions.jl` for some of the variables.
## References
* Stone. 1974. [Cross-Validatory Choice and Assessment of Statistical Predictions]
(https://rss.onlinelibrary.wiley.com/doi/abs/10.1111/j.2517-6161.1974.tb00994.x)
"""
struct LeaveOneOut <: ErrorMethod
loss::Dict{Symbol,SupervisedLoss}
end

LeaveOneOut(; loss=Dict()) = LeaveOneOut(loss)

function cverror(setup, geotable, method::LeaveOneOut)
# uniform weights
weighting = UniformWeighting()

# point folds
folding = OneFolding()

wcv = WeightedValidation(weighting, folding, lambda=1, loss=method.loss)

cverror(setup, geotable, wcv)
end
Loading

0 comments on commit 53bdb08

Please sign in to comment.