# SHAP feature selection #
## Code to select feature with combination of LightGBM and SHAP ##

## Parameters cell ##

Parameters are overiddent by papermill when run inside DVC stages



In [1]:
number_of_bootstraps = 5 # this sets global setting of which how many bootstraps to use

lgb_params = {
    'boosting_type': 'gbdt',
    'objective': 'regression',
    'metric': {'l2', 'l1'},
    'max_leaves': 20,
    'max_depth': 3,
    'learning_rate': 0.07,
    'feature_fraction': 0.8,
    'bagging_fraction': 1,
    'min_data_in_leaf': 6,
    'lambda_l1': 0.9,
    'lambda_l2': 0.9,
    "verbose": -1
}

debug_local = True #to use local version

In [2]:
from pathlib import Path
import sys
import inspect

local = (Path("..") / "yspecies").resolve()
if debug_local and local.exists():
  sys.path.insert(0, Path("..").as_posix())
  #sys.path.insert(0, local.as_posix())
  print("extending pathes with local yspecies")
  print(sys.path)
  %load_ext autoreload
  %autoreload 2

extending pathes with local yspecies
['..', '/data/sources/yspecies/notebooks', '/opt/miniconda3/envs/yspecies/lib/python38.zip', '/opt/miniconda3/envs/yspecies/lib/python3.8', '/opt/miniconda3/envs/yspecies/lib/python3.8/lib-dynload', '', '/opt/miniconda3/envs/yspecies/lib/python3.8/site-packages', '/opt/miniconda3/envs/yspecies/lib/python3.8/site-packages/IPython/extensions', '/home/antonkulaga/.ipython']


In [3]:
from typing import *
from yspecies.dataset import *
from yspecies.utils import *
from yspecies.workflow import *
from yspecies.partition import *
from yspecies.selection import *

In [4]:
from dataclasses import dataclass
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [5]:
import pandas as pd
import shap
from pprint import pprint
import random
import numpy as np
import lightgbm as lgb
from scipy.stats import kendalltau
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.utils import resample
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error, accuracy_score, recall_score, precision_score, f1_score

In [6]:
#settings
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option('display.float_format', lambda x: '%.3f' % x)
import pprint
pp = pprint.PrettyPrinter(indent=4)

### Loading data ###
Let's load data from species/genes/expressions selected by select_samples.py notebook

In [7]:
from pathlib import Path
locations: Locations = Locations("./") if Path("./data").exists() else Locations("../")

In [8]:
data = ExpressionDataset.from_folder(locations.interim.selected)
data

expressions,genes,species,samples,Genes Metadata,Species Metadata
"(445, 18329)",18329,39,445,"(18329, 2)","(40, 18)"


## Setting up SHAP selection pipeline ##

Deciding on selection parameters (which fields to include, exclude, predict)

In [9]:
from sklearn.pipeline import Pipeline

selection = FeatureSelection(
    samples = ["tissue","species"], #samples metadata to include
    species =  [], #species metadata other then Y label to include
    exclude_from_training = ["species"],  #exclude some fields from LightGBM training
    to_predict = "lifespan", #column to predict
    categorical = ["tissue"]
)


Setting up the pipeline

In [10]:
pipe = Pipeline([
    ('extractor', DataExtractor(selection)), # to extract the data required for ML from the dataset
    ("partitioner", DataPartitioner(species_in_validation=2, not_validated_species = ["Homo_sapiens"])), # to partition it according to sorted stratification
    ("shap_computation", ShapSelector(ModelFactory(parameters = lgb_params)))] # to train lightGBM and do feature selection
)

Training and fitting on GeneExpression data

In [111]:
results = pipe.fit_transform(data)
results

Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[78]	valid_0's l1: 2.25003	valid_0's l2: 42.6584
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[65]	valid_0's l1: 2.63869	valid_0's l2: 38.4333
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[51]	valid_0's l1: 2.16113	valid_0's l2: 18.1163
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[70]	valid_0's l1: 3.57256	valid_0's l2: 90.7147
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[19]	valid_0's l1: 9.27652	valid_0's l2: 193.914


Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.


MEAN metrics = R^2    0.862
MSE   76.767
MAE    3.980
dtype: float64


Unnamed: 0_level_0,symbol,gain_score_to_lifespan,kendall_tau_to_lifespan,shap_sum_fold_0,shap_sum_fold_1,shap_sum_fold_2,shap_sum_fold_3,shap_sum_fold_4
ensembl_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
Unnamed: 0_level_2,R^2,MSE,MAE,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
ENSG00000116251,RPL22,535461.879,0.026,-1.175,25.896,-77.771,-48.37,2.43
ENSG00000182534,MXRA7,172043.799,-0.201,6.053,2.588,-11.843,-14.597,1.438
ENSG00000211693,TRGV11,165870.544,-0.584,-127.627,-158.144,-157.94,-123.302,-91.296
ENSG00000187912,CLEC17A,57025.545,0.078,-41.171,-4.137,59.414,12.91,-46.419
ENSG00000150687,PRSS23,54054.657,0.42,94.193,25.966,57.861,114.401,356.593
ENSG00000096654,ZNF184,53416.88,0.285,9.719,6.437,86.137,52.035,-1.851
ENSG00000104835,SARS2,33166.771,0.022,-32.044,-33.978,-38.036,-34.412,-22.932
ENSG00000175854,SWI5,10384.018,-0.596,24.851,-3.198,46.269,43.501,9.027
ENSG00000165501,LRR1,1349.369,0.628,4.478,0.392,13.354,6.59,6.178
0,0.932,42.658,2.25,,,,,

Unnamed: 0_level_0,symbol,gain_score_to_lifespan,kendall_tau_to_lifespan,shap_sum_fold_0,shap_sum_fold_1,shap_sum_fold_2,shap_sum_fold_3,shap_sum_fold_4
ensembl_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
ENSG00000116251,RPL22,535461.879,0.026,-1.175,25.896,-77.771,-48.37,2.43
ENSG00000182534,MXRA7,172043.799,-0.201,6.053,2.588,-11.843,-14.597,1.438
ENSG00000211693,TRGV11,165870.544,-0.584,-127.627,-158.144,-157.94,-123.302,-91.296
ENSG00000187912,CLEC17A,57025.545,0.078,-41.171,-4.137,59.414,12.91,-46.419
ENSG00000150687,PRSS23,54054.657,0.42,94.193,25.966,57.861,114.401,356.593
ENSG00000096654,ZNF184,53416.88,0.285,9.719,6.437,86.137,52.035,-1.851
ENSG00000104835,SARS2,33166.771,0.022,-32.044,-33.978,-38.036,-34.412,-22.932
ENSG00000175854,SWI5,10384.018,-0.596,24.851,-3.198,46.269,43.501,9.027
ENSG00000165501,LRR1,1349.369,0.628,4.478,0.392,13.354,6.59,6.178

Unnamed: 0,R^2,MSE,MAE
0,0.932,42.658,2.25
1,0.938,38.433,2.639
2,0.968,18.116,2.161
3,0.863,90.715,3.573
4,0.61,193.914,9.277


# PROBLEM #

## One more re-training and we have totally different result ##

In [112]:
pipe.fit_transform(data)

Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[12]	valid_0's l1: 10.6719	valid_0's l2: 235.503
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[84]	valid_0's l1: 2.85972	valid_0's l2: 37.6545
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[119]	valid_0's l1: 1.42366	valid_0's l2: 18.5336
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[63]	valid_0's l1: 3.38839	valid_0's l2: 120.358
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[115]	valid_0's l1: 2.6169	valid_0's l2: 48.8461


Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.


MEAN metrics = R^2    0.721
MSE   92.179
MAE    4.192
dtype: float64


Unnamed: 0_level_0,symbol,gain_score_to_lifespan,kendall_tau_to_lifespan,shap_sum_fold_0,shap_sum_fold_1,shap_sum_fold_2,shap_sum_fold_3,shap_sum_fold_4
ensembl_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
Unnamed: 0_level_2,R^2,MSE,MAE,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
ENSG00000116251,RPL22,549363.431,0.095,-17.085,-15.373,-4.788,2.226,-0.391
ENSG00000182534,MXRA7,194007.618,-0.152,-5.163,2.616,-1.322,3.7,3.065
ENSG00000211693,TRGV11,154477.176,-0.622,-87.713,-142.545,-131.418,-143.914,-145.22
ENSG00000187912,CLEC17A,55692.87,0.102,-18.267,12.216,25.125,-2.634,-29.564
ENSG00000150687,PRSS23,52384.633,0.279,241.282,90.521,86.462,78.419,55.729
ENSG00000104835,SARS2,34696.831,0.015,-26.634,-43.602,-24.205,-40.106,-36.62
ENSG00000179909,ZNF154,26422.976,-0.268,-33.894,11.365,-2.895,-27.167,0.032
ENSG00000255181,CCDC166,6456.032,-0.104,-1.772,34.961,3.961,1.837,-0.486
0,-0.070,235.503,10.672,,,,,
1,0.933,37.655,2.86,,,,,

Unnamed: 0_level_0,symbol,gain_score_to_lifespan,kendall_tau_to_lifespan,shap_sum_fold_0,shap_sum_fold_1,shap_sum_fold_2,shap_sum_fold_3,shap_sum_fold_4
ensembl_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
ENSG00000116251,RPL22,549363.431,0.095,-17.085,-15.373,-4.788,2.226,-0.391
ENSG00000182534,MXRA7,194007.618,-0.152,-5.163,2.616,-1.322,3.7,3.065
ENSG00000211693,TRGV11,154477.176,-0.622,-87.713,-142.545,-131.418,-143.914,-145.22
ENSG00000187912,CLEC17A,55692.87,0.102,-18.267,12.216,25.125,-2.634,-29.564
ENSG00000150687,PRSS23,52384.633,0.279,241.282,90.521,86.462,78.419,55.729
ENSG00000104835,SARS2,34696.831,0.015,-26.634,-43.602,-24.205,-40.106,-36.62
ENSG00000179909,ZNF154,26422.976,-0.268,-33.894,11.365,-2.895,-27.167,0.032
ENSG00000255181,CCDC166,6456.032,-0.104,-1.772,34.961,3.961,1.837,-0.486

Unnamed: 0,R^2,MSE,MAE
0,-0.07,235.503,10.672
1,0.933,37.655,2.86
2,0.97,18.534,1.424
3,0.84,120.358,3.388
4,0.932,48.846,2.617


## One more re-training and again - we have totally different result ##

In [113]:
pipe.fit_transform(data)

Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[42]	valid_0's l1: 5.39749	valid_0's l2: 103.318
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[140]	valid_0's l1: 2.73662	valid_0's l2: 59.2562
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[82]	valid_0's l1: 3.60679	valid_0's l2: 50.5347
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[51]	valid_0's l1: 3.09597	valid_0's l2: 26.8577
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[107]	valid_0's l1: 1.12726	valid_0's l2: 4.43289


Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.


MEAN metrics = R^2    0.910
MSE   48.880
MAE    3.193
dtype: float64


Unnamed: 0_level_0,symbol,gain_score_to_lifespan,kendall_tau_to_lifespan,shap_sum_fold_0,shap_sum_fold_1,shap_sum_fold_2,shap_sum_fold_3,shap_sum_fold_4
ensembl_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
Unnamed: 0_level_2,R^2,MSE,MAE,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
ENSG00000116251,RPL22,608064.466,0.085,-167.934,-2.946,-8.28,94.779,5.98
ENSG00000182534,MXRA7,192219.074,-0.031,-52.572,-1.265,-2.518,37.655,6.761
ENSG00000211693,TRGV11,167549.824,-0.645,-141.992,-154.131,-145.877,-141.376,-116.505
ENSG00000187912,CLEC17A,70580.18,0.01,-241.223,44.22,0.24,83.166,4.441
ENSG00000096654,ZNF184,63931.532,0.313,-16.382,30.773,1.165,150.298,115.334
ENSG00000150687,PRSS23,45146.042,0.346,386.477,59.956,99.624,46.345,52.218
ENSG00000108433,GOSR2,39571.828,0.098,-6.316,-24.09,-30.972,-13.751,-23.649
ENSG00000104835,SARS2,32621.475,0.018,-26.254,-29.68,-31.649,-28.109,-43.112
ENSG00000186376,ZNF75D,27907.674,0.491,-1.071,13.694,55.943,10.521,21.017
ENSG00000250644,AC068580.4,11190.577,0.088,-51.186,22.348,0.231,45.48,17.047

Unnamed: 0_level_0,symbol,gain_score_to_lifespan,kendall_tau_to_lifespan,shap_sum_fold_0,shap_sum_fold_1,shap_sum_fold_2,shap_sum_fold_3,shap_sum_fold_4
ensembl_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
ENSG00000116251,RPL22,608064.466,0.085,-167.934,-2.946,-8.28,94.779,5.98
ENSG00000182534,MXRA7,192219.074,-0.031,-52.572,-1.265,-2.518,37.655,6.761
ENSG00000211693,TRGV11,167549.824,-0.645,-141.992,-154.131,-145.877,-141.376,-116.505
ENSG00000187912,CLEC17A,70580.18,0.01,-241.223,44.22,0.24,83.166,4.441
ENSG00000096654,ZNF184,63931.532,0.313,-16.382,30.773,1.165,150.298,115.334
ENSG00000150687,PRSS23,45146.042,0.346,386.477,59.956,99.624,46.345,52.218
ENSG00000108433,GOSR2,39571.828,0.098,-6.316,-24.09,-30.972,-13.751,-23.649
ENSG00000104835,SARS2,32621.475,0.018,-26.254,-29.68,-31.649,-28.109,-43.112
ENSG00000186376,ZNF75D,27907.674,0.491,-1.071,13.694,55.943,10.521,21.017
ENSG00000250644,AC068580.4,11190.577,0.088,-51.186,22.348,0.231,45.48,17.047

Unnamed: 0,R^2,MSE,MAE
0,0.769,103.318,5.397
1,0.908,59.256,2.737
2,0.917,50.535,3.607
3,0.962,26.858,3.096
4,0.994,4.433,1.127
