In [3]:
# Core Vector Store, Embedded Instance

import sqlite3
import os
import hnswlib
import json
from PIL import Image

class Store:
    def __init__(self, name, path=None, persistent=False, max_elements=100000):
        self.name = name
        if persistent:
            assert path is not None
            self.con = sqlite3.connect(os.path.join(path, "temp.db"))
        else:
            self.con = sqlite3.connect(":memory:")
            
            
        self.cur = self.con.cursor()
        self.cur.execute("CREATE TABLE {}(id INT NOT NULL PRIMARY KEY, img_url TEXT NOT NULL)".format(self.name + "_imagestore"))
        self.cur.execute("CREATE TABLE {}(id INT NOT NULL PRIMARY KEY, text TEXT NOT NULL)".format(self.name + "_textstore"))
        self.max_elements = max_elements
        self.dim = None
        self.idx = None
        
    def insert(self, ids, embeddings, items, datatype):
        
        assert datatype in ["image", "text"]
        
        # if first insertion, set dim and construct index
        if self.dim == None:
            self.dim = embeddings.squeeze().shape[-1]
        if self.idx == None:
            self.idx = hnswlib.Index(space='l2', dim=self.dim)
            self.idx.init_index(max_elements=self.max_elements, ef_construction=20, M=100)
            self.idx.set_ef(60)
        
        # add embeddings to index
        self.idx.add_items(embeddings, ids)
        
        # add to sqlite store
        data = list(zip(ids, items))
        if datatype == "image":
            self.cur.executemany("INSERT INTO {} VALUES (?, ?)".format(self.name + "_imagestore"), data)
        elif datatype == "text":
            self.cur.executemany("INSERT INTO {} VALUES (?, ?)".format(self.name + "_textstore"), data)
            
        self.con.commit()
    
    def delete(self, ids):
        self.cur.executemany("DELETE FROM {} WHERE ids = ?".format(self.name + "_imagestore"), ids)
        self.con.commit()
        self.cur.executemany("DELETE FROM {} WHERE ids = ?".format(self.name + "_textstore"), ids)
        self.con.commit()
    
    def query(self, queries, n_results_per_query):
        
        # get ids from hnsw index
        ids, distances = self.idx.knn_query(queries, k=n_results_per_query)
        ids = ids.tolist()
        
        #get objects from stores
        out = {"ids": [], "images": [], "text": []}
        for i in range(len(queries)):
            res_images = self.cur.execute('SELECT img_url FROM {} WHERE id IN ({})'.format(self.name + "_imagestore", ', '.join('?' for _ in ids[i])), ids[i]).fetchall()
            res_text = self.cur.execute('SELECT text FROM {} WHERE id IN ({})'.format(self.name + "_textstore", ', '.join('?' for _ in ids[i])), ids[i]).fetchall()
            self.con.commit()
            out['ids'].append(ids[i])
            out['images'].append(res_images)
            out['text'].append(res_text)
        
        return out
    
    def get_pil_images(self, query_result):
        impaths = query_result['images']
        out = []
        for query_idx in range(len(impaths)):
            current_out = []
            current_query = impaths[query_idx]
            for img_idx in range(len(current_query)):
                current_impath = current_query[img_idx][0]
                current_out.append(Image.open(current_impath))
            out.append(current_out)
        return out
    
    def __del__(self):
        self.con.close()

In [4]:
# Embedder Providers

from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, CLIPTextModelWithProjection, AutoProcessor, CLIPVisionModelWithProjection
from PIL import Image

class CLIPEmbedder:
    def __init__(self, model_name="openai/clip-vit-base-patch32", processor_name="openai/clip-vit-base-patch32"):
        self.model_name = model_name
        self.processor_name = processor_name
        
    def __call__(self, text=None, images=None):
        
        if images is not None:
            for i in range(len(images)):
                if type(images[i]) == str:
                    images[i] = Image.open(images[i])
        
        if images is None and text is None:
            return
        
        if images is not None and text is None:
            model = CLIPVisionModelWithProjection.from_pretrained(self.model_name)
            processor = AutoProcessor.from_pretrained(self.processor_name)
            inputs = processor(images=images, return_tensors="pt", padding=True)
            outputs = model(**inputs)
            return outputs.image_embeds
        
        if images is None and text is not None:
            model = CLIPTextModelWithProjection.from_pretrained(self.model_name)
            tokenizer = AutoTokenizer.from_pretrained(self.processor_name)
            inputs = tokenizer(text=text, return_tensors="pt", padding=True)
            outputs = model(**inputs)
            return outputs.text_embeds
        
        if images is not None and text is not None:
            model = CLIPModel.from_pretrained(self.model_name)
            processor = CLIPProcessor.from_pretrained(self.processor_name)
            inputs = processor(text=text, images=images, return_tensors="pt", padding=True)
            outputs = model(**inputs)
            return (outputs.text_embeds, outputs.image_embeds)
    
    def embed(self, text=None, images=None):
        return self.__call__(text, images)

In [None]:
# Hosted Server

from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
import uvicorn

class Server:
    def __init__(self):
        pass
    
    def start(self, store):

        async def insert(request):
            #store.insert()
            return JSONResponse({'hello': 'world' + request.path_params[]})
        
        async def delete(request):
            #store.delete()
            return JSONResponse({'hello': 'world'})
        
        async def query(request):
            #store.query()
            return JSONResponse({'hello': 'world'})

        app = Starlette(debug=True, routes=[
            Route('/insert', insert),
            Route('/delete', delete),
            Route('/query', query),
        ])
        
        uvicorn.run(app)

    
    def terminate(self):
        pass

In [None]:
# Client

class Client:
    def __init__(self):
        pass

In [None]:
# Local Instance

class Instance:
    def __init__(self):
        pass

In [72]:
import torch
import glob
import numpy as np
from PIL import Image
ids = list(range(1000))
c = CLIPEmbedder()

In [73]:
store = Store(name="ungus")

In [74]:
images = glob.glob("images/*")
for i in range(10, 13):
    im = Image.open(images[i])
    emb = c.embed(images=[im]).detach().numpy()
    print(i, emb.shape, images[i])
    store.insert([i], emb, [images[i]], datatype="image")

10 (1, 512) images/3396004.jpg
11 (1, 512) images/243004.jpg
12 (1, 512) images/3121004.jpg


In [75]:
q = [np.random.rand(512), np.random.rand(512)]
a = store.get_pil_images(store.query(q, 2))

SELECT * FROM ungus_imagestore WHERE id IN (?, ?)
SELECT * FROM ungus_imagestore WHERE id IN (?, ?)


In [1]:
# brute force approx nearest neighbor

import gradio as gr
import sqlite3
import time
import string
import random
import json
import math
import torch
from PIL import Image
from transformers import AutoTokenizer, AutoProcessor, CLIPVisionModelWithProjection, CLIPTextModelWithProjection

def randomword(length):
    letters = string.ascii_lowercase
    return ''.join(random.choice(letters) for i in range(length))

tablename = randomword(16)

def build_vector_db(image_dir, progress=gr.Progress()):

    global tablename
    tablename = randomword(16)
    con = sqlite3.connect("temp.db")
    cur = con.cursor()
    cur.execute("CREATE TABLE {}(emb text primary key, image text)".format(tablename))

    model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
    processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

    for d in progress.tqdm(image_dir, desc="Building"):

        if ("jpg" not in d.name) and ("png" not in d.name):
            continue
        image = Image.open(d.name)
        inputs = processor(images=image, return_tensors="pt")
        outputs = model(**inputs)
        image_embed = outputs.image_embeds.squeeze().tolist()
        image_embed_json = json.dumps(image_embed)
        cur.execute("INSERT INTO {} VALUES(\"{}\", \"{}\")".format(tablename, image_embed_json, d.name))
        con.commit()

    con.close()
    return "Vector DB Ready"

def search_query(query):

    con = sqlite3.connect("temp.db")
    cur = con.cursor()
    
    model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
    tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
    inputs = tokenizer([query], return_tensors="pt")
    outputs = model(**inputs)
    text_embed = outputs.text_embeds.squeeze()

    min_image_path = None
    min_dist = math.inf

    for row in cur.execute("SELECT * FROM {}".format(tablename)):
        embedding_json = row[0]
        image_embed = torch.Tensor(json.loads(embedding_json))
        dist = torch.dist(image_embed, text_embed)
        if dist < min_dist:
            min_dist = dist
            min_image_path = row[1]

    image = Image.open(min_image_path)
    con.close()
    return image

with gr.Blocks() as interface:
    image_dir = gr.File(file_count="directory",label="Input Files", height=200)
    upload = gr.Button(value="Build Vector DB")
    outtext = gr.Textbox()
    upload.click(fn=build_vector_db, inputs=image_dir, outputs=outtext)
    query = gr.Textbox(placeholder="Text Query Here")
    search = gr.Button(value="Search")
    image = gr.Image()
    search.click(fn=search_query, inputs=query, outputs=image)

interface.queue().launch()

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




In [6]:
# miscninja

import gradio as gr
import sqlite3
import time
import string
import random
import json
import math
import torch
from PIL import Image
from transformers import AutoTokenizer, AutoProcessor, CLIPVisionModelWithProjection, CLIPTextModelWithProjection

def randomword(length):
    letters = string.ascii_lowercase
    return ''.join(random.choice(letters) for i in range(length))

tablename = randomword(16)
store = None

def build_vector_db(image_dir, progress=gr.Progress()):

    global tablename
    global store
    tablename = randomword(16)
    store = Store(name=tablename)
    c = CLIPEmbedder()

    i = 0
    for d in progress.tqdm(image_dir, desc="Building"):
        if ("jpg" not in d.name) and ("png" not in d.name):
            continue
        image = Image.open(d.name)
        emb = c.embed(images=[image]).detach().numpy()
        store.insert([i], emb, [d.name], datatype="image")
        i += 1
        
    return "Vector DB Ready"

def search_query(query):
    
    c = CLIPEmbedder()
    text_embed = c.embed(text=[query]).detach().squeeze()
    image = random.choice(store.get_pil_images(store.query([text_embed.numpy()], 3))[0])
    return image

with gr.Blocks() as interface:
    image_dir = gr.File(file_count="directory",label="Input Files", height=200)
    upload = gr.Button(value="Build Vector DB")
    outtext = gr.Textbox()
    upload.click(fn=build_vector_db, inputs=image_dir, outputs=outtext)
    query = gr.Textbox(placeholder="Text Query Here")
    search = gr.Button(value="Search")
    image = gr.Image()
    search.click(fn=search_query, inputs=query, outputs=image)

interface.queue().launch()

Exception ignored in: <function Store.__del__ at 0x12fba9af0>
Traceback (most recent call last):
  File "/var/folders/30/vlx9grx116j94n9gvhf7fq7c0000gn/T/ipykernel_70974/337572358.py", line 87, in __del__
sqlite3.ProgrammingError: SQLite objects created in a thread can only be used in that same thread. The object was created in thread id 11528826880 and this is thread id 8116772992.


Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.


