In [22]:
# from google.colab import drive
# drive.mount('/content/drive')  # use this if you are using google colab

In [23]:
from sklearn.cluster import KMeans
import pickle, os
from scipy.spatial import distance
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
import numpy as np

In [24]:
def load_dataset(name_file):

    desired_directory = '.' # Replace with your desired directory path

    file_path = os.path.join(desired_directory, name_file) 

    with open(file_path, 'rb') as f:
        data_dict = pickle.load(f)

    return data_dict


In [25]:
loaders_dict = load_dataset(f"dataset-flowers102-features.pkl")
x_train = loaders_dict["x_train"]
x_test = loaders_dict["x_test"]
y_train = loaders_dict["y_train"]
y_test = loaders_dict["y_test"]
print(f"x_train:{x_train.shape}, y_train:{y_train.shape}")
print(f"x_test:{x_test.shape}, y_test:{y_test.shape}")


x_train:(4094, 512), y_train:(4094,)
x_test:(4095, 512), y_test:(4095,)


## clustering images

In [26]:
k_number = 50
kmeans = KMeans(n_clusters=k_number)
clusters = kmeans.fit_predict(x_train)
centroids = kmeans.cluster_centers_

  super()._check_params_vs_input(X, default_n_init=10)


In [27]:
# centroid_distances = [distance.euclidean(x_test[0], centroid) for centroid in centroids]
# centroid_distances

## find nearest clusters neighbors

### find nearest clusters and their indices

In [28]:
def find_nearest_clusters_neighbors(x_train, x_test, k):
    knn_classifier = KNeighborsClassifier(n_neighbors=k)
    knn_classifier.fit(x_train, range(len(x_train)))
    nearest_indices = knn_classifier.kneighbors(x_test, n_neighbors=k, return_distance=False)
    # nearest_neighbors = x_train[nearest_indices]
    nearest_neighbors = [x_train[indices] for indices in nearest_indices]

    return nearest_indices, nearest_neighbors

In [29]:
n_nearest_neighbors = 10
x_train_clusters = centroids
nearest_clusters_indices, nearest_clusters_neighbors = find_nearest_clusters_neighbors(x_train_clusters, 
                                                                                       x_test, 
                                                                                       k=n_nearest_neighbors)

In [30]:
print(f"clusters Indices:")
print(nearest_clusters_indices)
print(f"nearest neighbor clusters centroids:")
print(nearest_clusters_neighbors)

clusters Indices:
[[20 12 18 ... 38 41  4]
 [46 42 32 ...  7 26 48]
 [21  4 31 ...  6 10 27]
 ...
 [ 0 12 20 ... 31 41 10]
 [28  0 12 ... 31 47 45]
 [36 32 40 ...  4 17  7]]
nearest neighbor clusters centroids:
[array([[0.22892004, 0.82133806, 1.1913738 , ..., 1.7158766 , 0.38292915,
        1.6183989 ],
       [0.7596075 , 0.6826276 , 0.66603184, ..., 1.7282653 , 0.31310606,
        2.4220743 ],
       [0.19584107, 0.2816192 , 1.1604404 , ..., 0.6059605 , 2.7469993 ,
        1.6321579 ],
       ...,
       [2.517344  , 1.2339938 , 2.2233503 , ..., 0.05033731, 0.9710306 ,
        0.3003127 ],
       [1.4004402 , 1.7415406 , 0.8638333 , ..., 1.0276184 , 0.9682396 ,
        0.44224083],
       [0.8228469 , 1.3843824 , 0.8859598 , ..., 0.756986  , 2.1044142 ,
        0.9414741 ]], dtype=float32), array([[0.5359869 , 0.21914327, 0.6179406 , ..., 1.6925219 , 1.4959364 ,
        1.5529635 ],
       [0.06056881, 0.51961637, 1.3644825 , ..., 1.0720811 , 1.5133944 ,
        1.9949055 ],
       

### get the data of nearest clusters

In [31]:
def gather_clusters_data(k_number, clusters):
    clusters_data = {i: [] for i in range(k_number)}
    clusters_data_labels = {i: [] for i in range(k_number)}
    for i, label in enumerate(clusters):
        clusters_data[label].append(x_train[i])
        clusters_data_labels[label].append(y_train[i])
    return clusters_data, clusters_data_labels

In [32]:
clusters_data, clusters_data_labels = gather_clusters_data(k_number, clusters)

In [33]:
def get_nearest_clusters_data(nearest_clusters_indices, clusters_data, clusters_data_labels):
    # Access data points in the cluster with centroid index 'nearest_clusters_indices'
    all_data_in_nearest_clusters = []
    all_labels_in_nearest_clusters = []
    for indices in nearest_clusters_indices:
        data_in_nearest_clusters = []
        labels_in_nearest_clusters = []
        for index in indices:
            data_in_nearest_clusters.extend(clusters_data[index])
            labels_in_nearest_clusters.extend(clusters_data_labels[index])
        all_data_in_nearest_clusters.append(data_in_nearest_clusters)
        all_labels_in_nearest_clusters.append(labels_in_nearest_clusters)
    return all_data_in_nearest_clusters, all_labels_in_nearest_clusters

In [34]:
all_data_in_nearest_clusters, all_labels_in_nearest_clusters = get_nearest_clusters_data(nearest_clusters_indices, 
                                                                                 clusters_data, 
                                                                                 clusters_data_labels)

## classify using data in nearest clusters

In [46]:
def classify_knn(x_train, y_train, x_test, y_test, k):
    y_preds = []
    for data_train, label_train, one_data_test in zip(x_train, y_train, x_test):
        knn_classifier = KNeighborsClassifier(n_neighbors=k)
        knn_classifier.fit(data_train, label_train)
        y_preds.append(knn_classifier.predict([one_data_test]))
    
    accuracy = accuracy_score(y_test, y_preds)
    return y_preds, accuracy

In [49]:
k = 100
x_train_data = all_data_in_nearest_clusters
y_train_data = all_labels_in_nearest_clusters
predictions, accuracy = classify_knn(x_train_data, y_train_data, x_test, y_test, k=k)

In [56]:
print(f"number of predicted labels: {len(predictions)}")
print(f"Accuracy: {accuracy * 100:.2f}%")
print(f"Predicted labels:")
for i in range(len(predictions)):
    print(f"{predictions[i]} --> {y_test[i]}")

number of predicted labels: 4095
Accuracy: 70.99%
Predicted labels:
[77] --> 77
[91] --> 91
[82] --> 85
[29] --> 29
[22] --> 22
[80] --> 51
[60] --> 60
[15] --> 15
[50] --> 94
[73] --> 73
[88] --> 52
[76] --> 34
[74] --> 19
[10] --> 10
[98] --> 98
[79] --> 79
[100] --> 100
[40] --> 53
[69] --> 69
[93] --> 93
[77] --> 38
[36] --> 36
[64] --> 64
[45] --> 45
[59] --> 30
[93] --> 93
[36] --> 36
[93] --> 93
[73] --> 73
[93] --> 31
[45] --> 45
[1] --> 1
[76] --> 76
[71] --> 71
[76] --> 76
[37] --> 37
[36] --> 36
[64] --> 64
[80] --> 80
[45] --> 45
[87] --> 87
[14] --> 96
[72] --> 72
[73] --> 83
[76] --> 76
[51] --> 83
[42] --> 90
[82] --> 82
[8] --> 81
[50] --> 50
[84] --> 44
[64] --> 64
[11] --> 11
[71] --> 71
[53] --> 53
[73] --> 73
[85] --> 85
[50] --> 50
[36] --> 36
[81] --> 83
[50] --> 50
[39] --> 39
[37] --> 37
[23] --> 23
[54] --> 54
[45] --> 45
[76] --> 76
[89] --> 90
[73] --> 73
[72] --> 72
[51] --> 51
[36] --> 36
[88] --> 88
[85] --> 85
[45] --> 45
[77] --> 38
[46] --> 46
[72] --> 

In [None]:
redictions}")
print(f"Accuracy: {accuracy * 100:.2f}%")