In [None]:
import faiss
import pandas as pd
import matplotlib.pyplot as plt

import os
from glob import glob
import json
import numpy as np
import pickle

from PIL import Image
import random

In [None]:
data_path = ''
feature_vector_folder = ''
densenet_features_files = glob(feature_vector_folder+'/*.json')

In [None]:
feature_values = []
file_names = []
for feature_file in densenet_features_files:
    with open(feature_file,"r") as file:
        feature_dictionary = json.loads(file.read())
    
    feature_values = feature_values + [np.array(list(feature_dictionary.values()))]
    file_names = file_names + [np.array(list(feature_dictionary.keys()))]




In [None]:
feature_list = []
file_list = []
i = 0
for folder in feature_values:
    for file in folder:
        feature_list = feature_list + [file]
    i = i + 1
    print(i)

In [None]:
mat = faiss.PCAMatrix (1024, 500)
mat.train(np.array(feature_list).astype('float32'))
assert mat.is_trained
feature_values_transformed = mat.apply(np.array(feature_list).astype('float32'))

In [None]:
ncentroids = 20
niter = 20
verbose = True
kmeans = faiss.Kmeans(feature_values_transformed.shape[1], ncentroids, niter=niter, verbose=verbose)
kmeans.train(feature_values_transformed)

In [None]:
D, I = kmeans.index.search(feature_values_transformed, 1)

In [None]:
d = feature_values_transformed.shape[1]
index = faiss.IndexFlatL2 (d)
index.add (feature_values_transformed)
D_c, I_c = index.search (kmeans.centroids, 20)

In [None]:
data_frame = pd.DataFrame(file_list,columns = ['filename'])
data_frame['Cluster'] = I
data_frame['Distance'] = D
data_frame.to_csv('/mnt/largedrive0/katariap/feature_extraction/data/Dataset/Clusters_densenet.csv')

In [None]:
clusters = {}
for i in range(len(file_list)):
    
    if (I[i] not in list(clusters.keys())):
        
        clusters[I[i][0]] = [file_list[i]]
    else:
        clusters[I[i][0]] = clusters[I[i][0]] + [file_list[i]]




In [None]:
for number in range(ncentroids):

        fig = plt.figure(figsize = (30,30))
        files = clusters[number]

        if len(files) > 10:
            files = random.sample(files,10)
        for index,file in enumerate(files):
            plt.subplot(5,5,index+1)
            name = file.split('/')[-1]
            img = Image.open(file)
            img = np.array(img)
            plt.imshow(img)
            plt.axis('off')
            plt.title(name ,fontsize = 7)

In [None]:
clusters = {}
cluster_file = '/mnt/largedrive0/katariap/feature_extraction/data/Dataset/clusters.pickle'
with open(cluster_file,'rb') as data_file:
    clusters = pickle.load(data_file)
    


In [None]:
final_list = []
selected_clusters = [1,2,4,6,7,8,9,10,11,12,13,17,18,19]
for i in selected_clusters:
    final_list = final_list + clusters[i]

In [None]:
selected_patches = pd.DataFrame(final_list, columns = ['Patch'])
selected_patches.to_csv('/mnt/largedrive0/katariap/feature_extraction/data/Dataset/selected_after_clustering.csv')

In [None]:
with open('/mnt/largedrive0/katariap/feature_extraction/data/Dataset/clusters.pickle', 'wb') as file:
    pickle.dump(clusters, file)