# Combining Clustering and Attribution methods

In [1]:
# imports

from cnn_architecture import CNN2Model
from utils import *
from load_datasets import load_and_prep_dataset

import tensorflow_datasets as tfds
import pandas as pd
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import random
import ast

2024-06-28 20:46:16.496080: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
for dataset in ["CIFAR","CINIC","SVHN"]:


    ####################################################################
    #       combine sign distributions, clusters and c_scores          #
    ####################################################################

    # get the clustering table
    clustered_sign_distr = pd.read_csv(f'5b Clusters/clustered_sign_distr_{dataset}_dense1.csv')
    # get the c_scores
    c_scores_all = pd.DataFrame()
    for i in range(15):
        c_scores = np.load(f"5d Contribution scores/c_scores_dense1_{dataset}_wt{i}.npy")
        c_scores = pd.Series(c_scores, name="c_scores")
        c_scores_all = pd.concat([c_scores_all,c_scores], axis=0)
    c_scores_all = c_scores_all.reset_index(drop=True)
    # combine
    clustered_sign_distr_combi = pd.concat([clustered_sign_distr,c_scores_all], axis=1)

    ####################################################################
    #       get statistics and observations                            #
    ####################################################################

    cluster_stats = pd.read_csv(f"5b Clusters/cluster_stats_{dataset}.csv")
    cluster_stats = cluster_stats.set_index("variable")

    cluster_rows = pd.DataFrame()
    # collect all means
    means = cluster_stats[["mean_c0","mean_c1","mean_c2","mean_c3"]]
    means_p_in = list(means.loc["prune_rate_in"])
    means_p_out = list(means.loc["prune_rate_out"])
    means_s_in = list(means.loc["sign_rate_in"])
    means_s_out = list(means.loc["sign_rate_out"])
    means_c_scores = []
    for cluster in range(4):
        mean_c_score = np.mean(clustered_sign_distr_combi[clustered_sign_distr_combi["cluster"]==cluster].loc[:,"c_scores"])
        means_c_scores.append(mean_c_score)

    for cluster in range(4):
        cluster_obs = {}

        # for prune in and prune out check whether, highest, lowest or medium
        mean_p_in = cluster_stats.loc["prune_rate_in"][f"mean_c{cluster}"]
        if mean_p_in == np.max(means_p_in):
            cluster_obs["prune_rate_in"] = "highest"
        elif mean_p_in == np.min(means_p_in):
            cluster_obs["prune_rate_in"] = "lowest"
        else:
            cluster_obs["prune_rate_in"] = "medium"
        cluster_obs["prune_rate_in"] = cluster_obs["prune_rate_in"] + f"({np.round(mean_p_in,3)})"

        mean_p_out = cluster_stats.loc["prune_rate_out"][f"mean_c{cluster}"]
        if mean_p_out == np.max(means_p_out):
            cluster_obs["prune_rate_out"] = "highest"
        elif mean_p_out == np.min(means_p_out):
            cluster_obs["prune_rate_out"] = "lowest"
        else:
            cluster_obs["prune_rate_out"] = "medium"
        cluster_obs["prune_rate_out"] = cluster_obs["prune_rate_out"] + f"({np.round(mean_p_out,3)})"

        # for sign in and out check whether balanced, positive or negative
        mean_s_in = cluster_stats.loc["sign_rate_in"][f"mean_c{cluster}"]
        if mean_s_in > 0.55:
            cluster_obs["sign_rate_in"] = "positive"
        elif mean_s_in < 0.45:
            cluster_obs["sign_rate_in"] = "negative"
        else:
            cluster_obs["sign_rate_in"] = "balanced"
        cluster_obs["sign_rate_in"] = cluster_obs["sign_rate_in"] + f"({np.round(mean_s_in,3)})"

        mean_s_out = cluster_stats.loc["sign_rate_out"][f"mean_c{cluster}"]
        if mean_s_out > 0.55:
            cluster_obs["sign_rate_out"] = "positive"
        elif mean_s_out < 0.45:
            cluster_obs["sign_rate_out"] = "negative"
        else:
            cluster_obs["sign_rate_out"] = "balanced"
        cluster_obs["sign_rate_out"] = cluster_obs["sign_rate_out"] + f"({np.round(mean_s_out,3)})"

        # enter the ratio of neurons
        cluster_obs["ratio"] = np.round(cluster_stats.loc["prune_rate_in"][f"ratio_c{cluster}"],3)

        # for c_scores check whether, highest, lowest or medium
        mean_c_score = means_c_scores[cluster]
        if mean_c_score == np.max(means_c_scores):
            cluster_obs["c_scores"] = "highest"
        elif mean_c_score == np.min(means_c_scores):
            cluster_obs["c_scores"] = "lowest"
        else:
            cluster_obs["c_scores"] = "medium"
        rounded_mean = "%.3g" % float(mean_c_score)
        cluster_obs["c_scores"] = cluster_obs["c_scores"] + f"({rounded_mean})"

        # collect rows in dataframe
        cluster_row = pd.DataFrame(cluster_obs, index=[f"cluster_{cluster}"])
        cluster_rows = pd.concat([cluster_rows, cluster_row], axis=0)
        
    cluster_rows.index = ["cluster_0","cluster_1","cluster_2","cluster_3"]
    cluster_rows = cluster_rows.sort_values("ratio",ascending=False)
    display(cluster_rows)
    print(cluster_rows.to_latex())


Unnamed: 0,prune_rate_in,prune_rate_out,sign_rate_in,sign_rate_out,ratio,c_scores
cluster_0,lowest(0.907),medium(0.244),balanced(0.516),balanced(0.517),0.603,medium(0.00252)
cluster_2,medium(0.957),highest(0.312),negative(0.443),balanced(0.509),0.318,medium(0.000289)
cluster_1,medium(0.931),medium(0.216),positive(0.713),positive(0.582),0.072,highest(0.0341)
cluster_3,highest(0.997),lowest(0.214),negative(0.021),balanced(0.49),0.006,lowest(2.02e-09)


\begin{tabular}{lllllrl}
\toprule
 & prune_rate_in & prune_rate_out & sign_rate_in & sign_rate_out & ratio & c_scores \\
\midrule
cluster_0 & lowest(0.907) & medium(0.244) & balanced(0.516) & balanced(0.517) & 0.603000 & medium(0.00252) \\
cluster_2 & medium(0.957) & highest(0.312) & negative(0.443) & balanced(0.509) & 0.318000 & medium(0.000289) \\
cluster_1 & medium(0.931) & medium(0.216) & positive(0.713) & positive(0.582) & 0.072000 & highest(0.0341) \\
cluster_3 & highest(0.997) & lowest(0.214) & negative(0.021) & balanced(0.49) & 0.006000 & lowest(2.02e-09) \\
\bottomrule
\end{tabular}



Unnamed: 0,prune_rate_in,prune_rate_out,sign_rate_in,sign_rate_out,ratio,c_scores
cluster_0,lowest(0.879),medium(0.282),balanced(0.505),balanced(0.528),0.498,medium(0.00389)
cluster_2,medium(0.95),highest(0.36),negative(0.443),balanced(0.505),0.447,medium(0.000446)
cluster_1,medium(0.919),lowest(0.265),positive(0.651),positive(0.605),0.051,highest(0.0504)
cluster_3,highest(0.996),medium(0.28),negative(0.035),balanced(0.505),0.005,lowest(1.73e-10)


\begin{tabular}{lllllrl}
\toprule
 & prune_rate_in & prune_rate_out & sign_rate_in & sign_rate_out & ratio & c_scores \\
\midrule
cluster_0 & lowest(0.879) & medium(0.282) & balanced(0.505) & balanced(0.528) & 0.498000 & medium(0.00389) \\
cluster_2 & medium(0.95) & highest(0.36) & negative(0.443) & balanced(0.505) & 0.447000 & medium(0.000446) \\
cluster_1 & medium(0.919) & lowest(0.265) & positive(0.651) & positive(0.605) & 0.051000 & highest(0.0504) \\
cluster_3 & highest(0.996) & medium(0.28) & negative(0.035) & balanced(0.505) & 0.005000 & lowest(1.73e-10) \\
\bottomrule
\end{tabular}



Unnamed: 0,prune_rate_in,prune_rate_out,sign_rate_in,sign_rate_out,ratio,c_scores
cluster_2,medium(0.949),highest(0.298),balanced(0.475),balanced(0.517),0.417,medium(0.001)
cluster_0,lowest(0.886),medium(0.272),balanced(0.511),balanced(0.5),0.37,medium(0.00499)
cluster_1,medium(0.9),medium(0.233),balanced(0.547),positive(0.564),0.181,highest(0.0121)
cluster_3,highest(0.994),lowest(0.21),negative(0.058),balanced(0.496),0.032,lowest(4.74e-10)


\begin{tabular}{lllllrl}
\toprule
 & prune_rate_in & prune_rate_out & sign_rate_in & sign_rate_out & ratio & c_scores \\
\midrule
cluster_2 & medium(0.949) & highest(0.298) & balanced(0.475) & balanced(0.517) & 0.417000 & medium(0.001) \\
cluster_0 & lowest(0.886) & medium(0.272) & balanced(0.511) & balanced(0.5) & 0.370000 & medium(0.00499) \\
cluster_1 & medium(0.9) & medium(0.233) & balanced(0.547) & positive(0.564) & 0.181000 & highest(0.0121) \\
cluster_3 & highest(0.994) & lowest(0.21) & negative(0.058) & balanced(0.496) & 0.032000 & lowest(4.74e-10) \\
\bottomrule
\end{tabular}

