# Import Libraries

In [1]:
import sklearn.ensemble as ensemble

from RiskLabAI.features.feature_importance.generate_synthetic_data import *
from RiskLabAI.features.feature_importance.clustering import *

from RiskLabAI.features.feature_importance.feature_importance_factory import FeatureImportanceFactory

from RiskLabAI.features.feature_importance.feature_importance_mdi import FeatureImportanceMDI
from RiskLabAI.features.feature_importance.feature_importance_mda import FeatureImportanceMDA
from RiskLabAI.features.feature_importance.feature_importance_sfi import FeatureImportanceSFI
from RiskLabAI.features.feature_importance.orthogonal_features import orthogonal_features

from RiskLabAI.features.feature_importance.clustered_feature_importance_mdi import ClusteredFeatureImportanceMDI

# Generate Synthetic Test Data

In [2]:
X, y = get_test_dataset(
    n_features=40,
    n_informative=5, 
    n_redundant=30,
    n_samples=10000,
    sigma_std=0.1
)

In [3]:
X.columns

Index(['I_0', 'I_1', 'I_2', 'I_3', 'I_4', 'N_0', 'N_1', 'N_2', 'N_3', 'N_4',
       'R_0', 'R_1', 'R_2', 'R_3', 'R_4', 'R_5', 'R_6', 'R_7', 'R_8', 'R_9',
       'R_10', 'R_11', 'R_12', 'R_13', 'R_14', 'R_15', 'R_16', 'R_17', 'R_18',
       'R_19', 'R_20', 'R_21', 'R_22', 'R_23', 'R_24', 'R_25', 'R_26', 'R_27',
       'R_28', 'R_29'],
      dtype='object')

# MDI

In [4]:
classifier = ensemble.RandomForestClassifier(
    n_estimators=10,
    max_features=1.0,
    max_samples=1.0,
    oob_score=False,
)

strategy = FeatureImportanceMDI(
    classifier, x=X, y=y,
)

results = FeatureImportanceFactory().\
        build(strategy).\
        get_results()

results = results.sort_values(by='Mean', ascending=False)
results

Unnamed: 0,Mean,StandardDeviation
I_2,0.121092,0.013242
R_5,0.091206,0.039162
I_1,0.085116,0.039256
R_18,0.064997,0.035767
R_19,0.059288,0.016105
R_16,0.049463,0.024808
R_12,0.045055,0.013701
R_13,0.042412,0.026281
I_4,0.041678,0.014187
R_11,0.027764,0.011287


# MDA

In [5]:
classifier = ensemble.RandomForestClassifier(
    n_estimators=10,
    max_features=1.0,
    max_samples=1.0,
    oob_score=False,
)

strategy = FeatureImportanceMDA(
    classifier, x=X, y=y, n_splits=5, 
)

results = FeatureImportanceFactory().\
        build(strategy).\
        get_results()

results = results.sort_values(by='Mean', ascending=False)
results

Fold 0 start ...
Fold 1 start ...
Fold 2 start ...
Fold 3 start ...
Fold 4 start ...


Unnamed: 0,Mean,StandardDeviation
I_3,-0.010264,0.024802
N_3,-0.013066,0.011061
N_2,-0.01547,0.014726
N_1,-0.015811,0.004653
N_0,-0.021184,0.018419
N_4,-0.022642,0.010627
R_28,-0.062935,0.013944
R_22,-0.069364,0.030597
R_3,-0.069815,0.011269
R_2,-0.070349,0.044298


# SFI

In [6]:
classifier = ensemble.RandomForestClassifier(
    n_estimators=10,
    max_features=1.0,
    max_samples=1.0,
    oob_score=False,
)

strategy = FeatureImportanceSFI(
    classifier, x=X, y=y, n_splits=5, 
)

results = FeatureImportanceFactory().\
        build(strategy).\
        get_results()

results = results.sort_values(by='Mean', ascending=False)
results

Unnamed: 0,FeatureName,Mean,StandardDeviation
26,R_16,-5.386203,0.283883
23,R_13,-5.444189,0.395572
28,R_18,-5.528684,0.272121
15,R_5,-5.59063,0.342327
25,R_15,-5.605896,0.366522
37,R_27,-5.684375,0.267969
1,I_1,-5.72725,0.290894
12,R_2,-7.723966,0.760229
32,R_22,-7.816613,0.902991
29,R_19,-7.860419,0.937492


# Clustered MDI

In [7]:
import warnings

warnings.filterwarnings('ignore')

corr0, clusters, silh = cluster_kmeans_base(
    X.corr(),
    number_clusters=25,
    iterations=20
)

In [8]:
classifier = ensemble.RandomForestClassifier(
    n_estimators=10,
    max_features=1.0,
    max_samples=1.0,
    oob_score=False,
)

strategy = ClusteredFeatureImportanceMDI(
    classifier, clusters=clusters, x=X, y=y, 
)

results = FeatureImportanceFactory().\
        build(strategy).\
        get_results()

results = results.sort_values(by='Mean', ascending=False)
results

Unnamed: 0,Mean,StandardDeviation
C_3,0.358832,0.002727
C_5,0.215858,0.0031
C_4,0.17317,0.00245
C_2,0.112733,0.003512
C_0,0.091413,0.002911
C_1,0.047993,0.001354


# Clustered MDA

In [9]:
import warnings

warnings.filterwarnings('ignore')

_, clusters, __ = cluster_kmeans_base(
    X.corr(),
    number_clusters=25,
    iterations=20
)

In [10]:
classifier = ensemble.RandomForestClassifier(
    n_estimators=10,
    max_features=1.0,
    max_samples=1.0,
    oob_score=False,
)

strategy = ClusteredFeatureImportanceMDI(
    classifier, clusters=clusters, x=X, y=y, 
)

results = FeatureImportanceFactory().\
        build(strategy).\
        get_results()

results = results.sort_values(by='Mean', ascending=False)
results

Unnamed: 0,Mean,StandardDeviation
C_2,0.356855,0.003315
C_4,0.208143,0.004026
C_3,0.18292,0.004439
C_0,0.112756,0.002157
C_1,0.090347,0.002545
C_5,0.048979,0.001931


# Orthogonal

In [11]:
orthogonal_features, eigen_dataframe = orthogonal_features(
    X,
    variance_threshold=0.95
)

orthogonal_features

Unnamed: 0,0,1,2,3,4,5,6,7,8
0,-1.966873,0.036157,0.173458,1.077774,-1.503938,0.130776,0.668909,-2.058400,-1.006863
1,-0.942325,0.077657,-0.353651,0.501068,-0.202194,0.613058,0.346068,-1.031550,0.131250
2,-0.716439,-0.100930,1.283325,0.848877,-0.707392,-0.737716,0.302418,-0.792607,0.121775
3,-1.134535,-0.250868,-0.106080,-0.126013,-2.322576,-0.998544,-0.297545,-0.326369,0.363300
4,-0.329927,1.214664,-0.432226,-0.428989,-0.210331,-0.518758,-0.687217,-0.505478,-0.652000
...,...,...,...,...,...,...,...,...,...
9995,1.376160,1.152348,1.887187,-0.651099,2.773032,1.152963,-1.127660,0.659086,0.587282
9996,1.462891,0.338382,0.231409,-0.444821,0.675293,-1.244900,-0.090487,1.449306,0.378862
9997,-0.418783,-1.451355,0.065706,0.890295,-1.736130,-0.795710,1.396982,0.312831,0.852758
9998,0.314944,-0.312340,1.215733,-0.166766,0.009004,0.138418,-0.108078,0.610226,1.498948


In [12]:
eigen_dataframe

Unnamed: 0,Index,EigenValue,EigenVector,CumulativeVariance
39,PC 40,109727.227777,"[-0.0002533457614532855, -0.002371335154301076...",0.274346
38,PC 39,100434.654131,"[-0.004061988096088499, -0.0028192330267524826...",0.525457
37,PC 38,67438.681324,"[-0.012750920514659276, -0.1553819808291369, -...",0.694071
36,PC 37,45755.859659,"[-0.1181753265816836, 0.013753531931235574, -0...",0.808472
35,PC 36,25059.321881,"[-0.0007120315058587331, -0.000790528582097759...",0.871126
34,PC 35,10281.068949,"[0.0025267470908919842, -0.0012484949515612909...",0.896832
33,PC 34,10158.207601,"[-0.12141511085293638, 0.00793315264629496, -0...",0.92223
32,PC 33,9969.954282,"[-0.0008416811807057875, 0.003441308930441123,...",0.947157
31,PC 32,9885.702847,"[-0.11697775455102763, 0.006973184401138588, -...",0.971874
