In [2]:
import pandas as pd
import numpy as np
import os

from collections import Counter

import chromadb
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
from chromadb.utils.data_loaders import ImageLoader

In [3]:
data = pd.read_csv("datasets/train.csv")

# Filter to only include item_weight entity

# Extract numeric value and unit from entity_value
def split_value(value):
    value = value.strip()
    parts = value.strip().split(' ')
    if len(parts) == 2:
        try: 
            float(parts[0])
        except:
            return None, None
        return float(parts[0]), parts[1]
    return None, None

data[['numeric_value', 'unit']] = data['entity_value'].apply(split_value).apply(pd.Series)


# Drop rows with missing values
data.dropna(inplace=True)

sampled_data = data.groupby('unit').apply(lambda x: x.sample(min(len(x), 100))).reset_index(drop=True)
# Save processed data to a new CSV file
sampled_data.to_csv('vector.csv', index=True)

print("Data preprocessing complete, dataset ready to add to vectorDB")

Data preprocessing complete, dataset ready to add to vectorDB


  sampled_data = data.groupby('unit').apply(lambda x: x.sample(min(len(x), 100))).reset_index(drop=True)


In [4]:
from src.training.height.utils_height import download_images
download_images(sampled_data['image_link'].tolist(), sampled_data.index.tolist(), "datasets/vector_images")

100%|██████████| 2343/2343 [00:00<00:00, 42264.12it/s]


In [5]:
client = chromadb.Client()
embedding_function = OpenCLIPEmbeddingFunction()
data_loader = ImageLoader()

  from .autonotebook import tqdm as notebook_tqdm


In [30]:
collection = client.get_or_create_collection(
    name="products",
    embedding_function=embedding_function,
    data_loader=data_loader
)

In [32]:
vector_data = []
for idx, row in sampled_data.iterrows():
    image_location = os.path.join("datasets/vector_images", f"{idx}.jpg")
    metadata = {
        "unit": row["unit"],
        "entity_name": row["entity_name"],
        "numeric_value": row["numeric_value"]
    }
    vector_data.append({
        "image_location": image_location,
        "metadata": metadata
    })

In [33]:
vector_data[1]

{'image_location': 'datasets/vector_images/1.jpg',
 'metadata': {'unit': 'candela',
  'entity_name': 'item_weight',
  'numeric_value': 4.0}}

In [34]:
for idx, val in enumerate(vector_data):
    if idx % 100 == 0 and idx != 0:
        print("At", idx)
    collection.add(
        ids = [str(idx)],
        uris=[val['image_location']],
        metadatas=[val['metadata']]
    )

At 100
At 200
At 300
At 400
At 500
At 600
At 700
At 800
At 900
At 1000
At 1100
At 1200
At 1300
At 1400
At 1500
At 1600
At 1700
At 1800
At 1900
At 2000
At 2100
At 2200
At 2300


In [42]:
def query_and_find_most_frequent(img_path, entity_name):
    # query_text = remove_numbers(query_text)
    results = collection.query(query_uris=[img_path],
                               where={"entity_name": entity_name},
                               n_results=12)
    # print(results) 
    string_values = [result['unit'] for result in results['metadatas'][0]]
    
    string_counts = Counter(string_values)
    # print(string_values)
    # print(string_counts.most_common(1)) 
    most_frequent_string = string_counts.most_common(1)[0][0]
    
    return most_frequent_string

In [43]:
query_data = pd.read_csv("src/training/height/data/height_with_ocr.csv")
query_data_id = query_data["id"]
query_data_unit = query_data["unit"]
entity_name = "height"

In [44]:
def process_item(id, label, entity_name, idx):
    path = "src/training/height/dataset/" + str(id) + ".jpg"
    res = query_and_find_most_frequent(path, entity_name)
    return idx, res == label

# Initialize variables
total_count = len(query_data_id)
count = 0

# Process each item in a single thread
for idx, (label, id) in enumerate(zip(query_data_unit, query_data_id)):
    path = "src/training/height/dataset/" + str(id) + ".jpg"
    res = query_and_find_most_frequent(path, entity_name)
    if idx % 500 == 0 and idx != 0:
        print("At", idx, "Accuracy:", str(count / (idx+1)))
    if res == label:
        count += 1

# Print the final accuracy
print(count / total_count)

At 500 Accuracy: 0.5169660678642715
At 1000 Accuracy: 0.4915084915084915


KeyboardInterrupt: 