In [1]:
import glob

# Get the filepaths of the images
paths = glob.glob('./animals_dataset/animals/animals/**/*.jpg', recursive=True)

len(paths)


5400

In [2]:
import torch
from torchvision import transforms

# Inference Arguments
BATCH_SIZE = 128

# Load the embedding model with the last layer removed
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
model = torch.nn.Sequential(*(list(model.children())[:-1]))
model.eval()

# Preprocessing for images
preprocess = transforms.Compose([
    transforms.Resize(256), # Change the size of the image (the longer side of the image will measure 256 pixels)
    transforms.CenterCrop(224), # Cut the image to have a 224*224 picture
    transforms.ToTensor(), # Change the image into a tensor
])
print(paths)

Using cache found in /home/ewwweee/.cache/torch/hub/pytorch_vision_v0.10.0


['./animals_dataset/animals/animals/swan/1ae8f734f4.jpg', './animals_dataset/animals/animals/swan/86d096a7e8.jpg', './animals_dataset/animals/animals/swan/72e127323d.jpg', './animals_dataset/animals/animals/swan/8e2ce5f005.jpg', './animals_dataset/animals/animals/swan/71a1f7f436.jpg', './animals_dataset/animals/animals/swan/234d3389c2.jpg', './animals_dataset/animals/animals/swan/1f999b23eb.jpg', './animals_dataset/animals/animals/swan/26c3011b3b.jpg', './animals_dataset/animals/animals/swan/4b39208a81.jpg', './animals_dataset/animals/animals/swan/22c6670d0d.jpg', './animals_dataset/animals/animals/swan/8ef0ae3f42.jpg', './animals_dataset/animals/animals/swan/93fab6c725.jpg', './animals_dataset/animals/animals/swan/70a9ce4a92.jpg', './animals_dataset/animals/animals/swan/6e2efc3bf3.jpg', './animals_dataset/animals/animals/swan/3a83dfbaa3.jpg', './animals_dataset/animals/animals/swan/27d864cabb.jpg', './animals_dataset/animals/animals/swan/76ea10b2e8.jpg', './animals_dataset/animals/ani

In [3]:
from pymilvus import MilvusClient
import random

# Initialisation du client Milvus Lite
client = MilvusClient("milvus_image_db.db")

In [5]:
from pymilvus import MilvusClient, DataType

schema = MilvusClient.create_schema(
    auto_id=False,
    enable_dynamic_field=True,
)

# Add fields to schema
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=2048)
schema.add_field(field_name="path", datatype=DataType.VARCHAR, max_length=512)
schema.add_field(field_name="animal", datatype=DataType.VARCHAR, max_length=100)

{'auto_id': False, 'description': '', 'fields': [{'name': 'id', 'description': '', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': False}, {'name': 'vector', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 2048}}, {'name': 'path', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 512}}, {'name': 'animal', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 100}}], 'enable_dynamic_field': True}

In [6]:
index_params = client.prepare_index_params()

#  Add indexes
index_params.add_index(
    field_name="vector", 
    index_type="AUTOINDEX",
    metric_type="COSINE"
)

# Create collection
client.create_collection(
    collection_name="animals_vector_db",
    schema=schema,
    index_params=index_params
)

In [7]:
from PIL import Image
from tqdm import tqdm

embed_data = []

# Embed function that embeds the batch and inserts it
def embed(data):
    with torch.no_grad():
        output = model(torch.stack(data[0])).squeeze()
        return output
        

data_batch = [[],[]]

i=0
# Read the images into batches for embedding and insertion
for path in tqdm(paths):
    
    im = Image.open(path).convert('RGB') # Load the image and convert it into RGB
    data_batch[0].append(preprocess(im)) # We add the image in the batch after change it
    data_batch[1].append(path) # We also add the path 
    if len(data_batch[0]) % BATCH_SIZE == 0: # We embed the whole batch
        i+=1
        embed_data.append(embed(data_batch))
        data_batch = [[],[]]

# Embed and insert the remainder
if len(data_batch[0]) != 0:
     embed_data.append(embed(data_batch))


100%|██████████| 5400/5400 [09:05<00:00,  9.90it/s]


In [8]:
list_vector = []
for vector in embed_data:
    for i in range(vector.shape[0]):
        list_vector.append(vector[i])

len(list_vector)

5400

In [9]:
import re 
data =[{"id":i,"vector":list_vector[i],"path":paths[i],"animal":re.split('/',paths[i])[-2]} for i in range(len(list_vector))]

In [None]:
res = client.insert(
    collection_name="animals_vector_db",
    data=data
)
res

{'insert_count': 5400, 'ids': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215,

: 