In [3]:
import torch
import glob
from torchvision import transforms
from PIL import Image
#!pip3 install --upgrade pymilvus
from pymilvus import utility
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection
from pymilvus import connections
from getpass import getpass
from matplotlib import pyplot as plt
from matplotlib.pyplot import figure
import numpy as np
import pickle
import random

### **Inception** Image search using PyTorch and Milvus

In this example, we will perform image similarity search using PyTorch and Milvus. 

We are going to use the Animals-10 dataset available in Kaggle. Download and extract the compressed archive containing the images. 

https://www.kaggle.com/datasets/alessiocorrado99/animals10

We shall make use of pre-trained Inception model to generate the vector embeddings from the images and use them for our similarity search

In [4]:
# Get the filepaths of the images
paths = glob.glob('animals/raw-img/*/*.j*', recursive=True)
random.shuffle(paths)

In [8]:
# Load the embedding model from the tensorflow hub with the last layer removed
model = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=True)
model.eval()

# Preprocessing for images
preprocess = transforms.Compose([
    transforms.Resize(299),
    transforms.CenterCrop(299),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

Using cache found in /home/instructor/.cache/torch/hub/pytorch_vision_v0.10.0


In [11]:
# Function to create embeddings from the model
def embed(data):
    with torch.no_grad():
        emb = np.array(model(data.unsqueeze(0)))
        return emb.flatten().tolist()

In [12]:
# Test the embedding generation 
test = 'animals/raw-img/cane/OIP--2z_zAuTMzgYM_KynUl9CQHaE7.jpeg'
im = Image.open(test)
im = preprocess(im)
print(im.shape)
emb = embed(im)
print(len(emb))

torch.Size([3, 299, 299])
1000


In [None]:
# Configs
COLLECTION_NAME = 'SIM_SEARCH_TORCH'  # Collection name
DIMENSION = 1000  # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"

# Inference Arguments
BATCH_SIZE = 128
TOP_K = 3
img_limit = 2000

In [None]:
connections.connect(
  alias="default",
  host='localhost',
  port='19530',
  # user='root',
  # password=getpass('Milvus Password: ')
)

In [None]:
# Milvus
# Drop the old collection to start fresh

if utility.has_collection(COLLECTION_NAME):
    utility.drop_collection(COLLECTION_NAME)


In [None]:
filepath_field = FieldSchema(name='filepath', dtype=DataType.VARCHAR,is_primary=True, max_length=4000)
embedding_field = FieldSchema(name='inception_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)

fields = [filepath_field, embedding_field]

# Create collection schema

schema = CollectionSchema(fields=fields)

# Create collection
collection = Collection(
    name=COLLECTION_NAME,
    schema=schema,
    using='default')
utility.list_collections()

In [None]:

data_batch = [[],[]]

for ind, path in enumerate(paths):
    im = Image.open(path).convert('RGB')
    im = preprocess(im)
    embedding = embed(im)
    data_batch[0].append(path)
    data_batch[1].append(embedding)
    # print([[path], [embedding]])

    if ind%100==0 and ind>0:
        print(f'Completed {ind} of {len(paths)} images')

    if ind==img_limit:
        break

        
print(f'Completed all the images')

In [None]:
# Pickle the data
pickle_file = open('img_embeddings.pkl', 'wb')
pickle.dump(data_batch, pickle_file)
pickle_file.close()

In [None]:
# insert the data in batches
with open('img_embeddings.pkl', 'rb') as handle:
    data_batch = pickle.load(handle)

tmp_batch = [[], []]
insert_bath_size = 1000
for x in range(len(data_batch[0])):
    tmp_batch[0].append(data_batch[0][x])
    tmp_batch[1].append(data_batch[1][x])

    if x>0 and x%insert_bath_size==0:
        collection.insert(tmp_batch)
        tmp_batch = [[], []]
        print(f'Inserted the batch {int(x/insert_bath_size)} to Milvus collection with insert batch size of {insert_bath_size}')
        
if tmp_batch[0]:
    collection.insert(tmp_batch)

# collection.flush()
# collection.compact()
print(f'Flushed the data to Milvus')

In [None]:
print(type(data_batch[1][x][0]))

In [None]:
# Create an index for collection. Drop any old remnant index with the same name.

collection.drop_index(index_name="IVF_FLAT_INDX_IMG_SEARCH")

index_params = {
  "metric_type":"L2",
  "index_type":"IVF_FLAT",
  "params":{"nlist":1024},
  "index_name": "IVF_FLAT_INDX_IMG_SEARCH"
}

collection.create_index(field_name="inception_embedding", index_params=index_params)

In [None]:
# Test using an image 
test = 'sheep.jpg'
im = Image.open(test)
im = preprocess(im)
search_embedding = embed(im)

In [None]:
# Load the collection to search
# collection.flush()
# collection.compact()
# collection.release()
collection.load(replica_number=1)

In [None]:
# Search for similar images in our collection
search_res = collection.search(data=[search_embedding], anns_field='inception_embedding', param={'nprobe': 128, 'metric_type': 'L2',}, limit=5, output_fields=['filepath'])


In [None]:
plt.figure()
f, axarr = plt.subplots(6, 1, figsize=(32, 32))
axarr[0].imshow(Image.open(test).resize((512, 512),  Image.Resampling.LANCZOS))
axarr[0].set_axis_off()
axarr[0].set_title('Query Image')

for indx, result in enumerate(search_res[0]):
    axarr[indx+1].set_title('Distance: ' + str(result.distance))
    axarr[indx+1].imshow(Image.open(result.entity.get('filepath')).resize((512, 512),  Image.Resampling.LANCZOS))
    axarr[indx+1].set_axis_off()

plt.show()

In [None]:
for ind, path in enumerate(paths):
    print(path)

In [7]:
print(model)

Sequential(
  (0): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (2): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (4): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (5): BasicConv2d(
    (conv): Conv2d(80, 192, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(192, eps=0.001, moment