In [55]:
def Collection_init(*,base_dir='__a_gis_image_db__',host='127.0.0.1',port=19530,name='image_db',
                   id={'length':256,'description':'filename','name':'id'},
                   vec={'dim':512,'description':'encoding','name':'vec'},
                   album={'length':256,'description':'name of album','name':'album'}):
    import milvus
    import pymilvus
    import dataclasses
    import typing

    # Optional, if you want store all related data to specific location
    # default it wil using %APPDATA%/milvus-io/milvus-server
    milvus.default_server.set_base_dir(base_dir)
    
    # start you milvus server
    if not milvus.default_server.running:
        milvus.default_server.start()
        milvus.default_server.cleanup()

    # connect
    pymilvus.connections.connect(host=host, port=port)

    # Create.
    field1 = pymilvus.FieldSchema(name=id['name'], dtype=pymilvus.DataType.VARCHAR, max_length=id['length'], description=id["description"], is_primary=True)
    field2 = pymilvus.FieldSchema(name=vec['name'], dtype=pymilvus.DataType.FLOAT_VECTOR, description="float vector", dim=vec['dim'],is_primary=False)
    field3 = pymilvus.FieldSchema(name=album['name'], dtype=pymilvus.DataType.VARCHAR, max_length=album['length'], description=album['description'], is_primary=False)
    schema = pymilvus.CollectionSchema(fields=[field1, field2], description="collection description")
    db = pymilvus.Collection(name=name, data=None, schema=schema, properties={"collection.ttl.seconds": 15})
    
    @dataclasses.dataclass
    class _Collection:
        base_dir: str
        host: str
        name: str
        port: int
        id: dict
        vec: dict
        album :str
        db: typing.Any
    
    return _Collection(base_dir=base_dir,host=host,name=name,port=port,id=id,vec=vec,album=album,db=db)

def Collection_insert(*,collection,ids,encodings):
    import numpy
    #Require length of ids and encodings are the same
    data=[
        [str(x) for x in ids],
        [numpy.asarray(x.cpu()) for x in encodings]
    ]
    collection.db.insert(data)
    return collection

def Encoder_init(*,model='clip-ViT-B-32'):
    import sentence_transformers
    encoder = sentence_transformers.SentenceTransformer(model)
    return encoder

def Encoder_encode(*,encoder,images,batch_size=64,show_progress_bar=True):
    return encoder.encode(images, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=show_progress_bar)

def Collection_search(*,collection,encodings,topk=32,nprobe=16,nlist=1024):
    import numpy
    search_vectors=[numpy.asarray(x.cpu()) for x in encodings]
    
    # Create an index.
    _METRIC_TYPE = 'L2'
    _INDEX_TYPE = 'IVF_FLAT'
    index_param = {
        "index_type": _INDEX_TYPE,
        "params": {"nlist": nlist},
        "metric_type": _METRIC_TYPE}
    collection.db.create_index(collection.vec['name'], index_param)

    # Search the index.
    search_param = {
        "data": search_vectors,
        "anns_field": collection.vec['name'],
        "param": {"metric_type": _METRIC_TYPE, "params": {"nprobe": nprobe}},
        "limit": topk
        }

    collection.db.load()
    return collection.db.search(**search_param)




In [92]:
import pathlib
from PIL import Image
import os

image_dir=os.environ.get('HOME')+"/Downloads"
image_paths=[image_dir]
#image_paths=list(Path("/media/wawiesel/OS/Users/home/Desktop/Last-Backup/Pictures/2018-Peru").rglob("P10803*.JPG"))
# Load the images
image_names=[]
for path0 in image_paths:
    path = pathlib.Path(path0)
    if path.is_file():
        image_names.append(path)
    else:
        image_names.extend(list(path.glob('*.jpg')))
        image_names.extend(list(path.glob('*.png')))        
print("Images:", len(image_names))
images = [Image.open(filepath) for filepath in image_names]

Images: 191


In [34]:
collection = Collection_init()

In [35]:
encoder = Encoder_init()

In [36]:
encoded_images = Encoder_encode(encoder=encoder,images=images)

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

In [45]:
collection = Collection_insert(collection=collection,ids=image_names,encodings=encoded_images)

In [95]:
results = Collection_search(collection=collection,encodings=[encoded_images[55]],topk=15)

In [96]:
import os
#from anytree import Node, RenderTree
import IPython.display as ipd
import ipywidgets as widgets
import json

#debug output
if False:
    for i, result in enumerate(results):
        print("\nSearch result for {}th vector: ".format(i))
        for j, res in enumerate(result):
            print("Top {}: {}".format(j, res))
        
a = []
ids = results[0].ids
for i in range(len(ids)):
    p = ids[i]
    dist = results[0].distances[i]
    f = os.path.basename(p)
    
    # Create the image widget with adjusted size
    img = widgets.Image(value=open(p, 'rb').read(), format='png', width='180px', height='180px')  # Reduced width to avoid padding issues
    
    # Create a label for the distance
    label = widgets.Label(value="{:.1f}".format(dist), layout=widgets.Layout(color='red', width='180px', overflow='hidden'))
    
    # Create a container (VBox) for the image and the label
    box = widgets.VBox([img, label], layout=widgets.Layout(align_items='center', justify_content='center', width='200px', overflow='hidden'))
    
    # Append the container to the list
    a.append(box)

# Create a GridBox with all VBoxes, adjusted to avoid any overflow
wid = widgets.GridBox(a, layout=widgets.Layout(grid_template_columns="repeat(auto-fill, minmax(200px, 1fr))", width='auto', overflow='hidden'))
ipd.display(wid)

GridBox(children=(VBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x02X\x00\x00\x02X\x…