<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Examples-with-different-splitters" data-toc-modified-id="Examples-with-different-splitters-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Examples with different splitters</a></span><ul class="toc-item"><li><span><a href="#Getting-the-data-and-splitting-X-and-y" data-toc-modified-id="Getting-the-data-and-splitting-X-and-y-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>Getting the data and splitting X and y</a></span></li><li><span><a href="#ShuffleSplit" data-toc-modified-id="ShuffleSplit-1.2"><span class="toc-item-num">1.2&nbsp;&nbsp;</span>ShuffleSplit</a></span></li><li><span><a href="#KFold" data-toc-modified-id="KFold-1.3"><span class="toc-item-num">1.3&nbsp;&nbsp;</span>KFold</a></span></li><li><span><a href="#RepeatedKFold" data-toc-modified-id="RepeatedKFold-1.4"><span class="toc-item-num">1.4&nbsp;&nbsp;</span>RepeatedKFold</a></span></li><li><span><a href="#StratifiedKFold" data-toc-modified-id="StratifiedKFold-1.5"><span class="toc-item-num">1.5&nbsp;&nbsp;</span>StratifiedKFold</a></span></li><li><span><a href="#RepeatedStratifiedKFold" data-toc-modified-id="RepeatedStratifiedKFold-1.6"><span class="toc-item-num">1.6&nbsp;&nbsp;</span>RepeatedStratifiedKFold</a></span></li><li><span><a href="#StratifiedShuffleSplit" data-toc-modified-id="StratifiedShuffleSplit-1.7"><span class="toc-item-num">1.7&nbsp;&nbsp;</span>StratifiedShuffleSplit</a></span></li><li><span><a href="#Classic-train_test_split" data-toc-modified-id="Classic-train_test_split-1.8"><span class="toc-item-num">1.8&nbsp;&nbsp;</span>Classic train_test_split</a></span></li><li><span><a href="#Classic-train_test_split-with-stratify-parameter" data-toc-modified-id="Classic-train_test_split-with-stratify-parameter-1.9"><span class="toc-item-num">1.9&nbsp;&nbsp;</span>Classic train_test_split with <em>stratify</em> parameter</a></span></li></ul></li></ul></div>

This notebook should be read along with <a href="https://medium.com/@mvolkanyurtseven/three-steps-in-case-of-imbalanced-data-and-close-look-at-the-splitter-classes-8b73628a25e6">this blog</a>.

# Examples with different splitters

In [1]:
import pandas as pd
from sklearn.model_selection import KFold, ShuffleSplit, StratifiedKFold, StratifiedShuffleSplit, RepeatedKFold, RepeatedStratifiedKFold, train_test_split

In [2]:
#generic function we'll use for all the splitter classes
def showSplits(splitter,X,y):
    i=0
    for train_index, test_index in splitter.split(X, y):   
        print(f"split no: {i}")    
        y_train, y_test = y[train_index], y[test_index]
        #print(len(y_test))
        print(y_test.value_counts(normalize=True),end="\n\n")
        print(test_index,end="\n\n--------------------------------\n\n")
        i+=1

## Getting the data and splitting X and y

In [3]:
df=pd.read_excel(r"E:\OneDrive\Dökümanlar\GitHub\dataset\traintestsplitdata.xlsx")
#you can find the data in https://github.com/VolkiTheDreamer/dataset/

In [4]:
df["Label"].value_counts()

0    990
1     10
Name: Label, dtype: int64

In [5]:
#only the last 10 item have the label 1
df[df["Label"]==1]

Unnamed: 0,Feature,Label
990,f991,1
991,f992,1
992,f993,1
993,f994,1
994,f995,1
995,f996,1
996,f997,1
997,f998,1
998,f999,1
999,f1000,1


In [6]:
X=df.iloc[:,0]
y=df.iloc[:,1]

On each class, we will try to create train/test splits a few times, changing the **random_state** parameter. Observe how the indices are changing(or not changing)

## ShuffleSplit

The dataset is shuffled every time, and then split. This may cause overlaping of the splits as the documentation says.

In [7]:
ss = ShuffleSplit(n_splits=5, test_size=.1, random_state=0)
showSplits(ss,X,y)

split no: 0
0    0.98
1    0.02
Name: Label, dtype: float64

[993 859 298 553 672 971  27 231 306 706 496 558 784 239 578  55 906 175
  14  77  31 481 310 311 883 788  45 103 760   1 823 710 614 790 408 736
 957 366 918 267 230 996 635 698 251 783 819 141 316 587 331 295 262 432
 862 582 272 270 987 319 569 643 142 202 413 196 264 531 252 576 738 299
 740 247 926 412 389 796 601 654 261 456 386 982 909 693 236 501 497 874
 452 494 923 279 638 485 568 108 367 644]

--------------------------------

split no: 1
0    1.0
Name: Label, dtype: float64

[ 18 939 532 671 506 276 601 845 386 201  84 641 417 586 356 928 166 736
 303 237 711 852 664 628 148 123 342 429 246 795 233 107  27 850  76  17
 687 163  90 390 558 325 787 556 515 798 353 531 463 335 489 145 706 759
 137 127 525 791 774 307 290 424 673  67 772 263 855 810 880 514 701 366
 566 314 825 588 721 584 882  69 895 955 289  53 854 592 675 847 331 816
  80 405 425 519 813  37 981 187 384 459]

--------------------------------

split

## KFold

In [8]:
kf = KFold(n_splits=10,random_state=0,shuffle=True)
showSplits(kf,X,y)

split no: 0
0    0.98
1    0.02
Name: Label, dtype: float64

[  1  14  27  31  45  55  77 103 108 141 142 175 196 202 230 231 236 239
 247 251 252 261 262 264 267 270 272 279 295 298 299 306 310 311 316 319
 331 366 367 386 389 408 412 413 432 452 456 481 485 494 496 497 501 531
 553 558 568 569 576 578 582 587 601 614 635 638 643 644 654 672 693 698
 706 710 736 738 740 760 783 784 788 790 796 819 823 859 862 874 883 906
 909 918 923 926 957 971 982 987 993 996]

--------------------------------

split no: 1
0    0.98
1    0.02
Name: Label, dtype: float64

[  5   8  30  34  54  60  65  97 101 122 144 145 150 204 214 215 240 249
 255 271 278 283 312 315 317 320 351 362 372 378 382 384 385 395 403 418
 424 435 457 458 466 467 483 489 492 495 513 520 523 529 545 580 583 596
 608 615 624 630 636 648 653 666 679 688 713 717 722 751 753 757 767 769
 773 780 785 789 795 826 832 834 844 860 864 873 876 886 902 904 922 927
 939 942 945 946 958 975 986 989 990 999]

----------------------------

In [9]:
#changing the shuffle parameter
kf = KFold(n_splits=10,shuffle=False)
showSplits(kf,X,y)

split no: 0
0    1.0
Name: Label, dtype: float64

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
 96 97 98 99]

--------------------------------

split no: 1
0    1.0
Name: Label, dtype: float64

[100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
 190 191 192 193 194 195 196 197 198 199]

--------------------------------

split no: 2
0    1.0
Name: Label, dtype: float64

[200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 21

## RepeatedKFold

It can be used when one requires to run KFold n times, producing different splits in each repetition.

In [10]:
rkf = RepeatedKFold(n_splits=10,n_repeats=2,random_state=0)
showSplits(rkf,X,y)

split no: 0
0    0.98
1    0.02
Name: Label, dtype: float64

[  1  14  27  31  45  55  77 103 108 141 142 175 196 202 230 231 236 239
 247 251 252 261 262 264 267 270 272 279 295 298 299 306 310 311 316 319
 331 366 367 386 389 408 412 413 432 452 456 481 485 494 496 497 501 531
 553 558 568 569 576 578 582 587 601 614 635 638 643 644 654 672 693 698
 706 710 736 738 740 760 783 784 788 790 796 819 823 859 862 874 883 906
 909 918 923 926 957 971 982 987 993 996]

--------------------------------

split no: 1
0    0.98
1    0.02
Name: Label, dtype: float64

[  5   8  30  34  54  60  65  97 101 122 144 145 150 204 214 215 240 249
 255 271 278 283 312 315 317 320 351 362 372 378 382 384 385 395 403 418
 424 435 457 458 466 467 483 489 492 495 513 520 523 529 545 580 583 596
 608 615 624 630 636 648 653 666 679 688 713 717 722 751 753 757 767 769
 773 780 785 789 795 826 832 834 844 860 864 873 876 886 902 904 922 927
 939 942 945 946 958 975 986 989 990 999]

----------------------------

split no: 19
0    1.0
Name: Label, dtype: float64

[  3  10  15  21  33  60  64  70  78  79  83 103 106 117 129 131 165 172
 185 205 213 236 252 257 267 281 284 287 311 324 328 349 359 371 373 391
 395 396 415 416 418 439 448 470 472 476 486 500 504 505 507 523 528 545
 585 590 594 595 600 614 617 635 656 662 685 693 699 703 718 725 728 734
 741 749 754 779 783 784 797 808 819 826 833 836 844 848 849 853 858 879
 897 904 907 914 919 947 952 966 969 978]

--------------------------------



## StratifiedKFold

In [11]:
skf = StratifiedKFold(n_splits=10,shuffle=False) #as shuffle is false, no need for random state
showSplits(skf,X,y)

split no: 0
0    0.99
1    0.01
Name: Label, dtype: float64

[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35
  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53
  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71
  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89
  90  91  92  93  94  95  96  97  98 990]

--------------------------------

split no: 1
0    0.99
1    0.01
Name: Label, dtype: float64

[ 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
 189 190 191 192 193 194 195 196 197 991]

----------------------------

In [12]:
#now changing the shuffle parameter
skf = StratifiedKFold(n_splits=10,shuffle=True)
showSplits(skf,X,y)

split no: 0
0    0.99
1    0.01
Name: Label, dtype: float64

[ 18  19  22  31  33  46  63  66  85 103 110 113 115 120 140 160 183 189
 191 202 206 208 218 235 244 250 251 257 262 275 288 290 299 303 349 353
 356 374 378 384 389 391 412 413 423 437 442 448 458 460 474 478 499 504
 551 560 561 565 568 569 579 583 588 599 602 608 649 655 669 702 706 719
 730 741 753 758 763 784 788 803 818 829 832 833 834 846 873 890 903 904
 910 911 912 915 921 930 932 954 982 990]

--------------------------------

split no: 1
0    0.99
1    0.01
Name: Label, dtype: float64

[  5  14  24  29  34  51  61  73  77  83  91  94 106 119 128 130 142 145
 147 156 162 192 197 231 236 237 277 300 304 365 371 390 411 422 434 439
 440 444 456 467 479 481 482 483 488 500 516 519 521 529 541 545 553 571
 585 613 629 637 668 674 675 680 708 714 720 721 728 731 746 747 751 759
 761 774 775 776 789 793 796 804 813 825 826 831 838 839 852 855 863 886
 898 927 948 953 961 964 974 978 986 993]

----------------------------

## RepeatedStratifiedKFold

In [13]:
rskf = RepeatedStratifiedKFold(n_splits=10, n_repeats=2, random_state=0)
showSplits(rskf,X,y)

split no: 0
0    0.99
1    0.01
Name: Label, dtype: float64

[  4   5   7  10  15  27  89  95 107 127 138 139 141 180 199 205 213 219
 220 231 236 238 245 246 262 275 288 290 301 306 307 313 314 320 322 331
 334 347 365 373 379 380 391 417 446 450 462 463 467 508 513 514 516 538
 543 560 572 587 604 621 641 658 669 671 685 693 717 718 719 721 728 758
 759 788 807 812 829 830 833 834 835 837 839 840 851 855 873 892 896 902
 903 909 915 924 928 960 967 975 981 990]

--------------------------------

split no: 1
0    0.99
1    0.01
Name: Label, dtype: float64

[ 13  17  33  59  85  99 115 130 160 167 186 197 203 204 222 223 253 263
 269 278 284 289 308 311 324 330 332 333 348 357 358 360 381 386 401 402
 404 412 416 422 429 456 480 485 488 494 496 497 521 526 530 544 564 567
 571 580 602 638 646 651 652 656 657 691 694 698 699 706 707 745 763 765
 766 779 786 789 790 791 797 821 828 845 848 853 872 876 886 888 890 906
 926 929 934 945 949 950 957 966 986 997]

----------------------------

## StratifiedShuffleSplit

In [14]:
sss = StratifiedShuffleSplit(n_splits=10, test_size=0.2, random_state=0)
showSplits(sss,X,y)

split no: 0
0    0.99
1    0.01
Name: Label, dtype: float64

[778 639 782 995 265  32 227 257 201  28 368 131 882 559 537 774  41 926
  42 431 387 562 599 847 739 410 305 714 510 396 802 488  47 763 123 174
 754 277 600 226 286  72 256 660 550 260 723 459 244 912 269 574 275 139
 841 169  53 296 423 280 307 770  36 598 430 543 119 846 966 755 324 147
 921 606  23 292 607  94 551 756 594 684 248 719 128 373 328 184 115 512
 956 335 629 901 973 555 850 705  58 377 633 610  84 209 637 611 835  11
 398 130 868 515 734 359  70 870 916  80 804 816 939 791 288 544 861   9
  99 845 273 910 970 709 203 779 589 148 925 697 411 952 197 177 461 560
 621 675 472 860 192  91 323 370 797 690 444 183 749 992 933 182 699 442
 486  43 806 291 863 832 388  86 450  95 528 143 565 777 581 383  19 707
 321 896 617  87 696 180 554 151 371 448 937 677  57 888 659 849  98 314
 853  82]

--------------------------------

split no: 1
0    0.99
1    0.01
Name: Label, dtype: float64

[201  97 651 283 905 743 772 1

## Classic train_test_split

In [15]:
for i in range(3):
    print(f"i={i}")
    X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.2,random_state=i)
    #print(len(y_test))
    print(y_test.index)
    print(y_test.value_counts(normalize=True),end="\n\n**************\n\n")

i=0
Int64Index([993, 859, 298, 553, 672, 971,  27, 231, 306, 706,
            ...
            834, 832, 435,   5, 769, 679, 722, 215, 653, 150],
           dtype='int64', length=200)
0    0.98
1    0.02
Name: Label, dtype: float64

**************

i=1
Int64Index([507, 818, 452, 368, 242, 929, 262, 810, 318,  49,
            ...
            699, 826, 233, 584, 974, 430, 874, 550, 608, 207],
           dtype='int64', length=200)
0    0.985
1    0.015
Name: Label, dtype: float64

**************

i=2
Int64Index([ 37, 726, 846, 295, 924, 658, 682, 286, 880, 272,
            ...
            484, 451, 577,  77, 575, 839, 810, 930, 616, 809],
           dtype='int64', length=200)
0    1.0
Name: Label, dtype: float64

**************



At the 3rd trial, we obtained one-class test data, which is not good for the success of the model

## Classic train_test_split with *stratify* parameter

In [16]:
for i in range(3): #3 is a arbitrary number, whatever you choose, the proportions would stay tha same
    print(f"i={i}")
    X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.2,random_state=i,stratify=y)
    #print(len(y_test))
    print(y_test.index)
    print(y_test.value_counts(normalize=True),end="\n\n**************\n\n")

i=0
Int64Index([778, 639, 782, 995, 265,  32, 227, 257, 201,  28,
            ...
            937, 677,  57, 888, 659, 849,  98, 314, 853,  82],
           dtype='int64', length=200)
0    0.99
1    0.01
Name: Label, dtype: float64

**************

i=1
Int64Index([170, 844, 416, 588, 655, 393, 466, 927,  75, 297,
            ...
            321, 628, 367, 129, 606, 515, 492, 562, 621, 753],
           dtype='int64', length=200)
0    0.99
1    0.01
Name: Label, dtype: float64

**************

i=2
Int64Index([367, 336, 348, 148, 539, 901, 224, 366, 873, 168,
            ...
            784, 764, 875, 527, 951, 487, 994,  61, 760, 758],
           dtype='int64', length=200)
0    0.99
1    0.01
Name: Label, dtype: float64

**************

