In [1]:
import os

import duckdb
from dotenv import load_dotenv
import pandas as pd
from lotus.dtype_extensions import ImageArray
from lotus.types import CascadeArgs, ProxyModel

from join_optimizer.lotus.evaluate import evaluate_filter

load_dotenv()

LOAD_INDEX = True

OFF_DATASET_DIR = os.getenv("OFF_DATASET_DIR")

OFF_PARQUET = os.path.join(OFF_DATASET_DIR, "products.parquet")
OFF_IMAGES_DIR = os.path.join(OFF_DATASET_DIR, "images")
DATASET_CAPTION_DB_BLIP = os.path.join(OFF_DATASET_DIR, "off_uk_top2000_with_images_caps_blip-image-captioning-large.db")
DATASET_CAPTION_DB_INSTRUCTBLIP = os.path.join(OFF_DATASET_DIR, "off_uk_top2000_with_images_caps_instructblip-flan-t5-xl.db")

sample_size_percentage = 1
seed = 80
df = duckdb.query(f"""
    SELECT *
    FROM parquet_scan('{OFF_PARQUET}')

    USING SAMPLE {sample_size_percentage} PERCENT (reservoir, {seed})
    ORDER BY code ASC

""").to_df()

df["image"] = ImageArray(df["code"].apply(lambda i: os.path.join(OFF_IMAGES_DIR, f"{str(i)}.jpg")))
df["image_url"] = ImageArray(df["image_front_url"]
                             # .apply(lambda i: i.replace('.400.', '.full.'))
                             )


  from tqdm.autonotebook import tqdm, trange


#### Creating the index

In [2]:
from lotus.fts_store.db_fts_store import SQLiteFTSStore
from lotus.vector_store import FaissVS
import lotus
from lotus.models import LM, SentenceTransformersRM

gpt_5_nano = LM("gpt-5-nano")
gpt_5_mini = LM("gpt-5-mini")

gpt_5_____sure = LM("gpt-5")

gpt_4o_mini = LM("gpt-4o-mini")
gpt_4o = LM("gpt-4o")

# CLIP embedding model – works for both text & image
# rm  = SentenceTransformersRM(model="clip-ViT-B-32")
rm = SentenceTransformersRM(model="clip-ViT-L-14", max_batch_size=32)

lotus.settings.configure(lm=gpt_5_mini, helper_lm=gpt_5_nano, rm=rm, vs=FaissVS(), cs=SQLiteFTSStore())

2025-08-28 17:15:17,195 - INFO - Load pretrained SentenceTransformer: clip-ViT-L-14
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [41]:
if not LOAD_INDEX:
    df = df.sem_index("image", index_dir=f"{OFF_DATASET_DIR}/image{sample_size_percentage}_index")



In [42]:
df = df.load_sem_index("image", index_dir=f"{OFF_DATASET_DIR}/image{sample_size_percentage}_index")
df = df.load_sem_index("image_url", index_dir=f"{OFF_DATASET_DIR}/image{sample_size_percentage}_index")

In [43]:
df = df.sample(n=1000, random_state=seed)


# Prompt

In [44]:
prompt = "dark chocolate"

# Full LLM calls

In [45]:
df_resd_llm = df.sem_filter(prompt, col_li=["image"], return_stats=False)

Filtering: 100%|██████████ 1000/1000 LM calls [01:32<00:00, 10.83it/s]


# Binary search filter

In [36]:

cascade_args = CascadeArgs(
    recall_target=0.9,
    precision_target=0.9,
    sampling_percentage=0.1,
    proxy_model=ProxyModel.EMBEDDING_MODEL,
)

df_resd_binary_s = df.sem_filter(prompt, col_li=["image"], cascade_args=cascade_args,
                                              return_stats=True, find_top_k=True)


Filtering: 100%|██████████ 1/1 LM calls [00:04<00:00,  4.59s/it]
Filtering: 100%|██████████ 1/1 LM calls [00:08<00:00,  8.49s/it]
Filtering: 100%|██████████ 1/1 LM calls [00:07<00:00,  7.73s/it]
Filtering: 100%|██████████ 1/1 LM calls [00:04<00:00,  4.54s/it]
Filtering: 100%|██████████ 1/1 LM calls [00:08<00:00,  8.88s/it]
Filtering: 100%|██████████ 1/1 LM calls [00:06<00:00,  6.42s/it]
Filtering: 100%|██████████ 1/1 LM calls [00:05<00:00,  5.90s/it]
Filtering: 100%|██████████ 1/1 LM calls [00:05<00:00,  5.04s/it]
Filtering: 100%|██████████ 1/1 LM calls [00:08<00:00,  8.31s/it]
Filtering: 100%|██████████ 1/1 LM calls [00:04<00:00,  4.73s/it]
Filtering: 100%|██████████ 1/1 LM calls [00:08<00:00,  8.97s/it]
Filtering: 100%|██████████ 1/1 LM calls [00:05<00:00,  5.43s/it]


In [37]:
metrics, FP, FN = evaluate_filter(
    dataset_df=df_resd_llm,
    filtered_df=df_resd_binary_s,
    id_column='code'
)
print(metrics)

{'TP': 20, 'FP': 1, 'FN': 16, 'precision': 0.9523809523809523, 'recall': 0.5555555555555556, 'f1': 0.7017543859649122}


# Sampling

In [22]:
cascade_args = CascadeArgs(
    recall_target=0.95,
    precision_target=0.9,
    sampling_percentage=0.1,
    proxy_model=ProxyModel.EMBEDDING_MODEL,
    cascade_IS_weight=1,
    cascade_num_calibration_quantiles=100,
    failure_probability=0.1,
    cascade_IS_random_seed=114,
)
df_resd_lotus = df.sem_filter(prompt, col_li=["image"], cascade_args=cascade_args, return_stats=False,
                                           find_top_k=False)


Running oracle for threshold learning: 100%|██████████ 100/100 LM calls [00:18<00:00,  5.51it/s]
2025-08-26 19:23:06,065 - INFO - Sample recall: 1.0
2025-08-26 19:23:06,065 - INFO - Sample precision: 1.0
2025-08-26 19:23:06,066 - INFO - Learned cascade thresholds: (0.2213810384273529, 0.145962193608284)
2025-08-26 19:23:06,066 - INFO - Num routed to smaller model: 845
Running predicate evals with oracle LM: 100%|██████████ 155/155 LM calls [00:19<00:00,  8.06it/s]


In [23]:
metrics, FP, FN = evaluate_filter(
    dataset_df=df_resd_llm,
    filtered_df=df_resd_lotus,
    id_column='code'
)
print(metrics)

{'TP': 33, 'FP': 1, 'FN': 3, 'precision': 0.9705882352941176, 'recall': 0.9166666666666666, 'f1': 0.9428571428571428}


# Caption Search BLIP


In [24]:
df = df.sem_captions_index.attach_index("image", index_dir=DATASET_CAPTION_DB_BLIP)
df = df.sem_captions_index.load("image")
df_resd_blip = df.sem_captions_index.search(prompt, "image")

In [25]:
metrics, FP, FN = evaluate_filter(
    dataset_df=df_resd_llm,
    filtered_df=df_resd_blip,
    id_column='code'
)
print(metrics)


{'TP': 10, 'FP': 1, 'FN': 26, 'precision': 0.9090909090909091, 'recall': 0.2777777777777778, 'f1': 0.4255319148936171}


### with prompt augmentation

In [26]:
def augment_prompt(prompt, augmentation_prompt):
    prompt_as_df = pd.DataFrame({"query": [prompt]})
    return prompt_as_df.sem_map(augmentation_prompt, suffix="augmented_prompt")["augmented_prompt"][0].replace("'", " ").replace("-", " ")


In [27]:
prompt_augmentation_prompt = "you will receive a {query} to do a full text search filter on a dataset. since the search is sintactical, provide 10 other prompts similar to the one provided, so that similar items can be obtained. Separate the results with a simple space and without delimiters like \" or \«. only respond with the result."

prompt_augmentation_prompt = """
You will receive a plain-language search {query} and must return a SINGLE valid SQLite FTS5 MATCH expression (the right-hand side of `... MATCH <expr>`). Return ONLY the expression, with no quotes around the whole thing, no SQL, no code fences, and no explanations.




REQUIREMENTS
1) Output must be valid FTS5 boolean syntax using ONLY: parentheses `()`, `AND`, `OR`, `NOT`, double-quoted phrases, and (optionally) `NEAR` when allowed below. Do NOT use field qualifiers, weights, or other SQL.
2) Group synonyms/near-lexicon with OR inside parentheses. Use AND between concept buckets.
   - Example shape: `(concept1_a OR concept1_b OR "concept1 phrase") AND (concept2_a OR concept2_b) ...`
3) Expand the user_query into 2–5 concept buckets (meaningful facets like style, color/tone, item types, descriptors, etc.). Inside each bucket, include common synonyms, close lexical variants, and singular/plural irregulars. The database already handles case, diacritics, and stemming—only add explicit variants when helpful (e.g., "tuxedo OR tuxedos", "black-tie OR \"black tie\"").
4) If must_include is provided, ensure each term/phrase is present by adding extra AND groups for them (quoted as needed).
5) If exclude is provided, append `AND NOT (...)` with OR-joined terms/phrases to filter them out.
6) Phrases must use double quotes (e.g., "black tie"). Do NOT wrap the entire output in quotes.
7) Avoid `*` wildcards unless the input explicitly asks for prefix search.
8) Proximity:
   - If require_proximity = true, use NEAR **only in properly nested binary form** and at most to tie TWO buckets: `((bucketA) NEAR (bucketB)) AND (bucketC) ...`. Never chain `A NEAR B NEAR C` without nesting.
   - If require_proximity = false (default), do NOT use NEAR.
9) Keep the expression concise (<1000 characters).

OUTPUT
- Only the MATCH expression
"""

augmented_prompt = augment_prompt(prompt, prompt_augmentation_prompt)
print(augmented_prompt)

Mapping: 100%|██████████ 1/1 LM calls [00:11<00:00, 11.94s/it]

(dark OR "darkest" OR bitter OR bittersweet OR "semi sweet" OR semisweet OR "extra dark" OR intense) AND (chocolate OR chocolates OR "chocolate bar" OR "chocolate bars" OR cocoa OR cacao OR truffle OR bonbon)





In [28]:
df_resd_blip_augmented = df.sem_captions_index.search(augmented_prompt, "image")


In [29]:
metrics, FP, FN = evaluate_filter(
    dataset_df=df_resd_llm,
    filtered_df=df_resd_blip,
    id_column='code'
)
print(metrics)


{'TP': 10, 'FP': 1, 'FN': 26, 'precision': 0.9090909090909091, 'recall': 0.2777777777777778, 'f1': 0.4255319148936171}


# Caption Search INSTRUCTBLIP

In [48]:
df = df.sem_captions_index.attach_index("image", index_dir=DATASET_CAPTION_DB_INSTRUCTBLIP)
df = df.sem_captions_index.load("image")
df_resd_instructblip = df.sem_captions_index.search(prompt, "image")

In [49]:
metrics, FP, FN = evaluate_filter(
    dataset_df=df_resd_llm,
    filtered_df=df_resd_instructblip,
    id_column='code'
)
print(metrics)


{'TP': 19, 'FP': 1, 'FN': 15, 'precision': 0.95, 'recall': 0.5588235294117647, 'f1': 0.7037037037037037}


### Augmented

In [32]:
df_resd_instructblip_augmented = df.sem_captions_index.search(augmented_prompt, "image")

In [33]:
metrics, FP, FN = evaluate_filter(
    dataset_df=df_resd_llm,
    filtered_df=df_resd_instructblip_augmented,
    id_column='code'
)
print(metrics)


{'TP': 20, 'FP': 1, 'FN': 16, 'precision': 0.9523809523809523, 'recall': 0.5555555555555556, 'f1': 0.7017543859649122}


### Pass through small model

In [34]:
df_resd_captions_small_model = df.sem_filter(
    prompt, col_li=["image_cap"])

Filtering:   0%|           0/1000 LM calls [00:00<?, ?it/s]2025-08-26 19:24:28,180 - INFO - Retrying request to /chat/completions in 0.450501 seconds
Filtering: 100%|██████████ 1000/1000 LM calls [01:07<00:00, 14.85it/s]


In [38]:
metrics, FP, FN = evaluate_filter(
    dataset_df=df_resd_llm,
    filtered_df=df_resd_captions_small_model,
    id_column='code'
)
print(metrics)

{'TP': 22, 'FP': 0, 'FN': 14, 'precision': 1.0, 'recall': 0.6111111111111112, 'f1': 0.7586206896551725}

=== Usage Statistics ===
Virtual  = Total usage if no caching was used
Physical = Actual usage with caching applied

Virtual Cost:     $0.822935
Physical Cost:    $0.822935
Virtual Tokens:   1,451,181
Physical Tokens:  1,451,181
Cache Hits:       0

