In [1]:
import numpy as np
import pickle
from operator import itemgetter
from collections import Counter
from KNN_Classifier import KNN
from sklearn.metrics import accuracy_score
from PIL import Image
from prettytable import PrettyTable

In [2]:
class KNN:

    def __init__(self,k=5):
        self.k = k
        pass

    def fit(self,train_x, train_y):
        if len(train_x)!=len(train_y):
            raise IndexError('Length is different')
        else:
            self.train_x = train_x
            self.train_y = train_y


    def predict(self,test_x,verbose=True):
        predicted = []
        for index,test_item in enumerate(test_x):
            near_distance = []
            neighbours = []
            for item_index,train_item in enumerate(self.train_x):
                distance = self.calculate_eucliden_distance(test_item,train_item)
                near_distance.append([distance,item_index])

            sorted_distance_list = sorted(near_distance, key=itemgetter(0))

            neighbours = [sorted_distance_list[i] for i in range(0,self.k)]  # 1 is the index and zero is the distance
            elements_temp = [self.train_y[i[1]] for i in neighbours]
            intermediate = dict(Counter(elements_temp))
            result = max(intermediate.items(), key= itemgetter(1))
            predicted.append(result[0])
            if verbose:
                print('working ',index,'.........')
        return predicted

    def calculate_eucliden_distance(self,item_1, item_2):
        if len(item_1)!=len(item_2):
            raise IndexError('Length is different')
        else:
            result = np.mean(np.sqrt(np.square(np.subtract(item_1,item_2))))
            return result





### getting data back from binary file i.e. pickle file

In [3]:
def unpickle(file):
    with open(file,'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

In [4]:
def get_train_data(root_path,file_names):
    data=[]
    for i in file_names:
        temp = unpickle(root_path+i)
        data.append(unpickle(root_path+i))

    return data


In [5]:
def get_test_data(root_path,file_name):
    return unpickle(root_path+file_name)

In [6]:
def convert_to_UTF_8(data):
    return data.decode('UTF-8')


### you need to download the python version of the dataset from 
### https://www.cs.toronto.edu/~kriz/cifar.html

In [7]:
dataset_names = ['data_batch_1','data_batch_2','data_batch_3','data_batch_4','data_batch_5']
output_classes = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
root_path = '../Data/cifar-10-batches-py/'

### getting complete training data

In [8]:
data = get_train_data(root_path,dataset_names)

In [9]:
model = KNN()

In [10]:
for i in range(0, len(data)):
    model.fit(data[i][b'data'],data[i][b'labels'])

In [11]:
test_data = get_test_data(root_path,'test_batch')

## Predicting the value of the test data

### Prediction is done only for first 50 values since the dataset is very large (10,000 images) 

In [12]:
predicted = model.predict(test_data[b'data'][:50],verbose=False)
actual_data = test_data[b'labels'][:50]

### converting filenames from bit to string

In [13]:
test_names = [convert_to_UTF_8(x) for x in test_data[b'filenames']]

In [14]:
print('---------------------predicted data----------------------')
print(predicted)
print('--------------------End of predicted data----------------')

---------------------predicted data----------------------
[0, 8, 3, 0, 4, 0, 6, 6, 3, 8, 0, 9, 0, 6, 8, 4, 3, 0, 8, 2, 0, 0, 0, 0, 2, 0, 4, 0, 0, 0, 4, 2, 8, 2, 1, 4, 0, 8, 5, 0, 0, 2, 0, 4, 0, 9, 2, 9, 2, 0]
--------------------End of predicted data----------------


In [15]:
print('---------------------Actual data----------------------')
print(actual_data)
print('--------------------End of Actual data----------------')

---------------------Actual data----------------------
[3, 8, 8, 0, 6, 6, 1, 6, 3, 1, 0, 9, 5, 7, 9, 8, 5, 7, 8, 6, 7, 0, 4, 9, 5, 2, 4, 0, 9, 6, 6, 5, 4, 5, 9, 2, 4, 1, 9, 5, 4, 6, 5, 6, 0, 9, 3, 9, 7, 6]
--------------------End of Actual data----------------


### accuracy score 

In [16]:
acc_score = accuracy_score(actual_data, predicted)
print('Accuracy Score -------------> ',acc_score*100,'%')

Accuracy Score ------------->  26.0 %


In [17]:
imagearray = []

In [18]:
imagearray = []
bunch_1 = test_data[b'data']

## generating and storing image for fifty sample of test data

In [19]:
for index,i in enumerate(bunch_1):
    if index<50:
        r = i[:1024]
        g = i[1024:2048]
        b = i[2048:]
        rgb = np.dstack((r,g,b))
        imgarray = np.reshape(rgb,(32,32,3))
        img = Image.fromarray(imgarray,'RGB')
        name = 'images/'+str(index)+'.png'
        img.save(name)
        imagearray.append(img)

In [20]:
imagearray

[<PIL.Image.Image image mode=RGB size=32x32 at 0x7FA5787632B0>,
 <PIL.Image.Image image mode=RGB size=32x32 at 0x7FA578313EB8>,
 <PIL.Image.Image image mode=RGB size=32x32 at 0x7FA57832B240>,
 <PIL.Image.Image image mode=RGB size=32x32 at 0x7FA57832B2B0>,
 <PIL.Image.Image image mode=RGB size=32x32 at 0x7FA57832B320>,
 <PIL.Image.Image image mode=RGB size=32x32 at 0x7FA57832B390>,
 <PIL.Image.Image image mode=RGB size=32x32 at 0x7FA57832B400>,
 <PIL.Image.Image image mode=RGB size=32x32 at 0x7FA57832B470>,
 <PIL.Image.Image image mode=RGB size=32x32 at 0x7FA57832B4E0>,
 <PIL.Image.Image image mode=RGB size=32x32 at 0x7FA57832B550>,
 <PIL.Image.Image image mode=RGB size=32x32 at 0x7FA57832B5C0>,
 <PIL.Image.Image image mode=RGB size=32x32 at 0x7FA57832B630>,
 <PIL.Image.Image image mode=RGB size=32x32 at 0x7FA57832B6A0>,
 <PIL.Image.Image image mode=RGB size=32x32 at 0x7FA57832B710>,
 <PIL.Image.Image image mode=RGB size=32x32 at 0x7FA57832B780>,
 <PIL.Image.Image image mode=RGB size=32

### getting the actual name of the output class from number

In [21]:
predicted = [output_classes[x] for x in predicted]

In [22]:
actual_data = [output_classes[x] for x in actual_data]

In [23]:
table = PrettyTable([])

### adding the name of the Image

In [24]:
table.add_column('name',test_names[:50])

In [25]:
table.add_column('predicted',predicted)

In [26]:
table.add_column('actual',actual_data)

In [27]:
print(table)

+----------------------------------+------------+------------+
|               name               | predicted  |   actual   |
+----------------------------------+------------+------------+
|    domestic_cat_s_000907.png     |  airplane  |    cat     |
|      hydrofoil_s_000078.png      |    ship    |    ship    |
|      sea_boat_s_001456.png       |    cat     |    ship    |
|      jetliner_s_001705.png       |  airplane  |  airplane  |
|     green_frog_s_001658.png      |    deer    |    frog    |
|       crapaud_s_002124.png       |  airplane  |    frog    |
|   shooting_brake_s_000973.png    |    frog    | automobile |
|     green_frog_s_000634.png      |    frog    |    frog    |
|      tabby_cat_s_001397.png      |    cat     |    cat     |
|        wagon_s_002806.png        |    ship    | automobile |
|        plane_s_000026.png        |  airplane  |  airplane  |
|      dustcart_s_000045.png       |   truck    |   truck    |
|     toy_spaniel_s_001592.png     |  airplane  |    do