In [106]:
# Core Vector Store, Embedded Instance

import sqlite3
import os
import hnswlib
import json

class Store:
    def __init__(self, name, path=None, persistent=False, max_elements=100000):
        self.name = name
        if persistent:
            assert path is not None
            self.con = sqlite3.connect(os.path.join(path, "temp.db"))
        else:
            self.con = sqlite3.connect(":memory:")
            
            
        self.cur = self.con.cursor()
        self.cur.execute("CREATE TABLE {}(id INT NOT NULL PRIMARY KEY, img_url TEXT NOT NULL)".format(self.name + "_imagestore"))
        self.cur.execute("CREATE TABLE {}(id INT NOT NULL PRIMARY KEY, text TEXT NOT NULL)".format(self.name + "_textstore"))
        self.max_elements = max_elements
        self.dim = None
        self.idx = None
        
    def insert(self, ids, embeddings, items, datatype):
        
        assert datatype in ["image", "text"]
        
        # if first insertion, set dim and construct index
        if self.dim == None:
            self.dim = embeddings.squeeze().shape[-1]
        if self.idx == None:
            self.idx = hnswlib.Index(space='l2', dim=self.dim)
            self.idx.init_index(max_elements=self.max_elements, ef_construction=100, M=16)
            self.idx.set_ef(10)
        
        # add embeddings to index
        self.idx.add_items(embeddings, ids)
        
        # add to sqlite store
        data = list(zip(ids, items))
        if datatype == "image":
            self.cur.executemany("INSERT INTO {} VALUES (?, ?)".format(self.name + "_imagestore"), data)
        elif datatype == "text":
            self.cur.executemany("INSERT INTO {} VALUES (?, ?)".format(self.name + "_textstore"), data)
            
        self.con.commit()
    
    def delete(self, ids):
        self.cur.executemany("DELETE FROM {} WHERE ids = ?".format(self.name + "_imagestore"), ids)
        self.con.commit()
        self.cur.executemany("DELETE FROM {} WHERE ids = ?".format(self.name + "_textstore"), ids)
        self.con.commit()
    
    def query(self, queries, n_results_per_query):
        
        # get ids from hnsw index
        ids, distances = self.idx.knn_query(queries, k=n_results_per_query)
        print(ids)
        ids = ids.tolist()
        
        #get objects from stores
        out = {"ids": [], "images": [], "text": []}
        for i in range(len(queries)):
            res_images = self.cur.execute('SELECT * FROM {} WHERE id IN ({})'.format(self.name + "_imagestore", ', '.join('?' for _ in ids[i])), ids[i])
            res_text = self.cur.execute('SELECT * FROM {} WHERE id IN ({})'.format(self.name + "_textstore", ', '.join('?' for _ in ids[i])), ids[i])
            self.con.commit()
            res_images = res_images.fetchall()
            res_text = res_text.fetchall()
            out['ids'].append(ids[i])
            out['images'].append(res_images)
            out['text'].append(res_text)
        
        return out
    
    def __del__(self):
        self.con.close()

In [3]:
# Embedder Providers

from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, CLIPTextModelWithProjection, AutoProcessor, CLIPVisionModelWithProjection
from PIL import Image

class CLIPEmbedder:
    def __init__(self, model_name="openai/clip-vit-base-patch32", processor_name="openai/clip-vit-base-patch32"):
        self.model_name = model_name
        self.processor_name = processor_name
        
    def __call__(self, text=None, images=None):
        
        if images is not None:
            for i in range(len(images)):
                if type(images[i]) == str:
                    images[i] = Image.open(images[i])
        
        if images is None and text is None:
            return
        
        if images is not None and text is None:
            model = CLIPVisionModelWithProjection.from_pretrained(self.model_name)
            processor = AutoProcessor.from_pretrained(self.processor_name)
            inputs = processor(images=images, return_tensors="pt", padding=True)
            outputs = model(**inputs)
            return outputs.image_embeds
        
        if images is None and text is not None:
            model = CLIPTextModelWithProjection.from_pretrained(self.model_name)
            tokenizer = AutoTokenizer.from_pretrained(self.processor_name)
            inputs = tokenizer(text=text, return_tensors="pt", padding=True)
            outputs = model(**inputs)
            return outputs.text_embeds
        
        if images is not None and text is not None:
            model = CLIPModel.from_pretrained(self.model_name)
            processor = CLIPProcessor.from_pretrained(self.processor_name)
            inputs = processor(text=text, images=images, return_tensors="pt", padding=True)
            outputs = model(**inputs)
            return (outputs.text_embeds, outputs.image_embeds)
    
    def embed(self, text=None, images=None):
        return self.__call__(text, images)

In [None]:
# Hosted Server

from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
import uvicorn

class Server:
    def __init__(self):
        pass
    
    def start(self, store):

        async def insert(request):
            #store.insert()
            return JSONResponse({'hello': 'world' + request.path_params[]})
        
        async def delete(request):
            #store.delete()
            return JSONResponse({'hello': 'world'})
        
        async def query(request):
            #store.query()
            return JSONResponse({'hello': 'world'})

        app = Starlette(debug=True, routes=[
            Route('/insert', insert),
            Route('/delete', delete),
            Route('/query', query),
        ])
        
        uvicorn.run(app)

    
    def terminate(self):
        pass

In [None]:
# Client

class Client:
    def __init__(self):
        pass

In [None]:
# Local Instance

class Instance:
    def __init__(self):
        pass

In [8]:
import torch
import glob
import numpy as np
from PIL import Image
ids = list(range(1000))
c = CLIPEmbedder()

In [158]:
store = Store(name="ungus")

In [159]:
images = glob.glob("images/*")
for i in range(10, 20):
    im = Image.open(images[i])
    emb = c.embed(images=[im]).detach().numpy()
    print(i, emb.shape, images[i])
    store.insert([i], emb, [images[i]], datatype="image")

10 (1, 512) images/3396004.jpg
11 (1, 512) images/243004.jpg
12 (1, 512) images/3121004.jpg
13 (1, 512) images/2416004.jpg
14 (1, 512) images/2347004.jpg
15 (1, 512) images/197004.jpg
16 (1, 512) images/2575004.jpg
17 (1, 512) images/3042004.jpg
18 (1, 512) images/2224004.jpg
19 (1, 512) images/320004.jpg


In [160]:
for i in range(100):
    q = [np.random.rand(512)]
    store.query(q, 4)

[[12 14 17 16]]
[[12 17 16 14]]
[[12 17 14 15]]
[[12 17 15 14]]
[[14 12 17 15]]
[[12 14 17 15]]
[[12 14 17 15]]
[[12 17 14 15]]
[[12 17 14 16]]
[[12 14 16 17]]
[[12 17 14 16]]
[[14 12 17 15]]
[[12 17 14 15]]
[[12 17 14 15]]
[[12 17 14 16]]
[[12 17 14 16]]
[[12 17 14 15]]
[[12 14 17 16]]
[[12 14 15 17]]
[[12 17 14 10]]
[[12 17 14 15]]
[[12 17 14 16]]
[[12 17 14 16]]
[[12 17 14 18]]
[[12 15 17 14]]
[[12 17 16 14]]
[[12 14 17 15]]
[[12 17 14 10]]
[[12 17 14 15]]
[[12 17 14 15]]
[[12 17 14 15]]
[[12 17 15 14]]
[[12 17 14 16]]
[[12 17 14 15]]
[[12 15 14 17]]
[[12 14 17 15]]
[[12 17 15 16]]
[[12 17 14 15]]
[[12 14 17 16]]
[[12 14 17 16]]
[[12 17 16 14]]
[[12 17 14 16]]
[[12 17 15 16]]
[[12 17 14 15]]
[[12 17 14 15]]
[[12 17 14 10]]
[[12 14 17 15]]
[[12 17 14 16]]
[[12 17 14 15]]
[[12 17 14 15]]
[[12 17 15 14]]
[[14 12 17 18]]
[[12 17 14 15]]
[[12 17 14 16]]
[[12 14 17 15]]
[[12 17 14 16]]
[[12 17 14 15]]
[[12 14 17 16]]
[[12 14 15 17]]
[[12 17 14 15]]
[[12 14 15 17]]
[[12 17 15 14]]
[[12 17 