In [26]:
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
from sklearn.datasets import make_regression, make_classification
import numpy as np
import pandas as pd


seed = 23
n_samples, n_features = 10000, 10
rng = np.random.RandomState(seed)

In [9]:
X, y = make_regression(n_samples, n_features, random_state=rng)

In [10]:
X

array([[-4.58239332e-01, -2.31174144e-01,  1.48403992e+00, ...,
         1.04525677e+00,  9.35572992e-02,  1.95460397e+00],
       [ 8.42035387e-02,  3.26011179e-01, -1.04484218e-01, ...,
        -8.76047552e-01,  1.00626734e+00, -2.42725658e-01],
       [-1.30822891e-02,  5.88473314e-01, -1.41204680e+00, ...,
        -4.31196128e-01, -1.26424982e+00, -3.88891262e-01],
       ...,
       [ 1.59467201e+00,  3.05973375e+00,  2.79991984e+00, ...,
         1.29513239e-01,  1.25824178e+00,  1.84573710e+00],
       [-4.80649163e-01,  1.72656540e+00,  1.04270320e+00, ...,
        -3.40189476e-01, -3.11752701e-01, -1.99238970e-01],
       [ 1.92684959e+00, -2.87173177e+00,  1.39559117e+00, ...,
        -3.97895336e-04, -2.78057779e-02,  3.40758638e-01]])

In [11]:
y

array([ 113.66270029,  107.99109493, -442.82334839, ...,  346.19944174,
         48.77139234, -383.54835383])

In [12]:
print(f"First two rows of data:\n {np.round(X[:2],2)}")

First two rows of data:
 [[-0.46 -0.23  1.48 -0.31 -0.48  0.94  0.16  1.05  0.09  1.95]
 [ 0.08  0.33 -0.1   0.2   0.54 -0.83  0.68 -0.88  1.01 -0.24]]


In [15]:
y_ = pd.Series(y)
y_

0       113.662700
1       107.991095
2      -442.823348
3       104.462599
4      -195.337412
           ...    
9995   -251.593463
9996   -118.858454
9997    346.199442
9998     48.771392
9999   -383.548354
Length: 10000, dtype: float64

In [16]:
y_.describe()

count    10000.000000
mean        -5.050810
std        222.434158
min       -813.574055
25%       -152.195532
50%         -3.890952
75%        146.296339
max        908.739307
dtype: float64

In [17]:
y_cat, bins = pd.cut(y_, bins=10, labels=range(10), retbins=True)

In [19]:
print(y_cat.value_counts())

4    3014
5    2519
3    1956
6    1194
2     767
7     293
1     178
8      55
0      19
9       5
dtype: int64


In [22]:
X_ = pd.DataFrame(X)

In [23]:
X_

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0,-0.458239,-0.231174,1.484040,-0.309721,-0.480573,0.935285,0.157378,1.045257,0.093557,1.954604
1,0.084204,0.326011,-0.104484,0.198806,0.537453,-0.829147,0.680906,-0.876048,1.006267,-0.242726
2,-0.013082,0.588473,-1.412047,-0.614927,-0.517049,-1.022472,-0.925048,-0.431196,-1.264250,-0.388891
3,-0.302820,-0.797809,0.592562,0.052309,-1.120761,0.752672,1.447970,0.329184,0.127921,0.437623
4,-2.557962,-1.567870,0.582708,1.387223,-1.450129,1.258563,1.046836,-2.101903,-0.887932,0.947598
...,...,...,...,...,...,...,...,...,...,...
9995,0.317593,-1.515174,-0.020282,-1.059574,0.558489,0.388389,-0.339238,0.143179,-2.208290,0.227925
9996,0.354165,0.709040,0.632137,-0.774783,-0.673294,-1.083729,0.934332,0.893051,-1.111011,-0.431090
9997,1.594672,3.059734,2.799920,-0.344799,-0.199630,-0.383699,0.774620,0.129513,1.258242,1.845737
9998,-0.480649,1.726565,1.042703,1.392867,-0.731790,-0.745989,0.595309,-0.340189,-0.311753,-0.199239


In [21]:
X_.describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
count,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0
mean,-0.010624,-0.005884,1.2e-05,-0.001914,-0.026992,-0.011965,0.008083,-0.018381,-0.000722,-0.002267
std,1.014159,0.993799,1.000025,0.992644,0.99753,0.999088,0.993898,0.991265,1.020129,1.009748
min,-4.152989,-4.172294,-3.960401,-3.826886,-4.422657,-4.108991,-4.01581,-4.287271,-3.625988,-4.045573
25%,-0.683566,-0.675737,-0.673139,-0.678429,-0.697577,-0.699026,-0.659753,-0.67823,-0.690844,-0.68674
50%,-0.005726,0.00264,-0.004549,-0.000145,-0.025879,-0.018316,0.006912,-0.012854,0.004401,-0.00257
75%,0.673601,0.662853,0.67143,0.677847,0.649557,0.669323,0.661144,0.654594,0.696231,0.680064
max,4.521074,3.361317,4.589343,3.70346,3.823261,4.067649,3.916595,3.623694,3.929562,3.5687


In [27]:
split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=rng)

In [28]:
for train_index, test_index in split.split(X, y_cat):
    # the split function generates i.i.d indices for each class 
    # to separate the train and test set by a given number of splits
    strat_train_set = X_.loc[train_index]
    strat_y_train = y_cat.loc[train_index]
    strat_test_set = X_.loc[test_index]
    strat_y_test = y_cat.loc[test_index]

In [29]:
strat_train_set

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
5498,-0.722896,-2.167057,-0.828859,-0.297737,0.394178,-0.864133,-0.925511,0.364701,0.948596,-1.121018
4301,0.438987,-1.972582,-0.680537,0.941492,-0.288819,0.192195,-0.101067,-0.141447,-1.151059,0.049520
3215,0.719472,-1.294699,0.129087,-0.163816,-0.180900,-1.534752,-1.306688,0.844039,-0.156383,0.821687
5991,0.604157,0.042875,-1.640517,0.177905,-0.506002,1.140042,0.220835,-1.229406,-1.003975,0.807004
8745,2.079151,0.351218,0.844595,-0.408200,0.574743,0.735139,0.762412,-0.727781,1.583575,0.428892
...,...,...,...,...,...,...,...,...,...,...
277,0.714483,0.516366,-0.585050,-0.544144,0.841181,-0.555988,-0.303443,-0.819769,0.993353,0.786345
4415,0.339498,0.744561,1.552490,-0.913270,0.437374,-0.664930,-0.387265,0.616102,0.888187,2.606137
2233,-0.798417,-0.283834,-0.878529,0.811151,-0.444508,1.408542,-0.690978,0.147137,0.273363,0.410989
944,0.178873,0.348849,-1.331275,-2.259645,0.351337,-0.100015,1.665513,-0.260179,-0.478667,-3.040155


In [32]:
strat_y_train

5498    3
4301    4
3215    3
5991    4
8745    6
       ..
277     4
4415    5
2233    4
944     3
7699    4
Length: 8000, dtype: category
Categories (10, int64): [0 < 1 < 2 < 3 ... 6 < 7 < 8 < 9]

## y_cat is our goal values