# "Scikit-learn's Metadata Routing API"

### 1. What is Metadata?

- definition of metadata:
    - metadata can be any data, that we want to apply on top of our tabular data, without it necessarily being part of it
    - alternative: metadata is any data, that some method or function in a data science "pipeline/process" handles besides from X and y; it's a param that influences this function's treatment of the data

- examples for metadata:
    - classical examples you might know from scikit-learn: sample_weight (and groups)
    - but other libraries offer support for other kinds of metadata
    - (graphically show some other examples of metadata: gender, race, sex, zipcode, .... and area/library it is used in)
    - self defined metadata to be used in custom metrics (and possibly custom estimators)

- use cases for metadata:
    - sample_weight and can be used re-balance data and groups prevent data leakage in a cross validation use case
    - fairness related use case
    - business logic

- definition of routing:
    - routing just means that we pass metadata around between functions that are involved in a data science pipeline/process to where it is used or consumed

- before metadata routing API:
    - we were limited to only use sample_weight in the some of the metrics that would define it
    - we could not consistently use it in a larger or more nested structure

- with metadata routing API:
    - we can pass sample_weight and groups through several levels of estimators and pipes in scikit-learn
    - we can combine objects from other libraries with scikit-learn estimators while still passing their metadata
    - we can define our own custom metrics using self defined metadata



### 2. Passing metadata without the routing API


In [1]:
import pandas as pd
data = pd.DataFrame({"sex":[1,0,1,1,0], "age":[17,32,82,27,54], "race":[1,0,0,1,0], "severity":[4,8,2,9,5], "medication":[1,1,1,0,0], "recovery_time":[10,22,90,32,5]})
data

Unnamed: 0,sex,age,race,severity,medication,recovery_time
0,1,17,1,4,1,10
1,0,32,0,8,1,22
2,1,82,0,2,1,90
3,1,27,1,9,0,32
4,0,54,0,5,0,5


In [31]:
import numpy as np

#X = data.iloc[:, :-1].to_numpy()
#y = data.iloc[:,-1].to_numpy()

# actual real data
rng = np.random.RandomState(42)

X = rng.rand(200, 5)
y = rng.randint(0, 2, size=X.shape[0])

groups = rng.randint(0, 10, size=X.shape[0])
sample_weight = rng.rand(X.shape[0])

- example: medical study on the effectiveness of a treatment
    - we would want to group the hospitals using the groups parameter scikit-learn offers
        - we assume that each hospital’s data may have systematic biases due to factors like medical devices, policies, socioeconomic status of the patients, ...
        - since a hospital is a collection of patients
    - same hospital should not be both in the train and in the validation set
    - groups is a metadata that is used in splitters exclusively, to make sure that if patterns exist within the data, we don’t leak those patterns between train and test set, because we want to train our models on the targets and not on other patterns within the data

    - without the metadata routing API:


In [29]:
from sklearn.model_selection import GroupKFold
from sklearn.linear_model import Ridge
from sklearn.model_selection import cross_validate
from sklearn.metrics import get_scorer

ridge = Ridge()

cv = GroupKFold(n_splits=2) # GroupKFold considers groups while splitting
scoring = get_scorer("neg_mean_squared_error")

cross_validate(ridge, X, y, groups=groups, cv=cv, scoring=scoring)

{'fit_time': array([0.00228572, 0.00149226]),
 'score_time': array([0.00090718, 0.00065589]),
 'test_score': array([-0.27150562, -0.26527336])}

- or, we could also want to pass sample_weight

- we would use sample_weight when we want to draw the attention of the machine learning algorithm to a specific group of samples, that our data under-represents in some way

- two ways of getting data
    - randomised trail under very constrained conditions
    - from real worlds observations, and bring the statistic on top of data we train on

- in our example talking about a medical study on the effectiveness of a treatment, sample_weight might
    - encode sex or race of a patient (to balance out data) 
    - or if we suspect a correlation between a feature an the fact if a patient got the new treatment we are interested in, e.g. a bias in which patient was chosen for the treatment, then sample_weight could be used to counter-balance that

- there are several methods to determine sample_weight (from calculating proportions or more enhanced statistics from the data, or using more general statistical principles or natural laws that we know will take effect)

-----------------------------------------
older ideas of what sample_weight is:
- when do we want to pass sample_weight?
    - when we are interested in minimizing the error of predictions for a certain sub-group of the samples more than the general error (by giving this sub-group a higher sample_weight than the rest of the data)
    - the loss for this particular sub-group then results often smaller compared to only train on the samples that we are interested in because we take the richness of all the data into account
-----------------------------------------

- here, we could pass sample_weight to fit() and use it in cross validation:

In [46]:
from sklearn.model_selection import GroupKFold
from sklearn.linear_model import Ridge
from sklearn.model_selection import cross_validate
from sklearn.metrics import get_scorer

ridge = Ridge().fit(X,y, sample_weight=sample_weight) # Ridge().fit() uses sample_weight

cv = GroupKFold(n_splits=2) # GroupKFold considers groups while splitting
scoring = get_scorer("neg_mean_squared_error")

cross_validate(ridge, X, y, groups=groups, cv=cv, scoring=scoring)

{'fit_time': array([0.00187016, 0.00127435]),
 'score_time': array([0.00099516, 0.00067568]),
 'test_score': array([-0.27150562, -0.26527336])}

- or we could pass it to fit() and to the scoring, but not to the cross validation:

In [39]:
from sklearn.linear_model import Ridge
from sklearn.metrics import get_scorer

ridge = Ridge().fit(X,y, sample_weight=sample_weight) # Ridge().fit() uses sample_weight

scoring = get_scorer("neg_mean_squared_error") # mean_squared_error can consume sample_weight

scoring(ridge, X, y, sample_weight=sample_weight)

np.float64(-0.2271912296930205)

- but we couldn't combine everything:

In [45]:
from sklearn.model_selection import GroupKFold
from sklearn.linear_model import Ridge
from sklearn.model_selection import cross_validate
from sklearn.metrics import get_scorer

ridge = Ridge().fit(X,y, sample_weight=sample_weight) # Ridge().fit() uses sample_weight

scoring = get_scorer("neg_mean_squared_error") # mean_squared_error can consume sample_weight

cv=GroupKFold(n_splits=2) # uses groups

cross_validate(ridge, X, y, groups=groups, cv=cv, scoring=scoring)

InvalidParameterError: The 'scoring' parameter of cross_validate must be a str among {'accuracy', 'neg_mean_squared_log_error', 'positive_likelihood_ratio', 'recall', 'precision_weighted', 'jaccard_micro', 'neg_mean_squared_error', 'jaccard_samples', 'jaccard', 'matthews_corrcoef', 'roc_auc', 'recall_weighted', 'neg_mean_absolute_percentage_error', 'd2_absolute_error_score', 'completeness_score', 'balanced_accuracy', 'max_error', 'neg_mean_gamma_deviance', 'neg_median_absolute_error', 'roc_auc_ovo', 'roc_auc_ovr_weighted', 'explained_variance', 'jaccard_weighted', 'fowlkes_mallows_score', 'neg_mean_poisson_deviance', 'normalized_mutual_info_score', 'neg_mean_absolute_error', 'average_precision', 'f1_weighted', 'neg_root_mean_squared_error', 'recall_samples', 'adjusted_rand_score', 'f1', 'top_k_accuracy', 'precision_macro', 'homogeneity_score', 'precision_micro', 'roc_auc_ovo_weighted', 'v_measure_score', 'rand_score', 'r2', 'mutual_info_score', 'neg_negative_likelihood_ratio', 'f1_macro', 'neg_root_mean_squared_log_error', 'precision_samples', 'recall_micro', 'f1_samples', 'roc_auc_ovr', 'precision', 'recall_macro', 'neg_log_loss', 'neg_brier_score', 'adjusted_mutual_info_score', 'f1_micro', 'jaccard_macro'}, a callable, an instance of 'list', an instance of 'tuple', an instance of 'dict' or None. Got np.float64(-0.2271912296930205) instead.

- and look what happens if we tried to pass metadata to a cross validation that uses grid search on something that uses metadata:

In [52]:
from sklearn.linear_model import Ridge
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import GroupKFold
from sklearn.metrics import get_scorer

ridge = Ridge().fit(X,y, sample_weight=sample_weight)

param_grid={"alpha": [0.1, 1, 10]}
cv=GroupKFold(n_splits=2)
scoring = get_scorer("neg_mean_squared_error")

search = GridSearchCV(ridge, param_grid=param_grid, cv=cv, scoring=scoring)

cross_validate(
    search,
    X,
    y,
    groups=groups,
    sample_weight=sample_weight,
    cv=cv,
    scoring=scoring,
)

TypeError: got an unexpected keyword argument 'sample_weight'

In [51]:
from sklearn.linear_model import Ridge
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import GroupKFold
from sklearn.metrics import get_scorer

ridge = Ridge()

param_grid={"alpha": [0.1, 1, 10]}
cv=GroupKFold(n_splits=2)
scoring = get_scorer("neg_mean_squared_error")

search = GridSearchCV(ridge, param_grid=param_grid, cv=cv, scoring=scoring)

search.fit(X, y, groups=groups)

search.best_estimator_

- we have just had a glimpse of what metadata can be used for

- and if we want to predict treatment efficiency for future patients in fairer, less biased models, even if past data was not gathered from a randomized trial and has some inherent dependencies, we want to use metadata everywhere / in the whole data science pipeline

- but we have seen the limitations on using metadata in nested structures

### 3. Using the metadata routing API

In [23]:
from sklearn.linear_model import Ridge
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import GroupKFold
from sklearn.metrics import get_scorer

import sklearn

with sklearn.config_context(enable_metadata_routing=True):
    ridge = Ridge().set_fit_request(sample_weight=True)
    
    param_grid={"alpha": [0.1, 1, 10]}
    cv=GroupKFold(n_splits=2)
    scoring = get_scorer("neg_mean_squared_error").set_score_request(sample_weight=True)

    search = GridSearchCV(ridge, param_grid=param_grid, cv=cv, scoring=scoring)

    cross_validate(
        search,
        X,
        y,
        params={"sample_weight": sample_weight, "groups": groups},
        cv=cv,
        scoring=scoring,
    )

- show that we have better results compared to without metadata routing? maybe to much for 7 minutes
    - score might be over-confident without metadata routing (because we leaked data without groups)
    - use new test data to show that

- explain:
    - scorer take sample weight (scoring= param is only present in estimators ending in CV); the scoring then passes the metadata into the metric used in cross validation for evaluating the success with the internal validation set
    - slitter splits CV and is mainly interested in groups

- summing up: with metadata routing API:
    - we can pass sample_weight and groups through several levels of estimators and pipes in scikit-learn
    - we can combine objects from other libraries with scikit-learn estimators while still passing their metadata
    - we can define our own custom metrics using self defined metadata 
    - and use it in a special setting with [TunedThresholdClassifier](https://scikit-learn.org/dev/auto_examples/model_selection/plot_cost_sensitive_learning.html#cost-sensitive-learning-when-gains-and-costs-are-not-constant) (as we will see in the next part)

### 4. Further information

- User Guide on Metadata Routing: </br> 
[https://scikit-learn.org/stable/metadata_routing.html#metadata-routing](https://scikit-learn.org/stable/metadata_routing.html#metadata-routing)

- Developer Guide on Metadata Routing: </br> 
[https://scikit-learn.org/stable/auto_examples/miscellaneous/plot_metadata_routing.html#metadata-routing](https://scikit-learn.org/stable/auto_examples/miscellaneous/plot_metadata_routing.html#metadata-routing)

- Adrin Jalali's talk on the internal logic of metadata routing at EuroPython Conference 2023: </br> 
[https://www.youtube.com/watch?v=1rf6HI-pYq8](https://www.youtube.com/watch?v=1rf6HI-pYq8)

- Blogpost by Florian Wilhelm on Inverse Probability of Treatment Weighting: </br> 
[https://florianwilhelm.info/2017/04/causal_inference_propensity_score](https://florianwilhelm.info/2017/04/causal_inference_propensity_score)

- link to Vincent's VW video