# Train Test Split

## Import the relevant libraries

In [2]:
import numpy as np
from sklearn.model_selection import train_test_split

## Generate some data we are going to split

In [8]:
# Let's generate a new data frame 'a' array which will contain all integers from 1 to 100
a = np.arange(1,101)
a

array([  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
        14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,
        27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,
        40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,
        53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,
        66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,
        79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,
        92,  93,  94,  95,  96,  97,  98,  99, 100])

In [10]:
# Similarly, let's create another ndarray 'b', which will contain integers from 401 to 500
# We have intentionally picked these numbers so we can easily compare the two
b = np.arange(401,501)
b

array([401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413,
       414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426,
       427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439,
       440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452,
       453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465,
       466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478,
       479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491,
       492, 493, 494, 495, 496, 497, 498, 499, 500])

## Split the data

In [12]:
# train_test_split(x) splits arrays or matrices into random train and test subsets
train_test_split(a)

[array([  7,  80,  60,  45,  47,  46,   5,  58,   1,  35,  85,  51,  89,
         78,  31,  23,  63,  36,  24,   3,  26,  83,  88,  13,  75,  72,
         87,  50,   4,  34,  38,   2,  66,  74,  59,  14,  53,  40,  21,
         37,  29,  54,  57,  17,  95,  67,  55,  77,  65,  25,  62,  61,
        100,  79,  30,  98,  49,  19,  64,  94,  42,  27,  76,  10,  93,
         69,  84,  48,  32,  16,  56,  22,  68,  71,  52]),
 array([81, 90,  6, 28, 91, 11, 86,  8, 44, 97, 15, 96,  9, 12, 33, 20, 82,
        18, 41, 92, 39, 43, 99, 73, 70])]

In [16]:
# There are several different arguments we can set when we employ this method
# Most often, we have inputs and targets, so we have to split 2 different arrays
# we are simulating this situation by splitting 'a' and 'b'

# You can specify the 'test_size' or the 'train_size' (but the latter is deprecated and will be removed)
# essentially the two have the same meaning 
# Common splits are 75%-25%, 80-20, 85-15, 90-10

# Can employ'random_state' to ensure that when you are splitting the data you will always get the SAME random shuffle

# Note 2 arrays will be split into 4
# The order is train1, test1, train2, test2 
# It is very useful to store them in 4 variables, so we can later use them
a_train, a_test, b_train, b_test = train_test_split(a, b, test_size=0.2, random_state=365)

## Explore the result

In [18]:
# Let's check the shapes
# Basically, we are checking how does the 'test_size' work
a_train.shape, a_test.shape

((80,), (20,))

In [20]:
# Explore manually
a_train

array([ 25,  32,  99,  73,  91,  66,   3,  59,  94,   1,   8,  15,  90,
        54,  31,  20,  77,  82,  30,  35,  95,  42,  38,   7,  11,  50,
        21,  48,   2,  17,  10,  58,  68,  43,  41,  16,  88,  72,  79,
       100,  80,  39,  24,  86,  22,  23,  62,  76,  18,  47,  55,  26,
        60,  19,  71,  64,  51,  63,  65,  28,  12,  78,  13,  44,  75,
        87,  40,   4,  29,  49,  37,  57,  27,  74,   6,  45,  92,  34,
        53,  83])

In [22]:
# Explore manually
a_test

array([ 9, 69, 81, 56, 33, 93, 84, 61, 46, 89, 85, 67, 97,  5, 70, 36, 98,
       96, 14, 52])

In [24]:
b_train.shape, b_test.shape

((80,), (20,))

In [32]:
b_train #compared to a_train, 25 matches with 425 etc

array([425, 432, 499, 473, 491, 466, 403, 459, 494, 401, 408, 415, 490,
       454, 431, 420, 477, 482, 430, 435, 495, 442, 438, 407, 411, 450,
       421, 448, 402, 417, 410, 458, 468, 443, 441, 416, 488, 472, 479,
       500, 480, 439, 424, 486, 422, 423, 462, 476, 418, 447, 455, 426,
       460, 419, 471, 464, 451, 463, 465, 428, 412, 478, 413, 444, 475,
       487, 440, 404, 429, 449, 437, 457, 427, 474, 406, 445, 492, 434,
       453, 483])

In [30]:
b_test #compared to a_test, 9 matches with 409 etc

array([409, 469, 481, 456, 433, 493, 484, 461, 446, 489, 485, 467, 497,
       405, 470, 436, 498, 496, 414, 452])