Skip to content

Commit

Permalink
Rename internal functions
Browse files Browse the repository at this point in the history
  • Loading branch information
eliascarv committed Dec 15, 2023
1 parent 3568482 commit ace25a1
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions src/cverrors/wcv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ WeightedValidation(weighting::W, folding::F; lambda::T=one(T), loss=Dict()) wher
WeightedValidation{W,F,T}(weighting, folding, lambda, loss)

function cverror(setup, geotable, method::WeightedValidation)
# retrieve problem info
ovars = _outputvars(setup, geotable)
ovars = _outputs(setup, geotable)
loss = method.loss
for var in ovars
if var keys(loss)
Expand All @@ -52,8 +51,8 @@ function cverror(setup, geotable, method::WeightedValidation)

# error for a fold
function ε(f)
# solve sub-problem
solution = _solution(setup, geotable, f)
# fold prediction
pred = _prediction(setup, geotable, f)

# holdout set
holdout = view(geotable, f[2])
Expand All @@ -64,7 +63,7 @@ function cverror(setup, geotable, method::WeightedValidation)
# loss for each variable
losses = map(ovars) do var
= loss[var]
= getproperty(solution, var)
= getproperty(pred, var)
y = getproperty(holdout, var)
var => mean(ℒ, ŷ, y, 𝓌, normalize=false)
end
Expand All @@ -79,18 +78,18 @@ function cverror(setup, geotable, method::WeightedValidation)
Dict(var => mean(get.(εs, var, 0)) for var in ovars)
end

# output variables of the problem
_outputvars(::InterpSetup, gtb) = setdiff(propertynames(gtb), [:geometry])
_outputvars(s::LearnSetup, gtb) = s.output
# output variables
_outputs(::InterpSetup, gtb) = setdiff(propertynames(gtb), [:geometry])
_outputs(s::LearnSetup, gtb) = s.output

# solution for a given fold
function _solution(s::InterpSetup{I}, geotable, f) where {I}
# prediction for a given fold
function _prediction(s::InterpSetup{I}, geotable, f) where {I}
sdat = view(geotable, f[1])
sdom = view(domain(geotable), f[2])
sdat |> I(sdom, s.model)
end

function _solution(s::LearnSetup, geotable, f)
function _prediction(s::LearnSetup, geotable, f)
source = view(geotable, f[1])
target = view(geotable, f[2])
target |> Learn(source, s.model, s.input => s.output)
Expand Down

0 comments on commit ace25a1

Please sign in to comment.