In [1]:
import torch
import pandas as pd
import sys
import os
import mlflow
import mlflow.pytorch
from datetime import datetime
from collections import defaultdict

In [2]:
sys.path.append('../src')
from embedding import initialize_clip_model, generate_embedding
from retrieval import hybrid_retrieval, PostgresVectorRetrieval, TextSearchRetrieval, FaissVectorRetrieval

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"
print(f"Device is {device}")

Device is mps


In [4]:
df = pd.read_csv('benchmark_query.csv')
df.head()

Unnamed: 0,Pid,Name,Category,Basic_query,Attribute_query,Natural_query
0,230765.156074.8EA6270009853D26.9B6C065E0E0C70F...,"TEMU 3pcs Cat Toy Set, Hemp Rope And Feather M...",Animals and Pet Supplies,cat toy,3pcs cat toy set,3pcs Cat Toy Set with fast delivery
1,230765.156074.8EA6270009853D26.5997EEBEEF688B9...,TEMU High-density Aquarium Biochemical Filter ...,Animals and Pet Supplies,filter sponge,high density filter sponge,TEMU High density Aquarium Biochemical for bus...
2,159496.2.E9BF3C1C3B82E113.98C4E18825C4C542.737...,Life Extension Florassist Daily Bowel Regulari...,Animals and Pet Supplies,bowel capsules,vegetarian bowel capsules,Daily Bowel capsule for pets
3,159496.2.E9BF3C1C3B82E113.57057A47B3BB8BF4.088...,Carlson Co-Q10 100 mg - 60 Softgels,Animals and Pet Supplies,CoQ10 capsules,CoQ10 gel capsules,Best CoQ10 gel capsules
4,159496.2.E9BF3C1C3B82E113.1AC1A04CB4B08FD2.853...,Youtheory Collagen 6000 mg - 290 Tablets,Animals and Pet Supplies,collagen 6000,youtheory collagen 6000 tablets,Youtheory Collagen 6000 mg 290 that works well...


In [5]:
# Models to test
clip_model = "openai/clip-vit-base-patch32"
initialize_clip_model(clip_model)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


(CLIPProcessor:
 - image_processor: CLIPImageProcessor {
   "crop_size": {
     "height": 224,
     "width": 224
   },
   "do_center_crop": true,
   "do_convert_rgb": true,
   "do_normalize": true,
   "do_rescale": true,
   "do_resize": true,
   "image_mean": [
     0.48145466,
     0.4578275,
     0.40821073
   ],
   "image_processor_type": "CLIPImageProcessor",
   "image_std": [
     0.26862954,
     0.26130258,
     0.27577711
   ],
   "resample": 3,
   "rescale_factor": 0.00392156862745098,
   "size": {
     "shortest_edge": 224
   }
 }
 
 - tokenizer: CLIPTokenizerFast(name_or_path='openai/clip-vit-base-patch32', vocab_size=49408, model_max_length=77, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|startoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
 	49406: AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_wo

In [6]:
# Initialize MLflow
mlflow.set_tracking_uri(uri="http://35.209.59.178:8591")

In [14]:
# Runs a single mlflow experiment
def run_experiment(config, dataset_name=None, model_name=None):
    with mlflow.start_run(run_name=f"{config['name']}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"):
        # Log parameters
        component_weights = {}
        for c, w in zip(config['components'], config['weights']):
            if isinstance(c, PostgresVectorRetrieval):
                component_name = f"postgres_vector_{c.column_name}"
            elif isinstance(c, FaissVectorRetrieval):
                component_name = f"faiss_vector_{c.column_name}"  # Use column_name which contains 'text' or 'image'
            else:
                component_name = type(c).__name__.lower()
            component_weights[component_name] = w
        
        mlflow.log_params({
            "dataset": dataset_name,
            "model": model_name,
            "clip_model": clip_model,
            "top_k": TOP_K,
            **component_weights
        })
        
        # Initialize counters
        query_types = ['Basic_query', 'Attribute_query', 'Natural_query']
        results = {
            'overall': {'hits': 0, 'total': 0},
            'Basic_query': {'hits': 0, 'total': 0},
            'Attribute_query': {'hits': 0, 'total': 0},
            'Natural_query': {'hits': 0, 'total': 0}
        }
        
        category_results = defaultdict(lambda: {
            'Basic_query': {'hits': 0, 'total': 0},
            'Attribute_query': {'hits': 0, 'total': 0},
            'Natural_query': {'hits': 0, 'total': 0}
        })

        # Process each query type
        for query_type in query_types:
            print(f"\nProcessing {query_type}...")
            print("-" * 80)
            
            for _, row in df.iterrows():
                query = row[query_type]
                target_pid = row['Pid']
                category = row['Category']
                
                # Generate embedding and run hybrid search
                query_embedding = generate_embedding(query_text=query)
                pids, _ = hybrid_retrieval(
                    query=query,
                    query_embedding=query_embedding,
                    components=config['components'],
                    weights=config['weights'],
                    top_k=TOP_K
                )
                
                # Check if the ground truth Pid is in the results
                hit = target_pid in pids
                
                # Update counters
                results[query_type]['total'] += 1
                results['overall']['total'] += 1
                category_results[category][query_type]['total'] += 1
                
                if hit:
                    results[query_type]['hits'] += 1
                    results['overall']['hits'] += 1
                    category_results[category][query_type]['hits'] += 1
        
        # Calculate and log metrics
        for category in results:
            if results[category]['total'] > 0:
                recall = results[category]['hits'] / results[category]['total']
                print(f"\n{category} Recall@{TOP_K}: {recall:.2f}")
                mlflow.log_metric(f"{category}_recall_at_k", round(recall, 2))
                mlflow.log_metric(f"{category}_total_queries", results[category]['total'])
                mlflow.log_metric(f"{category}_total_hits", results[category]['hits'])

        # Log category-specific results
        print("\nCategory-specific results:")
        print("=" * 80)
        for category in category_results:
            print(f"\nCategory: {category}")
            for query_type in query_types:
                if category_results[category][query_type]['total'] > 0:
                    recall = category_results[category][query_type]['hits'] / category_results[category][query_type]['total']
                    print(f"{query_type} Recall@{TOP_K}: {recall:.2f}")
                    mlflow.log_metric(f"{category}_{query_type}_recall_at_k", round(recall, 2))
                    mlflow.log_metric(f"{category}_{query_type}_total_queries", 
                                    category_results[category][query_type]['total'])
                    mlflow.log_metric(f"{category}_{query_type}_total_hits", 
                                    category_results[category][query_type]['hits'])

In [8]:
# Database configuration
DB_CONFIG = {
    'dbname': os.getenv('PGDATABASE', 'finly'),
    'user': os.getenv('PGUSER', 'postgres'),
    'password': os.getenv('PGPASSWORD', 'postgres'),
    'host': os.getenv('PGHOST', 'localhost'),
    'port': os.getenv('PGPORT', '5432')
}

# Top K for the retrieval and recall@K calculation
TOP_K = 5

# Name and model of dataset for logging
dataset_name = "benchmark_query"
model_name = clip_model

# List of configurations to run
configurations = [
    {
        'name': 'text_search_only',
        'components': [
            TextSearchRetrieval('ts_rank_cd', DB_CONFIG)
        ],
        'weights': [1]
    },
    {
        'name': 'text_embedding_only',
        'components': [
            PostgresVectorRetrieval('text_embedding', DB_CONFIG),
        ],
        'weights': [1]
    },
    {
        'name': 'image_embedding_only',
        'components': [
            PostgresVectorRetrieval('image_embedding', DB_CONFIG),
        ],
        'weights': [1]
    },
    {
        'name': 'text_and_image_embedding',
        'components': [
            PostgresVectorRetrieval('text_embedding', DB_CONFIG),
            PostgresVectorRetrieval('image_embedding', DB_CONFIG),
        ],
        'weights': [0.5, 0.5]
    },
    {
        'name': 'hybrid_pg_vector_ts',
        'components': [
            PostgresVectorRetrieval('text_embedding', DB_CONFIG),
            PostgresVectorRetrieval('image_embedding', DB_CONFIG),
            TextSearchRetrieval('ts_rank_cd', DB_CONFIG)
        ],
        'weights': [0.4, 0.3, 0.3]
    }
]

In [9]:
experiment_name = "postgres_vector_experiments"
mlflow.set_experiment(experiment_name)

for config in configurations:
    print(f"\nRunning experiment: {config['name']}")
    print("=" * 80)
    run_experiment(config, dataset_name=dataset_name, model_name=model_name)

2025/05/08 16:13:40 INFO mlflow.tracking.fluent: Experiment with name 'postgres_vector_experiments' does not exist. Creating a new experiment.



Running experiment: text_search_only

Processing Basic_query...
--------------------------------------------------------------------------------

Processing Attribute_query...
--------------------------------------------------------------------------------

Processing Natural_query...
--------------------------------------------------------------------------------

overall Recall@5: 0.40

Basic_query Recall@5: 0.52

Attribute_query Recall@5: 0.56

Natural_query Recall@5: 0.11

Category-specific results:

Category: Animals and Pet Supplies
Basic_query Recall@5: 0.50
Attribute_query Recall@5: 0.50
Natural_query Recall@5: 0.10

Category: Apparel and Accessories
Basic_query Recall@5: 0.30
Attribute_query Recall@5: 0.90
Natural_query Recall@5: 0.20

Category: Arts and Entertainment
Basic_query Recall@5: 0.60
Attribute_query Recall@5: 0.80
Natural_query Recall@5: 0.30

Category: Business and Industrial
Basic_query Recall@5: 0.80
Attribute_query Recall@5: 0.80
Natural_query Recall@5: 0.20

C

In [12]:
# List of configurations to run
configurations = [
    {
        'name': 'text_embedding_only',
        'components': [
            FaissVectorRetrieval('text'),
        ],
        'weights': [1]
    },
    {
        'name': 'image_embedding_only',
        'components': [
            FaissVectorRetrieval('image'),
        ],
        'weights': [1]
    },
    {
        'name': 'text_and_image_embedding',
        'components': [
            FaissVectorRetrieval('text'),
            FaissVectorRetrieval('image'),
        ],
        'weights': [0.5, 0.5]
    },
    {
        'name': 'hybrid_pg_vector_ts',
        'components': [
            FaissVectorRetrieval('text'),
            FaissVectorRetrieval('image'),
            TextSearchRetrieval('ts_rank_cd', DB_CONFIG)
        ],
        'weights': [0.4, 0.3, 0.3]
    }
]

In [15]:
experiment_name = "faiss_vector_experiments"
mlflow.set_experiment(experiment_name)

for config in configurations:
    print(f"\nRunning experiment: {config['name']}")
    print("=" * 80)
    run_experiment(config, dataset_name=dataset_name, model_name=model_name)


Running experiment: text_embedding_only

Processing Basic_query...
--------------------------------------------------------------------------------

Processing Attribute_query...
--------------------------------------------------------------------------------

Processing Natural_query...
--------------------------------------------------------------------------------

overall Recall@5: 0.14

Basic_query Recall@5: 0.00

Attribute_query Recall@5: 0.23

Natural_query Recall@5: 0.18

Category-specific results:

Category: Animals and Pet Supplies
Basic_query Recall@5: 0.00
Attribute_query Recall@5: 0.10
Natural_query Recall@5: 0.00

Category: Apparel and Accessories
Basic_query Recall@5: 0.00
Attribute_query Recall@5: 0.10
Natural_query Recall@5: 0.10

Category: Arts and Entertainment
Basic_query Recall@5: 0.00
Attribute_query Recall@5: 0.10
Natural_query Recall@5: 0.00

Category: Business and Industrial
Basic_query Recall@5: 0.00
Attribute_query Recall@5: 0.10
Natural_query Recall@5: 0.00