In [1]:
from typing import Union

from scipy.spatial import distance_matrix
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.metrics import (
    precision_recall_curve,
    roc_curve,
    average_precision_score,
    roc_auc_score,
    confusion_matrix)
from sklearn.model_selection import KFold
from mlxtend.feature_selection import SequentialFeatureSelector as SFS
from sklearn.feature_selection import SelectKBest
import sklearn
import numpy as np
import pandas as pd

In [2]:
def get_selected_features(
        clf: sklearn.base.BaseEstimator,
        X: pd.DataFrame,
        y: pd.Series,
        scoring: str,
        splits: Union[int, list[list, list]] = None) -> SFS:
    """Select features based on the Recursive Feature Elimination.

    Args:
        clf: A Scikit-learn classifier.
        X: The input dataset.
        y: The input labels.
        splits: An integer indicating the number of folds or a pre-defined
            list of training and validation split row indices.

    Returns:
        The SFS object with RFE result.

    """
    sfs = SFS(
        clf,
        k_features='parsimonious',
        forward=False,
        floating=False,
        verbose=1,
        scoring=scoring,
        n_jobs=-1,
        cv=splits)
    sfs = sfs.fit(X, y)

    return sfs

In [9]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
clf = RandomForestClassifier(n_estimators=100, random_state=0)
kf = KFold(n_splits=2, shuffle=True, random_state=0)
splits = list(kf.split(X_train)) # list bc otherwise will be generator
sfs = get_selected_features(clf, X_train, y_train, 'balanced_accuracy', splits)

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   4 out of   4 | elapsed:    0.5s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done   4 out of   4 | elapsed:    0.5s finished
Features: 3/1[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   3 out of   3 | elapsed:    0.2s finished
Features: 2/1[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 out of   2 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done   2 out of   2 | elapsed:    0.1s finished
Features: 1/1

In [10]:
print(sfs.k_feature_idx_)

(3,)
