In [1]:
# stratified-kfold for regression 
import numpy as np 
import pandas as pd 
 
from sklearn import datasets 
from sklearn import model_selection 

In [2]:
# we create a sample dataset with 15000 samples  
# and 100 features and 1 target 
X, y = datasets.make_regression( 
    n_samples=15000, n_features=100, n_targets=1 
) 
 
# create a dataframe out of our numpy arrays 
df = pd.DataFrame( 
    X, 
    columns=[f"f_{i}" for i in range(X.shape[1])] 
) 
df.loc[:, "target"] = y 

In [3]:
df

Unnamed: 0,f_0,f_1,f_2,f_3,f_4,f_5,f_6,f_7,f_8,f_9,...,f_91,f_92,f_93,f_94,f_95,f_96,f_97,f_98,f_99,target
0,0.223970,-1.664681,0.428109,-0.110086,-1.227328,-0.977546,0.597022,0.499344,-0.919195,-0.554525,...,-0.471622,-0.826717,-0.570486,0.097506,0.456067,0.919311,-0.609127,1.156041,-1.294603,-199.966531
1,-2.134054,-1.359385,0.870963,1.322919,0.163915,-0.338145,-0.953794,0.675151,-1.160180,1.867907,...,-1.411619,0.696711,-1.562239,-0.545749,-0.755881,-0.782515,-0.899546,0.105559,1.129632,123.432316
2,-0.190019,0.097398,0.697805,-1.120048,-0.254720,-1.055490,0.242970,0.867908,-1.207903,0.432400,...,-1.376247,-2.341474,-1.136189,-0.695651,0.437186,-0.235150,-0.563329,0.214958,0.046323,-273.838806
3,0.256697,-0.529162,0.708929,0.649498,1.370773,-0.378014,0.513940,-0.555078,2.089526,-1.001197,...,-0.230791,0.248942,-0.709519,0.375288,-0.718326,0.681851,-0.623601,3.344169,-0.280413,-41.063996
4,-0.610768,1.789506,-1.205058,-1.879438,0.783265,0.228100,0.244645,1.560855,0.434679,1.137960,...,0.805311,-1.067458,-0.916404,-2.145369,0.160943,0.405452,0.533245,1.368570,-0.723004,-274.555563
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
14995,0.839829,-0.310830,0.343549,-1.194158,0.267259,-0.064102,-1.027815,0.323880,0.749266,-0.239301,...,0.033154,0.703723,-1.248646,1.215352,1.247183,-0.732674,-0.025275,0.581372,-0.733778,-228.840728
14996,-0.243759,0.560275,-0.476375,1.189348,0.407169,0.237948,0.401241,0.159243,-0.251356,-0.833833,...,0.800656,-0.322952,0.336633,1.149071,0.411646,-0.654955,1.678458,-0.561241,-0.530203,147.359785
14997,0.226550,-0.140586,-1.012690,0.447480,-0.702770,-0.641372,0.059138,1.675124,0.295142,-0.736423,...,-0.633563,0.025168,-1.589290,-0.186061,0.158419,0.698538,0.454031,-0.229974,1.468166,-120.558863
14998,-0.413925,1.709374,0.795769,0.016329,-1.540249,-2.310901,-0.065999,-0.357786,0.255153,-0.738194,...,-0.193731,0.829684,0.278808,0.229010,0.329094,1.542992,-0.409581,0.401992,-0.946111,197.694253


In [5]:
# we create a new column called kfold and fill it with -1 
df["kfold"] = -1 

In [6]:
df

Unnamed: 0,f_0,f_1,f_2,f_3,f_4,f_5,f_6,f_7,f_8,f_9,...,f_92,f_93,f_94,f_95,f_96,f_97,f_98,f_99,target,kfold
0,0.223970,-1.664681,0.428109,-0.110086,-1.227328,-0.977546,0.597022,0.499344,-0.919195,-0.554525,...,-0.826717,-0.570486,0.097506,0.456067,0.919311,-0.609127,1.156041,-1.294603,-199.966531,-1
1,-2.134054,-1.359385,0.870963,1.322919,0.163915,-0.338145,-0.953794,0.675151,-1.160180,1.867907,...,0.696711,-1.562239,-0.545749,-0.755881,-0.782515,-0.899546,0.105559,1.129632,123.432316,-1
2,-0.190019,0.097398,0.697805,-1.120048,-0.254720,-1.055490,0.242970,0.867908,-1.207903,0.432400,...,-2.341474,-1.136189,-0.695651,0.437186,-0.235150,-0.563329,0.214958,0.046323,-273.838806,-1
3,0.256697,-0.529162,0.708929,0.649498,1.370773,-0.378014,0.513940,-0.555078,2.089526,-1.001197,...,0.248942,-0.709519,0.375288,-0.718326,0.681851,-0.623601,3.344169,-0.280413,-41.063996,-1
4,-0.610768,1.789506,-1.205058,-1.879438,0.783265,0.228100,0.244645,1.560855,0.434679,1.137960,...,-1.067458,-0.916404,-2.145369,0.160943,0.405452,0.533245,1.368570,-0.723004,-274.555563,-1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
14995,0.839829,-0.310830,0.343549,-1.194158,0.267259,-0.064102,-1.027815,0.323880,0.749266,-0.239301,...,0.703723,-1.248646,1.215352,1.247183,-0.732674,-0.025275,0.581372,-0.733778,-228.840728,-1
14996,-0.243759,0.560275,-0.476375,1.189348,0.407169,0.237948,0.401241,0.159243,-0.251356,-0.833833,...,-0.322952,0.336633,1.149071,0.411646,-0.654955,1.678458,-0.561241,-0.530203,147.359785,-1
14997,0.226550,-0.140586,-1.012690,0.447480,-0.702770,-0.641372,0.059138,1.675124,0.295142,-0.736423,...,0.025168,-1.589290,-0.186061,0.158419,0.698538,0.454031,-0.229974,1.468166,-120.558863,-1
14998,-0.413925,1.709374,0.795769,0.016329,-1.540249,-2.310901,-0.065999,-0.357786,0.255153,-0.738194,...,0.829684,0.278808,0.229010,0.329094,1.542992,-0.409581,0.401992,-0.946111,197.694253,-1


In [7]:
# the next step is to randomize the rows of the data 
df = df.sample(frac=1).reset_index(drop=True) 

In [8]:
df

Unnamed: 0,f_0,f_1,f_2,f_3,f_4,f_5,f_6,f_7,f_8,f_9,...,f_92,f_93,f_94,f_95,f_96,f_97,f_98,f_99,target,kfold
0,-1.871023,0.002347,-0.100897,-0.699443,-0.258656,-0.675550,-1.944814,0.637522,1.883547,0.209237,...,0.335607,0.331743,-0.408142,0.800736,0.319487,0.806897,0.563989,-1.942365,-143.374443,-1
1,0.243421,1.246340,-3.795698,-1.112081,1.978044,0.733066,0.165492,-0.757500,1.203836,1.924050,...,-0.471518,1.742315,1.281713,-0.299137,0.853786,0.097654,-0.951888,0.121161,-32.620309,-1
2,-0.746145,0.445028,0.501101,0.737003,0.305016,-0.229007,-2.606008,1.353098,-1.081187,0.900805,...,1.122604,-1.150231,-0.958864,-0.958121,-0.816682,-1.265704,3.666872,0.447208,-34.447372,-1
3,0.134426,0.668969,-0.286982,-1.665236,2.059995,-1.559147,-0.205234,-0.694217,0.519911,-0.845182,...,2.070275,-0.122821,-1.162960,0.378726,0.442660,-1.370974,1.517360,0.834702,-125.332504,-1
4,0.371534,-0.400097,-0.190660,1.193273,-0.981540,1.608261,0.966934,2.166185,-0.496196,-0.442643,...,0.862871,-0.981960,0.526344,0.132604,-0.506084,-0.963304,2.484995,-0.831349,312.020720,-1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
14995,-0.694858,2.072531,0.743673,-0.286398,0.553287,-0.955437,-1.433266,1.098900,-2.176012,-2.672660,...,-0.707683,-0.972383,0.134028,-1.145988,0.446997,0.495582,0.226269,-0.674155,-98.241490,-1
14996,-2.079279,1.121122,0.173969,-0.070473,1.182968,0.935029,-1.202881,-2.795834,-1.395464,0.168294,...,-1.975873,0.222066,0.608003,-0.383162,-1.189989,-1.035347,0.731851,-0.764954,46.488710,-1
14997,-0.274027,-1.010586,-3.133371,1.034192,0.505735,-0.337084,0.463254,-1.644425,-0.386557,-0.323093,...,-0.995583,-0.738850,-0.230283,-1.213200,-0.365869,-0.893766,-0.217317,1.015299,-52.401252,-1
14998,0.388453,0.293064,-0.102664,0.424681,-0.757930,-2.168874,-0.819553,-0.816970,-0.531885,1.826212,...,0.845327,1.253695,-0.142836,-0.642236,-0.344600,-0.110986,0.530337,0.602445,50.575459,-1


In [10]:
# calculate the number of bins by Sturge's rule 
# I take the floor of the value, you can also 
# just round it 
num_bins = np.floor(1 + np.log2(len(df))) 

In [11]:
num_bins

14.0

In [13]:
# bin targets 
df.loc[:, "bins"] = pd.cut( 
    df["target"], bins=num_bins, labels=False
)

In [14]:
df

Unnamed: 0,f_0,f_1,f_2,f_3,f_4,f_5,f_6,f_7,f_8,f_9,...,f_93,f_94,f_95,f_96,f_97,f_98,f_99,target,kfold,bins
0,-1.871023,0.002347,-0.100897,-0.699443,-0.258656,-0.675550,-1.944814,0.637522,1.883547,0.209237,...,0.331743,-0.408142,0.800736,0.319487,0.806897,0.563989,-1.942365,-143.374443,-1,5
1,0.243421,1.246340,-3.795698,-1.112081,1.978044,0.733066,0.165492,-0.757500,1.203836,1.924050,...,1.742315,1.281713,-0.299137,0.853786,0.097654,-0.951888,0.121161,-32.620309,-1,6
2,-0.746145,0.445028,0.501101,0.737003,0.305016,-0.229007,-2.606008,1.353098,-1.081187,0.900805,...,-1.150231,-0.958864,-0.958121,-0.816682,-1.265704,3.666872,0.447208,-34.447372,-1,6
3,0.134426,0.668969,-0.286982,-1.665236,2.059995,-1.559147,-0.205234,-0.694217,0.519911,-0.845182,...,-0.122821,-1.162960,0.378726,0.442660,-1.370974,1.517360,0.834702,-125.332504,-1,5
4,0.371534,-0.400097,-0.190660,1.193273,-0.981540,1.608261,0.966934,2.166185,-0.496196,-0.442643,...,-0.981960,0.526344,0.132604,-0.506084,-0.963304,2.484995,-0.831349,312.020720,-1,10
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
14995,-0.694858,2.072531,0.743673,-0.286398,0.553287,-0.955437,-1.433266,1.098900,-2.176012,-2.672660,...,-0.972383,0.134028,-1.145988,0.446997,0.495582,0.226269,-0.674155,-98.241490,-1,6
14996,-2.079279,1.121122,0.173969,-0.070473,1.182968,0.935029,-1.202881,-2.795834,-1.395464,0.168294,...,0.222066,0.608003,-0.383162,-1.189989,-1.035347,0.731851,-0.764954,46.488710,-1,7
14997,-0.274027,-1.010586,-3.133371,1.034192,0.505735,-0.337084,0.463254,-1.644425,-0.386557,-0.323093,...,-0.738850,-0.230283,-1.213200,-0.365869,-0.893766,-0.217317,1.015299,-52.401252,-1,6
14998,0.388453,0.293064,-0.102664,0.424681,-0.757930,-2.168874,-0.819553,-0.816970,-0.531885,1.826212,...,1.253695,-0.142836,-0.642236,-0.344600,-0.110986,0.530337,0.602445,50.575459,-1,7


In [15]:
# initiate the kfold class from model_selection module 
kf = model_selection.StratifiedKFold(n_splits=5) 

In [16]:
# fill the new kfold column 
# note that, instead of targets, we use bins! 
for f, (t_, v_) in enumerate(kf.split(X=df, y=df.bins.values)): 
    df.loc[v_, 'kfold'] = f 



In [17]:
df

Unnamed: 0,f_0,f_1,f_2,f_3,f_4,f_5,f_6,f_7,f_8,f_9,...,f_93,f_94,f_95,f_96,f_97,f_98,f_99,target,kfold,bins
0,-1.871023,0.002347,-0.100897,-0.699443,-0.258656,-0.675550,-1.944814,0.637522,1.883547,0.209237,...,0.331743,-0.408142,0.800736,0.319487,0.806897,0.563989,-1.942365,-143.374443,0,5
1,0.243421,1.246340,-3.795698,-1.112081,1.978044,0.733066,0.165492,-0.757500,1.203836,1.924050,...,1.742315,1.281713,-0.299137,0.853786,0.097654,-0.951888,0.121161,-32.620309,0,6
2,-0.746145,0.445028,0.501101,0.737003,0.305016,-0.229007,-2.606008,1.353098,-1.081187,0.900805,...,-1.150231,-0.958864,-0.958121,-0.816682,-1.265704,3.666872,0.447208,-34.447372,0,6
3,0.134426,0.668969,-0.286982,-1.665236,2.059995,-1.559147,-0.205234,-0.694217,0.519911,-0.845182,...,-0.122821,-1.162960,0.378726,0.442660,-1.370974,1.517360,0.834702,-125.332504,0,5
4,0.371534,-0.400097,-0.190660,1.193273,-0.981540,1.608261,0.966934,2.166185,-0.496196,-0.442643,...,-0.981960,0.526344,0.132604,-0.506084,-0.963304,2.484995,-0.831349,312.020720,0,10
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
14995,-0.694858,2.072531,0.743673,-0.286398,0.553287,-0.955437,-1.433266,1.098900,-2.176012,-2.672660,...,-0.972383,0.134028,-1.145988,0.446997,0.495582,0.226269,-0.674155,-98.241490,4,6
14996,-2.079279,1.121122,0.173969,-0.070473,1.182968,0.935029,-1.202881,-2.795834,-1.395464,0.168294,...,0.222066,0.608003,-0.383162,-1.189989,-1.035347,0.731851,-0.764954,46.488710,4,7
14997,-0.274027,-1.010586,-3.133371,1.034192,0.505735,-0.337084,0.463254,-1.644425,-0.386557,-0.323093,...,-0.738850,-0.230283,-1.213200,-0.365869,-0.893766,-0.217317,1.015299,-52.401252,4,6
14998,0.388453,0.293064,-0.102664,0.424681,-0.757930,-2.168874,-0.819553,-0.816970,-0.531885,1.826212,...,1.253695,-0.142836,-0.642236,-0.344600,-0.110986,0.530337,0.602445,50.575459,4,7


In [18]:
df.to_csv('df1.csv')