In [7]:
import re

import pandas as pd

%matplotlib inline

# Load the CIFAR-10 dataset

In [20]:
train = pd.read_csv("image_train.csv")
test = pd.read_csv("image_test.csv")

## Convert ~turi-create~ dataset into more applicable dataset

In [21]:
def extract_features(dataframe: pd.DataFrame, column: str):
    raw_data = [
        re.search(r"\[(.*)\]", col).group(1)
        for col in dataframe.loc[:, column]
    ]
    normalized_data = [
        list(map(float, row.split()))
        for row in raw_data
    ]
    return pd.DataFrame(normalized_data).add_prefix(column)

In [22]:
# fix the train dataset
image_array_df = extract_features(train, 'image_array')

deep_features_df = extract_features(train, 'deep_features')

train = pd.concat(
    (
    train.drop(['deep_features', 'image_array'], axis=1),
    image_array_df,
    deep_features_df
    )
    , axis='columns'
)

In [23]:
# fix the test dataset
image_array_df = extract_features(test, 'image_array')

deep_features_df = extract_features(test, 'deep_features')

test = pd.concat(
    (
    test.drop(['deep_features', 'image_array'], axis=1),
    image_array_df,
    deep_features_df
    )
    , axis='columns'
)

In [24]:
# convert label to category
train['label'] = train.label.astype('category')
test['label'] = test.label.astype('category')

# Train a nearest-neighbors model for retrival images using deep features

In [31]:
from sklearn.neighbors import NearestNeighbors

deep_features = [
    col
    for col in train.columns
    if col.startswith('deep_features')
]

X = train[deep_features]
y = train.label

knn_model = NearestNeighbors()

In [32]:
knn_model.fit(X, y)

NearestNeighbors()

# Using image retirval model with deep features to find similar images

##### Cat

In [70]:
cat = train.iloc[[18]]

In [71]:
cat

Unnamed: 0,id,image,label,image_array0,image_array1,image_array2,image_array3,image_array4,image_array5,image_array6,...,deep_features4086,deep_features4087,deep_features4088,deep_features4089,deep_features4090,deep_features4091,deep_features4092,deep_features4093,deep_features4094,deep_features4095
18,384,Height: 32 Width: 32,cat,46.0,45.0,50.0,47.0,45.0,51.0,45.0,...,0.366557,0.0,0.0,1.69667,0.0,0.0,0.0,0.0,0.0,0.0


In [72]:
distances, indices = knn_model.kneighbors(cat[deep_features], n_neighbors=5)

In [73]:
train.iloc[indices[0]]

Unnamed: 0,id,image,label,image_array0,image_array1,image_array2,image_array3,image_array4,image_array5,image_array6,...,deep_features4086,deep_features4087,deep_features4088,deep_features4089,deep_features4090,deep_features4091,deep_features4092,deep_features4093,deep_features4094,deep_features4095
18,384,Height: 32 Width: 32,cat,46.0,45.0,50.0,47.0,45.0,51.0,45.0,...,0.366557,0.0,0.0,1.69667,0.0,0.0,0.0,0.0,0.0,0.0
288,6910,Height: 32 Width: 32,cat,154.0,133.0,92.0,134.0,112.0,75.0,108.0,...,1.91561,0.403345,0.0,1.52395,0.0,0.0,0.0,0.714656,0.0,0.0
1565,39777,Height: 32 Width: 32,cat,145.0,166.0,165.0,164.0,185.0,184.0,185.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1468,36870,Height: 32 Width: 32,cat,16.0,20.0,19.0,14.0,19.0,17.0,11.0,...,0.348435,0.0,0.0,0.527505,0.0,0.0,0.0,1.39074,0.0,0.0
1633,41734,Height: 32 Width: 32,cat,122.0,27.0,34.0,120.0,24.0,31.0,119.0,...,0.085609,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


##### Car

In [77]:
car = train.iloc[[8]]

In [80]:
distances, indices = knn_model.kneighbors(car[deep_features])

In [81]:
train.iloc[indices[0]]

Unnamed: 0,id,image,label,image_array0,image_array1,image_array2,image_array3,image_array4,image_array5,image_array6,...,deep_features4086,deep_features4087,deep_features4088,deep_features4089,deep_features4090,deep_features4091,deep_features4092,deep_features4093,deep_features4094,deep_features4095
8,136,Height: 32 Width: 32,automobile,35.0,59.0,53.0,36.0,56.0,56.0,42.0,...,0.0,0.0,5.31948,0.0,0.666042,0.0,0.0,1.34132,0.0,0.0
372,8977,Height: 32 Width: 32,automobile,186.0,195.0,199.0,182.0,192.0,198.0,184.0,...,0.0,0.0,4.84683,0.0,2.18244,0.0,0.0,1.66367,0.0,0.0
1757,44395,Height: 32 Width: 32,automobile,89.0,95.0,50.0,83.0,84.0,43.0,69.0,...,0.0,0.0,3.39069,0.0,1.04151,0.0,0.0,0.403854,0.0,0.0
1343,33261,Height: 32 Width: 32,automobile,110.0,118.0,104.0,98.0,104.0,80.0,92.0,...,0.0,0.0,3.5074,0.0,0.712008,0.0,0.0,1.18539,0.0,0.0
1009,24146,Height: 32 Width: 32,automobile,229.0,231.0,227.0,232.0,235.0,231.0,231.0,...,0.0,0.0,2.83524,0.0,0.0,0.0,0.0,2.27581,0.0,0.0
