From e8b60610505e8f5882506c71957a6ab895ed1fb2 Mon Sep 17 00:00:00 2001 From: Elias Carvalho Date: Mon, 18 Dec 2023 16:28:07 -0300 Subject: [PATCH] Add 'defaultloss' helper function --- Project.toml | 2 ++ src/GeoStatsValidation.jl | 9 ++++++++- src/utils.jl | 7 +++++++ 3 files changed, 17 insertions(+), 1 deletion(-) create mode 100644 src/utils.jl diff --git a/Project.toml b/Project.toml index 673050a..45bb439 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.1.0" [deps] ColumnSelectors = "9cc86067-7e36-4c61-b350-1ac9833d277f" +DataScienceTraits = "6cb2f572-2d2b-4ba6-bdb3-e710fa044d6c" DensityRatioEstimation = "ab46fb84-d57c-11e9-2f65-6f72e4a7229f" GeoStatsBase = "323cb8eb-fbf6-51c0-afd0-f8fba70507b2" GeoStatsModels = "ad987403-13c5-47b5-afee-0a48f6ac4f12" @@ -17,6 +18,7 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" [compat] ColumnSelectors = "0.1" +DataScienceTraits = "0.2" DensityRatioEstimation = "1.2" GeoStatsBase = "0.42" GeoStatsModels = "0.2" diff --git a/src/GeoStatsValidation.jl b/src/GeoStatsValidation.jl index 43da7c2..2f99cfc 100644 --- a/src/GeoStatsValidation.jl +++ b/src/GeoStatsValidation.jl @@ -1,8 +1,13 @@ +# ------------------------------------------------------------------ +# Licensed under the MIT License. See LICENSE in the project root. +# ------------------------------------------------------------------ + module GeoStatsValidation using Meshes using GeoTables using Transducers +using DataScienceTraits using DensityRatioEstimation using GeoStatsModels: GeoStatsModel @@ -13,9 +18,11 @@ using GeoStatsTransforms: Interpolate, InterpolateNeighbors using ColumnSelectors: selector using GeoStatsBase: WeightingMethod, DensityRatioWeighting, UniformWeighting using GeoStatsBase: FoldingMethod, BallFolding, BlockFolding, OneFolding, UniformFolding -using GeoStatsBase: weight, folds, defaultloss, mean +using GeoStatsBase: weight, folds, mean +using DataScienceTraits: Continuous, Categorical using LossFunctions.Traits: SupervisedLoss +include("utils.jl") include("cverror.jl") export cverror, LeaveOneOut, LeaveBallOut, KFoldValidation, BlockValidation, WeightedValidation, DensityRatioValidation diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..90f200c --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,7 @@ +# ------------------------------------------------------------------ +# Licensed under the MIT License. See LICENSE in the project root. +# ------------------------------------------------------------------ + +defaultloss(val) = defaultloss(scitype(val)) +defaultloss(::Type{Continuous}) = L2DistLoss() +defaultloss(::Type{Categorical}) = MisclassLoss()