In [None]:
# Define functions.
def divide_dataset(N, test_prop = 0.1, val_prop = 0.1):
    """
    Divide a dataset into training, validation, and test sets by random sampling.

    Args
    ----
    N: int
          Total number of samples in the data set.
    test_prop: float
          The proportion of the test set.
    val_prop: float
          The proportion of the validation set.
    Returns
    ----
    result: dictionary consisting of keys "train_ids", "val_ids", and "test_ids". 
    """  
    # Get sample numbers for test, validation, and training.
    test_N = round(N * test_prop)
    val_N = round(N * val_prop)
    train_N = N - test_N - val_N

    # Sampling test indices.
    test_index = np.sort(np.random.choice(N, test_N, replace=False))

    # Get trainval indices by removing test indices from all indices.
    trainval_index = np.setdiff1d(np.array(range(N)), test_index)

    # Sampling validation indices.
    val_index = np.sort(np.random.choice(trainval_index, val_N, replace = False))

    # Get train indices by removing validation indicies from trainval indices.
    train_index = np.setdiff1d(trainval_index, val_index)

    # Summarize as a dictionary.
    result = {"train_ids":train_index,"val_ids":val_index,"test_ids":test_index}

    return result

def train_test_split(N, test_prop = 0.2):
    """
    Divide a dataset into training and test sets by random sampling.

    Args
    ----
    N: int
          Total number of samples in the data set.
    test_prop: float
          The proportion of the test set.
    Returns
    ----
    result: dictionary consisting of keys "train_ids" and "test_ids". 
    """  
    # Get sample sizes for training and test.
    test_N = round(N * test_prop)
    train_N = N - test_N
    
    # Sampling test indices.
    test_index = np.sort(np.random.choice(N, test_N, replace=False))
    
    # Get train indices by removing test indices.
    train_index = np.setdiff1d(np.array(range(N)), test_index)
    
    result = {"train_ids":train_index, "test_ids":test_index}
    return result