In [1]:
from PIL import Image
from tqdm import tqdm
from transformers import CLIPProcessor, CLIPModel
import chromadb
import torch
# import easyocr
from sklearn.metrics.pairwise import cosine_distances
import warnings
warnings.filterwarnings("ignore")

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
model.to(device)
# reader = easyocr.Reader(['en'])

CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e-05,

In [3]:
client = chromadb.PersistentClient(
    "db/",
)
image_collection = client.get_or_create_collection(
    "images", metadata={"hnsw:space": "cosine"}
)
# text_collection = client.get_or_create_collection(
#     "texts", metadata={"hnsw:space": "cosine"}
# )

In [4]:
def get_image_embeddings(image):
    image = Image.open(image)
    with torch.no_grad():
        processed_image = processor(images=image, return_tensors="pt").to(device)
        image_features = model.get_image_features(**processed_image)
    return image_features.cpu().squeeze(0).numpy().tolist()

# def apply_OCR(image_path, OCR_threshold = 0.5):
#     text = ""
#     result = reader.readtext(image_path)
#     for detection in result:
#         if detection[2] > OCR_threshold:
#             text += detection[1] + " "
#     if text == "":
#         return None
#     text = text.replace(",", "")
#     return text


def get_text_embeddings(text):
    with torch.no_grad():
        processed_text = processor(text=text, return_tensors="pt").to(device)
        text_features = model.get_text_features(**processed_text)
    return text_features.cpu().squeeze(0).numpy().tolist()

In [5]:
text = get_text_embeddings("a fish in the deep sea")
image = get_image_embeddings("images/Attention_is_all_you_need.jpg")
cosine_distances([text], [image])

array([[0.80367397]])

In [14]:
# read all OS images
all_images = []
with open("file_names.txt", "r") as f:
    for line in f:
        all_images.append(line.strip())

linux_paths = [path.replace("\\", "/").replace("C:", "/mnt/c") for path in all_images]

In [15]:
for i, image in tqdm(
    enumerate(linux_paths), total=len(linux_paths), desc="Indexing images"
):
    if len(image_collection.get(ids=all_images[i])["ids"]) > 0:
        continue
    image_embeddings = get_image_embeddings(linux_paths[i])
    image_collection.upsert(ids=[all_images[i]], embeddings=image_embeddings)
    # ocr_text = apply_OCR(linux_paths[i])
    # text_embeddings = get_text_embeddings(ocr_text)
    # if text_embeddings is not None:
    #     text_collection.upsert(ids=[all_images[i]], embeddings=image_embeddings)

# remove data in collections but not in the file system
for i, id in tqdm(
    enumerate(image_collection.get()["ids"]),
    total=len(image_collection.get()["ids"]),
    desc="Cleaning up database",
):
    if id not in all_images:
        image_collection.delete(ids=[id])
    try:
        image_collection.delete(ids=[id])
    except:
        pass

Indexing images:  30%|███       | 31/102 [00:05<00:13,  5.25it/s]


KeyboardInterrupt: 

In [None]:
# print model nparams
print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")
# model size in memory
print(f"memory consumption ")

Model parameters: 149620737


In [None]:
image = get_image_embeddings("images/fig_accuracy_latency.png")
text = get_text_embeddings("group of men eating")

cosine_distances([text], [image])

array([[0.86681233]])

In [None]:
# import cosine similarity
from sklearn.metrics.pairwise import cosine_distances, cosine_similarity
empty = get_text_embeddings("")
test = get_text_embeddings("How do you turn this on")

print(cosine_similarity([empty], [test]))
print(cosine_distances([empty], [test]))

[[0.89013494]]
[[0.10986506]]


In [None]:
def search(text):
    text_embedding = get_text_embeddings(text)
    results = collection.query(text_embedding, n_results=5)
    distances = results["distances"][0]
    paths = results["ids"][0]
    return paths, distances

In [None]:
search("robotic llama")

NameError: name 'collection' is not defined

In [None]:
from flask import Flask, request, jsonify

app = Flask(__name__)


@app.route("/search", methods=["POST"])
def search_route():
    text = request.json.get("text", "")
    paths, distances = search(text)
    print(distances)
    return jsonify(paths)


if __name__ == "__main__":
    app.run(host="0.0.0.0", port=5000)

 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5000
 * Running on http://172.25.97.13:5000
[33mPress CTRL+C to quit[0m
127.0.0.1 - - [31/May/2024 12:58:18] "POST /search HTTP/1.1" 200 -


[0.7487562894821167, 0.7599424123764038, 0.7623105645179749, 0.7630934715270996, 0.7659106850624084]


127.0.0.1 - - [31/May/2024 12:58:26] "POST /search HTTP/1.1" 200 -


[0.7218299508094788, 0.7218299508094788, 0.7234379053115845, 0.7277137041091919, 0.7313181161880493]


127.0.0.1 - - [31/May/2024 12:58:34] "POST /search HTTP/1.1" 200 -


[0.7028124928474426, 0.728251576423645, 0.7285056114196777, 0.7296743392944336, 0.7304768562316895]
