Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing the MLJ measures API #1

Open
ablaom opened this issue Jun 30, 2020 · 1 comment
Open

Implementing the MLJ measures API #1

ablaom opened this issue Jun 30, 2020 · 1 comment

Comments

@ablaom
Copy link
Collaborator

ablaom commented Jun 30, 2020

It would be nice if one could use the measures provided here in MLJ performance evaluation and elsewhere. This means implementing the API documented here, which does not appear to be the case.

Note that an MLJ measure does not have to return a numerical value. We regard, for example, confmat as a measure:

using MLJ
X, y = @load_crabs
model = @load DecisionTreeClassifier
y = coerce([0, 0, 1, 1, 1], OrderedFactor)
e = evaluate(model, X, y, measure=confmat, operation=predict_mode)
julia> e.per_fold[1][1]
              ┌───────────────────────────┐
              │       Ground Truth        │
┌─────────────┼─────────────┬─────────────┤
│  Predicted  │      B      │      O      │
├─────────────┼─────────────┼─────────────┤
│      B      │     310      │
├─────────────┼─────────────┼─────────────┤
│      O      │      30      │
└─────────────┴─────────────┴─────────────┘

So here's one tentative suggestion for implementing the MLJ API.

In MLJ one can already have a measure m with signature m(yhat, y, X) where X represents the full table of input features, which we can suppose is a Tables.jl table. In your case, you only care about one particular column of X - let's call it the group column - whose classes you want to filter on (eg, a column like ["male", "female", "male", "male", "binary"]). One could:

  1. Introduce a new parameter for each MLJFair metric, called group_name, or whatever, which specifies the name of the group column. So one would instantiate the measure like this: m = MLJFair.TruePositive(group_name=:gender).

  2. Overload calling of the metric appropriately, so that m(yhat, y, X) returns a dictionary of numerical values keyed on the "group" class, eg, Dict("male" => 2, "female" =>3, "binary" => 1). Or I suppose you could return a struct of some kind, but I think a dict would be the most user-friendly.

  3. To complete the API you may have to overload some measure traits, for example:

MLJBase.name(::Type{<:MLJFair.TruePositive}) = "TruePositive"
MLJBase.target_scitype( ... ) = OrderedFactor{2}
MLJBase.supports_weights(...) = false # for now
MLJBase.prediction_type(..) = :deterministic
MJLBase.orientation(::Type) = :other # other options are :score, :loss
MLJBase.reports_each_observation(::Type) = false
MLJBase.aggregation(::Type) = Sum()  
MLJBase.is_feature_dependent(::Type) = true              <---- Important

If you did this, then things like evaluate(model, X, y, measure=MLJFair.TruePositive(group=gender), resampling=CV()) would work.

How does this sound?

@ablaom
Copy link
Collaborator Author

ablaom commented Jun 30, 2020

The other idea, mentioned on the call (which does not assume you have existing structs MLJFair.TruePositive and so forth) would be a wrapper. So user does something like m = FairnessMetric(measure=TruePositive(), group_name=:gender) and then m(yhat, y, X) will return the dictionary of true positive counts, keyed on gender.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant