In [1]:
import os

lst_images = os.listdir('../images')
lst_images

['image2.png',
 'image0.png',
 'image3.png',
 'image6.jpeg',
 'image5.jpg',
 'image4.jpg',
 'image1.jpg']

In [2]:
# use chromadb for the same
import chromadb

client = chromadb.PersistentClient('./db')
client.list_collections()

[Collection(name=clip_embeddings)]

In [29]:
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
from chromadb.utils.data_loaders import ImageLoader
data_loader = ImageLoader()


In [None]:
embedding_function = OpenCLIPEmbeddingFunction('ViT-B-16-SigLIP', 'webli')

# device='cuda' for GPU


In [32]:
client.delete_collection('clip_embeddings')

In [33]:
collection = client.get_or_create_collection(name='clip_embeddings', embedding_function=embedding_function, data_loader=data_loader)

In [34]:
from PIL import Image
import numpy as np

collection.add(
    ids=lst_images,
    uris=[os.path.join('../images', img) for img in lst_images],
    metadatas=[{'image': img} for img in lst_images],
)

In [37]:
results = collection.query(
    query_texts=["a photo of staircase"],
    n_results=5,
)
results

{'ids': [['image0.png',
   'image1.jpg',
   'image5.jpg',
   'image4.jpg',
   'image6.jpeg']],
 'embeddings': None,
 'documents': [[None, None, None, None, None]],
 'uris': None,
 'data': None,
 'metadatas': [[{'image': 'image0.png'},
   {'image': 'image1.jpg'},
   {'image': 'image5.jpg'},
   {'image': 'image4.jpg'},
   {'image': 'image6.jpeg'}]],
 'distances': [[1.801679647785668,
   1.966377784601841,
   1.974831337251781,
   1.982125198556581,
   1.9870082452673323]],
 'included': [<IncludeEnum.distances: 'distances'>,
  <IncludeEnum.documents: 'documents'>,
  <IncludeEnum.metadatas: 'metadatas'>]}

In [41]:
results['ids'][0]
results['metadatas'][0]

[{'image': 'image0.png'},
 {'image': 'image1.jpg'},
 {'image': 'image5.jpg'},
 {'image': 'image4.jpg'},
 {'image': 'image6.jpeg'}]

In [None]:
# update old image with a new image

collection.update(
    ids=lst_images[0],
    uri=os.path.join('../images', 'staircase.jpg'),
    metadata={'image': 'staircase.jpg'}
)

In [44]:
import chromadb
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
from chromadb.utils.data_loaders import ImageLoader

# format
# {
#     '12_34_180': {
#         'image_path': '12_34_180.jpg',
#         'label': '12_34_180'
#     },
#     '12_34_270': {
#         'image_path': '12_34_270.jpg',
#         'label': '12_34_270'
#     }
# }

db_client = chromadb.PersistentClient('./db')

def store_images(data, collection_name='embeddings'):    
    collection = db_client.get_or_create_collection(name=collection_name, embedding_function=embedding_function, data_loader=ImageLoader())
    collection.add(
        ids=data.keys(),        # list of x_y_yaw strings
        uris=[node['image_path'] for node in data],
        metadatas=[data[node] for node in data]
    )
    
def update_images(new_data, collection_name='embeddings'):
    collection = db_client.get_or_create_collection(name=collection_name, embedding_function=embedding_function, data_loader=ImageLoader())
    for node in new_data:
        collection.update(
            ids=node,
            uri=new_data[node]['image_path'],
            metadata=new_data[node]
        )
        
def query_images(query_text, n_results=20, collection_name='embeddings'):
    collection = db_client.get_or_create_collection(name=collection_name, embedding_function=embedding_function, data_loader=ImageLoader())
    results = collection.query(
        query_texts=[query_text],
        n_results=n_results,
    )
    
    return results

In [48]:
import asyncio
import aiohttp
import pandas as pd
import json
from tqdm.asyncio import tqdm_asyncio
import os
import gc
import requests


OPENAI_API_KEY = "EMPTY"
OPENAI_API_BASE = "http://0.0.0.0:8000/v1"
MODEL_NAME = "allenai/Molmo-7B-D-0924"
CONCURRENT_REQUESTS = 10  # Number of concurrent API requests
BATCH_SIZE = 10  # Number of images to process in a batch

In [56]:
from icecream import ic

In [71]:
def ask_text_query(text_prompt, timeout=5):
    # For 1 text query, return the output of the VLM
    # Send the prompt to the API and get the results
    try:
        payload = {
            "model": MODEL_NAME,
            "messages": [{
                "role": "user",
                "content": [
                    {"type": "text", "text": text_prompt},
                ],
            }],
        }

        headers = {
            "Authorization": f"Bearer {OPENAI_API_KEY}",
            "Content-Type": "application/json",
        }

        # Set a specific timeout for the request
        response = requests.post(f"{OPENAI_API_BASE}/chat/completions", json=payload, headers=headers, timeout=timeout)
        data = response.json()
        
        ic(response.status_code)
        ic(data)
        
        vlm_output = data['choices'][0]['message']['content']
    
    except Exception as e:
        vlm_output = f"Exception: {str(e)}"

    return vlm_output

In [73]:
ask_text_query("hi")

ic| response.status_code: 200
ic| data: {'choices': [{'finish_reason': 'stop',
                        'index': 0,
                        'logprobs': None,
                        'message': {'content': " Hello! I'm here to assist you with any "
                                               'questions or tasks you might have. What '
                                               'can I help you with today?',
                                    'role': 'assistant',
                                    'tool_calls': []},
                        'stop_reason': None}],
           'created': 1733258623,
           'id': 'chatcmpl-234ab4e781d14ff3ba2cd96744915530',
           'model': 'allenai/Molmo-7B-D-0924',
           'object': 'chat.completion',
           'prompt_logprobs': None,
           'usage': {'completion_tokens': 26,
                     'prompt_tokens': 6,
                     'prompt_tokens_details': None,
                     'total_tokens': 32}}


" Hello! I'm here to assist you with any questions or tasks you might have. What can I help you with today?"

In [None]:
from .vlm import run_multiple_image_query
import subprocess
import json

def multiple_image_query(prompt, image_dir):
    results = asyncio.run(run_multiple_image_query(image_dir, prompt))
    return results

multiple_image_query("count no of cardboard cartons or boxes if any", '../images/')

[' There are none.',
 ' Counting the <points x1="71.5" y1="91.3" x2="74.6" y2="77.6" x3="80.0" y3="78.1" x4="81.0" y4="89.5" x5="90.5" y5="87.3" alt="no of cardboard cartons or boxes if any">no of cardboard cartons or boxes if any</points> shows a total of 5.',
 ' There are none.',
 ' Counting the <points x1="46.8" y1="32.9" x2="47.2" y2="36.4" x3="47.3" y3="30.6" x4="47.6" y4="28.4" x5="49.2" y5="29.6" x6="49.2" y6="40.5" x7="49.8" y7="35.9" x8="51.2" y8="41.4" x9="52.0" y9="36.2" x10="52.4" y10="41.1" x11="53.3" y11="36.2" x12="53.8" y12="33.6" x13="56.6" y13="39.6" x14="58.4" y14="43.8" x15="58.8" y15="39.6" x16="59.6" y16="44.8" x17="59.9" y17="40.5" x18="60.7" y18="45.3" x19="61.3" y19="40.2" x20="62.2" y20="45.0" alt="no of cardboard cartons or boxes if any">no of cardboard cartons or boxes if any</points> shows a total of 20.',
 'Timeout Error: Request took longer',
 'Timeout Error: Request took longer',
 'Timeout Error: Request took longer']

In [None]:
import re

def extract_points(text):
    # Parse the <points> tag and extract relevant data
    match = re.search(r'<points([^>]*)>(.*?)</points>', text)
    if not match:
        return None

    attributes = match.group(1)
    main_message = match.group(2)

    # Extract the coordinates
    x_coords = []
    y_coords = []
    alt_message = None

    # Parse the attributes of the points tag
    for attr in attributes.split():
        if attr.startswith('x'):
            x_coords.append(float(attr.split('=')[1].strip('"')))
        elif attr.startswith('y'):
            y_coords.append(float(attr.split('=')[1].strip('"')))
        elif attr.startswith('alt'):
            alt_message = re.search(r'alt="([^"]*)"', attributes).group(1)

    return {
        "x_coordinates": x_coords,
        "y_coordinates": y_coords,
        "alt_message": alt_message,
        "main_message": main_message
    }

In [None]:
# Example usage
input_text = 'so this is <points x1="14.4" y1="63.5" x2="44.7" y2="31.4" x3="44.9" y3="70.2" x4="58.4" y4="68.7" x5="63.2" y5="29.9" x6="86.3" y6="59.8" x7="94.4" y7="84.1" x8="98.0" y8="93.6" alt="all main objects in the image">all main objects in the image</points> and that is the answer'

result = extract_points(input_text)
print(result)