# SHAP feature selection #
## Code to select feature with combination of LightGBM and SHAP ##

## Parameters cell ##

Parameters are overiddent by papermill when run inside DVC stages



In [1]:
number_of_bootstraps = 5 # this sets global setting of which how many bootstraps to use

lgb_params = {
    'boosting_type': 'gbdt',
    'objective': 'regression',
    'metric': {'l2', 'l1'},
    'max_leaves': 20,
    'max_depth': 3,
    'learning_rate': 0.07,
    'feature_fraction': 0.8,
    'bagging_fraction': 1,
    'min_data_in_leaf': 6,
    'lambda_l1': 0.9,
    'lambda_l2': 0.9,
    "verbose": -1
}

debug_local = True #to use local version

In [2]:
from pathlib import Path
import sys
import inspect

local = (Path("..") / "yspecies").resolve()
if debug_local and local.exists():
  sys.path.insert(0, Path("..").as_posix())
  #sys.path.insert(0, local.as_posix())
  print("extending pathes with local yspecies")
  print(sys.path)
  %load_ext autoreload
  %autoreload 2

extending pathes with local yspecies
['..', '/data/sources/yspecies/notebooks', '/opt/miniconda3/envs/yspecies/lib/python38.zip', '/opt/miniconda3/envs/yspecies/lib/python3.8', '/opt/miniconda3/envs/yspecies/lib/python3.8/lib-dynload', '', '/opt/miniconda3/envs/yspecies/lib/python3.8/site-packages', '/opt/miniconda3/envs/yspecies/lib/python3.8/site-packages/IPython/extensions', '/home/antonkulaga/.ipython']


In [3]:
from typing import *
from yspecies.dataset import *
from yspecies.utils import *
from yspecies.workflow import *
from yspecies.partition import *
from yspecies.selection import *

In [4]:
from dataclasses import dataclass
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [5]:
import pandas as pd
import shap
from pprint import pprint
import random
import numpy as np
import lightgbm as lgb
from scipy.stats import kendalltau
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.utils import resample
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error, accuracy_score, recall_score, precision_score, f1_score

In [6]:
#settings
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option('display.float_format', lambda x: '%.3f' % x)
import pprint
pp = pprint.PrettyPrinter(indent=4)

### Loading data ###
Let's load data from species/genes/expressions selected by select_samples.py notebook

In [7]:
from pathlib import Path
locations: Locations = Locations("./") if Path("./data").exists() else Locations("../")

In [8]:
data = ExpressionDataset.from_folder(locations.interim.selected)
data

expressions,genes,species,samples,Genes Metadata,Species Metadata
"(452, 18301)",18301,43,452,"(18301, 2)","(44, 18)"


In [9]:
"series" not in data.samples.head().columns

False

## Setting up SHAP selection pipeline ##

Deciding on selection parameters (which fields to include, exclude, predict)

In [16]:
from sklearn.pipeline import Pipeline

selection = SelectedFeatures(
    samples = ["species"], #samples metadata to include
    species =  [], #species metadata other then Y label to include
    exclude_from_training = ["species"],  #exclude some fields from LightGBM training
    to_predict = "lifespan" #column to predict
)


In [17]:
show(data.extended_samples(["tissue","species"], ["lifespan"]),10,10)

Unnamed: 0_level_0,tissue,species,lifespan
run,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
SRR1287653,Blood,Ailuropoda_melanoleuca,36.8
SRR1287654,Blood,Ailuropoda_melanoleuca,36.8
SRR1287655,Blood,Ailuropoda_melanoleuca,36.8
SRR2308103,Liver,Ailuropoda_melanoleuca,36.8
SRR1981979,Brain,Aotus_nancymaae,20.0
SRR1981981,Liver,Aotus_nancymaae,20.0
SRR1981987,Heart,Aotus_nancymaae,20.0
SRR1981988,Kidney,Aotus_nancymaae,20.0
SRR636839,Liver,Bos_taurus,20.0
SRR636840,Liver,Bos_taurus,20.0


Setting up the pipeline

In [18]:
pipe = Pipeline([
    ('extractor', DataExtractor(selection)), # to extract the data required for ML from the dataset
    ("partitioner", DataPartitioner(species_in_validation=2, not_validated_species = ["Homo_sapiens"])), # to partition it according to sorted stratification
    ("shap_computation", ShapSelector(ModelFactory(parameters = lgb_params)))] # to train lightGBM and do feature selection
)

Training and fitting on GeneExpression data

In [19]:
d = pipe.fit_transform(data)
d #known bug: I know that latest metrics row for the first run is broken

Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[71]	valid_0's l1: 2.00136	valid_0's l2: 12.3935
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[56]	valid_0's l1: 3.20237	valid_0's l2: 116.326
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[45]	valid_0's l1: 3.59401	valid_0's l2: 43.0496
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[103]	valid_0's l1: 4.10909	valid_0's l2: 99.8946
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[106]	valid_0's l1: 1.48717	valid_0's l2: 8.04701


Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.


MEAN metrics = R^2    0.902
MSE   55.942
MAE    2.879
dtype: float64


Unnamed: 0_level_0,symbol,gain_score_to_lifespan,shap,kendall_tau_to_lifespan
ensembl_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Unnamed: 0_level_2,R^2,MSE,MAE,Unnamed: 4_level_2
ENSG00000211693,TRGV11,232740.997,-193.345,-0.615
ENSG00000096654,ZNF184,68743.218,36.828,0.236
ENSG00000187912,CLEC17A,56230.505,-29.761,0.016
ENSG00000104835,SARS2,21629.789,-44.152,0.109
ENSG00000211716,TRBV9,16751.509,-5.815,-0.395
ENSG00000250644,AC068580.4,15598.666,18.559,0.03
ENSG00000175854,SWI5,7896.644,-2.081,-0.593
ENSG00000226979,LTA,7434.786,-15.03,0.187
ENSG00000255181,CCDC166,6251.737,11.948,-0.091
ENSG00000165501,LRR1,1766.778,3.353,0.677

Unnamed: 0_level_0,symbol,gain_score_to_lifespan,shap,kendall_tau_to_lifespan
ensembl_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
ENSG00000211693,TRGV11,232740.997,-193.345,-0.615
ENSG00000096654,ZNF184,68743.218,36.828,0.236
ENSG00000187912,CLEC17A,56230.505,-29.761,0.016
ENSG00000104835,SARS2,21629.789,-44.152,0.109
ENSG00000211716,TRBV9,16751.509,-5.815,-0.395
ENSG00000250644,AC068580.4,15598.666,18.559,0.03
ENSG00000175854,SWI5,7896.644,-2.081,-0.593
ENSG00000226979,LTA,7434.786,-15.03,0.187
ENSG00000255181,CCDC166,6251.737,11.948,-0.091
ENSG00000165501,LRR1,1766.778,3.353,0.677

Unnamed: 0,R^2,MSE,MAE
0,0.981,12.393,2.001
1,0.784,116.326,3.202
2,0.922,43.05,3.594
3,0.836,99.895,4.109
4,0.988,8.047,1.487


# PROBLEM #

## One more re-training and we have totally different result ##

In [20]:
#shap.summary_plot('mass_g', shap_values, df)
#shap.summary_plot(shap_values, df, feature_names=shap_feature_names, sort=False, plot_type='dot', max_display=100, show=False)
pipe.fit_transform(data)

Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[29]	valid_0's l1: 5.27857	valid_0's l2: 75.6098
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[101]	valid_0's l1: 2.04033	valid_0's l2: 28.6696
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[120]	valid_0's l1: 2.199	valid_0's l2: 7.9687
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[158]	valid_0's l1: 1.1327	valid_0's l2: 4.10436
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[81]	valid_0's l1: 3.10023	valid_0's l2: 67.3767


Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.


MEAN metrics = R^2    0.938
MSE   36.746
MAE    2.750
dtype: float64


Unnamed: 0_level_0,symbol,gain_score_to_lifespan,shap,kendall_tau_to_lifespan
ensembl_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Unnamed: 0_level_2,R^2,MSE,MAE,Unnamed: 4_level_2
ENSG00000196166,C8orf86,734163.666,-21.64,0.119
ENSG00000211693,TRGV11,196388.198,-168.727,-0.619
ENSG00000096654,ZNF184,87878.862,65.001,0.289
ENSG00000187912,CLEC17A,76750.492,-39.988,-0.028
ENSG00000226979,LTA,23853.347,-37.179,0.135
ENSG00000104835,SARS2,19252.304,-24.452,0.078
ENSG00000175854,SWI5,12161.896,-8.424,-0.658
ENSG00000146063,TRIM41,4777.021,25.036,-0.247
ENSG00000235961,PNMA6A,4662.676,3.482,-0.131
ENSG00000165501,LRR1,1231.872,1.18,0.651

Unnamed: 0_level_0,symbol,gain_score_to_lifespan,shap,kendall_tau_to_lifespan
ensembl_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
ENSG00000196166,C8orf86,734163.666,-21.64,0.119
ENSG00000211693,TRGV11,196388.198,-168.727,-0.619
ENSG00000096654,ZNF184,87878.862,65.001,0.289
ENSG00000187912,CLEC17A,76750.492,-39.988,-0.028
ENSG00000226979,LTA,23853.347,-37.179,0.135
ENSG00000104835,SARS2,19252.304,-24.452,0.078
ENSG00000175854,SWI5,12161.896,-8.424,-0.658
ENSG00000146063,TRIM41,4777.021,25.036,-0.247
ENSG00000235961,PNMA6A,4662.676,3.482,-0.131
ENSG00000165501,LRR1,1231.872,1.18,0.651

Unnamed: 0,R^2,MSE,MAE
0,0.85,75.61,5.279
1,0.954,28.67,2.04
2,0.981,7.969,2.199
3,0.994,4.104,1.133
4,0.909,67.377,3.1


## One more re-training and again - we have totally different result ##

In [21]:
pipe.fit_transform(data)

Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[114]	valid_0's l1: 2.32957	valid_0's l2: 32.0603
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[247]	valid_0's l1: 2.62773	valid_0's l2: 63.8004
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[113]	valid_0's l1: 1.66096	valid_0's l2: 5.75718
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[54]	valid_0's l1: 2.92446	valid_0's l2: 24.9066
Training until validation scores don't improve for 7 rounds
Early stopping, best iteration is:
[51]	valid_0's l1: 3.85619	valid_0's l2: 36.8788


Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.
Setting feature_perturbation = "tree_path_dependent" because no background data was given.


MEAN metrics = R^2    0.949
MSE   32.681
MAE    2.680
dtype: float64


Unnamed: 0_level_0,symbol,gain_score_to_lifespan,shap,kendall_tau_to_lifespan
ensembl_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Unnamed: 0_level_2,R^2,MSE,MAE,Unnamed: 4_level_2
ENSG00000196166,C8orf86,725252.559,-32.767,0.065
ENSG00000211693,TRGV11,234871.937,-206.719,-0.648
ENSG00000096654,ZNF184,69714.26,29.646,0.229
ENSG00000187912,CLEC17A,60489.851,-5.328,-0.047
ENSG00000175854,SWI5,26683.365,-9.359,-0.636
ENSG00000104835,SARS2,16511.682,-22.628,0.115
ENSG00000184933,OR6A2,9377.53,6.091,0.487
ENSG00000134757,DSG3,6047.504,-2.382,-0.023
ENSG00000255181,CCDC166,5236.526,6.549,-0.086
ENSG00000179909,ZNF154,4797.725,3.106,-0.384

Unnamed: 0_level_0,symbol,gain_score_to_lifespan,shap,kendall_tau_to_lifespan
ensembl_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
ENSG00000196166,C8orf86,725252.559,-32.767,0.065
ENSG00000211693,TRGV11,234871.937,-206.719,-0.648
ENSG00000096654,ZNF184,69714.26,29.646,0.229
ENSG00000187912,CLEC17A,60489.851,-5.328,-0.047
ENSG00000175854,SWI5,26683.365,-9.359,-0.636
ENSG00000104835,SARS2,16511.682,-22.628,0.115
ENSG00000184933,OR6A2,9377.53,6.091,0.487
ENSG00000134757,DSG3,6047.504,-2.382,-0.023
ENSG00000255181,CCDC166,5236.526,6.549,-0.086
ENSG00000179909,ZNF154,4797.725,3.106,-0.384

Unnamed: 0,R^2,MSE,MAE
0,0.961,32.06,2.33
1,0.907,63.8,2.628
2,0.988,5.757,1.661
3,0.954,24.907,2.924
4,0.936,36.879,3.856
