In [27]:
from sklego.datasets import load_heroes
import numpy as np
import pandas as pd 

from sklearn.preprocessing import StandardScaler, LabelBinarizer
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer, mean_squared_error

def clean_data(dataf):
    return (dataf
          .dropna()
          .assign(dmg=lambda d: d['attack'] * d['attack_spd']))

df = load_heroes(give_pandas=True)
ml_df = df.pipe(clean_data)

X = ml_df[["health", "dmg"]].values
y = ml_df["attack_type"].values

In [28]:
pipe = Pipeline([
    ("scale", StandardScaler()),
    ("model", LogisticRegression(solver='lbfgs'))
])

model = GridSearchCV(
    estimator=pipe, 
    cv=5,
    param_grid={'scale__with_mean': [True, False], 
                'scale__with_std': [True, False], 
                'model__C': [0.01, 0.1, 1.0, 10.0, 100.0]},
    return_train_score=True,
)

model.fit(X, y);

In [37]:
import pandas as pd


class RefitPolicy:
    _sort_func_name = "final_score_sort"

    def __init__(self):
        self.filter_funcs = []
        self.assignments = []
        self.sorting_func = None

    def filter(self, func):
        self.filter_funcs.append(func)
        return self

    def assign(self, **kwargs):
        if self._sort_func_name in kwargs.keys():
            raise ValueError(f"We use {self._sort_func_name} internally, cannot use this column name")
        self.assignments.append(kwargs)
        return self

    def sort(self, func):
        self.sorting_func = func
        return self

    def pick_best_estimator(self, gridsearch_obj):
        dataf = pd.DataFrame(gridsearch_obj.cv_results_)
        for filter in self.filter_funcs:
            dataf = dataf.loc[filter]
        for assignment in self.assignments:
            for name, func in assignment.items():
                dataf = dataf.assign(**{name: func})
        dataf = dataf.assign(**{self._sort_func_name: self.sorting_func})
        return dataf.sort_values(self._sort_func_name, ascending=False)

In [None]:
def pd_pipeline(dataf):
    return (dataf
            .assign(train_test_diff=lambda d: d['mean_train_score'] - d['mean_test_score'])
            .sort_values('train_test_diff'))

In [39]:
(RefitPolicy()
 .assign(train_test_diff=lambda d: d['mean_train_score'] - d['mean_test_score'])
 .sort(lambda d: d['train_test_diff'])
 .pick_best_estimator(model))

Unnamed: 0,mean_fit_time,std_fit_time,mean_score_time,std_score_time,param_model__C,param_scale__with_mean,param_scale__with_std,params,split0_test_score,split1_test_score,...,rank_test_score,split0_train_score,split1_train_score,split2_train_score,split3_train_score,split4_train_score,mean_train_score,std_train_score,train_test_diff,final_score_sort
4,0.001782,9e-05,0.000302,1.2e-05,0.1,True,True,"{'model__C': 0.1, 'scale__with_mean': True, 's...",0.941176,0.823529,...,19,0.818182,0.833333,0.863636,0.865672,0.850746,0.846314,0.01819,0.064696,0.064696
6,0.002458,0.000249,0.000291,3e-06,0.1,False,True,"{'model__C': 0.1, 'scale__with_mean': False, '...",0.941176,0.823529,...,19,0.818182,0.833333,0.863636,0.865672,0.850746,0.846314,0.01819,0.064696,0.064696
10,0.002713,0.000131,0.000297,1.8e-05,1.0,False,True,"{'model__C': 1.0, 'scale__with_mean': False, '...",0.941176,0.823529,...,1,0.818182,0.848485,0.863636,0.865672,0.880597,0.855314,0.021172,0.061197,0.061197
12,0.001824,5.5e-05,0.000298,1.9e-05,10.0,True,True,"{'model__C': 10.0, 'scale__with_mean': True, '...",0.941176,0.823529,...,1,0.818182,0.848485,0.863636,0.865672,0.880597,0.855314,0.021172,0.061197,0.061197
8,0.001769,6.7e-05,0.000292,2e-06,1.0,True,True,"{'model__C': 1.0, 'scale__with_mean': True, 's...",0.941176,0.823529,...,1,0.818182,0.848485,0.863636,0.865672,0.880597,0.855314,0.021172,0.061197,0.061197
14,0.002899,0.000158,0.000282,1e-06,10.0,False,True,"{'model__C': 10.0, 'scale__with_mean': False, ...",0.941176,0.823529,...,1,0.818182,0.848485,0.863636,0.865672,0.880597,0.855314,0.021172,0.061197,0.061197
18,0.002949,0.000107,0.000285,4e-06,100.0,False,True,"{'model__C': 100.0, 'scale__with_mean': False,...",0.941176,0.823529,...,1,0.818182,0.848485,0.863636,0.865672,0.865672,0.852329,0.018247,0.058212,0.058212
17,0.003971,0.000146,0.000319,8.1e-05,100.0,True,False,"{'model__C': 100.0, 'scale__with_mean': True, ...",0.941176,0.823529,...,1,0.818182,0.848485,0.863636,0.865672,0.865672,0.852329,0.018247,0.058212,0.058212
16,0.001745,6.7e-05,0.000282,1e-06,100.0,True,True,"{'model__C': 100.0, 'scale__with_mean': True, ...",0.941176,0.823529,...,1,0.818182,0.848485,0.863636,0.865672,0.865672,0.852329,0.018247,0.058212,0.058212
15,0.005065,0.000332,0.000279,6e-06,10.0,False,False,"{'model__C': 10.0, 'scale__with_mean': False, ...",0.941176,0.823529,...,1,0.818182,0.848485,0.863636,0.865672,0.865672,0.852329,0.018247,0.058212,0.058212
