# XAI THESIS

## Startup

### GPU Settings

### Imports

In [None]:
#############################
#							#
#         Main.py			#
#							#
#############################

# Import needed libraries
from dis import dis
import os
from sre_constants import GROUPREF_EXISTS
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Suppress warnings
import tensorflow as tf
print('Tensorflow version: ' + tf.__version__)
print('Keras version: ' + tf.keras.__version__)
tf.get_logger().setLevel('ERROR') # Suppress warnings
from buildModel import buildModelVGG16ExplicitTL, buildModelVGG16ExplicitFT
from dataLoader import loadPreprocessedData
import random
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.metrics import confusion_matrix
from datetime import datetime
import numpy as np
from scipy.spatial import distance 
from tensorflow.keras.applications.vgg16 import preprocess_input as preprocess_vgg16

# Test Tensorflow version
tfk = tf.keras
tfkl = tf.keras.layers
print('')

### Seed

In [None]:
seed = 432 #543
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
tf.compat.v1.set_random_seed(seed)

### Folder Settings

In [None]:
now = datetime.now().strftime('%b%d_%H-%M-%S')
dataset_dir = 'imagenette2-320'
model_dir = 'CNN'

training_dir = os.path.join(dataset_dir, 'train')
validation_dir = os.path.join(dataset_dir, 'val')
test_dir = os.path.join(dataset_dir, 'test')

### Labels Settings

In [None]:
labels = sorted(os.listdir(training_dir))

print('Using dataset from folder: ' + dataset_dir)

## VGG-16 Network

### Main parameters

In [None]:
#######################################################
#   VGG16 Network (Transfer Learning & Fine Tuning)   #
#######################################################

# Define model metadata
input_shape = (256, 256, 3)
classes = 10
tl_epochs = 200
ft_epochs = 200

# Load dataset
batch_size = 32
train_val_gen = loadPreprocessedData(training_dir, validation_dir, test_dir, seed, batch_size, preprocess_vgg16)
train_gen = train_val_gen['train']
valid_gen = train_val_gen['validation']

#### Network Training settings

In [None]:
if_train = 'n'

### Neural Network

In [None]:
if if_train == 'Y':
### TRANSFER LEARNING ###

# Create model
        model = buildModelVGG16ExplicitTL(input_shape, classes, tfk, tfkl, seed)

        # Create folders
        exps_dir = os.path.join(model_dir)
        if not os.path.exists(exps_dir): os.makedirs(exps_dir)
        exp_dir = os.path.join(exps_dir, 'CNN_' + str(now))
        if not os.path.exists(exp_dir): os.makedirs(exp_dir)

        # Train the model
        history = model.fit(
                x = train_gen,
                epochs = tl_epochs,
                validation_data = valid_gen,
                callbacks = [tfk.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=10, restore_best_weights=True)]
        ).history

        # Save best epoch model
        print()
        model.save(model_dir + "/" + str(now) + '/model_tl')
        np.save(model_dir + "/" + str(now) + "/history_tl.npy", history)
        print()

        ### FINE TUNING ###

        # Create model
        model = buildModelVGG16ExplicitFT(model, tfk)

        # Create folders
        exps_dir = os.path.join(model_dir)
        if not os.path.exists(exps_dir): os.makedirs(exps_dir)
        exp_dir = os.path.join(exps_dir, 'CNN_' + str(now))
        if not os.path.exists(exp_dir): os.makedirs(exp_dir)

        # Train the model
        history_ft = model.fit(
                x = train_gen,
                epochs = ft_epochs,
                validation_data = valid_gen,
                callbacks = [tfk.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=10, restore_best_weights=True)]
        ).history

        # Save best epoch model
        print()
        model.save(model_dir + "/" + str(now) + '/model_ft')
        np.save(model_dir + "/" + str(now) + "/history_ft.npy", history_ft)
        print()

        # Plot the training history
        plt.figure(figsize=(15, 5))
        plt.plot(history['loss'], label='Training TL',
                alpha=.3, color='#ff7f0e', linestyle='--')
        plt.plot(history['val_loss'],
                label='Validation TL', alpha=.8, color='#ff7f0e')
        plt.plot(history_ft['loss'], label='Training FT',
                alpha=.3, color='#8fce00', linestyle='--')
        plt.plot(history_ft['val_loss'],
                label='Validation FT', alpha=.8, color='#8fce00')
        plt.legend(loc='upper left')
        plt.title('Categorical Crossentropy')
        plt.grid(alpha=.3)
        plt.figure(figsize=(15, 5))
        plt.plot(history['accuracy'], label='Training TL',
                alpha=.8, color='#ff7f0e', linestyle='--')
        plt.plot(history['val_accuracy'],
                label='Validation TL', alpha=.8, color='#ff7f0e')
        plt.plot(history_ft['accuracy'], label='Training FT',
                alpha=.8, color='#8fce00', linestyle='--')
        plt.plot(history_ft['val_accuracy'],
                label='Validation FT', alpha=.8, color='#8fce00')
        plt.legend(loc='upper left')
        plt.title('Accuracy')
        plt.grid(alpha=.3)
        plt.show()

        # Evaluation
        predictions = model.predict(valid_gen)

        # Compute the confusion matrix
        cmat = confusion_matrix(valid_gen.classes, np.argmax(predictions, axis=-1))

        # Compute the classification metrics
        accuracy = accuracy_score(valid_gen.classes, np.argmax(predictions, axis=-1))
        precision = precision_score(valid_gen.classes, np.argmax(predictions, axis=-1), average='macro')
        recall = recall_score(valid_gen.classes, np.argmax(predictions, axis=-1), average='macro')
        f1 = f1_score(valid_gen.classes, np.argmax(predictions, axis=-1), average='macro')
        print()
        print('Validation Metrics:')
        print('Accuracy:',accuracy.round(4))
        print('Precision:',precision.round(4))
        print('Recall:',recall.round(4))
        print('F1:',f1.round(4))

        #Plot the confusion matrix
        plt.figure(figsize=(10,8))
        sns.heatmap(cmat.T, xticklabels=labels, yticklabels=labels)
        plt.xlabel('True labels')
        plt.ylabel('Predicted labels')
        plt.show()

else:
        model_path = input("Insert the path to the model folder you want to use: ")
        model = tf.keras.models.load_model(model_path)
        model.summary()

## Feature Maps Extraction

### Image choice

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

valid_data_generator_np = ImageDataGenerator(rescale=1/255.)

valid_generator_np = valid_data_generator_np.flow_from_directory(directory=validation_dir,
                                                target_size=(256,256),
                                                color_mode='rgb',
                                                classes=None, # can be set to labels
                                                class_mode='categorical',
                                                batch_size=1,
                                                shuffle=True,
                                                seed=seed)

valid_data_generator_p = ImageDataGenerator(preprocessing_function=preprocess_vgg16)

valid_generator_p = valid_data_generator_p.flow_from_directory(directory=validation_dir,
                                                target_size=(256,256),
                                                color_mode='rgb',
                                                classes=None, # can be set to labels
                                                class_mode='categorical',
                                                batch_size=1,
                                                shuffle=True,
                                                seed=seed)

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

valid_data_generator_np = ImageDataGenerator(rescale=1/255.)

valid_generator_np = valid_data_generator_np.flow_from_directory(directory=test_dir,
                                                target_size=(256,256),
                                                color_mode='rgb',
                                                classes=None, # can be set to labels
                                                class_mode='categorical',
                                                batch_size=50,
                                                shuffle=True,
                                                seed=seed)

valid_data_generator_p = ImageDataGenerator(preprocessing_function=preprocess_vgg16)

valid_generator_p = valid_data_generator_p.flow_from_directory(directory=test_dir,
                                                target_size=(256,256),
                                                color_mode='rgb',
                                                classes=None, # can be set to labels
                                                class_mode='categorical',
                                                batch_size=50,
                                                shuffle=True,
                                                seed=seed)

In [None]:
batch_np = next(valid_generator_np)
batch_p = next(valid_generator_p)
i = 0

In [None]:
image = batch_p[0][i]
original_image = batch_np[0][i]

plt.imshow(original_image)
plt.show()

predicted_class = model.predict(image.reshape(1, 256, 256, 3))[0].tolist()

max_value = max(predicted_class)
mlc = predicted_class.index(max_value)
print("Predicted Class: " + labels[mlc])
print("\nConfidence")
confidence = list(map(lambda x, y: x+ ': ' +str(np.around(y*100, 3)) + '%', labels, predicted_class))
for conf in confidence: print(conf)
i += 1

### Feature maps from Grad-CAM + weights + Grad-CAMs

In [None]:
from statistics import mean
import skimage.filters
from matplotlib import cm

model.layers[-1].activation = None

conv_layers = [layer.output for layer in model.layers if isinstance(layer, tf.keras.layers.Conv2D)]  

fmaps = []
weights = []
grad_cam = []

for layer in conv_layers:

   gradModel = tfk.models.Model(
      inputs = [model.inputs],
      outputs = [layer, model.output]
   )

   with tf.GradientTape() as tape:
      inputs = tf.cast(tf.expand_dims(image, 0), tf.float32)
      (convOutputs, predictions) = gradModel(inputs)
      loss = predictions[..., 2]
   
   grads = tape.gradient(loss, convOutputs)

   castConvOutputs = tf.cast(convOutputs > 0, "float32")
   castGrads = tf.cast(grads > 0, "float32")
   guidedGrads = castConvOutputs * castGrads * grads

   
   convOutputs = convOutputs[0]
   guidedGrads = guidedGrads[0]

   w = tf.reduce_mean(guidedGrads, axis=(0, 1))
   cam = tf.reduce_sum(tf.multiply(w, convOutputs), axis=-1)

   cam = tf.image.resize(np.expand_dims(cam, axis=2), [256, 256])

   cam = np.divide(
                np.subtract(cam, np.min(cam)), (np.max(cam)-np.min(cam))) if np.max(cam)-np.min(cam) != 0 else cam

   fmaps.append(convOutputs)
   weights.append(w)
   grad_cam.append(cam)

print(predictions)


for g in grad_cam:

   g_heatmap = (cm.jet(g))[:,:,0,0:3]
   
   g_heatmap = skimage.filters.gaussian(
      g_heatmap, sigma=(1.0, 1.0), truncate=6.5, channel_axis=2)
   g_heatmap = np.add(np.multiply(g_heatmap,0.5), np.multiply(original_image, 0.5))

   plt.imshow(g_heatmap)
   plt.show()
   
mean_gcam = np.max(grad_cam, axis=0)
mean_gcam = np.divide(
                np.subtract(mean_gcam, np.min(mean_gcam)), (np.max(mean_gcam)-np.min(mean_gcam))) if np.max(mean_gcam)-np.min(mean_gcam) != 0 else mean_gcam
mean_gcam_heatmap = (cm.jet(mean_gcam))[:,:,0,0:3]
mean_gcam_heatmap = skimage.filters.gaussian(
         mean_gcam_heatmap, sigma=(1.0, 1.0), truncate=6.5, channel_axis=2)
mean_gcam_overlay = np.add(np.multiply(mean_gcam_heatmap,0.5), np.multiply(original_image, 0.5))

plt.imshow(mean_gcam_overlay)
plt.show()

model.layers[-1].activation = tf.keras.activations.softmax

## Feature Maps Clustering

### Imports

In [None]:
from sklearn.cluster import KMeans, DBSCAN, OPTICS
from sklearn.cluster import AgglomerativeClustering
from sklearn import metrics
from skimage.transform import resize
from yellowbrick.cluster import silhouette_visualizer
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import AffinityPropagation
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import skimage.filters
from matplotlib import cm
from mpl_toolkits.axes_grid1 import ImageGrid

### Utility Functions

In [None]:
def reshape_fmaps(fmaps):
        n = fmaps.shape[-1]
        filters = []
        for i in range(0,n):
                fmap = fmaps[:,:,i].flatten()
                #norm = np.linalg.norm(fmap)
                #if(norm != 0.0): fmap = fmap/norm
                fmap = np.divide(
                        np.subtract(fmap, np.min(fmap)), (np.max(fmap)-np.min(fmap))) if np.max(fmap)-np.min(fmap) != 0 else fmap
                #p = 10
                #fmap = 1/(1+np.e**(-p*(fmap-0.5)))
                #fmap = (fmap - 1/(1+np.e**(-p*(0-0.5)))) / (1/(1+np.e**(-p*(1-0.5))) - 1/(1+np.e**(-p*(0-0.5)))) 
                filters.append(fmap)
        
        return filters

def computeOverlay(heatmap, original_image, fmap_size):
        #norm = np.linalg.norm(heatmap)
        #if(norm != 0.0): heatmap = heatmap/norm
        heatmap = np.divide(
                np.subtract(heatmap, np.min(heatmap)), (np.max(heatmap)-np.min(heatmap))) if np.max(heatmap)-np.min(heatmap) != 0 else heatmap
        #threshold = (np.max(heatmap) - np.min(heatmap))*0.33
        #for i in range(len(heatmap)):
        #        for j in range(len(heatmap[0])):
        #                if heatmap[i][j] < threshold: heatmap[i][j] = 0.0

        #heatmap = 0.5*np.cos(heatmap*np.pi+np.pi)+0.5
        new_heatmap = skimage.filters.gaussian(
                heatmap, sigma=(1.0, 1.0), truncate=5.5, channel_axis=2)
        new_heatmap = np.divide(
                np.subtract(new_heatmap, np.min(new_heatmap)), (np.max(new_heatmap)-np.min(new_heatmap))) if np.max(new_heatmap)-np.min(new_heatmap) != 0 else new_heatmap
        jet_heatmap = (cm.jet(heatmap))[:,:,0:3]
        overlay = np.add(np.multiply(jet_heatmap,0.5), np.multiply(original_image, 0.5))

        return {'overlay':(255*overlay).astype(np.uint8), 'heatmap':resize(heatmap, (fmap_size, fmap_size))}

def computeOverlayForMax(heatmap, original_image, fmap_size):
        #norm = np.linalg.norm(heatmap)
        #if(norm != 0.0): heatmap = heatmap/norm
        #threshold = (np.max(heatmap) - np.min(heatmap))*0.33
        #for i in range(len(heatmap)):
        #        for j in range(len(heatmap[0])):
        #                if heatmap[i][j] < threshold: heatmap[i][j] = 0.0

        #heatmap = 0.5*np.cos(heatmap*np.pi+np.pi)+0.5
        #heatmap = 1/(1+np.e**(-p*(heatmap-0.5)))
        #heatmap = (heatmap - 1/(1+np.e**(-p*(0-0.5)))) / (1/(1+np.e**(-p*(1-0.5))) - 1/(1+np.e**(-p*(0-0.5)))) 
        new_heatmap = (cm.jet(heatmap))[:,:,0:3]
        #new_heatmap = skimage.filters.gaussian(
        #        new_heatmap, sigma=(1.0, 1.0), truncate=5.5, channel_axis=2)
        overlay = np.add(np.multiply(new_heatmap,0.4), np.multiply(original_image, 0.6))

        return {'overlay':(255*overlay).astype(np.uint8), 'heatmap':resize(heatmap, (fmap_size, fmap_size))}

def compute_cluster_grid(groups, fmaps, cluster_no, original_image, indices, weights, if_print):
        if not if_print: return
        cluster = groups[cluster_no]
        if len(cluster) == 0: return
        row_size = 4
        col_size = 4
        total_dim = row_size*col_size
        fmap_size = len(fmaps)
        fmaps = tf.image.resize(fmaps, size=[256, 256]) # resize for visualization
        original_image = tf.image.resize(original_image, size=[256, 256])
        fig = plt.figure(figsize=(30, 30))
        grid = ImageGrid(fig, 111,  # similar to subplot(111)
                        nrows_ncols=(row_size, col_size),  
                        axes_pad=0.1,  # pad between axes in inch.
                        )
        n = 0
        if len(cluster) < total_dim:
                for i in range(len(cluster)):
                        index = indices[cluster[i]]
                        overlay = computeOverlay(fmaps[:, :, index], original_image, fmap_size)['overlay']
                        grid[n].get_yaxis().set_ticks([])
                        grid[n].get_xaxis().set_ticks([])
                        grid[n].imshow(overlay, aspect='auto')
                        #grid[n].imshow(fmaps[0, :, :, cluster[i]], cmap='viridis',aspect='auto')
                        n += 1
        else:
                for i in range(total_dim):
                        index = indices[cluster[i]]
                        overlay = computeOverlay(fmaps[:, :, index], original_image, fmap_size)['overlay']
                        grid[n].get_yaxis().set_ticks([])
                        grid[n].get_xaxis().set_ticks([])
                        grid[n].imshow(overlay, aspect='auto')
                        #grid[n].imshow(fmaps[0, :, :, cluster[i]], cmap='viridis',aspect='auto')
                        n += 1
        
        plt.show()

def compute_cluster_mean_median_max(groups, fmaps, cluster_no, original_image, w, indices, if_print, cluster_w, avg_cluster_w):
        cluster = groups[cluster_no]
        if len(cluster) == 0: return
        row_size = 1
        col_size = 3
        fmap_size = len(fmaps)
        fmaps = tf.image.resize(fmaps, size=[256, 256]) # resize for visualization
        original_image = tf.image.resize(original_image, size=[256, 256])

        mean_image = compute_cluster_mean(cluster, fmaps, original_image, w, indices, fmap_size)
        median_image = compute_cluster_median(cluster, fmaps, original_image, indices, fmap_size)
        max_image = compute_cluster_max(cluster, fmaps, original_image, indices, w, fmap_size)

        if if_print:
                fig = plt.figure(figsize=(30, 30))
                grid = ImageGrid(fig, 111,  # similar to subplot(111)
                                nrows_ncols=(row_size, col_size),  
                                axes_pad=0.1,  # pad between axes in inch.
                                share_all=True
                                )
                grid[0].get_yaxis().set_ticks([])
                grid[0].get_xaxis().set_ticks([])
                grid[0].imshow(mean_image['overlay'], aspect='auto')
                grid[1].imshow(median_image['overlay'], aspect='auto')
                grid[2].imshow(max_image['overlay'], aspect='auto')
                plt.show()

        return {
                'mean': mean_image['heatmap'], 
                #'median': median_image['heatmap'], 
                #'max': max_image['heatmap'], 
                'total_weight': cluster_w, 
                'average_weight': avg_cluster_w,
                'image_count': len(groups[i])
        }

def compute_cluster_mean(cluster, fmaps, original_image, weights, indices, fmap_size):
        mean = [[0]*256]*256
        sumw = 0.0
        for i in range(len(cluster)): 
                w = weights[cluster[i]]
                index = indices[cluster[i]]
                weightedprod = (w)*fmaps[:,:,index]
                mean += weightedprod
                sumw += w
        mean = np.divide(mean, sumw)
        overlay = computeOverlay(mean, original_image, fmap_size)
        return overlay

def compute_cluster_median(cluster, fmaps, original_image, indices, fmap_size):
        filler = []
        for i in range(len(cluster)): 
                map = fmaps[:,:,indices[cluster[i]]]
                filler.append(map)
        median = np.median(filler, axis = 0)
        overlay = computeOverlay(median, original_image, fmap_size)
        return overlay

def compute_cluster_max(cluster, fmaps, original_image, indices, weights, fmap_size):
        filler = []
        for i in range(len(cluster)): 
                map = fmaps[:,:,indices[cluster[i]]]
                w = weights[cluster[i]]
                map = np.divide(
                        np.subtract(map, np.min(map)), (np.max(map)-np.min(map))) if np.max(map)-np.min(map) != 0 else map
                filler.append(map)
        max = np.max(filler, axis=0)
        overlay = computeOverlayForMax(max, original_image, fmap_size)
        return overlay

def compute_cluster_weight(cluster, weights):
        sum = 0.0
        for element in cluster:
                sum += weights[element]
        return np.round(sum, 3)

def nullifyFeatureMaps(fmaps):
        for m in range (0, len(fmaps)):
                threshold = (np.max(fmaps[m]) - np.min(fmaps[m]))*0.33
                for n in range(0, len(fmaps[m])):
                        fmaps[m][n] = fmaps[m][n] if fmaps[m][n] >= threshold else 0.0
        
        return fmaps

def get_data_radiant(data):
        angle = np.arctan2(data[:, 1].max() - data[:, 1].min(), 
                        data[:, 0].max() - data[:, 0].min())
        return np.clip(angle, np.radians(0.35), np.radians(0.55))

def find_elbow(data, theta):

    # make rotation matrix
    co = np.cos(theta)
    si = np.sin(theta)
    rotation_matrix = np.array(((co, -si), (si, co)))

    # rotate data vector
    rotated_vector = data.dot(rotation_matrix)

    # return index of elbow
    return np.where(rotated_vector == rotated_vector.min())[0][0]

def computeDiscardValue(n_fmaps, depth):
        discard_value = 0

        if n_fmaps >= 512:
                discard_value = 0.175
        elif n_fmaps >= 256:
                discard_value = 0.35
        elif n_fmaps >= 128:
                discard_value = 0.7
        elif n_fmaps >= 64:
                discard_value = 1.4
        else: 
                discard_value = 76.8/n_fmaps
                
        return discard_value


#### Clustering Settings

In [None]:
if_clustering = 'agglomerative_pruned'

#### K-Means

In [None]:
nclusters = 8
if if_clustering == 'kmeans':
        for depth in range(0,len(fmaps)):

                print("Clustering at depth: " + str(depth))

                fmap_atdepth = fmaps[depth]
                elements = np.array(reshape_fmaps(fmap_atdepth))
                print("Number of feature maps: " + str(len(elements)))

                pca = PCA(n_components=64, random_state=43)
                pca.fit(elements)
                x = pca.transform(elements)

                cluster_res = KMeansL1L2(n_clusters=nclusters, n_init = 100, max_iter= 300, tol=1e-6, random_state=None, norm='L2').fit(x)

                centroids = cluster_res.cluster_centers_
                n_labels = cluster_res.labels_

                groups = [[] for i in range(len(centroids))]

                min = 0


                for i in range(0, n_labels.size): 
                        groups[n_labels[i]].append(i)

                print("Number of clusters: " + str(len(groups)))
                print(centroids)
                print()
                for i in range(len(centroids)-1):
                        for j in range(i+1,len(centroids)):
                                print("Distance between "+str(i+1)+" and "+str(j+1)+": "+str(distance.euclidean(centroids[i],centroids[j]))+"\n")
                for i in range(len(groups)): plot_cluster_grid(groups, fmap_atdepth, i)

#### Affinity Propagation

In [None]:
if if_clustering == 'affinity':
        for depth in range(0,len(fmaps)):

                print("Clustering at depth: " + str(depth))

                fmap_atdepth = fmaps[depth]
                elements = np.array(reshape_fmaps(fmap_atdepth))
                print("Number of feature maps: " + str(len(elements)))

                pca = PCA(n_components=64, random_state=43)
                pca.fit(elements)
                x = pca.transform(elements)

                cluster_res = AffinityPropagation(max_iter= 500, convergence_iter = 60, damping=0.7, random_state=5).fit(x)

                centroids = cluster_res.cluster_centers_
                n_labels = cluster_res.labels_

                groups = [[] for i in range(len(centroids))]

                min = 0


                for i in range(0, n_labels.size): 
                        groups[n_labels[i]].append(i)

                print("Number of clusters: " + str(len(groups)))
                print(centroids)
                print()
                for i in range(len(centroids)-1):
                        for j in range(i+1,len(centroids)):
                                print("Distance between "+str(i+1)+" and "+str(j+1)+": "+str(distance.euclidean(centroids[i],centroids[j]))+"\n")
                for i in range(len(groups)): plot_cluster_grid(groups, fmap_atdepth, i)

#### DBSCAN

In [None]:
print_all_images = 1

if if_clustering == 'dbscan':
        for depth in range(0,len(fmaps)):

                print('\n')
                print("=" * 180)

                print("Clustering at depth: " + str(depth))

                fmap_atdepth = fmaps[depth]
                elements = np.array(reshape_fmaps(fmap_atdepth))
                print("Number of feature maps: " + str(len(elements)))
                print("Size of feature map: " + str(len(elements[0])))
                
                pca = PCA(n_components=0.8, random_state=43) #np.minimum(len(elements), len(elements[0]))
                pca.fit(elements)
                x = pca.transform(elements)
                
                neigh = NearestNeighbors(n_neighbors=2).fit(x)
                distances, indexes = neigh.kneighbors(x)
                distances = np.sort(distances, axis = 0)
                distances=distances[:,1]
                plt.plot(distances)
                ind = np.arange(len(distances))
                new_distances = np.vstack((ind, distances)).T
                print("Angle: " + str(np.rad2deg(get_data_radiant(new_distances))))
                opt_eps = distances[find_elbow(new_distances, get_data_radiant(new_distances))]
                # opt_eps = 0.3 + 0.1*np.log(depth+1)/np.log(25)
                opt_eps = min(opt_eps, 0.5)
                print("Optimal epsilon: " + str(opt_eps))
                plt.show()
                
                cluster_res = DBSCAN(eps = opt_eps, min_samples=2, metric='euclidean').fit(x)

                n_labels = cluster_res.labels_
                centroids = len(np.unique(n_labels))
                
                print("Number of clusters: " + str(centroids-1) + " + noisy cluster")

                groups = [[] for i in range(centroids)]
                
                for i in range(0, n_labels.size): 
                        groups[n_labels[i]+1].append(i)
                
                for i in range(1,len(groups)): 
                        if len(groups[i]) == 0: continue
                        print("Cluster " + str(i) + ":")
                        if print_all_images: plot_cluster_grid(groups, fmap_atdepth, i, original_image)
                        else: plot_cluster_mean_median(groups, fmap_atdepth, i, original_image)

                if print_all_images:    
                        print("Noisy cluster :")
                        plot_cluster_grid(groups, fmap_atdepth, 0, original_image)

#### OPTICS

In [None]:
print_all_images = 1
print_noisy = 1

if if_clustering == 'optics':
        for depth in range(0,len(fmaps)):

                print('\n')
                print("=" * 180)

                print("Clustering at depth: " + str(depth))

                fmap_atdepth = np.asarray(fmaps[depth])
                elements = np.array(reshape_fmaps(fmap_atdepth))
                print("Number of feature maps: " + str(len(elements)))
                print("Size of feature map: " + str(len(fmap_atdepth)) + 'x' + str(len(fmap_atdepth[0])))
                
                pca = PCA(n_components=0.8, random_state=43) #np.minimum(len(elements), len(elements[0]))
                pca.fit(elements)
                x = pca.transform(elements)
                
                cluster_res = OPTICS(min_samples = 5, metric="correlation", cluster_method= "xi" , xi= 0, predecessor_correction=False).fit(x)
                #cluster_res = OPTICS(min_samples = 3+int(len(elements)/512), p=1.5).fit(x)

                n_labels = cluster_res.labels_
                centroids = len(np.unique(n_labels))
                
                print("Number of clusters: " + str(centroids-1) + " + noisy cluster")

                groups = [[] for i in range(centroids)]

                for i in range(0, n_labels.size):
                        groups[n_labels[i]+1].append(i)
                
                for i in range(1,len(groups)): 
                        if len(groups[i]) == 0: continue
                        print("Cluster " + str(i) + ": " + str(len(groups[i])) + " images")
                        if print_all_images: plot_cluster_grid(groups, fmap_atdepth, i, original_image)
                        else: plot_cluster_mean_median(groups, fmap_atdepth, i, original_image)

                if (print_all_images*print_noisy):    
                        print("Noisy cluster: " + str(len(groups[0])) + " images")
                        plot_cluster_grid(groups, fmap_atdepth, 0, original_image)

#### Agglomerative Clustering

In [None]:
if_print = 1
print_all_images = 0
print_silhouette = 0
starting_depth = len(fmaps)-1
ending_depth = len(fmaps)-1

from typing_extensions import final
import warnings
warnings.filterwarnings("ignore")

cluster_groups = []

if if_clustering == 'agglomerative':
        for depth in range(starting_depth,ending_depth+1):

                print('\n')
                print("=" * 180)

                print("Clustering at depth: " + str(depth))

                fmap_atdepth = np.asarray(fmaps[depth])
                weights_atdepth = weights[depth]

                normalw = np.divide(weights_atdepth, sum(weights_atdepth))
                normalw = np.multiply(normalw, 100)

                elements = np.array(reshape_fmaps(fmap_atdepth))

                newfmaps = []
                newweights = []
                newindices = []
                for w in range(0, len(normalw)): 
                        discard = computeDiscardValue(len(elements), depth)
                        if normalw[w] > discard:
                                newfmaps.append(elements[w])
                                newweights.append(normalw[w])
                                newindices.append(w)
                                
                elements = newfmaps
                #elements = nullifyFeatureMaps(elements)
                print("Number of feature maps: " + str(len(elements)) + " (" + str(len(normalw) - len(elements)) + " images removed)")
                print("Size of feature map: " + str(len(fmap_atdepth)) + 'x' + str(len(fmap_atdepth[0])))
                
                pca = PCA(n_components = 0.95, random_state=seed) #np.minimum(len(elements), len(elements[0]))
                pca.fit(elements)
                x = pca.transform(elements)

                print("Dimensions for clustering: " + str(len(x[0])))
                
                best_score = -np.inf
                best_silhouette = -1
                for n in range(3, 11): # 3-8 clusters
                        cluster_res = AgglomerativeClustering(n_clusters = n).fit(x)
                        silhouette_score = metrics.silhouette_score(x, cluster_res.labels_)
                        
                        if print_silhouette: silhouette_visualizer(cluster_res, x, colors='yellowbrick')
                        
                        n_labels_temp = cluster_res.labels_
                        centroids_temp = len(np.unique(n_labels_temp))
                        groups_temp = [[] for i in range(centroids_temp)]
                        for i in range(0, n_labels_temp.size):
                                groups_temp[n_labels_temp[i]].append(i)

                        sum_size_distance = 0
                        for i in range(0, len(groups_temp)):
                                #print(len(groups_temp[i]))
                                for j in range(0, len(groups_temp)):
                                        sum_size_distance += (len(groups_temp[i]) - len(groups_temp[j]))**2 #distance squared
                        avg_size_distance = (sum_size_distance/len(groups_temp)**2)

                        score = np.longdouble((silhouette_score)**(int(len(normalw)/256) + 1))/(avg_size_distance)
                        if(score > best_score and silhouette_score > best_silhouette*0.8): 
                                best_score = score
                                if(silhouette_score > best_silhouette):
                                        best_silhouette = silhouette_score
                                n_labels = cluster_res.labels_

                centroids = len(np.unique(n_labels))
                
                print("Number of clusters: " + str(centroids))
                print("Leftover weight: " + str(np.sum(newweights)) + "%")

                groups = [[] for i in range(centroids)]

                for i in range(0, n_labels.size):
                        groups[n_labels[i]].append(i)

                cluster_images = []

                for i in range(0,len(groups)): 
                        if len(groups[i]) == 0: continue
                        print("Cluster " + str(i+1) + ": " + str(len(groups[i])) + " images")
                        cluster_w = compute_cluster_weight(groups[i], newweights)
                        print("Total weight (sum): " + str(cluster_w) + "%")
                        avg_cluster_w = np.round(cluster_w/len(groups[i]), 5)
                        print("Average image weight: " + str(avg_cluster_w) + "%")
                        if print_all_images: compute_cluster_grid(groups, fmap_atdepth, i, original_image, newindices, if_print)
                        else: 
                                dict_images = compute_cluster_mean_median_max(groups, fmap_atdepth, i, original_image, newweights, newindices, if_print, cluster_w, avg_cluster_w)
                                cluster_images.append(dict_images)
                
                cluster_groups.append(cluster_images)


#### Agglomerative Clustering with thresholding

In [None]:
if_print = 1
print_all_images = 0
print_silhouette = 0
print_full_silhouette = 0
print_threshold = 0
starting_depth = 12
ending_depth = len(fmaps)-1

from typing_extensions import final
import warnings
warnings.filterwarnings("ignore")

cluster_groups = []

if if_clustering == 'agglomerative_pruned':
        for depth in range(starting_depth,ending_depth+1):

                print('\n')
                print("=" * 180)

                print("Clustering at depth: " + str(depth))

                fmap_atdepth = np.asarray(fmaps[depth])
                weights_atdepth = weights[depth]

                normalw = np.divide(weights_atdepth, sum(weights_atdepth))
                normalw = np.multiply(normalw, 100)

                elements = np.array(reshape_fmaps(fmap_atdepth))

                newfmaps = []
                newweights = []
                newindices = []
                for w in range(0, len(normalw)): 
                        discard = computeDiscardValue(len(elements), depth)
                        if normalw[w] > discard:
                                newfmaps.append(elements[w])
                                newweights.append(normalw[w])
                                newindices.append(w)
                                
                elements = newfmaps
                
                #elements = nullifyFeatureMaps(elements)
                print("Number of feature maps: " + str(len(elements)) + " (" + str(len(normalw) - len(elements)) + " images removed)")
                print("Size of feature map: " + str(len(fmap_atdepth)) + 'x' + str(len(fmap_atdepth[0])))
                
                pca_n_components = min(50, min(len(elements), len(elements[0])))
                pca = PCA(n_components=pca_n_components, random_state=seed) #np.minimum(len(elements), len(elements[0]))
                x_pca = pca.fit_transform(elements)
                tsne_perplexity = (2/3)*len(elements)
                tsne = TSNE(n_components = 2, random_state=seed)
                x = tsne.fit_transform(np.array(x_pca))

                print("Dimensions for clustering: " + str(len(x[0])))
                
                best_score = -np.inf
                best_silhouette = -1
                sil_x, sil_y = [],[]
                for n in range(3, 9): # 3-8 clusters
                        cluster_res = AgglomerativeClustering(n_clusters = n).fit(x)
                        silhouette_score = metrics.silhouette_score(x, cluster_res.labels_)

                        sil_x.append(n)
                        sil_y.append(silhouette_score)
                        
                        score = silhouette_score

                        #print(score)
                        
                        if(score > best_score and silhouette_score > best_silhouette*0.9): 
                                best_score = score
                                if(silhouette_score > best_silhouette):
                                        best_silhouette = silhouette_score
                                n_labels = cluster_res.labels_

                        if print_full_silhouette: silhouette_visualizer(cluster_res, x, colors='yellowbrick')

                if print_silhouette:
                        plt.plot(sil_x, sil_y)
                        plt.xlabel('Number of Clusters')
                        plt.ylabel('Silhouette Score')
                        plt.title('Silhouette Scores with varying Cluster')
                        plt.show()
                centroids = len(np.unique(n_labels))
                
                print("Number of clusters: " + str(centroids))
                print("Leftover weight: " + str(np.sum(newweights)) + "%")

                groups = [[] for i in range(centroids)]
                neurons_indices = [[] for i in range(centroids)]

                for i in range(0, n_labels.size):
                        groups[n_labels[i]].append(i)

                cluster_images = []
                cluster_w = []

                for i in range(0,len(groups)): 
                        for n in range(0, len(groups[i])):
                                neurons_indices[i].append(newindices[groups[i][n]])
                        print(neurons_indices[i])
                        if len(groups[i]) == 0:
                                cluster_w.append(0)
                                continue
                        cluster_w.append(compute_cluster_weight(groups[i], newweights))
                
                threshold_w = max(max(cluster_w)/3, np.mean(cluster_w)/2)

                print('Threshold:', str(threshold_w)+'%')
                print()

                if print_threshold:
                        linspace = np.linspace(1, len(cluster_w), len(cluster_w))
                        plt.scatter(linspace, cluster_w)
                        plt.plot([linspace[0], linspace[len(linspace)-1]], [threshold_w, threshold_w], label='Threshold', c='r')
                        plt.legend()
                        plt.show()

                for i in range(len(groups)):
                        deleted = '[removed]' if cluster_w[i] < threshold_w else ''
                        print("Cluster " + str(i+1) + ": " + str(len(groups[i])) + " images", deleted)
                        print("Total weight (sum): " + str(cluster_w[i]) + "%")
                        if cluster_w[i] < threshold_w: continue
                        avg_cluster_w = np.round(cluster_w[i]/len(groups[i]), 5)
                        if print_all_images: compute_cluster_grid(groups, fmap_atdepth, i, original_image, newindices, newweights, if_print)
                        else: 
                                dict_images = compute_cluster_mean_median_max(groups, fmap_atdepth, i, original_image, newweights, newindices, if_print, cluster_w, avg_cluster_w)
                                cluster_images.append(dict_images)
                
                cluster_groups.append(cluster_images)

## Human Knowledge Collection Preparation

#### Testing partial image display

In [None]:
if_save_partial = False
if_save_final = False
if_save_overlay = False
partial_depth = 0
partial_cluster = 2
partial_type = 'mean'
q = 0

from numpy.lib.stride_tricks import as_strided

def computeBlur(val):
        if val == 2:
            return float(3.0)
        if val == 4:
            return float(6.0)
        if val == 8:
            return float(12.0)
        if val == 16:
            return float(16.0)
        if val == 32:
            return float(24.0)
        return float(3.0)

def apply_modifier(number):
    return number*1.5+0.5

def tile_array(a, b0, b1):
        r, c = a.shape
        rs, cs = a.strides
        x = as_strided(a, (r, b0, c, b1), (rs, 0, cs, 0)) 
        return x.reshape(r*b0, c*b1) 

original_fmap = cluster_groups[partial_depth][partial_cluster][partial_type]

#plt.imshow(original_fmap, cmap='jet')
#plt.show()
#plt.imshow(original_image)
#plt.show()

np_image = np.asarray(original_image)

original_fmap = tile_array(original_fmap, int(len(np_image)/len(original_fmap)), int(len(np_image[0])/len(original_fmap[0])))

total_len = len(original_image)*len(original_image[0])

ordered_map = np.resize(original_fmap, total_len)
user_input = ''

number = 0
while q < 100:
    number+=1
    q = apply_modifier(q)

print(number)
q = 0.5
number = 0

while not (user_input == 'q') and q <= 100:

    percentile = np.percentile(ordered_map, q=100-q, axis=0)

    print(percentile)

    mask = tf.constant(original_fmap)
    mask = tf.cast(mask >= percentile, "float32")

    mask_trues = len(mask[mask == True])

    blur_factor = computeBlur(int(512/len(original_fmap)))
    mask = skimage.filters.gaussian(mask, sigma=(12, 12), truncate=1.5, channel_axis=-1)

    print(mask_trues)
    print(str(np.round(mask_trues*100/total_len, decimals = 3)) + '%')

    modified_image = original_image
    mask3d = np.stack((mask, mask, mask), axis=2)
    modified_image = original_image*mask3d
    modified_image = tf.image.resize(modified_image, size=[512, 512])

    fig = plt.imshow(modified_image)
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
    plt.show()

    if if_save_partial and np.round(mask_trues*100/total_len, decimals = 3) < 95:
        image_to_save = np.asarray(modified_image)
        fname = "img_test"+str(number)+".png"
        plt.imsave(fname, image_to_save)

    q = apply_modifier(q)
    number+=1
    if not if_save_partial: user_input = input()

plt.imshow(original_fmap, cmap='jet')
plt.show()
plt.imshow(original_image)
plt.show()

def apply_map(image, fmap):
    blur_factor = computeBlur(int(512/len(fmap)))
    fmap = skimage.filters.gaussian(fmap, sigma=(128, 128), truncate=6, channel_axis=-1)
    fmap = np.divide(np.subtract(fmap, np.min(fmap)), (np.max(fmap)-np.min(fmap))) if np.max(fmap)-np.min(fmap) != 0 else fmap
    fmap = (cm.jet(fmap))[:,:,0:3]

    return np.add(np.multiply(fmap,0.4), np.multiply(image, 0.6))

overlay_image = apply_map(original_image, original_fmap)
plt.imshow(overlay_image)
plt.show()
if if_save_final:
        img_to_save = tf.image.resize(original_image, size=[512, 512])
        img_to_save = np.asarray(img_to_save)
        plt.imsave("img_original.png", img_to_save)

if if_save_overlay:
    overlay_image = tf.image.resize(overlay_image, size=[256,256])
    overlay_image = np.asarray(overlay_image)
    plt.imsave("img_overlay.png", overlay_image)