# Classification Experiment - Retrospective Study with Neural Fragility

In [1]:
%load_ext lab_black

In [2]:
import collections
import json
import os
from itertools import product
from pathlib import Path
import sys
import pandas as pd
import numpy as np
from mne_bids.path import get_entities_from_fname
from natsort import natsorted
from rerf.rerfClassifier import rerfClassifier

# comparative classifiers
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.svm import SVC
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.neighbors import KNeighborsClassifier


from sklearn.calibration import calibration_curve
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import brier_score_loss
from sklearn.metrics import roc_curve
from sklearn.preprocessing import OrdinalEncoder
from sklearn.utils import resample
from sklearn.metrics import (
    average_precision_score,
    roc_auc_score,
    f1_score,
    roc_curve,
    balanced_accuracy_score,
    accuracy_score,
    auc,
    brier_score_loss,
    plot_precision_recall_curve,
    average_precision_score,
    precision_recall_curve,
)
from sklearn.model_selection import GroupKFold, cross_validate
from sklearn.utils import resample
from sklearn.calibration import calibration_curve

import dabest
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt

sys.path.append("../")
# functions related to the feature comparison experiment
from analysis.publication.study import (
    load_patient_dict,
    determine_feature_importances,
    extract_Xy_pairs,
    format_supervised_dataset,
    _sequential_aggregation,
    tune_hyperparameters,
)
from analysis.publication.extract_datasets import load_ictal_frag_data
from analysis.publication.utils import NumpyEncoder

from sample_code.io import read_participants_tsv
from sample_code.utils import _load_turbo, _plot_roc_curve

%matplotlib inline
# %load_ext autoreload
# %autoreload 2

In [3]:
def average_roc(fpr, tpr):
    """Compute average ROC statistics."""
    tprs = []
    aucs = []
    mean_fpr = np.linspace(0, 1, 200)

    n_splits = len(fpr)
    print(f"Computing average ROC over {n_splits} CV splits")
    for i in range(n_splits):
        interp_tpr = np.interp(mean_fpr, fpr[i], tpr[i])
        interp_tpr[0] = 0.0
        tprs.append(interp_tpr)
        aucs.append(auc(mean_fpr, interp_tpr))

    mean_tpr = np.mean(tprs, axis=0)
    mean_tpr[-1] = 1.0
    mean_auc = auc(mean_fpr, mean_tpr)
    std_auc = np.std(aucs)

    std_tpr = np.std(tprs, axis=0)
    return mean_fpr, tprs, aucs

In [4]:
def combine_patient_predictions(
    ytrues, ypred_probs, subject_groups, pat_predictions=None, pat_true=None
):
    if pat_predictions is None or pat_true is None:
        pat_predictions = collections.defaultdict(list)
        pat_true = dict()

    # loop through things
    for ytrue, ypred_proba, subject in zip(ytrues, ypred_probs, subject_groups):
        pat_predictions[subject].append(float(ypred_proba))

        if subject not in pat_true:
            pat_true[subject] = ytrue[0]
        else:
            if pat_true[subject] != ytrue[0]:
                raise RuntimeError("wtf subject should all match...")
    return pat_predictions, pat_true

In [5]:
from scipy.spatial.distance import cdist

# get line between optimum and clinical op point
def create_line(x1, x2, y1, y2, n_points=200):
    slope = (y2 - y1) / (x2 - x1)

    xs = np.linspace(x1, x2, n_points)
    ys = np.linspace(y1, y2, n_points)

    return xs, ys


def find_intersect_idx(x1s, y1s, x2s, y2s):
    """Help find intersection points between two curves."""
    euc_dists = []
    points = np.vstack((x2s, y2s)).T
    for idx, (x1, y1) in enumerate(zip(x1s, y1s)):
        point = np.array([x1, y1])[np.newaxis, :]
        dists = cdist(points, point)
        euc_dists.append(min(dists))
    return np.argmin(euc_dists)

# Define Data Directories

In [12]:
# set seed and randomness for downstream reproducibility
seed = 12345
random_state = 12345
np.random.seed(seed)

# proportion of subjects to use for training
train_size = 0.6

# classification model to use
clf_type = "mtmorf"

# BIDS related directories
bids_root = Path("/Volumes/Seagate Portable Drive/data")
bids_root = Path("/Users/adam2392/Dropbox/epilepsy_bids/")
bids_root = Path("/home/adam2392/hdd/Dropbox/epilepsy_bids/")
deriv_path = bids_root / "derivatives"
source_path = bids_root / "sourcedata"

# metadata table
excel_fpath = source_path / "organized_clinical_datasheet_raw.xlsx"

# where to store the cross-validation indices to split patients on
intermed_fpath = Path(deriv_path) / "baselinesliced"

# where to save results
study_path = Path(deriv_path) / "study"

# feature names
feature_names = [
    "fragility",
]

In [7]:
# defining evaluation criterion
metric = "roc_auc"
BOOTSTRAP = False

# define hyperparameters
windows = [
    (-80, 25),
]
thresholds = [
    0.5,
    0.6,
    0.7,
]
weighting_funcs = [None]

max_depth = [None, 5, 10]
max_features = ["auto", "log2"]
IMAGE_HEIGHT = 20
model_params = {
    "n_estimators": 500,
    "max_depth": max_depth[0],
    "" "max_features": max_features[0],
    "n_jobs": -1,
    "random_state": random_state,
}
model_params.update(
    {
        #         "projection_matrix": "S-RerF",
        "projection_matrix": "MT-MORF",
        "image_height": IMAGE_HEIGHT,
        "image_width": np.abs(windows[0]).sum(),
        "patch_height_max": 4,
        "patch_height_min": 1,
        "patch_width_max": 8,
        "patch_width_min": 1,
    }
)

# Define Classifiers

In [25]:
import os

os.environ["TF_XLA_FLAGS"] = "--tf_xla_enable_xla_devices"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.python.keras.wrappers.scikit_learn import KerasClassifier
from tensorflow.python.client import device_lib

print("GPUs: ", len(tf.config.experimental.list_physical_devices("GPU")))
print(device_lib.list_local_devices())


# CNN if tensorflow is installed
# Build CNN model
def _build_cnn():

    model = models.Sequential()
    model.add(
        layers.Conv2D(
            32,
            (3, 3),
            activation="relu",
            input_shape=(IMAGE_HEIGHT, np.sum(windows), 1),
        )
    )
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation="relu"))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation="relu"))
    model.add(layers.Flatten())
    model.add(layers.Dense(64, activation="relu"))
    model.add(layers.Dense(n_classes, activation="softmax"))

    model.compile(
        optimizer="adam",
        loss="categorical_crossentropy",
        metrics=["accuracy", tf.keras.metrics.AUC(multi_label=False)],
    )
    return model


# print(device_lib.list_local_devices('GPU'))
print(tf.test.is_gpu_available())
# cnn = KerasClassifier(_build_cnn, verbose=0)

GPUs:  1
[name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 14663678256107216531
, name: "/device:XLA_CPU:0"
device_type: "XLA_CPU"
memory_limit: 17179869184
locality {
}
incarnation: 16907912039669404536
physical_device_desc: "device: XLA_CPU device"
, name: "/device:XLA_GPU:0"
device_type: "XLA_GPU"
memory_limit: 17179869184
locality {
}
incarnation: 15548494327529476631
physical_device_desc: "device: XLA_GPU device"
, name: "/device:GPU:0"
device_type: "GPU"
memory_limit: 10759418688
locality {
  bus_id: 2
  numa_node: 1
  links {
  }
}
incarnation: 1571639952698376271
physical_device_desc: "device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:41:00.0, compute capability: 6.1"
]
Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.
True


In [17]:
## Setup for run
names = {
    #     "Log. Reg": "blue",
    #     "Lin. SVM": "firebrick",
    #     "SVM": "purple",
    #         "kNN": "black",
    #     "RF": "#f86000",
    #     "MLP": "green",
    #     "xgb": "black",
    #     "dummby": "gray",
    "cnn": "red",
    #     "srerf": "red",
    # "mtmorf": "orange"
}

ncores = -1
num_runs = 1
n_est = 500  # number of estimators

classifiers = [
    LogisticRegression(random_state=0, n_jobs=ncores, solver="liblinear"),
    #     LinearSVC(),
    #     SVC(C=1.0, probability=True, kernel="rbf", gamma="auto", random_state=0),
    #     RandomForestClassifier(n_estimators=n_est, max_features="auto", n_jobs=ncores),
    #     MLPClassifier(hidden_layer_sizes=(n_est,), random_state=0, max_iter=1000),
    #     GradientBoostingClassifier(random_state=random_state),
    #     DummyClassifier(strategy="most_frequent", random_state=random_state)
    #     rerfClassifier(**model_params),
    # rerfClassifier(**model_params),
    #     KerasClassifier(_build_cnn, verbose=0)
]

In [8]:
def load_data(
    feature_name,
    deriv_path,
    excel_fpath,
    patient_aggregation_method=None,
    intermed_fpath=None,
    save_cv_indices: bool = False,
):
    print(f"Loading data from {intermed_fpath}")
    # load unformatted datasets
    # i.e. datasets without data-hyperparameters applied
    if feature_name == "fragility":
        if not intermed_fpath:
            (
                unformatted_X,
                y,
                subject_groups,
                sozinds_list,
                onsetwin_list,
            ) = load_ictal_frag_data(deriv_path, excel_fpath=excel_fpath)
        else:
            (
                unformatted_X,
                y,
                subject_groups,
                sozinds_list,
                onsetwin_list,
            ) = load_ictal_frag_data(intermed_fpath, excel_fpath=excel_fpath)
    else:
        if not intermed_fpath:
            feature_subject_dict = load_patient_dict(
                deriv_path, feature_name, task="ictal", subjects=subjects
            )
            # get the (X, y) tuple pairs
            (
                unformatted_X,
                y,
                sozinds_list,
                onsetwin_list,
                subject_groups,
            ) = extract_Xy_pairs(
                feature_subject_dict,
                excel_fpath=excel_fpath,
                patient_aggregation_method=patient_aggregation_method,
                verbose=False,
            )
        else:
            # get the (X, y) tuple pairs
            feature_fpath = intermed_fpath / f"{feature_name}_unformatted.npz"

            with np.load(feature_fpath, allow_pickle=True) as data_dict:
                unformatted_X, y = data_dict["unformatted_X"], data_dict["y"]
                sozinds_list, onsetwin_list, subject_groups = (
                    data_dict["sozinds_list"],
                    data_dict["onsetwin_list"],
                    data_dict["subject_groups"],
                )
    # get the dataset parameters loaded in
    dataset_params = {"sozinds_list": sozinds_list, "onsetwin_list": onsetwin_list}

    # format supervised learning datasets
    # define preprocessing to convert labels/groups into numbers
    enc = OrdinalEncoder()  # handle_unknown='ignore', sparse=False
    #     subject_groups = enc.fit_transform(np.array(subjects)[:, np.newaxis])
    y = enc.fit_transform(np.array(y)[:, np.newaxis])
    subject_groups = np.array(subject_groups)

    # create held-out test dataset
    # create separate pool of subjects for testing dataset
    # 1. Cross Validation Training / Testing Split
    if save_cv_indices:
        gss = GroupShuffleSplit(n_splits=10, train_size=0.5, random_state=random_state)
        for jdx, (train_inds, test_inds) in enumerate(
            gss.split(unformatted_X, y, subject_groups)
        ):
            # if jdx != 7:
            #     continue
            train_pats = np.unique(subject_groups[train_inds])
            test_pats = np.unique(subject_groups[test_inds])
            np.savez_compressed(
                study_path / "inds" / f"{feature_name}-srerf-{jdx}-inds.npz",
                train_inds=train_inds,
                test_inds=test_inds,
                train_pats=train_pats,
                test_pats=test_pats,
            )
    return unformatted_X, y, subject_groups, dataset_params

In [23]:
def run_clf_validation(
    clf_type,
    clf_func,
    unformatted_X,
    y,
    subject_groups,
    dataset_params,
    study_path,
    windows,
    thresholds,
    weighting_funcs,
):
    #     if y.ndim != 1:
    #         y = y.copy().squeeze()
    #         y = y.copy[:, np.newaxis]

    unformatted_X = unformatted_X.copy()
    y = y.copy()
    subject_groups = subject_groups.copy()

    # run this without the above for a warm start
    for jdx in range(0, 10):
        cv_scores = collections.defaultdict(list)

        with np.load(
            # study_path / "inds" / 'clinical_complexity' / f"{jdx}-inds.npz",
            study_path
            / "inds"
            / "fixed_folds_subjects"
            / f"fragility-srerf-{jdx}-inds.npz",
            allow_pickle=True,
        ) as data_dict:
            # train_inds, test_inds = data_dict["train_inds"], data_dict["test_inds"]
            train_pats, test_pats = data_dict["train_pats"], data_dict["test_pats"]

        # set train indices based on which subjects
        train_inds = [
            idx for idx, sub in enumerate(subject_groups) if sub in train_pats
        ]
        test_inds = [idx for idx, sub in enumerate(subject_groups) if sub in test_pats]

        # note that training data (Xtrain, ytrain) will get split again
        # testing dataset (held out until evaluation)
        subjects_test = subject_groups[test_inds]
        print(subjects_test)

        if len(np.unique(y[test_inds])) == 1:
            print(f"Skipping group cv iteration {jdx} due to degenerate test set")
            continue

        """Run cross-validation."""
        window = windows[0]
        threshold = thresholds[0]
        weighting_func = weighting_funcs[0]
        X_formatted, dropped_inds = format_supervised_dataset(
            unformatted_X,
            **dataset_params,
            window=window,
            threshold=threshold,
            weighting_func=weighting_func,
        )

        # run cross-validation
        # instantiate model
        #         if clf_func == RandomForestClassifier:
        #             # instantiate the classifier
        #             clf = clf_func(**model_params)
        #                 elif clf_func == rerfClassifier:
        #                     model_params.update({"image_width": np.abs(window).sum()})
        #                     clf = clf_func(**model_params)
        #                 else:
        clf = clf_func
        print("Updated classifier: ", clf)

        # perform CV using Sklearn
        scoring_funcs = {
            "roc_auc": roc_auc_score,
            "accuracy": accuracy_score,
            "balanced_accuracy": balanced_accuracy_score,
            "average_precision": average_precision_score,
        }

        def dummy_cv(train, test):
            yield train_inds, test_inds

        n_samps = len(y)
        if isinstance(clf, KerasClassifier):
            print(X_formatted.shape)
            X_formatted = X_formatted.reshape(n_samps, 20, np.sum(window), 1)
            print("new shape: ", X_formatted.shape)

        scores = cross_validate(
            clf,
            X_formatted,
            y,
            groups=subject_groups,
            cv=dummy_cv(train_inds, test_inds),
            scoring=list(scoring_funcs.keys()),
            return_estimator=True,
            return_train_score=True,
        )

        # get the best classifier based on pre-chosen metric
        test_key = f"test_{metric}"
        print(scores.keys())
        print(scores)

        # removing array like structure
        scores = {key: val[0] for key, val in scores.items()}
        estimator = scores.pop("estimator")
        print("Using estimator ", estimator)

        # resample the held-out test data via bootstrap
        test_sozinds_list = np.asarray(dataset_params["sozinds_list"])[test_inds]
        test_onsetwin_list = np.asarray(dataset_params["onsetwin_list"])[test_inds]
        # evaluate on the testing dataset
        X_test, y_test = np.array(X_formatted)[test_inds, ...], np.array(y)[test_inds]
        groups_test = np.array(subject_groups)[test_inds]

        if BOOTSTRAP:
            for i in range(500):
                X_boot, y_boot, sozinds, onsetwins = resample(
                    X_test,
                    y_test,
                    test_sozinds_list,
                    test_onsetwin_list,
                    n_samples=len(y_test),
                )
        else:
            X_boot, y_boot = X_test.copy(), y_test.copy()

        # evaluate on the test set
        y_pred_prob = estimator.predict_proba(X_boot)[:, 1]
        y_pred = estimator.predict(X_boot)

        # store the actual outcomes and the predicted probabilities
        cv_scores["validate_ytrue"].append(list(y_test))
        cv_scores["validate_ypred_prob"].append(list(y_pred_prob))
        cv_scores["validate_ypred"].append(list(y_pred))
        cv_scores["validate_subject_groups"].append(list(groups_test))

        # store ROC curve metrics on the held-out test set
        fpr, tpr, thresholds = roc_curve(y_boot, y_pred_prob, pos_label=1)
        fnr, tnr, neg_thresholds = roc_curve(y_boot, y_pred_prob, pos_label=0)
        cv_scores["validate_fpr"].append(list(fpr))
        cv_scores["validate_tpr"].append(list(tpr))
        cv_scores["validate_fnr"].append(list(fnr))
        cv_scores["validate_tnr"].append(list(tnr))
        cv_scores["validate_thresholds"].append(list(thresholds))

        print("Done analyzing ROC stats...")

        # run the feature importances
        # compute calibration curve
        try:
            fraction_of_positives, mean_predicted_value = calibration_curve(
                y_boot, y_pred_prob, n_bins=10, strategy="quantile"
            )
        except Exception as e:
            try:
                print(e)
                fraction_of_positives, mean_predicted_value = calibration_curve(
                    y_boot, y_pred_prob, n_bins=5, strategy="uniform"
                )
            except Exception as e:
                print(e)
                #         finally:
                #             print(e)
                fraction_of_positives = [None]
                mean_predicted_value = [None]
        clf_brier_score = np.round(
            brier_score_loss(y_boot, y_pred_prob, pos_label=np.array(y_boot).max()), 2
        )

        print("Done analyzing calibration stats...")

        # store ingredients for a calibration curve
        cv_scores["validate_brier_score"].append(float(clf_brier_score))
        cv_scores["validate_fraction_pos"].append(list(fraction_of_positives))
        cv_scores["validate_mean_pred_value"].append(list(mean_predicted_value))

        # store outputs to run McNemars test and Cochrans Q test
        # get the shape of a single feature "vector" / structure array
        pat_predictions, pat_true = combine_patient_predictions(
            y_boot, y_pred_prob, subjects_test
        )
        cv_scores["validate_pat_predictions"].append(pat_predictions)
        cv_scores["validate_pat_true"].append(pat_true)

        # store output for feature importances
        if clf_type == "rf":
            n_jobs = -1
        else:
            n_jobs = 1
        results = determine_feature_importances(
            estimator, X_boot, y_boot, n_jobs=n_jobs
        )
        imp_std = results.importances_std
        imp_vals = results.importances_mean
        cv_scores["validate_imp_mean"].append(list(imp_vals))
        cv_scores["validate_imp_std"].append(list(imp_std))

        print("Done analyzing feature importances...")

        # save intermediate analyses
        clf_func_path = (
            study_path
            / "clf-train-vs-test"
            / "classifiers"
            / f"{clf_type}_classifiers_{feature_name}_{jdx}.npz"
        )
        clf_func_path.parent.mkdir(exist_ok=True, parents=True)

        # nested CV scores
        nested_scores_fpath = (
            study_path
            / "clf-train-vs-test"
            / f"study_cv_scores_{clf_type}_{feature_name}_{jdx}.json"
        )

        # save the estimators
        # if clf_type not in ["srerf", "mtmorf"]:
        #    np.savez_compressed(clf_func_path, estimators=estimator)

        # save all the master scores as a JSON file
        with open(nested_scores_fpath, "w") as fin:
            json.dump(cv_scores, fin, cls=NumpyEncoder)

        del estimator
        del scores

# Load Data and Run Classification Experiments

In [10]:
print(study_path)

/home/adam2392/hdd/Dropbox/epilepsy_bids/derivatives/study


In [13]:
feature_name = "fragility"


unformatted_X, y, subject_groups, dataset_params = load_data(
    feature_name,
    deriv_path,
    excel_fpath,
    patient_aggregation_method=None,
    intermed_fpath=intermed_fpath,
    save_cv_indices=False,
)

Loading data from /Users/adam2392/Dropbox/epilepsy_bids/derivatives/baselinesliced
Got 94 subjects
Got  431  datasets.
Got  94  patients
dict_keys(['jh101', 'jh103', 'jh105', 'jh107', 'jh108', 'la00', 'la01', 'la02', 'la03', 'la04', 'la05', 'la06', 'la07', 'la08', 'la09', 'la10', 'la11', 'la12', 'la13', 'la15', 'la16', 'la17', 'la20', 'la21', 'la22', 'la23', 'la24', 'la27', 'la28', 'la29', 'la31', 'nl01', 'nl03', 'nl04', 'nl05', 'nl07', 'nl08', 'nl09', 'nl10', 'nl13', 'nl14', 'nl15', 'nl16', 'nl17', 'nl18', 'nl19', 'nl20', 'nl21', 'nl22', 'nl23', 'nl24', 'pt1', 'pt2', 'pt3', 'pt6', 'pt7', 'pt8', 'pt10', 'pt11', 'pt12', 'pt13', 'pt14', 'pt15', 'pt16', 'pt17', 'tvb1', 'tvb2', 'tvb5', 'tvb7', 'tvb8', 'tvb11', 'tvb12', 'tvb14', 'tvb17', 'tvb18', 'tvb19', 'tvb23', 'tvb27', 'tvb28', 'tvb29', 'umf001', 'umf002', 'umf003', 'umf004', 'umf005', 'ummc001', 'ummc002', 'ummc003', 'ummc004', 'ummc005', 'ummc006', 'ummc007', 'ummc008', 'ummc009'])
416 416 416 416 416


In [22]:
for clf_name, clf_func in zip(names, classifiers):
    run_clf_validation(
        clf_name,
        clf_func,
        unformatted_X,
        y,
        subject_groups,
        dataset_params,
        study_path,
        windows,
        thresholds,
        weighting_funcs,
    )
#     break

['jh101' 'jh101' 'jh101' 'jh101' 'jh105' 'jh105' 'jh105' 'jh105' 'jh105'
 'jh105' 'jh105' 'jh105' 'jh105' 'jh105' 'la00' 'la02' 'la05' 'la05'
 'la05' 'la05' 'la05' 'la05' 'la05' 'la05' 'la05' 'la05' 'la05' 'la05'
 'la05' 'la05' 'la05' 'la09' 'la09' 'la12' 'la12' 'la12' 'la12' 'la12'
 'la12' 'la12' 'la12' 'la12' 'la15' 'la17' 'la20' 'la20' 'la20' 'la20'
 'la20' 'la20' 'la20' 'la20' 'la21' 'la21' 'la21' 'la21' 'la21' 'la23'
 'la23' 'la23' 'la24' 'la24' 'la24' 'la27' 'la27' 'la27' 'la27' 'la27'
 'la27' 'la27' 'la27' 'la27' 'la27' 'la27' 'la27' 'la27' 'la27' 'la27'
 'la27' 'la27' 'la27' 'la27' 'la27' 'la27' 'la27' 'la27' 'la27' 'la27'
 'la28' 'la28' 'la28' 'la28' 'la28' 'la28' 'la28' 'la28' 'la28' 'la29'
 'la29' 'la29' 'la29' 'la29' 'nl01' 'nl01' 'nl01' 'nl04' 'nl04' 'nl04'
 'nl05' 'nl05' 'nl05' 'nl08' 'nl08' 'nl08' 'nl10' 'nl10' 'nl10' 'nl14'
 'nl14' 'nl14' 'nl14' 'nl14' 'nl14' 'nl15' 'nl15' 'nl15' 'nl17' 'nl17'
 'nl17' 'nl17' 'nl17' 'nl20' 'nl20' 'nl20' 'nl20' 'nl21' 'nl21' 'nl21'
 'nl21

Exception: hi