In [2]:
import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn import model_selection

In [69]:
def create_folds(data):
    data['kfold'] = -1
    
    data = data.sample(frac=1).reset_index(drop=True)
    
    num_bins = np.floor(1 + np.log2(len(data)))
    
    data.loc[:, 'bins'] = pd.cut(
        data['target'], bins = int(num_bins), labels = False
    )
    kf = model_selection.StratifiedKFold(n_splits = 5)
    for f, (t_, v_) in enumerate(kf.split(X=data, y = data.bins.values)):
        print(f), print(' ')
        print(t_), print(' ')
        print(v_), print(' ')
        data.loc[v_, 'kfold'] = f
    
    data = data.drop('bins', axis = 1)
    
    return data

In [70]:
X, y = datasets.make_regression( n_samples = 15000, n_features = 100, n_targets = 1)
df = pd.DataFrame(X, columns = [f'f_{i}' for i in range(X.shape[1])])
df.loc[:, 'target'] = y

In [71]:
df.head()

Unnamed: 0,f_0,f_1,f_2,f_3,f_4,f_5,f_6,f_7,f_8,f_9,...,f_91,f_92,f_93,f_94,f_95,f_96,f_97,f_98,f_99,target
0,-1.484328,1.182369,-0.574792,-0.615991,0.51405,0.114719,-0.124999,-0.213301,0.178923,-1.044408,...,-0.469886,-0.375407,1.19308,3.014717,0.28456,-0.073605,-0.912613,0.948379,-1.03681,3.024093
1,0.19311,0.167554,1.232662,0.118989,0.413704,-1.188519,-0.015786,-1.751908,-1.272523,-0.154628,...,-0.006736,-0.723223,1.152635,1.449803,-0.426057,-1.859479,1.558751,-1.940464,0.198883,-30.831905
2,0.655615,-0.458892,-0.690228,0.532104,0.111328,-0.671696,-1.452936,1.203722,0.447595,0.208076,...,2.197944,0.749309,0.498138,0.155201,-1.589911,0.11409,0.579026,-0.67349,0.378613,-37.02112
3,2.857441,1.30613,-0.867378,0.117355,-0.511969,0.364917,0.657136,-2.088921,-0.54597,1.946379,...,-0.70631,-3.060594,0.228974,-2.018923,1.277559,-0.985244,-0.814621,1.118917,-0.777233,243.374892
4,-1.869826,0.005379,1.331509,-1.147966,0.967723,-1.176837,-0.052228,0.12861,-0.36532,-0.653881,...,0.148805,-0.325057,0.686479,2.318396,-0.610475,-0.333273,-1.975015,-0.41152,2.56203,-158.964849


In [78]:
df_1 = create_folds(df)

0
 
[ 2372  2830  2833 ... 14997 14998 14999]
 
[   0    1    2 ... 3222 3275 3394]
 
1
 
[    0     1     2 ... 14997 14998 14999]
 
[2372 2830 2833 ... 6504 6519 6842]
 
2
 
[    0     1     2 ... 14997 14998 14999]
 
[ 4366  5541  5574 ... 10246 10255 10439]
 
3
 
[    0     1     2 ... 14997 14998 14999]
 
[ 5974  8632  8786 ... 12400 12424 12569]
 
4
 
[    0     1     2 ... 12400 12424 12569]
 
[ 6467  9904 11703 ... 14997 14998 14999]
 


In [79]:
df_1.tail(30)

Unnamed: 0,f_0,f_1,f_2,f_3,f_4,f_5,f_6,f_7,f_8,f_9,...,f_92,f_93,f_94,f_95,f_96,f_97,f_98,f_99,target,kfold
14970,-0.073409,1.935839,2.815946,-0.040539,-0.177488,0.35447,-0.405054,0.721314,-1.576358,-0.578664,...,1.621799,-0.324799,-0.385271,1.900062,0.983945,-0.002397,-1.266173,2.512389,-4.913091,4
14971,-0.183316,0.085383,0.467468,-1.255173,-0.167519,0.404414,1.975946,0.671486,0.405693,2.004559,...,1.96039,0.618714,-1.38353,-0.80013,1.554795,0.657643,0.550875,-0.536484,144.121896,4
14972,-0.837978,0.065367,0.578595,1.084332,0.202756,0.213729,-0.325475,1.281426,-0.645752,0.107934,...,0.385044,1.117194,-0.053291,-0.658643,-0.344624,0.574285,1.749406,0.536858,-31.436541,4
14973,-1.438704,1.120182,0.52437,-0.940348,1.693549,0.632267,0.015285,-1.102906,0.864716,-0.777419,...,1.179465,0.640164,-0.39653,0.935326,-1.571516,0.348231,2.105435,-0.856301,-145.266273,4
14974,-1.825863,-0.295549,-0.602726,-1.139397,0.282674,0.336064,0.582661,-0.059644,0.093835,1.919468,...,0.524582,-1.528883,-0.349817,0.846935,1.690331,0.584405,0.154899,0.86051,-128.392516,4
14975,-0.285147,0.617213,0.956291,0.276001,-0.15419,-0.452725,0.452454,-1.572187,-1.408494,0.243112,...,0.534993,-0.186079,1.659598,-0.082719,0.117717,0.924834,0.255039,0.007507,175.668006,4
14976,0.77193,-0.669553,1.59464,0.272783,3.171255,1.279284,-0.064103,-1.138377,-0.478886,-0.839825,...,0.664863,0.177619,0.172181,0.770331,0.795984,-1.930706,0.45862,-2.081792,-295.457729,4
14977,-0.076034,-0.894372,-0.492656,-0.045238,0.589584,0.084541,-0.442672,-0.346102,-0.675127,-2.030113,...,0.77183,0.239702,-0.410019,-0.364564,0.437983,1.948024,-0.178698,-0.814335,28.813154,4
14978,-1.788061,0.454306,0.595587,0.33084,0.143994,-2.313516,-0.452894,-0.415574,-0.442174,0.302479,...,-1.260758,1.54195,-0.550306,1.056369,0.734458,0.297625,0.363205,-0.707953,-231.271735,4
14979,0.432027,-0.543121,0.016403,-0.328724,-0.701942,-1.255637,1.184327,-1.916637,-0.48538,0.553426,...,0.738095,0.167206,0.85809,-1.055052,-0.594399,1.897908,0.155332,0.761156,63.361407,4


In [80]:
df_1 = df_1.sample(frac=1).reset_index()

In [81]:
df_1.head()

Unnamed: 0,index,f_0,f_1,f_2,f_3,f_4,f_5,f_6,f_7,f_8,...,f_92,f_93,f_94,f_95,f_96,f_97,f_98,f_99,target,kfold
0,5736,-0.949732,2.648627,-0.956029,-0.659839,0.001028,0.368446,-0.649001,1.010351,-0.266103,...,-0.164147,-1.806141,1.664402,-0.803973,-1.483595,-1.962491,0.009401,-0.603869,-60.918021,1
1,3897,-0.666421,-1.634138,0.711889,-0.361687,1.265688,-0.00118,0.933482,-1.039037,0.047661,...,1.275942,-0.635454,-0.612535,-0.190689,-0.685433,-0.243252,-0.249976,-0.469559,141.110046,1
2,12385,0.422896,2.775942,1.39576,0.006723,0.095657,3.444344,-1.237212,-0.477181,0.658501,...,-1.195856,-1.664252,2.699797,1.095726,-1.105466,0.557212,0.565329,1.421766,364.577941,3
3,10553,-0.450566,-0.75219,-0.311539,0.48403,0.383313,-1.69504,-0.407795,0.564327,-0.652386,...,-0.862184,2.477488,1.673945,-0.054777,2.595066,1.421216,-0.314818,-0.806506,-198.441232,3
4,3731,-2.210837,-0.260171,-0.135624,-1.222982,-0.284386,-0.096053,1.086329,-0.052931,-0.010976,...,0.539758,0.152914,0.122563,-0.189758,-0.87958,0.935963,0.225654,-1.581781,72.961683,1
