Skip to content

Commit

Permalink
uff
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Nov 29, 2022
1 parent 65f1f6b commit 7dd1b4b
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 48 deletions.
2 changes: 1 addition & 1 deletion docs/src/intro.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ Predictions can then be computed using the generic `predict` method. The code be

```{julia}
#| output: true
n = 10
n = 5
Xtest = selectrows(X, first(test,n))
ytest = y[first(test,n)]
predict(mach, Xtest)
Expand Down
74 changes: 74 additions & 0 deletions src/ConformalModels/inductive_bayes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Simple
"The `SimpleInductiveBayes` is the simplest approach to Inductive Conformalized Bayes."
mutable struct SimpleInductiveBayes{Model <: Supervised} <: ConformalModel
model::Model
coverage::AbstractFloat
scores::Union{Nothing,AbstractArray}
heuristic::Function
train_ratio::AbstractFloat
end

function SimpleInductiveBayes(model::Supervised; coverage::AbstractFloat=0.95, heuristic::Function=f(y, ŷ)=-ŷ, train_ratio::AbstractFloat=0.5)
return SimpleInductiveBayes(model, coverage, nothing, heuristic, train_ratio)
end

@doc raw"""
MMI.fit(conf_model::SimpleInductiveBayes, verbosity, X, y)
For the [`SimpleInductiveBayes`](@ref) nonconformity scores are computed as follows:
``
S_i^{\text{CAL}} = s(X_i, Y_i) = h(\hat\mu(X_i), Y_i), \ i \in \mathcal{D}_{\text{calibration}}
``
A typical choice for the heuristic function is ``h(\hat\mu(X_i), Y_i)=1-\hat\mu(X_i)_{Y_i}`` where ``\hat\mu(X_i)_{Y_i}`` denotes the softmax output of the true class and ``\hat\mu`` denotes the model fitted on training data ``\mathcal{D}_{\text{train}}``. The simple approach only takes the softmax probability of the true label into account.
"""
function MMI.fit(conf_model::SimpleInductiveBayes, verbosity, X, y)

# Data Splitting:
train, calibration = partition(eachindex(y), conf_model.train_ratio)
Xtrain = selectrows(X, train)
ytrain = y[train]
Xtrain, ytrain = MMI.reformat(conf_model.model, Xtrain, ytrain)
Xcal = selectrows(X, calibration)
ycal = y[calibration]
Xcal, ycal = MMI.reformat(conf_model.model, Xcal, ycal)

# Training:
fitresult, cache, report = MMI.fit(conf_model.model, verbosity, Xtrain, ytrain)

# Nonconformity Scores:
= pdf.(MMI.predict(conf_model.model, fitresult, Xcal), ycal) # predict returns a vector of distributions
conf_model.scores = @.(conf_model.heuristic(ycal, ŷ))

return (fitresult, cache, report)
end

@doc raw"""
MMI.predict(conf_model::SimpleInductiveBayes, fitresult, Xnew)
For the [`SimpleInductiveBayes`](@ref) prediction sets are computed as follows,
``
\hat{C}_{n,\alpha}(X_{n+1}) = \left\{y: s(X_{n+1},y) \le \hat{q}_{n, \alpha}^{+} \{S_i^{\text{CAL}}\} \right\}, \ i \in \mathcal{D}_{\text{calibration}}
``
where ``\mathcal{D}_{\text{calibration}}`` denotes the designated calibration data.
"""
function MMI.predict(conf_model::SimpleInductiveBayes, fitresult, Xnew)
= MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, Xnew)...)
v = conf_model.scores
= Statistics.quantile(v, conf_model.coverage)
= map(p̂) do pp
L =.decoder.classes
probas = pdf.(pp, L)
is_in_set = 1.0 .- probas .<=
if !all(is_in_set .== false)
pp = UnivariateFinite(L[is_in_set], probas[is_in_set])
else
pp = missing
end
return pp
end
return
end
47 changes: 0 additions & 47 deletions src/ConformalModels/score_functions.jl

This file was deleted.

0 comments on commit 7dd1b4b

Please sign in to comment.