#Building an image retrieval system with deep features


#Fire up GraphLab Create

In [1]:
import graphlab

A newer version of GraphLab Create (v1.7.1) is available! Your current version is v1.6.1.

You can use pip to upgrade the graphlab-create package. For more information see https://dato.com/products/create/upgrade.


#Load the CIFAR-10 dataset

We will use a popular benchmark dataset in computer vision called CIFAR-10.  

(We've reduced the data to just 4 categories = {'cat','bird','automobile','dog'}.)

This dataset is already split into a training set and test set. In this simple retrieval example, there is no notion of "testing", so we will only use the training data.

In [2]:
image_train = graphlab.SFrame('image_train_data/')

[INFO] This non-commercial license of GraphLab Create is assigned to joe@joebooth.com and will expire on October 09, 2016. For commercial licensing options, visit https://dato.com/buy/.

[INFO] Start server at: ipc:///tmp/graphlab_server-86508 - Server binary: /Users/joebooth/.graphlab/anaconda/lib/python2.7/site-packages/graphlab/unity_server - Server log: /tmp/graphlab_server_1448157147.log
[INFO] GraphLab Server Version: 1.6.1


#Computing deep features for our images

The two lines below allow us to compute deep features.  This computation takes a little while, so we have already computed them and saved the results as a column in the data you loaded. 

(Note that if you would like to compute such deep features and have a GPU on your machine, you should use the GPU enabled GraphLab Create, which will be significantly faster for this task.)

In [3]:
#deep_learning_model = graphlab.load_model('http://s3.amazonaws.com/GraphLab-Datasets/deeplearning/imagenet_model_iter45')
#image_train['deep_features'] = deep_learning_model.extract_features(image_train)

In [4]:
image_train.head()

id,image,label,deep_features,image_array
24,Height: 32 Width: 32,bird,"[0.242871761322, 1.09545373917, 0.0, ...","[73.0, 77.0, 58.0, 71.0, 68.0, 50.0, 77.0, 69.0, ..."
33,Height: 32 Width: 32,cat,"[0.525087952614, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[7.0, 5.0, 8.0, 7.0, 5.0, 8.0, 5.0, 4.0, 6.0, 7.0, ..."
36,Height: 32 Width: 32,cat,"[0.566015958786, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[169.0, 122.0, 65.0, 131.0, 108.0, 75.0, ..."
70,Height: 32 Width: 32,dog,"[1.12979578972, 0.0, 0.0, 0.778194487095, 0.0, ...","[154.0, 179.0, 152.0, 159.0, 183.0, 157.0, ..."
90,Height: 32 Width: 32,bird,"[1.71786928177, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[216.0, 195.0, 180.0, 201.0, 178.0, 160.0, ..."
97,Height: 32 Width: 32,automobile,"[1.57818555832, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[33.0, 44.0, 27.0, 29.0, 44.0, 31.0, 32.0, 45.0, ..."
107,Height: 32 Width: 32,dog,"[0.0, 0.0, 0.220677852631, 0.0, ...","[97.0, 51.0, 31.0, 104.0, 58.0, 38.0, 107.0, 61.0, ..."
121,Height: 32 Width: 32,bird,"[0.0, 0.23753464222, 0.0, 0.0, 0.0, 0.0, ...","[93.0, 96.0, 88.0, 102.0, 106.0, 97.0, 117.0, ..."
136,Height: 32 Width: 32,automobile,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 7.5737862587, 0.0, ...","[35.0, 59.0, 53.0, 36.0, 56.0, 56.0, 42.0, 62.0, ..."
138,Height: 32 Width: 32,bird,"[0.658935725689, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[205.0, 193.0, 195.0, 200.0, 187.0, 193.0, ..."


#Train a nearest-neighbors model for retrieving images using deep features

We will now build a simple image retrieval system that finds the nearest neighbors for any image.

In [5]:
knn_model = graphlab.nearest_neighbors.create(image_train,features=['deep_features'],
                                             label='id')

PROGRESS: Starting brute force nearest neighbors model training.


#Use image retrieval model with deep features to find similar images

Let's find similar images to this cat picture.

In [6]:
graphlab.canvas.set_target('ipynb')
cat = image_train[18:19]
cat['image'].show()

In [7]:
knn_model.query(cat)

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 25.447ms     |
PROGRESS: | Done         |         | 100         | 225.995ms    |
PROGRESS: +--------------+---------+-------------+--------------+


query_label,reference_label,distance,rank
0,384,0.0,1
0,6910,36.9403137951,2
0,39777,38.4634888975,3
0,36870,39.7559623119,4
0,41734,39.7866014148,5


We are going to create a simple function to view the nearest neighbors to save typing:

In [8]:
def get_images_from_ids(query_result):
    return image_train.filter_by(query_result['reference_label'],'id')

In [9]:
cat_neighbors = get_images_from_ids(knn_model.query(cat))

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 23.639ms     |
PROGRESS: | Done         |         | 100         | 229.383ms    |
PROGRESS: +--------------+---------+-------------+--------------+


In [10]:
cat_neighbors['image'].show()

Very cool results showing similar cats.

##Finding similar images to a car

In [13]:
car = image_train[8:9]
car['image'].show()

In [14]:
get_images_from_ids(knn_model.query(car))['image'].show()

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 23.539ms     |
PROGRESS: | Done         |         | 100         | 215.183ms    |
PROGRESS: +--------------+---------+-------------+--------------+


#Just for fun, let's create a lambda to find and show nearest neighbor images

In [15]:
show_neighbors = lambda i: get_images_from_ids(knn_model.query(image_train[i:i+1]))['image'].show()

In [16]:
show_neighbors(8)

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 27.259ms     |
PROGRESS: | Done         |         | 100         | 217.899ms    |
PROGRESS: +--------------+---------+-------------+--------------+


In [17]:
show_neighbors(26)

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 22.427ms     |
PROGRESS: | Done         |         | 100         | 191.956ms    |
PROGRESS: +--------------+---------+-------------+--------------+


In [18]:
show_neighbors(126)

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 25.717ms     |
PROGRESS: | Done         |         | 100         | 214.212ms    |
PROGRESS: +--------------+---------+-------------+--------------+


In [19]:
show_neighbors(326)

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 29.34ms      |
PROGRESS: | Done         |         | 100         | 217.799ms    |
PROGRESS: +--------------+---------+-------------+--------------+


In [20]:
len (image_train)

2005

In [21]:
image_test = graphlab.SFrame('image_test_data/')

In [22]:
len (image_train), len (image_test)

(2005, 4000)

In [23]:
image_test

id,image,label,deep_features,image_array
0,Height: 32 Width: 32,cat,"[1.13469004631, 0.0, 0.0, 0.0, 0.0366497635841, ...","[158.0, 112.0, 49.0, 159.0, 111.0, 47.0, ..."
6,Height: 32 Width: 32,automobile,"[0.23135882616, 0.0, 0.0, 0.0, 0.0, 0.226023137 ...","[160.0, 37.0, 13.0, 185.0, 49.0, 11.0, 20 ..."
8,Height: 32 Width: 32,cat,"[0.0, 0.0, 0.0344192385674, 0.0, ...","[23.0, 19.0, 23.0, 19.0, 21.0, 28.0, 21.0, 16.0, ..."
9,Height: 32 Width: 32,automobile,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.6065092087, 0.0, ...","[217.0, 215.0, 209.0, 210.0, 208.0, 202.0, ..."
12,Height: 32 Width: 32,dog,"[0.322317481041, 0.0, 1.24933350086, 0.0, 0.0, ...","[91.0, 64.0, 30.0, 82.0, 58.0, 30.0, 87.0, 73.0, ..."
16,Height: 32 Width: 32,dog,"[0.0, 0.0, 0.347357034683, 0.0, ...","[95.0, 76.0, 78.0, 92.0, 77.0, 78.0, 89.0, 77.0, ..."
24,Height: 32 Width: 32,dog,"[1.31557655334, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[136.0, 134.0, 118.0, 142.0, 141.0, 126.0, ..."
25,Height: 32 Width: 32,bird,"[0.0, 0.317288756371, 0.0, 1.36552882195, ...","[100.0, 103.0, 74.0, 68.0, 91.0, 65.0, 116.0, ..."
31,Height: 32 Width: 32,dog,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 9.26018810272, 0.0, ...","[127.0, 130.0, 81.0, 130.0, 133.0, 88.0, ..."
33,Height: 32 Width: 32,dog,"[0.130786716938, 0.727667212486, 0.0, ...","[118.0, 113.0, 81.0, 122.0, 117.0, 83.0, ..."


In [24]:
image_big = image_test.append(image_train)

In [25]:
len(image_big)

6005