# Import Libraries

In [1]:
using DataFrames
using DataFramesMeta
using PyCall
using PlotlyJS
using Random
import Statistics: cor
using Clustering

@pyimport sklearn.metrics as Metrics
@pyimport sklearn.ensemble as Ensemble
@pyimport sklearn.tree as Tree
@pyimport sklearn.datasets as Datasets
@pyimport sklearn.model_selection as ModelSelection

include("GenerateSyntheticData.jl");
include("Clustering.jl")



covToCorr

# Generate Synthetic Test Data

In [2]:
X, y = getTestData(
    nFeatures=40,
    nInformative=5, 
    nRedundant=30,
    nSamples=10000,
    sigmaStd=0.1);

# Cluster Features

In [3]:
function clusterKMeansBase(correlation; 
                           numberClusters = 10, 
                           iterations = 10)
    distance = sqrt.((1 .- correlation)/2) # distance matrix
    silh, kmeansOut = [NaN], [NaN] # initial value for silh, kmeans
    for init ∈ 1:iterations
        for i ∈ 2:numberClusters
            kmeans_ = kmeans(distance, i) # clustering distance with maximum cluster i
            silh_ = Metrics.silhouette_samples(distance, assignments(kmeans_)) # silh score of clustering
            statistic = (mean(silh_)/std(silh_), mean(silh)/std(silh)) # calculate t-statistic
            if isnan(statistic[2]) || statistic[1]>statistic[2]
                silh, kmeansOut = silh_, kmeans_ # replace better clustering
            end
        end
    end
    indexSorted = sortperm(assignments(kmeansOut)) # sort arguments based on clustering
    correlationSorted = correlation[indexSorted, indexSorted] # new corr matrix based on clustering
    # dictionary of clustering
    clusters = Dict("$i"=> filter(p->assignments(kmeansOut)[p] == i, indexSorted) for i in unique(assignments(kmeansOut)))
    silh = DataFrames.DataFrame(silh = silh) # dataframe of silh scores
    return correlationSorted, clusters, silh, indexSorted, kmeansOut
end

correlationSorted, clusters, silh, indexSorted = clusterKMeansBase(
    X |> Matrix |> cor,
    numberClusters=25,
    iterations=20
);

# Plot Results

In [4]:
columnsSorted = names(X)[indexSorted]

templates.default = "plotly_dark";
PlotlyJS.templates

toSavePlot = plot(
    heatmap(z=correlationSorted, x=columnsSorted, y=columnsSorted),
    Layout(
        title="ONC clusters together with informative and redundant features.",
        xaxis_title="Features",
        yaxis_title="Features",
        height=800, width=800
    )
)   

# Save Results

In [5]:
PlotlyJS.savefig(toSavePlot, "Figs/clusters_together_with_informative_and_redundant_features.png")

"Figs/clusters_together_with_informative_and_redundant_features.png"