In [1]:
# Author: Ahmet Yildirim
# Date: 30.10.2020

In [2]:
import numpy as np
import pandas as pd
from sklearn.model_selection import cross_validate
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import make_scorer, mean_squared_error, r2_score
from sklearn.datasets import load_boston
import warnings

In [3]:
warnings.filterwarnings("ignore")
pd.options.display.max_rows = 500
pd.options.display.max_columns = 50

In [10]:
def stratified_regression_split(data_size, n_splits=5):
    n_range = np.arange(n_splits)
    shuffle = np.vectorize(np.random.RandomState(
        seed=26).permutation, signature='(n)->(n)')

    divisible_size = data_size - data_size % n_splits
    ind = np.arange(divisible_size).reshape(
        (int(np.floor(data_size/n_splits)), n_splits))
    shuffled_ind = shuffle(ind).transpose()
    return ((shuffled_ind[np.delete(n_range, n)].flatten(), shuffled_ind[n]) for n in n_range)

In [11]:
# Loading the dataset
dataset = load_boston()
data = pd.DataFrame(dataset['data'])
data['target'] = dataset['target']

# Random splitting the test data
test_size = 20/100
test = data.sample(int(data.shape[0]*test_size), random_state=11)

# Sorting the train data by target
train = data.drop(test.index).sort_values(by='target')

In [12]:
# Defining X and y
X_train = train.drop(['target'], 1).values
y_train = train['target'].values.ravel()

X_test = test.drop(['target'], 1).values
y_test = test['target'].values.ravel()

In [13]:
# Estimator
gb = GradientBoostingRegressor()

In [14]:
# Cross validate on stratified splits
# Data should be sorted by target variable before cross validation
gb_results = cross_validate(gb, X_train, y_train, cv=stratified_regression_split(
    X_train.shape[0], n_splits=5), scoring=('r2', 'neg_mean_squared_error'), return_train_score=True)

In [15]:
# Scores for each split
pd.DataFrame(gb_results)

Unnamed: 0,fit_time,score_time,test_r2,train_r2,test_neg_mean_squared_error,train_neg_mean_squared_error
0,0.052002,0.001,0.920915,0.985006,-6.7254,-1.252209
1,0.033003,0.001,0.912648,0.985257,-7.211364,-1.240448
2,0.031998,0.001001,0.821904,0.983943,-15.084663,-1.342388
3,0.032,0.001,0.905489,0.983465,-7.969533,-1.383897
4,0.063005,0.001,0.911381,0.984256,-7.308974,-1.32496
