In [32]:
import pandas as pd
import numpy as np
from os.path import join as oj

from imodels.importance import R2FExp, GeneralizedMDI, GeneralizedMDIJoint
from imodels.importance import LassoScorer, RidgeScorer, ElasticNetScorer, RobustScorer, LogisticScorer, JointRidgeScorer, JointLogisticScorer, JointRobustScorer
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.inspection import permutation_importance
import shap

In [2]:
DATA_DIR = "../../data"

In [3]:
X_df = pd.read_csv(oj(DATA_DIR, "X_tcga_var_filtered_log_transformed.csv"))
X = X_df.to_numpy()
y = pd.read_csv(oj(DATA_DIR, "Y_tcga.csv")).to_numpy().ravel()

In [5]:
X_df.head()

Unnamed: 0,A2M,NAT1,RP11.986E7.7,AAMP,AARS,ABAT,ABCA2,ABCA3,ABCF1,ABL1,...,EIF3CL.1,PLIN4,MTPN,RELL1,SNHG8.1,OST4,TSTD1,GAGE12J.4,NBPF10,PRICKLE4.1
0,9.614527,7.401763,10.859011,7.765291,7.625208,7.715826,7.500286,6.677448,7.228231,7.572464,...,9.237782,5.359995,8.073793,6.500534,6.411474,7.729637,7.369472,0.0,6.609727,7.551117
1,8.725323,9.375599,10.128432,7.388699,7.65829,7.463426,7.031804,8.70224,7.566244,7.363025,...,9.349586,6.583817,8.1859,6.551568,3.425952,7.161099,7.052134,0.0,7.665767,7.310317
2,8.671585,7.756825,10.690369,8.064953,7.976935,8.461162,7.444768,8.378548,7.731562,7.365497,...,9.417242,4.251032,8.359054,7.816944,5.631785,7.984734,7.587183,0.309101,6.425312,7.941761
3,9.33917,6.155162,7.340154,7.815608,8.993226,6.216886,6.922385,7.90677,8.111439,7.358005,...,9.887913,4.239956,8.562925,7.082101,7.108862,7.41429,5.987086,0.0,7.137403,7.841682
4,8.90942,4.402246,8.196574,8.366875,7.949886,5.881482,8.07633,7.432998,7.809034,7.864155,...,9.525658,5.304333,8.157452,6.425711,6.033721,8.060675,7.341724,0.0,6.776412,7.650708


In [6]:
y

array(['LumA', 'LumB', 'LumA', ..., 'LumA', 'LumA', 'LumA'], dtype=object)

In [7]:
X.shape, y.shape

((1083, 5000), (1083,))

## gjMDI Ridge

In [8]:
rf_model = RandomForestClassifier(n_estimators=100, min_samples_leaf=1, max_features="sqrt")

In [11]:
%%time
scorer = JointRidgeScorer(criterion="gcv")
gjMDI_obj = GeneralizedMDIJoint(rf_model, scorer=scorer, normalize_raw=True, random_state=331)
imp_values, scores, n_stumps, n_stumps_chosen, class_scores = gjMDI_obj.get_importance_scores(X, y, diagnostics=True)

CPU times: user 1min 32s, sys: 13.6 s, total: 1min 45s
Wall time: 48.6 s


In [13]:
pd.DataFrame({"feature": X_df.columns, "importance": imp_values}).sort_values("importance", ascending=False)

Unnamed: 0,feature,importance
626,ESR1,0.019130
731,GATA3,0.014185
686,FOXC1,0.013678
885,FOXA1,0.012837
1747,RRM2,0.012595
...,...,...
809,GRSF1,-0.000082
1857,SMARCA4,-0.000082
1852,SLIT3,-0.000084
344,CHML,-0.000119


In [14]:
scores.shape

(100, 5000)

In [15]:
class_scores.shape

(100, 5000, 5)

In [18]:
mean_class_scores = class_scores.mean(axis=0)

In [21]:
for class_idx, class_label in enumerate(gjMDI_obj.estimator.classes_):
    print(pd.DataFrame({"feature": X_df.columns, class_label: mean_class_scores[:, class_idx]}).sort_values(class_label, ascending=False))

      feature     Basal
686     FOXC1  0.054506
885     FOXA1  0.048119
626      ESR1  0.028315
731     GATA3  0.028036
4289     MLPH  0.023093
...       ...       ...
4853  GLYATL2 -0.000484
1747     RRM2 -0.000488
2773   TUBA1B -0.000499
4334    ASB13 -0.000525
3531  RACGAP1 -0.000544

[5000 rows x 2 columns]
     feature      Her2
626     ESR1  0.021726
186     BCL2  0.011076
796     GRB7  0.010931
1747    RRM2  0.010186
1890   SOX11  0.009610
...      ...       ...
2895   MGEA5 -0.000258
1338   NUMA1 -0.000290
3520   ATAD2 -0.000302
2024    TMPO -0.000344
4694   DEGS2 -0.000414

[5000 rows x 2 columns]
     feature      LumA
1747    RRM2  0.020509
626     ESR1  0.019233
2393    PRC1  0.016170
1234   MYBL2  0.015886
731    GATA3  0.015177
...      ...       ...
3324  TRIM29 -0.000470
2217    FZD4 -0.000482
2203     MIA -0.000543
720    GABRP -0.000591
1051   KRT16 -0.000774

[5000 rows x 2 columns]
       feature      LumB
1803     SFRP1  0.014807
626       ESR1  0.012417
3531   RAC

## MDI and Permutation

Note: MDI results are very unstable

In [27]:
rf_model.fit(X, y)
pd.DataFrame({"feature": X_df.columns, "MDI": rf_model.feature_importances_}).sort_values("MDI", ascending=False)

Unnamed: 0,feature,MDI
1989,TFF3,0.013160
2127,XBP1,0.012060
885,FOXA1,0.009465
688,FOXM1,0.009199
3114,TPX2,0.008971
...,...,...
3644,MRPL37,0.000000
3642,SHISA5,0.000000
3641,MZB1,0.000000
3640,FAM203A,0.000000


In [28]:
rf_model.fit(X, y)
pd.DataFrame({"feature": X_df.columns, "MDI": rf_model.feature_importances_}).sort_values("MDI", ascending=False)

Unnamed: 0,feature,MDI
4341,ZNF552,0.011308
1449,PLK1,0.011082
626,ESR1,0.010124
688,FOXM1,0.009293
731,GATA3,0.008125
...,...,...
2658,ABI1,0.000000
2659,HDAC5,0.000000
2660,PDCD6IP,0.000000
2667,UBA2,0.000000


In [29]:
perm_results = permutation_importance(rf_model, X, y, n_repeats=10, random_state=0)
pd.DataFrame({"feature": X_df.columns, "Permutation": perm_results.importances_mean}).sort_values("Permutation", ascending=False)

KeyboardInterrupt: 

In [None]:
perm_results = permutation_importance(rf_model, X, y, n_repeats=10, random_state=331)
pd.DataFrame({"feature": X_df.columns, "Permutation": perm_results.importances_mean}).sort_values("Permutation", ascending=False)

## TreeSHAP

In [33]:
explainer = shap.TreeExplainer(rf_model)
shap_values = explainer.shap_values(X)
shap_values
# results = abs(shap_values)
# results = results.mean(axis=0)

[array([[ 3.16174679e-06, -2.74692621e-03,  9.28960353e-05, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [-1.15930716e-05, -1.13244270e-03,  4.74957173e-05, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [-2.11362703e-05, -9.63393690e-04,  5.53187835e-05, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        ...,
        [ 2.28530572e-06, -1.26429179e-03,  6.08020185e-05, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [-1.15930716e-05, -1.03765433e-03,  1.99174636e-06, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 2.09019333e-05, -2.46904165e-03,  4.72206470e-05, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00]]),
 array([[-1.54514781e-05, -5.84947150e-04, -5.20077443e-04, ...,
          0.00000000e+00, -5.13663308e-05,  0.00000000e+00],
        [ 6.45471991e-05, -2.46507495e-04, -1.96589508e-04, ...,
          0.00000000e+00, -1.85695556e

AttributeError: 'list' object has no attribute 'shape'