# Import Libraries

In [1]:
using DataFrames
using DataFramesMeta
using PyCall
using PlotlyJS
using Random
import Statistics: cor
using Clustering

@pyimport sklearn.metrics as Metrics
@pyimport sklearn.ensemble as Ensemble
@pyimport sklearn.tree as Tree
@pyimport sklearn.datasets as Datasets
@pyimport sklearn.model_selection as ModelSelection

include("GenerateSyntheticData.jl");
include("Clustering.jl")
include("ClusteredMeanDecreaseImpurity.jl")



clusteredFeatureImportanceMDI

# Generate Synthetic Test Data

In [2]:
X, y = getTestData(
    nFeatures=40,
    nInformative=5, 
    nRedundant=30,
    nSamples=10000,
    sigmaStd=0.1);

# Fit Clustered MDI

## Clustering

In [3]:
corr0, clusters, silh = clusterKMeansBase(
    X |> Matrix |> Statistics.cor,
    numberClusters=25,
    iterations=20);

clusters

Dict{String, Vector{Int64}} with 6 entries:
  "4" => [6, 7, 8, 9, 10]
  "1" => [4, 13, 14, 17, 18, 19, 26, 32, 40]
  "5" => [2, 12, 23, 29, 33, 37]
  "2" => [5, 15, 34, 36, 38, 39]
  "6" => [1, 11, 16, 20, 24, 27, 28, 30]
  "3" => [3, 21, 22, 25, 31, 35]

## Fit

In [5]:
classifier = Tree.DecisionTreeClassifier(
    criterion="entropy",
    max_features=1,
    class_weight="balanced",
    min_weight_fraction_leaf=0,)

classifier = Ensemble.BaggingClassifier(
    base_estimator=classifier,
    n_estimators=500,
    max_features=1.0,
    max_samples=1.0,
    oob_score=false)

fit = classifier.fit(X |> Matrix, (y |> Matrix)[:, 1])

importances = clusteredFeatureImportanceMDI(fit, names(X), clusters)
importances

Unnamed: 0_level_0,ClusterIndex,Mean,StandardDeviation
Unnamed: 0_level_1,Any,Float64,Float64
1,Cluster 4,0.0686559,0.00096215
2,Cluster 1,0.0957109,0.000770608
3,Cluster 5,0.237582,0.008809
4,Cluster 2,0.163763,0.00243807
5,Cluster 6,0.181276,0.00679068
6,Cluster 3,0.253012,0.0080844


# Plot Results

In [6]:
templates.default = "plotly_dark";
PlotlyJS.templates

toSavePlot = plot(
    bar(
        importances,
        x=:Mean,
        y=:ClusterIndex,
        error_x=attr(type="data", array=:StandardDeviation, visible=true),
        orientation="h",),
    PlotlyJS.Layout(
        title="MDA Results",
        width=800, height=600,
        xaxis_title="Feature Importance",
        yaxis_title="Cluster Index",
    )
)

# Save Results

In [8]:
PlotlyJS.savefig(toSavePlot, "Figs/clustered_MDI_results.png")

"Figs/clustered_MDI_results.png"