In [1]:
import numpy as np
import util
from NNS import NNS
from PIL import Image

## 1. Data Loading

In [2]:
retrieval_repository_data = util.load_data('./data/image_retrieval_repository_data.pkl')
retrieval_test_data = util.load_data('./data/image_retrieval_test_data.pkl')

## 2. Data Exploration

In [3]:
print("Image Retrieval Repository Data Shape:", retrieval_repository_data.shape)
print("Image Retrieval Test Data Shape:", retrieval_test_data.shape)

Image Retrieval Repository Data Shape: (5000, 257)
Image Retrieval Test Data Shape: (1000, 257)


## 3. Data Preprocessing 

In [4]:
# remove index column
repository_data_index = retrieval_repository_data[:, 0]
test_data_index = retrieval_test_data[:, 0]
retrieval_repository_data = retrieval_repository_data[:, 1:]
retrieval_test_data = retrieval_test_data[:, 1:]

In [5]:
repository_data_index.shape, test_data_index.shape

((5000,), (1000,))

In [14]:
retrieval_repository_data.shape,  retrieval_test_data.shape
for i in range(len(retrieval_repository_data)):
    arr =retrieval_repository_data[i]
    arr = np.array(arr)

    arr = arr.reshape(16,16)
    img = Image.fromarray(arr)
    if img.mode == "F":
        img = img.convert('RGB') 

    img.save('img/{}_test.jpg'.format(i))

## 4. Model

In [7]:
nns_model = NNS(k=10)

## 5. Train

In [8]:
nns_model.fit(X_train=retrieval_repository_data)

## 6. Predict

In [9]:
k_nearest = nns_model.predict(retrieval_test_data)

  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [00:15<00:00, 66.58it/s]


In [10]:
k_nearest.shape


(1000, 10)

In [11]:
# merge index and corresponding results 
submit_data = np.hstack((
    test_data_index.reshape(-1, 1),
    k_nearest
    ))
submit_data.shape
submit_data

array([[7204.,  952., 1517., ...,  235.,  246., 2198.],
       [9613., 2503., 1008., ..., 1038., 1659.,  235.],
       [6119., 2355., 3745., ..., 4662., 4819., 3184.],
       ...,
       [4296., 4073., 3710., ..., 1994., 2931., 4116.],
       [3156., 1074., 2947., ..., 4453., 2178., 3539.],
       [5776., 4862., 3105., ..., 4881.,  462.,  159.]])

In [12]:
util.save_data('./retrieval_results.pkl', submit_data)

Saved successfully
