# "Scikit-learn's Metadata Routing API"

### Abstract
This talk will introduce scikit-learn users to the new API for <b>metadata routing</b>, a feature introduced in the recent releases and almost fully available since version 1.5 (released in May 2024).

We will explore what metadata is, how it can be used in machine learning pipelines, and how the new API simplifies routing metadata throughout a workflow. Routing metadata refers to an internal mechanism to pass metadata around between components of a data science pipeline, ensuring it reaches the functions that consume or utilize it.

Using well-known metadata such as `sample_weight` and `groups` which are implemented in many scikit-learn metrics and evaluation tools, we will examine the restrictions for passing metadata prior to the introduction of the new API. Then, we will enable the new routing API and demonstrate how it solves these challenges with examples that involve layers of nested-ness through cross-validation, hyperparameter tuning, or pipelines. We will explain the core components of the API, including methods like `set_fit_request()` and how to actually pass our metadata.

Attendees will leave with an understanding of how to enable and use the new routing API including passing metadata through Pipeline objects and validation tools like `cross_validate`. Additional references to the metadata user guide and developer guide will be provided for those interested in further exploration.

### 1. What is Metadata Routing?

- definition of metadata:
    - metadata can be any data, that we want to apply on top of our tabular data, without it necessarily being part of it
    - alternative: metadata is any data, that some function in a data science "pipeline/process" handles besides from X and y; it's a param that influences this function's treatment of the data

- examples for metadata:
    - you might know from scikit-learn: `sample_weight` and `groups`
    - but other libraries offer support for other kinds of metadata
    - (slide only: graphically show some other examples of metadata: gender, race, sex, zipcode, .... and area/library it is used in)
    - self defined metadata to be used in custom metrics (and possibly custom estimators)

- use cases for metadata:
    - fairness related use case
    - business logic
    - `sample_weight` and can be used to re-balance data and `groups` prevent data leakage in a cross validation
    - note: we don't have better results if we pass metadata in terms of a better score, but we will have a more realistic model

- definition of routing:
    - routing just means that we pass metadata around between several components involved in a data science pipeline to where the metadata is used or consumed

- prior to metadata routing API:
    - we were restricted to only use `sample_weight` only within the metrics that would define it
    - we could not consistently use it in a more nested structure

- with metadata routing API:
    - we can pass `sample_weight` and `groups` through several levels of nested-ness between the components of a data science pipeline
    - we can combine objects from other libraries with scikit-learn estimators while still passing their metadata
    - we can define our own custom metrics using self defined metadata



### 2. Passing metadata without the routing API

In [3]:
import pandas as pd
import numpy as np

# Note: To keep this example simple, we will simplify categorical data and only show two
# sexes, two races, and two social-economic classes, which spares us from
# one-hot-encoding categorical data. This example DataFrame doesn't mean to reflect
# reality and is supposed to be readable on a slide during a presentation.

rng = np.random.RandomState(41)
sex = rng.randint(0,2,size=100)
age = rng.normal(loc=70, scale=20, size=100)
age = np.clip(age, 0, 100).astype(int) 
race = rng.choice([0, 1], size=100, p=[0.7, 0.3])
social_class = rng.choice([0, 1], size=100, p=[0.6, 0.4])
severity = rng.choice(range(11), size=100)
medication = rng.randint(0,2,size=100)
recovery_time = rng.normal(loc=25, scale=20, size=100)
recovery_time = np.clip(recovery_time, 0, 200).astype(int)

data = pd.DataFrame({"sex": sex, "age": age, "race": race, "class": social_class, "severity": severity, "medication": medication, "recovery_time": recovery_time})
data

Unnamed: 0,sex,age,race,class,severity,medication,recovery_time
0,0,88,1,1,6,1,33
1,1,71,1,1,1,1,8
2,0,100,1,1,4,0,32
3,0,87,1,0,0,1,56
4,0,50,0,0,7,1,28
...,...,...,...,...,...,...,...
95,1,71,1,1,9,1,16
96,1,99,1,0,0,1,26
97,0,38,1,1,1,0,0
98,1,75,0,0,2,0,0


In [4]:
X = data.iloc[:, :-1].to_numpy()
y = data.iloc[:,-1].to_numpy()

groups = rng.randint(0, 10, size=X.shape[0])
sample_weight = rng.rand(X.shape[0])

- example: medical study on the effectiveness of a treatment
    - we would want to group the hospitals using the `groups` parameter scikit-learn offers
        - we assume that each hospital’s data may have systematic biases due to factors like medical devices, policies, socioeconomic status of the patients, ...
        - since a hospital is a collection of patients, we group the samples per hospital
    - same hospital should not be both in the train and in the validation set when we cross validate
    - `groups` is a metadata that is used in splitters exclusively, to make sure that if patterns exist within the data, we don’t leak those patterns between train and validation set, because we want to train our models on the targets and not on other patterns within the data


In [9]:
from sklearn.model_selection import GroupKFold, cross_validate
from sklearn.linear_model import Ridge

ridge = Ridge()

# GroupKFold considers groups while splitting:
cv = GroupKFold(n_splits=2)

# providing `groups` for GroupKFold splitter:
cross_validate(ridge, X, y, cv=cv, groups=groups, return_train_score=True)

{'fit_time': array([0.00191188, 0.00093627]),
 'score_time': array([0.00070477, 0.00059056]),
 'test_score': array([-0.2923336 , -0.12131859]),
 'train_score': array([0.21447416, 0.06814853])}

- we might also want to pass `sample_weight`

- we would use `sample_weight` when we want to draw the attention of the model to a specific sub-group of samples, that our data miss-represents in some way, or to a problem with their distribution in respect to another feature

- two ways of getting data
    - randomised trail under very constrained conditions
    - from real worlds observations, and then we can bring the statistic on top of data we train on

- in our example talking about a medical study on the effectiveness of a treatment, `sample_weight` might
    - encode sex or race of a patient (to balance out data) 
    - or if we suspect a correlation between a feature (like race) and the fact if a patient got the specific treatment, e.g. a bias in which patient was chosen for the treatment, then `sample_weight` could be used to counter-balance that
        - there are several methods to determine `sample_weight` (from calculating proportions or more enhanced statistics from the data, or using more general statistical principles or natural laws that we know will take effect)

-  giving certain samples a higher `sample_weight` than the rest will effect in minimizing the error or loss of the predictions for these samples more than the general error, while we are still taking the richness of all the data into account

- we could pass `sample_weight` into the fit method and into the scoring:
    - and it would be both taken into account

In [10]:
from sklearn.linear_model import Ridge
from sklearn.metrics import get_scorer

# Ridge().fit() can consume `sample_weight`:
ridge = Ridge().fit(X,y, sample_weight=sample_weight) 

# `mean_squared_error` can consume sample_weight:
scoring = get_scorer("neg_mean_squared_error")

# thus we pass `sample_weight` to the fitted ridge object and to the scorer:
scoring(ridge, X, y, sample_weight=sample_weight)

np.float64(-234.7176229868505)

- but trying to pass `sample_weight` into `Ridge().fit()` within a cross validation results in it being ignored:
    - the results are the same as not passing it
    - because it is not passed to the different splits of the data when Ridge is refitted

In [6]:
from sklearn.model_selection import GroupKFold, cross_validate
from sklearn.linear_model import Ridge
from sklearn.metrics import get_scorer

# Ridge().fit() can consume `sample_weight`, but when passed into `cross_validate` `sample_weight` is ignored:
ridge = Ridge().fit(X,y, sample_weight=sample_weight)

scoring = get_scorer("neg_mean_squared_error")

# GroupKFold considers `groups` while splitting:
cv=GroupKFold(n_splits=2)

cross_validate(ridge, X, y, groups=groups, cv=cv, scoring=scoring, return_train_score=True)

{'fit_time': array([0.00111079, 0.00096726]),
 'score_time': array([0.00051475, 0.0003984 ]),
 'test_score': array([-365.16178094, -279.02891722]),
 'train_score': array([-195.47024953, -263.3039523 ])}

- now we're going to combine everything into a more realistic scenario
- what if we tried to pass `sample_weight` to a cross validation that uses a grid search on something that can consume `sample_weight`
- note that `Ridge().fit()` can consume `sample_weight`, but when passed into `GridSearchCV` `sample_weight` is ignored, because it is re-fitted internally
    - so we try to pass it into the cross validation

In [22]:
#%%script echo skipping

from sklearn.linear_model import Ridge
from sklearn.model_selection import GroupKFold, GridSearchCV
from sklearn.metrics import get_scorer

# Ridge().fit() can consume `sample_weight`, but when passed into `GridSearchCV` `sample_weight` is ignored:
ridge = Ridge()

param_grid={"alpha": [0.1, 1, 10]}
# GroupKFold considers `groups` while splitting:
cv=GroupKFold(n_splits=2)
# `mean_squared_error` can consume sample_weight:
scoring = get_scorer("neg_mean_squared_error")

search = GridSearchCV(ridge, param_grid=param_grid, cv=cv, scoring=scoring)

cross_validate(
    search,
    X,
    y,
    cv=cv,
    scoring=scoring,
    sample_weight=sample_weight,
    groups=groups,
)

TypeError: got an unexpected keyword argument 'sample_weight'

- we get a TypeError as `sample_weight` is not recognised

- we have just witnessed the limitations on using metadata in nested structures

- if we want to predict treatment efficiency for future patients in fairer, less biased models, even if past data was not gathered from a randomized trial, we want to use metadata everywhere / in the whole data science pipeline

### 3. Using the metadata routing API

In [23]:
from sklearn.linear_model import Ridge
from sklearn.model_selection import GroupKFold, GridSearchCV
from sklearn.metrics import get_scorer

import sklearn

with sklearn.config_context(enable_metadata_routing=True):
    ridge = Ridge().set_fit_request(sample_weight=True)
    
    param_grid={"alpha": [0.1, 1, 10]}
    cv=GroupKFold(n_splits=2)
    scoring = get_scorer("neg_mean_squared_error").set_score_request(sample_weight=True)

    search = GridSearchCV(ridge, param_grid=param_grid, cv=cv, scoring=scoring)

    cross_validate(
        search,
        X,
        y,
        cv=cv,
        scoring=scoring,
        params={"sample_weight": sample_weight, "groups": groups},
    )


- at current time, we have to enable metadata routing, because it's still experimental
-
- the main idea of the routing API is that we will have a centralised place where to pass the metadata, which in this case is `cross_validate`
- our top layer tool to interact with is `cross_validate`
- here, we pass both our metadata in: `groups` and `sample_weight`
- 
- then, we want to define, where this metadata is routed to, that is, where it should be consumed
- we use the `set_{method}_request()` methods to define, that this method expects which metadata to be passed
- we do that with `set_fit_request()` on `Ridge.fit()` and with `set_score_request()` on the scorer object
-
- as before, we have `GroupKFold` as a cv-splitter, which accepts `groups` by design
- and  since this is a grouped splitter, we have no choice than to pass the `groups` (and thus we don't need to set any request)
-
- we have enabled metadata routing, we have set the requests to where sample should be routed and we have passed the values for `sample_weight` and `groups` into `cross_validate` as `params`
- now `cross_validate` will route our metadata to the objective function of our estimator, to the splitter used in cross validation and to the scoring metric used to evaluate the validation sets

- ongoing work:
    - with the 1.5 release of scikit-learn in May 2024, metadata routing is almost fully available
    - in the 1.6 release (probably November), we expect all of the estimators to be compatible with metadata routing API
    - we're working defining good default settings, so that users won't need to set all requests as currently
    - but for your custom business metrics the flexibility you see here will also be provided

- summing up: with metadata routing API:
    - we can pass `sample_weight` and `groups` through several levels of nested-ness between the components of a data science pipeline
    - we can combine objects from other libraries with scikit-learn estimators while still passing their metadata
    - we can define our own custom metrics using self defined metadata ...
    - ... and use it in a special setting with [TunedThresholdClassifier](https://scikit-learn.org/dev/auto_examples/model_selection/plot_cost_sensitive_learning.html#cost-sensitive-learning-when-gains-and-costs-are-not-constant)

### 4. Further information

- [Notebook for this talk](https://github.com/StefanieSenger/Talks/tree/main/2024_Metadata-Routing-API)

- more information on metadata routing:
    - [User Guide on Metadata Routing](https://scikit-learn.org/stable/metadata_routing.html#metadata-routing)
    - [Developer Guide on Metadata Routing](https://scikit-learn.org/stable/auto_examples/miscellaneous/plot_metadata_routing.html#metadata-routing)
    - [Adrin Jalali's talk on the internal logic of metadata routing at EuroPython Conference 2023](https://www.youtube.com/watch?v=1rf6HI-pYq8)

- usage of `sample_weight`:
    - [:probabl. Whiteboard Series by Vincent Warmerdam: Improving models via subsets](https://www.youtube.com/watch?v=REIg5NH2SNc)
    - [Blogpost by Florian Wilhelm on Inverse Probability of Treatment Weighting](https://florianwilhelm.info/2017/04/causal_inference_propensity_score)