# Combining Clustering and Attribution methods

In [3]:
# imports

from cnn_architecture import CNN2Model
from utils import *
from load_datasets import load_and_prep_dataset

import tensorflow_datasets as tfds
import pandas as pd
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import random

In [11]:
for dataset in ["CIFAR","CINIC","SVHN"]:


    ####################################################################
    #       combine sign distributions, clusters and c_scores          #
    ####################################################################

    # get the clustering table
    clustered_sign_distr = pd.read_csv(f'5b Clusters/clustered_sign_distr_{dataset}_dense1.csv')
    # get the c_scores
    c_scores_all = pd.DataFrame()
    for i in range(15):
        c_scores = np.load(f"5d Contribution scores/c_scores_dense1_{dataset}_wt{i}.npy")
        c_scores = pd.Series(c_scores, name="c_scores")
        c_scores_all = pd.concat([c_scores_all,c_scores], axis=0)
    c_scores_all = c_scores_all.reset_index(drop=True)
    # combine
    clustered_sign_distr_combi = pd.concat([clustered_sign_distr,c_scores_all], axis=1)

    ####################################################################
    #       get statistics and observations                            #
    ####################################################################

    cluster_stats = pd.read_csv(f"5b Clusters/cluster_stats_{dataset}.csv")
    cluster_stats = cluster_stats.set_index("variable")

    cluster_rows = pd.DataFrame()
    # collect all means
    means = cluster_stats[["mean_c0","mean_c1","mean_c2","mean_c3"]]
    means_p_in = list(means.loc["prune_rate_in"])
    means_p_out = list(means.loc["prune_rate_out"])
    means_s_in = list(means.loc["sign_rate_in"])
    means_s_out = list(means.loc["sign_rate_out"])
    means_c_scores = []
    for cluster in range(4):
        mean_c_score = np.mean(clustered_sign_distr_combi[clustered_sign_distr_combi["cluster"]==cluster].loc[:,"c_scores"])
        means_c_scores.append(mean_c_score)

    for cluster in range(4):
        cluster_obs = {}

        # for prune in and prune out check whether, highest, lowest or medium
        mean_p_in = cluster_stats.loc["prune_rate_in"][f"mean_c{cluster}"]
        if mean_p_in == np.max(means_p_in):
            cluster_obs["prune_rate_in"] = "highest"
        elif mean_p_in == np.min(means_p_in):
            cluster_obs["prune_rate_in"] = "lowest"
        else:
            cluster_obs["prune_rate_in"] = "medium"
        #cluster_obs["prune_rate_in"] = cluster_obs["prune_rate_in"] + f"({np.round(mean_p_in,2)})"
        cluster_obs["prune_rate_in"] = cluster_obs["prune_rate_in"] + f"({mean_p_in})"

        mean_p_out = cluster_stats.loc["prune_rate_out"][f"mean_c{cluster}"]
        if mean_p_out == np.max(means_p_out):
            cluster_obs["prune_rate_out"] = "highest"
        elif mean_p_out == np.min(means_p_out):
            cluster_obs["prune_rate_out"] = "lowest"
        else:
            cluster_obs["prune_rate_out"] = "medium"
        cluster_obs["prune_rate_out"] = cluster_obs["prune_rate_out"] + f"({np.round(mean_p_out,2)})"

        # for sign in and out check whether balanced, positive or negative
        mean_s_in = cluster_stats.loc["sign_rate_in"][f"mean_c{cluster}"]
        if mean_s_in > 0.55:
            cluster_obs["sign_rate_in"] = "positive"
        elif mean_s_in < 0.45:
            cluster_obs["sign_rate_in"] = "negative"
        else:
            cluster_obs["sign_rate_in"] = "balanced"
        cluster_obs["sign_rate_in"] = cluster_obs["sign_rate_in"] + f"({np.round(mean_s_in,2)})"

        mean_s_out = cluster_stats.loc["sign_rate_out"][f"mean_c{cluster}"]
        if mean_s_out > 0.55:
            cluster_obs["sign_rate_out"] = "positive"
        elif mean_s_out < 0.45:
            cluster_obs["sign_rate_out"] = "negative"
        else:
            cluster_obs["sign_rate_out"] = "balanced"
        cluster_obs["sign_rate_out"] = cluster_obs["sign_rate_out"] + f"({np.round(mean_s_out,2)})"

        # for c_scores check whether, highest, lowest or medium
        mean_c_score = means_c_scores[cluster]
        if mean_c_score == np.max(means_c_scores):
            cluster_obs["c_scores"] = "highest"
        elif mean_c_score == np.min(means_c_scores):
            cluster_obs["c_scores"] = "lowest"
        else:
            cluster_obs["c_scores"] = "medium"
        #cluster_obs["c_scores"] = cluster_obs["c_scores"] + f"({np.round(mean_c_score,6)})"
        cluster_obs["c_scores"] = cluster_obs["c_scores"] + f"({+mean_c_score})"

        # enter the ratio of neurons
        cluster_obs["ratio"] = np.round(cluster_stats.loc["prune_rate_in"][f"ratio_c{cluster}"],2)

        # collect rows in dataframe
        cluster_row = pd.DataFrame(cluster_obs, index=[f"cluster_{cluster}"])
        cluster_rows = pd.concat([cluster_rows, cluster_row], axis=0)
        
    cluster_rows.index = ["cluster_0","cluster_1","cluster_2","cluster_3"]
    cluster_rows = cluster_rows.sort_values("ratio",ascending=False)
    display(cluster_rows)






    # get statistics and observations


# add the means to the observations

Unnamed: 0,prune_rate_in,prune_rate_out,sign_rate_in,sign_rate_out,c_scores,ratio
cluster_0,lowest(0.9070765990386508),medium(0.24),balanced(0.52),balanced(0.52),medium(0.0025227038653827306),0.6
cluster_2,medium(0.9570605948274712),highest(0.31),negative(0.44),balanced(0.51),medium(0.00028910683841705173),0.32
cluster_1,medium(0.9312464204945958),medium(0.22),positive(0.71),positive(0.58),highest(0.03405462616797443),0.07
cluster_3,highest(0.997359566066576),lowest(0.21),negative(0.02),balanced(0.49),lowest(2.0224951320673527e-09),0.01


Unnamed: 0,prune_rate_in,prune_rate_out,sign_rate_in,sign_rate_out,c_scores,ratio
cluster_0,lowest(0.8789691996843633),medium(0.28),balanced(0.5),balanced(0.53),medium(0.003887623825354077),0.5
cluster_2,medium(0.950068678412714),highest(0.36),negative(0.44),balanced(0.51),medium(0.00044622539518254735),0.45
cluster_1,medium(0.9187995062934028),lowest(0.27),positive(0.65),positive(0.61),highest(0.05039863628970256),0.05
cluster_3,highest(0.9957347196691176),medium(0.28),negative(0.03),balanced(0.5),lowest(1.7338606568599438e-10),0.01


Unnamed: 0,prune_rate_in,prune_rate_out,sign_rate_in,sign_rate_out,c_scores,ratio
cluster_2,medium(0.9486253526475694),highest(0.3),balanced(0.48),balanced(0.52),medium(0.001004765399084735),0.42
cluster_0,lowest(0.8862551664694761),medium(0.27),balanced(0.51),balanced(0.5),medium(0.004993789056938509),0.37
cluster_1,medium(0.900448052166718),medium(0.23),balanced(0.55),positive(0.56),highest(0.012080559253671453),0.18
cluster_3,highest(0.9940372043185765),lowest(0.21),negative(0.06),balanced(0.5),lowest(4.735172130351183e-10),0.03
