# Generate Stratified Train-Test Split with Balanced Age and Gender
See this paper [On the Stratification of Multi-label Data](https://link.springer.com/chapter/10.1007/978-3-642-23808-6_10) for principle way of performing stratified split with multi-labels.
## Prerequisite
- scikit-multilearn

Install using the following command:
<code>pip install skmultilearn</code>

In [1]:
import pandas as pd
import numpy as np
import os

from skmultilearn.model_selection import IterativeStratification

In [2]:
# read dataset, convert string values to int
data_file = "training-groundtruth.csv"
data = pd.read_csv(data_file).replace({"female": 0, "male": 1, "Control": 0, "ProbableAD": 1})
data.head()

Unnamed: 0,adressfname,age,gender,educ,dx,mmse
0,adrso002,70,0,,0,26.0
1,adrso003,72,0,,0,30.0
2,adrso004,74,0,,0,30.0
3,adrso005,67,0,,0,27.0
4,adrso006,65,0,,0,28.0


In [3]:
def discretize_continuous_label(feature, discretize_bins):
    feature_bins = np.linspace(start=feature.min(), stop=feature.max(), num=discretize_bins)
    feature = np.digitize(feature, feature_bins, right=True)
    return feature


def split_data(
        df: pd.DataFrame,
        n_folds: int = 5,
        objective: str = "class",
        discretize_bins: int = 5,
        path: str = '/'
):
    """
    Split data into cross-validation sets.
    :param df: Dataframe containing the data.
    :param n_folds: Number of folds.
    :param objective: Objective of the outcome. Either "class" or "reg", which means classification or regression.
    :param discretize_bins: Number of bins to discretize continuous features. We don't want the number to be large because the limited number of data, probably 4 or 5 is enough.
    :param path: Path to save the data.
    :return: None
    """
    if objective == "class":
        output = df.dx.values
    elif objective == "reg":
        # there is one subject with nan mmse score, we will drop it.
        df = df[data.mmse.notna()]
        df = df.reset_index(drop=True)
        output = discretize_continuous_label(df.mmse, discretize_bins)
    else:
        raise ValueError("Not a valid value for 'objective'.")

    X = df.values
    columns = df.columns

    age = discretize_continuous_label(df.age, discretize_bins)
    gender = df.gender.values

    stratified_labels = np.stack([age, gender, output], axis=1)

    k_fold = IterativeStratification(n_splits=n_folds, order=1)
    for i, (train_idx, test_idx) in enumerate(k_fold.split(X, stratified_labels)):
        x_train = X[train_idx]
        x_test = X[test_idx]
        df_train = pd.DataFrame(data=x_train, columns=columns)
        df_test = pd.DataFrame(data=x_test, columns=columns)

        folder_path = f"{path}/split/{objective}/fold_{i}"
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
        df_train.to_csv(f"{folder_path}/train.csv", index=False)
        df_test.to_csv(f"{folder_path}/test.csv", index=False)


In [4]:
# where do you want to save the split files
dir_ = os.getcwd()

# split the dataset by the classification and regression labels separately, but always balanced for age and gender
split_data(data, n_folds=5, objective='class', path=dir_)
split_data(data, n_folds=5, objective='reg', path=dir_)