In [2]:
import warnings
warnings.filterwarnings('ignore')

In [3]:
# Stratified K-Fold for Regression
import numpy as np
import pandas as pd

from sklearn import datasets, model_selection

In [8]:
def create_folds(data):
    data["kfold"] = -1
    
    data = data.sample(frac=1).reset_index(drop=True)
    
    num_bins = np.floor(1 + np.log2(len(data))).astype(int)
    
    data.loc[:, "bins"] = pd.cut(data["target"], bins = num_bins, labels=False)
    
    kf = model_selection.StratifiedKFold(n_splits=5)
    
    for f, (t_, v_) in enumerate(kf.split(X=data, y=data.bins.values)):
        data.loc[v_, 'kfold'] = f
        
    data = data.drop("bins", axis=1)
    
    return data

In [9]:
if __name__ == "__main__":
    X, y = datasets.make_regression(n_samples=15000, n_features=100, n_targets=1)
    
    df = pd.DataFrame(X, columns=[f"f_{i}" for i in range(X.shape[1])])
    
    df.loc[:,"target"] = y
    
    df = create_folds(df)

In [10]:
df.head()

Unnamed: 0,f_0,f_1,f_2,f_3,f_4,f_5,f_6,f_7,f_8,f_9,...,f_92,f_93,f_94,f_95,f_96,f_97,f_98,f_99,target,kfold
0,-2.247263,1.103322,-0.161552,-1.327398,-0.25352,-2.804632,1.415777,0.606284,0.549696,0.577328,...,-0.813081,-1.322948,-0.52742,0.373597,-0.632935,0.425368,0.272859,-0.091118,152.119413,0
1,0.032676,-0.845902,1.170086,-0.711318,0.359619,0.238567,0.484936,1.488793,1.442507,1.023722,...,0.592769,0.044931,1.165851,1.940168,-0.207212,-0.582139,0.246515,0.13554,-136.799137,0
2,2.760491,-0.189535,0.567314,0.788226,-0.558714,0.679646,-0.285486,-0.810294,-0.328314,-0.906481,...,0.361322,-0.113551,-2.117612,-0.892497,-1.056271,-1.149352,-1.079195,-1.204498,-43.795505,0
3,-2.384299,-1.450349,0.274419,0.503464,-1.452006,-0.211453,-0.802653,-1.027676,0.781841,0.263892,...,-0.529914,0.67964,1.038054,-3.219707,2.895261,0.55778,0.152217,-0.281395,-399.383032,0
4,-1.254776,-0.555948,1.322382,0.544544,0.56304,-0.784334,-0.470122,-0.229424,0.147596,-2.12677,...,1.060788,-0.965381,-2.031138,-1.560985,-0.461603,1.978358,2.19125,-0.008365,105.397653,0


In [11]:
df['kfold'].unique()

array([0, 1, 2, 3, 4], dtype=int64)