# 1 - Setup

In [None]:
%load_ext autoreload
%autoreload 2

from modules.model_selection import *
from modules.visualize import *
from modules.utils import *
from modules.feature_extraction import *
from modules.tracking import *

from scipy.stats import loguniform, uniform, randint
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
import shap
import pandas as pd
import numpy as np

import pickle
from pathlib import Path

# output directories
DATA_DIR = Path('./dataset')
FIGURE_DIR = Path('./figures')
CACHE_DIR = Path('./cache')
FIGURE_DIR.mkdir(parents=True, exist_ok=True)
CACHE_DIR.mkdir(parents=True, exist_ok=True)


In [None]:
# FLAGS
ANIMATE_ALL = False # set to True to generate video visualizations for each sample
EVAL = True # set to True to train and evaluate models
REFIT = False # set to True rerun hyperparameter tuning
WEIGHTED = False # set to True to weight same-patient samples during training

# 2 - Data Loading

In [None]:
# parse demographic and UPDRS data for each patient
demographic_data = DemographicData(DATA_DIR/'demographics.xlsx')
updrs_data = UPDRSdata(DATA_DIR/'motor_scales_all_2021.03.08.xlsx')

In [None]:
# load, clean, and subsegment gait recordings
pose_files = sorted(list(DATA_DIR.glob('**/*gait*csv')), key=natsort)
poses = []
ids = []
for f in pose_files:
    try:
        pose = PoseSeries(f, min_duration=3, max_duration=7, tolerance=5) # parameters for subsegmentation
        if pose.seg_cnt == 0:
            continue
        ids.append(pose.id)
        poses.append(pose)
    except Exception as e:
        print(f"Failed to build pose for {f}: {str(e)}.")
ids = np.unique(ids)
print(f"Built {len(poses)} pose series from {len(ids)} patients.")

if ANIMATE_ALL:
    for p in tqdm(poses, desc='Pose Animation'):
        p.animate(visualize=False, show_good_range=True, save_fig=True)

seg_cnts, all_segments = [], []
for p in poses:
    seg_cnts.append(p.seg_cnt)
    all_segments.extend(p.segments)
durations = [(s[1]-s[0])*p.dt for s in all_segments]

print(f'Number of total segments: {np.sum(seg_cnts)}')
print(
    f'Segments per Recording: {np.mean(seg_cnts):.2f} +/- {np.std(seg_cnts):.2f}')
print(
    f'Segment Durations: {np.mean(durations):.2f} +/- {np.std(durations):.2f} s')

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
sns.histplot(seg_cnts, ax=ax1, discrete=True, lw=8)
sns.histplot(durations, ax=ax2, lw=8)
ax1.set_xlabel('# Segments')
ax1.set_title('Good Segments per Gait Recording')
ax2.set_xlabel('Duration [s]')
ax2.set_title('Segment Durations')

plt.show()


In [None]:
# load, clean, and subsegment finger tapping recordings
hand_files = sorted(list(DATA_DIR.glob('**/*fingertap*csv')), key=natsort)
hands = []
ids = []
for f in hand_files:
    try:
        hand = HandSeries(f, min_duration=3, max_duration=7, tolerance=5) # parameters for subsegmentation
        if hand.seg_cnt == 0:
            continue
        ids.append(hand.id)
        hands.append(hand)
    except Exception as e:
        print(f"Failed to build hand for {f}: {str(e)}.")
ids = np.unique(ids)
print(f"Built {len(hands)} hand series from {len(ids)} patients.")

if ANIMATE_ALL:
    for h in tqdm(hands, desc='Fingertap Animation'):
        h.animate(visualize=False, show_good_range=True, save_fig=True)
        break

seg_cnts, all_segments = [], []
for h in hands:
    seg_cnts.append(h.seg_cnt)
    all_segments.extend(h.segments)
durations = [(s[1]-s[0])*h.dt for s in all_segments]

print(f'Number of total segments: {np.sum(seg_cnts)}')
print(
    f'Segments per Recording: {np.mean(seg_cnts):.2f} +/- {np.std(seg_cnts):.2f}')
print(
    f'Segment Durations: {np.mean(durations):.2f} +/- {np.std(durations):.2f} s')

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
sns.histplot(seg_cnts, ax=ax1, discrete=True, lw=8)
sns.histplot(durations, ax=ax2, lw=8)
ax1.set_xlabel('# Segments')
ax1.set_title('Good Segments per Hand Recording')
ax2.set_xlabel('Duration [s]')
ax2.set_title('Segment Durations')

plt.show()


# 3 - Feature Extraction

In [None]:
# generate gait gait/body features
angle_joints = [
    ('nose', 'middle_shoulder', 'left_shoulder', 90, True),
    ('left_elbow', 'left_shoulder', 'left_hip', 0, False),
    ('right_elbow', 'right_shoulder', 'right_hip', 0, False),
]
distance_joints = [
    ('left_wrist', 'left_shoulder', False),
    ('right_wrist', 'right_shoulder', False),
    ('left_ankle', 'left_hip', False),
    ('right_ankle', 'right_hip', False),
]
x_displacement_joints = [
    ('left_ankle', 'right_ankle', False),
    ('left_knee', 'right_knee', False),
    ('middle_shoulder', 'middle_hip', True),
]
y_displacement_joints = [
    ('left_ankle', 'right_ankle', False),
    ('left_shoulder', 'right_shoulder', False),
    ('left_hip', 'right_hip', False),
]


def min_max_diff(f1, f2, c):
    c_max = [cc + '_max' for cc in c]
    c_min = [cc + '_min' for cc in c]
    f_max = np.maximum(f1, f2)
    f_min = np.minimum(f1, f2)

    f = f_max.tolist() + f_min.tolist()
    c = c_max + c_min

    f[::2] = f_max.tolist()
    f[1::2] = f_min.tolist()

    c[::2] = c_max
    c[1::2] = c_min

    return f, c


all_feats, all_scores = [], []
feat_cols, score_cols = None, None

# include temp measures to resolve issue with unilateral features
for i, p in enumerate(tqdm(poses)):
    info = [p.id]
    scores = updrs_data.get_all_scores(p.id, p.on_med).to_numpy().ravel()
    for j in range(p.seg_cnt):
        feats, col_names = [], []

        # for joints in angle_joints:
        #     base_name = f'{{{joints[0]}, {joints[1]}, {joints[2]}}} (angle)'
        #     angle = p.get_angle(*joints[0:4])[j]
        #     f, c = get_all_features(angle, base_name, p.dt, joints[-1])
        #     feats.extend(f)
        #     col_names.extend(c)

        for joints in angle_joints[0:1]:
            base_name = f'{{{joints[0]}, {joints[1]}, {joints[2]}}} (angle)'
            angle = p.get_angle(*joints[0:4])[j]
            f, c = get_all_features(angle, base_name, p.dt, joints[-1])
            feats.extend(f)
            col_names.extend(c)

        for j1, j2 in zip(angle_joints[1::2], angle_joints[2::2]):
            base_name = f'{{{j1[0].split("_")[-1]}, {j1[1].split("_")[-1]}, {j1[2].split("_")[-1]}}} (angle)'
            f1, c = get_all_features(p.get_angle(
                *j1[0:4])[j], base_name, p.dt, j1[-1])
            f2, _ = get_all_features(p.get_angle(
                *j2[0:4])[j], base_name, p.dt, j2[-1])
            f, c = min_max_diff(f1, f2, c)
            feats.extend(f)
            col_names.extend(c)

        # for joints in distance_joints:
        #     base_name = f'{{{joints[0]}, {joints[1]}, {joints[2]}}} (distance)'
        #     distance = p.get_distance(*joints[0:2])[j]
        #     f, c = get_all_features(distance, base_name, p.dt, joints[-1])
        #     feats.extend(f)
        #     col_names.extend(c)

        for j1, j2 in zip(distance_joints[::2], distance_joints[1::2]):
            base_name = f'{{{j1[0].split("_")[-1]}, {j1[1].split("_")[-1]}}} (distance)'
            f1, c = get_all_features(p.get_distance(
                *j1[0:2])[j], base_name, p.dt, j1[-1])
            f2, _ = get_all_features(p.get_distance(
                *j2[0:2])[j], base_name, p.dt, j2[-1])
            f, c = min_max_diff(f1, f2, c)
            feats.extend(f)
            col_names.extend(c)

        for joints in x_displacement_joints:
            base_name = f'{{{joints[0]}, {joints[1]}}} (x_displacement)'
            x_displacement = p.get_x_displacement(*joints[0:2])[j]
            f, c = get_all_features(
                x_displacement, base_name, p.dt, joints[-1])
            feats.extend(f)
            col_names.extend(c)

        for joints in y_displacement_joints:
            base_name = f'{{{joints[0]}, {joints[1]}}} (y_displacement)'
            y_displacement = p.get_y_displacement(*joints[0:2])[j]
            f, c = get_all_features(
                y_displacement, base_name, p.dt, joints[-1])
            feats.extend(f)
            col_names.extend(c)

        all_feats.append(np.concatenate([info, feats]))
        all_scores.append(np.concatenate([info, scores]))
        feat_cols = np.concatenate([['pidn'], col_names])
        score_cols = np.concatenate([['pidn'], UPDRSdata.score_names])

feat_df = pd.DataFrame(all_feats, columns=feat_cols).set_index('pidn')
score_df = pd.DataFrame(all_scores, columns=score_cols).set_index('pidn')
feat_df.to_csv(CACHE_DIR/'gait_features.csv')
score_df.to_csv(CACHE_DIR/'gait_scores.csv')


In [None]:
# generate finger tapping/hand features
distance_joints = [
    ('right_thumb', 'right_index'),
    ('left_thumb', 'left_index'),

    ('right_middle', 'right_wrist'),
    ('left_middle', 'left_wrist'),

    ('right_ring', 'right_wrist'),
    ('left_ring', 'left_wrist'),

    ('right_pinky', 'right_wrist'),
    ('left_pinky', 'left_wrist'),
]

all_feats, all_scores = [], []
feat_cols, score_cols = None, None
is_right = []

for i, h in enumerate(tqdm(hands)):
    info = [h.id]
    scores = updrs_data.get_all_scores(h.id, h.on_med).to_numpy().ravel()
    for j in range(h.seg_cnt):
        feats, col_names = [], []

        wrist_r = h.get_joint('right_wrist', raw=True)[j][:, 1]
        wrist_l = h.get_joint('left_wrist', raw=True)[j][:, 1]
        is_right = wrist_r.mean() < wrist_l.mean()  # note: for y-coordinate, top is 0

        for (joints_r, joints_l) in zip(distance_joints[::2], distance_joints[1::2]):
            distance_r = h.get_distance(*joints_r)[j]
            distance_l = h.get_distance(*joints_l)[j]
            distance = distance_r if is_right else distance_l

            base_name = f"{{{joints_r[0].split('_')[-1]}, {joints_r[1].split('_')[-1]}}} (distance)"
            f, c = get_all_features(distance, base_name, h.dt)
            feats.extend(f)
            col_names.extend(c)

        all_feats.append(np.concatenate([info, feats]))
        all_scores.append(np.concatenate([info, scores]))
        feat_cols = np.concatenate([['pidn'], col_names])
        score_cols = np.concatenate([['pidn'], UPDRSdata.score_names])

feat_df = pd.DataFrame(all_feats, columns=feat_cols).set_index('pidn')
score_df = pd.DataFrame(all_scores, columns=score_cols).set_index('pidn')
feat_df.to_csv(CACHE_DIR/'fingertap_features.csv')
score_df.to_csv(CACHE_DIR/'fingertap_scores.csv')


# 4 - Data Analysis

In [None]:
# generate design matrix for gait/body feature set
gait_features = pd.read_csv(CACHE_DIR/'gait_features.csv', index_col=0)
gait_scores = pd.read_csv(CACHE_DIR/'gait_scores.csv', index_col=0)

gait_features['id'] = np.arange(len(gait_features))
gait_scores['id'] = np.arange(len(gait_scores))

mask = ~gait_scores.isna().any(axis=1)
mask &= ~gait_features.isna().any(axis=1)
gait_features = gait_features[mask]
gait_scores = gait_scores[mask]

gait_labels = gait_scores[['3.total', 'id']]
low = gait_labels['3.total'] <= 32
gait_labels.loc[low, '3.total'] = 0
gait_labels.loc[~low, '3.total'] = 1
gait_labels = gait_labels.rename(columns={'3.total': 'label'}).astype(int)

tmp = gait_features.reset_index().set_index(['pidn'])
print(f'Samping from {len(np.unique(tmp.index))} valid patient entries:')
print(
    f'    Feature Matrix: {gait_features.shape[0]} x {gait_features.shape[1]}')
print()
print(f'Total Score Labels [Cutoff = 32]:')
print(f'    # Mild: {gait_labels.value_counts(subset="label")[0]}')
print(f'    # Severe: {gait_labels.value_counts(subset="label")[1]}')


In [None]:
# generate design matrix for finger tapping/hand feature set
ft_features = pd.read_csv(CACHE_DIR/'fingertap_features.csv', index_col=0)
ft_scores = pd.read_csv(CACHE_DIR/'fingertap_scores.csv', index_col=0)

ft_features['id'] = np.arange(len(ft_features))
ft_scores['id'] = np.arange(len(ft_scores))

mask = ~ft_scores.isna().any(axis=1)
mask &= ~ft_features.isna().any(axis=1)
ft_features = ft_features[mask]
ft_scores = ft_scores[mask]

ft_labels = ft_scores[['3.total', 'id']]
low = ft_labels['3.total'] <= 32
ft_labels.loc[low, '3.total'] = 0
ft_labels.loc[~low, '3.total'] = 1
ft_labels = ft_labels.rename(columns={'3.total': 'label'}).astype(int)

tmp = ft_features.reset_index().set_index(['pidn'])
print(f'Samping from {len(np.unique(tmp.index))} valid patient entries:')
print(f'    Feature Matrix: {ft_features.shape[0]} x {ft_features.shape[1]}')
print()
print(f'Total Score Labels [Cutoff = 32]:')
print(f'    # Mild: {ft_labels.value_counts(subset="label")[0]}')
print(f'    # Severe: {ft_labels.value_counts(subset="label")[1]}')


In [None]:
# pairwise combine body and hand features to generate design matrix for combined/integrated feature set
int_features = gait_features.reset_index().merge(
    ft_features.reset_index(), on='pidn')
int_features = int_features.set_index('pidn')
int_features = int_features.rename(columns={'id_x': 'id', 'id_y': 'id2'})

int_scores = []
for id in int_features['id']:
    int_scores.append(gait_scores[gait_scores['id'] == id])
int_scores = pd.concat(int_scores)

int_labels = int_scores[['3.total', 'id']]
low = int_labels['3.total'] <= 32
int_labels.loc[low, '3.total'] = 0
int_labels.loc[~low, '3.total'] = 1
int_labels = int_labels.rename(columns={'3.total': 'label'}).astype(int)

tmp = int_features.reset_index().set_index(['pidn'])
tmp2 = int_features.drop_duplicates(subset="id")
print(f'Samping from {len(np.unique(tmp.index))} valid patient entries:')
print(
    f'    Feature Matrix: {int_features.shape[0]} x {int_features.shape[1]}')
print()

tmp = int_labels.drop_duplicates(subset='id')
print(f'Total Score Labels [Cutoff = 32]:')
print(f'    # Mild: {int_labels.value_counts(subset="label")[0]}')
print(f'    # Severe: {int_labels.value_counts(subset="label")[1]}')
print()


### Evaluate Models

In [None]:
# models to evaluate
models = {
    'SVM': LinearSVC(dual=False, fit_intercept=False, max_iter=5000),
    'LR': LogisticRegression(penalty='elasticnet', solver='saga', fit_intercept=False, max_iter=5000),
    'LDA': LinearDiscriminantAnalysis(solver='lsqr'),
    'RF': RandomForestClassifier(),
    'AB': AdaBoostClassifier(),
    'KNN': KNeighborsClassifier(),
    'GNB': GaussianNB(),
}
# search space for models
params = {
    'SVM': {'penalty': ['l1', 'l2'],
            'C': loguniform(1e-3, 1e0)},
    'LR': {'C': loguniform(1e-3, 1e0),
           'l1_ratio': uniform(0, 1)},
    'LDA': [{'shrinkage': [None, 'auto']},
            {'shrinkage': uniform(0, 1)}],
    'RF': [{'max_depth': randint(3, 12),
            'min_samples_split': randint(2, 10),
            'min_samples_leaf': randint(1, 10)},
           {'max_depth': [None],
            'min_samples_split': randint(2, 10),
            'min_samples_leaf': randint(1, 10)}],
    'AB': {'learning_rate': loguniform(1e-3, 1e0)},
    'KNN': {'n_neighbors': randint(2, 30)},
    'GNB': {},
}
alphas = None


In [None]:
if not REFIT:
    models = {
        'SVM': LinearSVC(penalty='l2',
                         C=0.01,
                         dual=False,
                         fit_intercept=False,
                         max_iter=5000,
                         class_weight='balanced'),
        'LR': LogisticRegression(penalty='elasticnet',
                                 l1_ratio=0.2,
                                 C=1,
                                 solver='saga',
                                 fit_intercept=False,
                                 max_iter=5000,
                                 class_weight='balanced'),
        'LDA': LinearDiscriminantAnalysis(shrinkage='auto', solver='lsqr'),
        'RF': RandomForestClassifier(n_estimators=100,
                                     max_depth=8,
                                     min_samples_split=2,
                                     min_samples_leaf=1,
                                     class_weight='balanced'),
        'AB': AdaBoostClassifier(learning_rate=0.1),
        'KNN': KNeighborsClassifier(n_neighbors=9),
        'GNB': GaussianNB(),
    }
    params = defaultdict(dict)
    alphas = [0.030774356208980617, 0.024757157946593427, 0.02141453563853955, 0.018523222156741202, 0.008341801332506338,
                0.005398628767382501, 0.011149209970080915, 0.02862153445389273, 0.007215521357901021, 0.02862153445389273,
                0.009643883791544459, 0.008969245378672715, 0.011987818459583773, 0.011149209970080915, 0.03308910644196496,
                0.030774356208980617, 0.02141453563853955, 0.007215521357901021, 0.01490144370527747, 0.007758250168566794,
                0.03308910644196496, 0.011987818459583773, 0.008969245378672715]

if EVAL:
    results = eval_all(
        models, params, gait_features, gait_labels, n_repeats=1, alphas=alphas, weighted=WEIGHTED)
    with open(CACHE_DIR/'gait.dat', 'wb') as f:
        pickle.dump(results, f)


In [None]:
if not REFIT:
    models = {
        'SVM': LinearSVC(penalty='l2',
                         C=0.001,
                         dual=False,
                         fit_intercept=False,
                         max_iter=5000,
                         class_weight='balanced'),
        'LR': LogisticRegression(penalty='l2',
                                 C=0.0001,
                                 dual=False,
                                 solver='liblinear',
                                 fit_intercept=False,
                                 max_iter=5000,
                                 class_weight='balanced'),
        'LDA': LinearDiscriminantAnalysis(shrinkage='auto', solver='lsqr'),
        'RF': RandomForestClassifier(n_estimators=100,
                                     max_depth=5,
                                     min_samples_split=2,
                                     min_samples_leaf=1,
                                     class_weight='balanced'),
        'AB': AdaBoostClassifier(learning_rate=0.05),
        'KNN': KNeighborsClassifier(n_neighbors=5),
        'GNB': GaussianNB(),
    }
    params = defaultdict(dict)
    alphas = [0.10558981944335787, 0.04422514763163044, 0.06833538873292302, 0.10558981944335787, 0.02862153445389273, 
              0.09820327790631257, 0.05112830759482943, 0.084944224982638, 0.09133346228248625, 0.084944224982638, 
              0.04755162406834014, 0.09820327790631257, 0.026619313461261316, 0.09133346228248625, 0.059108990642279875, 
              0.03825402746632875, 0.10558981944335787, 0.01602228340877477, 0.04422514763163044, 0.09133346228248625, 
              0.03308910644196496, 0.084944224982638, 0.07900194712408971, 0.06355498291362295, 0.03557796490339495]

if EVAL:
    results = eval_all(models, params, ft_features, ft_labels, n_repeats=32, alphas=alphas, weighted=WEIGHTED)
    with open(CACHE_DIR/'fingertap.dat', 'wb') as f:
        pickle.dump(results, f)


In [None]:
if not REFIT:
    models = {
        'SVM': LinearSVC(penalty='l1',
                         C=0.005,
                         dual=False,
                         fit_intercept=False,
                         max_iter=5000,
                         class_weight='balanced'),
        'LR': LogisticRegression(penalty='elasticnet',
                                 l1_ratio=0.7,
                                 C=0.01,
                                 dual=False,
                                 solver='saga',
                                 fit_intercept=False,
                                 max_iter=5000,
                                 class_weight='balanced'),
        'LDA': LinearDiscriminantAnalysis(shrinkage=0.5, solver='lsqr'),
        'RF': RandomForestClassifier(n_estimators=100,
                                     max_depth=3,
                                     min_samples_split=2,
                                     min_samples_leaf=1,
                                     class_weight='balanced'),
        'AB': AdaBoostClassifier(learning_rate=0.5),
        'KNN': KNeighborsClassifier(n_neighbors=20),
        'GNB': GaussianNB(),
    }
    params = defaultdict(dict)
    alphas = [0.03825402746632875, 0.07347536163492151, 0.0411313750341857, 0.04422514763163044, 0.002810728502080728, 
              0.07900194712408971, 0.059108990642279875, 0.06833538873292302, 0.059108990642279875, 0.05112830759482943, 
              0.06355498291362295, 0.06833538873292302, 0.05112830759482943, 0.059108990642279875, 0.07900194712408971, 
              0.06833538873292302, 0.06355498291362295, 0.0411313750341857, 0.002102977594546134, 0.07900194712408971]

if EVAL:
    results = eval_all(models, params, int_features, int_labels, n_repeats=32, alphas=alphas, weighted=WEIGHTED)
    with open(CACHE_DIR/'integrated.dat', 'wb') as f:
        pickle.dump(results, f)


### Visualize

In [None]:
# visualize LASSO results
def get_count(results):
    selected = []
    for _, v in results['selected'].items():
        selected.extend(v)
    counts = [len(s) for s in selected]
    return np.array(counts)


with open('./cache/gait.dat', 'rb') as f:
    a = pickle.load(f)
with open('./cache/fingertap.dat', 'rb') as f:
    b = pickle.load(f)
with open('./cache/integrated.dat', 'rb') as f:
    c = pickle.load(f)

a, b, c = get_count(a), get_count(b), get_count(c)

fig = plt.figure(figsize=(10, 2))
sns.histplot(data=a, stat='proportion', ax=plt.gca(), binwidth=5)
sns.histplot(data=b, stat='proportion', ax=plt.gca(), binwidth=5)
sns.histplot(data=c, stat='proportion', ax=plt.gca(), binwidth=5)

plt.savefig(FIGURE_DIR/'lasso.png')

print(f'Body: mean={a.mean():.2f}, std={a.std():.2f}')
print(f'Hand: mean={b.mean():.2f}, std={b.std():.2f}')
print(f'Combined: mean={c.mean():.2f}, std={c.std():.2f}')


In [None]:
# visualize gait dataset CV results
with open('./cache/gait.dat', 'rb') as f:
    results = pickle.load(f)
gait_selected = VisualizeAll(
    gait_features, gait_labels, results, out_dir=FIGURE_DIR/'g')


In [None]:
# visualize finger tapping dataset CV results
with open('./cache/fingertap.dat', 'rb') as f:
    results = pickle.load(f)
ft_selected = VisualizeAll(ft_features, ft_labels,
                           results, out_dir=FIGURE_DIR/'ft')


In [None]:
# visualize combined dataset CV results
with open('./cache/integrated.dat', 'rb') as f:
    results = pickle.load(f)
int_selected = VisualizeAll(
    int_features, int_labels, results, out_dir=FIGURE_DIR/'int')


In [None]:
# SHAP analysis w/ best model
selected = int_selected['feat'].to_numpy()

clf = LogisticRegression(penalty='elasticnet',
                         dual=False,
                         solver='saga',
                         fit_intercept=False,
                         max_iter=5000,
                         class_weight='balanced')

shap_values, X_test_array = model_shap(clf, int_features, int_labels, selected=selected)
shap.summary_plot(shap_values, X_test_array, feature_names=selected, plot_size=(15, 10))
shap.summary_plot(shap_values, X_test_array, feature_names=selected, plot_size=(15, 10), plot_type="bar")