# Setup and test BQ connection

## Pre-requisites:
- enable vertex ai api
- create a new vertex ai connection object in BQ
- create an instance of the gemini embedding model using the vertex ai connection
```sql 
CREATE OR REPLACE MODEL `graphite-cell-472319-d2.amazon_esci.gemini_embedding`
REMOTE WITH CONNECTION `graphite-cell-472319-d2.US.vertex-ai`
OPTIONS (ENDPOINT = 'gemini-embedding-001');
```

In [None]:
# Cloud Storage
from google.cloud import storage
storage_client = storage.Client(project='graphite-cell-472319-d2')

# BigQuery
from google.cloud import bigquery
bigquery_client = bigquery.Client(project='graphite-cell-472319-d2')


# Load dataset and upload to BQ tables

In [None]:
# Install deps
import kagglehub
from kagglehub import KaggleDatasetAdapter
import pandas as pd
from google.cloud import bigquery

# ---- CONFIGURE THESE ----
PROJECT_ID = "graphite-cell-472319-d2"   # <-- replace with your Kaggle BigQuery project ID
DATASET_ID = "amazon_esci"           # BigQuery dataset name
# -------------------------

# File paths exactly as they exist in the dataset
examples_fp = "shopping_queries_dataset/shopping_queries_dataset_examples.parquet"
products_fp = "shopping_queries_dataset/shopping_queries_dataset_products.parquet"
sources_fp  = "shopping_queries_dataset/shopping_queries_dataset_sources.csv"

# Load files via kagglehub
df_examples = kagglehub.load_dataset(KaggleDatasetAdapter.PANDAS, "marquis03/amazon-esci", examples_fp)
df_products = kagglehub.load_dataset(KaggleDatasetAdapter.PANDAS, "marquis03/amazon-esci", products_fp)
df_sources  = kagglehub.load_dataset(KaggleDatasetAdapter.PANDAS, "marquis03/amazon-esci", sources_fp)

print("Examples:", df_examples.shape)
print("Products:", df_products.shape)
print("Sources:", df_sources.shape)

print(df_examples.head())

# Init BigQuery client
client = bigquery.Client(project=PROJECT_ID)

# Create dataset if it doesn't exist
dataset_ref = bigquery.Dataset(f"{PROJECT_ID}.{DATASET_ID}")
try:
    client.get_dataset(dataset_ref)
    print(f"Dataset {DATASET_ID} already exists")
except Exception:
    dataset_ref.location = "US"
    client.create_dataset(dataset_ref)
    print(f"Created dataset {DATASET_ID}")

# Helper to upload DataFrame to BigQuery
def upload_to_bq(df: pd.DataFrame, table_name: str):
    table_id = f"{PROJECT_ID}.{DATASET_ID}.{table_name}"
    job_config = bigquery.LoadJobConfig(
        autodetect=True,
        write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE
    )
    job = client.load_table_from_dataframe(df, table_id, job_config=job_config)
    job.result()
    print(f"Uploaded {len(df)} rows to {table_id}")

# Upload the three tables
upload_to_bq(df_examples, "shopping_queries_dataset_examples")
upload_to_bq(df_products, "shopping_queries_dataset_products")
upload_to_bq(df_sources, "shopping_queries_dataset_sources")


# Audit Amazon ESCI Dataset

In [None]:
query = """
    SELECT * FROM `graphite-cell-472319-d2.amazon_esci.shopping_queries_dataset_examples` LIMIT 10
"""
query_job = bigquery_client.query(query)  
results = query_job.result().to_dataframe()

print(results)

# Embed all queries with batched inference

In [None]:
from google.cloud import bigquery
from tqdm.notebook import tqdm

def run_batched_embedding(
    project_id: str,
    dataset_id: str,
    model_id: str,
    source_query: str,
    dest_table: str,
    batch_size: int = 10000,
):
    """
    Run batched embedding generation with ML.GENERATE_EMBEDDING in BigQuery.
    
    Args:
        project_id: GCP project ID
        dataset_id: BigQuery dataset ID
        model_id: Full model path for embedding (e.g. "project.dataset.model")
        source_query: SQL query that must output a column named 'content'
        dest_table: Destination table name (without project/dataset)
        batch_size: Number of rows per batch
    """

    client = bigquery.Client(project=project_id)
    dest_full_table = f"{project_id}.{dataset_id}.{dest_table}"

    # Count rows to figure out number of batches
    count_query = f"SELECT COUNT(*) as total FROM ({source_query})"
    total_rows = client.query(count_query).result().to_dataframe()["total"].iloc[0]
    total_batches = (total_rows + batch_size - 1) // batch_size

    print(f"Total rows: {total_rows}, Batch size: {batch_size}, Total batches: {total_batches}")

    for i in tqdm(range(total_batches), desc="Embedding batches"):
        offset = i * batch_size

        # Get i-th batch of queries
        batch_query = f"""
        WITH base AS (
          SELECT content
          FROM ({source_query})
          ORDER BY content
          LIMIT {batch_size} OFFSET {offset}
        ),
        emb AS (
          SELECT *
          FROM ML.GENERATE_EMBEDDING(
            MODEL `{model_id}`,
            (SELECT content FROM base),
            STRUCT(TRUE AS flatten_json_output, 'RETRIEVAL_DOCUMENT' AS task_type)
          )
        )
        SELECT 
          base.content,
          emb.ml_generate_embedding_result,
          emb.ml_generate_embedding_statistics,
          emb.ml_generate_embedding_status
        FROM base
        JOIN emb
        USING (content)
        """

        # Push batch result to the destination table
        job_config = bigquery.QueryJobConfig(
            destination=dest_full_table,
            write_disposition="WRITE_TRUNCATE" if i == 0 else "WRITE_APPEND"
        )

        job = client.query(batch_query, job_config=job_config)
        job.result()  # wait for completion

    print(f"✅ Finished embedding into {dest_full_table}")


In [None]:
run_batched_embedding(
    project_id="graphite-cell-472319-d2",
    dataset_id="amazon_esci",
    model_id="graphite-cell-472319-d2.amazon_esci.gemini_embedding",
    source_query="""
        SELECT DISTINCT LOWER(TRIM(query)) AS content
        FROM `graphite-cell-472319-d2.amazon_esci.shopping_queries_dataset_examples`
    """,
    dest_table="query_embeddings",
    batch_size=10000
)


# Filter out Products to Embed 

Filter: `small_version == 1 + product_locale == 'us'`

In [None]:
# from google.cloud import bigquery

client = bigquery.Client(project="graphite-cell-472319-d2")

# Consolidate the key textual data (title, description, bullet points) for each product into a single, clean field.
query = """
CREATE OR REPLACE TABLE `graphite-cell-472319-d2.amazon_esci.products_to_embed`
AS
SELECT 
  product_id,
  ANY_VALUE(CONCAT(
      'PRODUCT_TITLE: ', IFNULL(product_title, '...'), '\\n\\n',
      'PRODUCT_BULLET_POINTS: ', IFNULL(product_bullet_point, '...'), '\\n\\n',
      'PRODUCT_DESCRIPTION: ', IFNULL(product_description, '...')
  )) AS content
FROM `graphite-cell-472319-d2.amazon_esci.shopping_queries_dataset_examples`
JOIN `graphite-cell-472319-d2.amazon_esci.shopping_queries_dataset_products`
USING (product_id, product_locale)
WHERE product_locale = 'us'
  AND small_version = 1
GROUP BY product_id
"""

job = client.query(query)
job.result()  # Waits for the job to finish
print("Table created: graphite-cell-472319-d2.amazon_esci.products_to_embed")

In [None]:
df_preview = client.query(
    "SELECT * FROM `graphite-cell-472319-d2.amazon_esci.products_to_embed` LIMIT 5"
).to_dataframe()

df_preview

# Embed all products in batched inference
> **Warning**: takes upto 4 hours to run on whole filtered dataset (482k embeddings)


In [None]:
run_batched_embedding(
    project_id="graphite-cell-472319-d2",
    dataset_id="amazon_esci",
    model_id="graphite-cell-472319-d2.amazon_esci.gemini_embedding",
    source_query="""
        SELECT product_id, content
        FROM `graphite-cell-472319-d2.amazon_esci.products_to_embed`
    """,
    dest_table="product_embeddings_test",
    batch_size=10000
)


# Calculate NDCG with baseline vector search
Measure the performance of the general-purpose embeddings using NDCG metric.

We converted the ECSI labels to their numerical equivalents in line with original [papers](https://arxiv.org/pdf/2206.06588) of the dataset.
| Label | Relevance Score |
| :---: | :-------------: |
|   E   |       1.0       |
|   S   |       0.1       |
|   C   |      0.01       |
|   I   |       0.0       |

In [None]:
from google.cloud import bigquery

# Initialize client
client = bigquery.Client(project="graphite-cell-472319-d2")

# Full NDCG query
ndcg_query = """
-- Step 1: Aggregate labels
-- Weight ESCI labels for every (query, product) pair
WITH Examples_Aggregated AS (
  SELECT
    query,
    product_id,
    AVG(CASE
          WHEN esci_label = "E" THEN 1.0
          WHEN esci_label = "S" THEN 0.1
          WHEN esci_label = "C" THEN 0.01
          WHEN esci_label = "I" THEN 0.0
        END) AS relevance
  FROM `graphite-cell-472319-d2.amazon_esci.shopping_queries_dataset_examples`
  WHERE product_locale = "us"
    AND small_version = 1
  GROUP BY query, product_id
),

-- Step 2: Join embeddings
-- Get query and product embeddings  
EvaluationSet AS (
  SELECT
    ex.query,
    ex.relevance,
    qe.ml_generate_embedding_result AS query_embedding,
    pe.ml_generate_embedding_result AS product_embedding
  FROM Examples_Aggregated ex
  JOIN `graphite-cell-472319-d2.amazon_esci.query_embeddings` qe
    ON ex.query = qe.content
  JOIN `graphite-cell-472319-d2.amazon_esci.product_embeddings` pe
    ON ex.product_id = pe.product_id
),

-- Step 3: Distance
-- Calculate cosine distance between query and product embeddings
BaseResults AS (
  SELECT
    query,
    relevance,
    COSINE_DISTANCE(query_embedding, product_embedding) AS distance
  FROM EvaluationSet
),

-- Step 4: DCG
DCG_Calc AS (
  SELECT
    query,
    SUM(relevance / LOG(1 + rank, 2)) AS dcg
  FROM (
    SELECT
      query,
      relevance,
      ROW_NUMBER() OVER(PARTITION BY query ORDER BY distance ASC) AS rank
    FROM BaseResults
  )
  GROUP BY query
),

-- Step 5: IDCG
IDCG_Calc AS (
  SELECT
    query,
    SUM(relevance / LOG(1 + rank, 2)) AS idcg
  FROM (
    SELECT
      query,
      relevance,
      ROW_NUMBER() OVER(PARTITION BY query ORDER BY relevance DESC) AS rank
    FROM BaseResults
  )
  GROUP BY query
),

-- Step 6: NDCG
FinalNDCG AS (
  SELECT
    d.query,
    SAFE_DIVIDE(d.dcg, i.idcg) AS ndcg
  FROM DCG_Calc d
  JOIN IDCG_Calc i
    ON d.query = i.query
)

-- Final Result
SELECT
  COUNT(query) AS num_queries_evaluated,
  AVG(ndcg) AS mean_ndcg
FROM FinalNDCG;
"""

# Run query
job = client.query(ndcg_query)
result = job.result().to_dataframe()

# Print result
print(result)


# Generate centroid vectors for each query
Calculate a centroid vector for each query based on users interactions with the products post the search.

Because the dataset only contain ESCI labels, we will assume the following Add-To-Cart weights for each:
| Label | Add to Cart (ATC) Range |
| :---: | :---------------------: |
|   E   |      10,000-99,999      |
|   S   |       1,000-9,999       |
|   C   |         100-999         |
|   I   |          0-10           |


In [None]:
from google.cloud import bigquery

client = bigquery.Client(project="graphite-cell-472319-d2")

centroid_query = r'''
-- JavaScript UDAF to aggregate embeddings with their weight and return a normalized centroid of them
CREATE TEMPORARY AGGREGATE FUNCTION GET_EMBEDDING_CENTROID(sku_embedding ARRAY<FLOAT64>, atc FLOAT64)
  RETURNS ARRAY<FLOAT64>
  LANGUAGE js
  AS r"""

  export function initialState() {
    return { sumVector: null };
  }

  export function aggregate(state, sku_embedding, atc) {
    const weight = atc;
    if (!sku_embedding || sku_embedding.length === 0 || weight <= 0) {
      return;
    }

    if (state.sumVector === null) {
      state.sumVector = sku_embedding.map(num => num * weight);
    } else {
      for (let i = 0; i < state.sumVector.length; i++) {
        state.sumVector[i] += sku_embedding[i] * weight;
      }
    }
  }

  export function merge(state, partialState) {
    if (!partialState.sumVector) {
      return;
    }

    if (state.sumVector === null) {
      state.sumVector = partialState.sumVector;
    } else {
      for (let i = 0; i < state.sumVector.length; i++) {
        state.sumVector[i] += partialState.sumVector[i];
      }
    }
  }

  export function finalize(state) {
    if (!state.sumVector || state.sumVector.length === 0) {
      return [];
    }

    // Normalize the vector
    const vecLength = Math.sqrt(state.sumVector.reduce((acc, num) => acc + num * num, 0));
    if (vecLength === 0) {
      return [];
    }
    return state.sumVector.map(num => num / vecLength);
  }

""";

-- JavaScript UDF to calculate a standard deviation on a weighted input  
CREATE TEMPORARY FUNCTION StandardDeviation(arr ARRAY<STRUCT<distance FLOAT64, weight FLOAT64>>)
RETURNS FLOAT64
LANGUAGE js AS """
if (arr.length === 0) {
  return 0;
}

const weight_sum = arr.reduce((acc, num) => acc + num.weight, 0);
const mean = arr.reduce((acc, num) => acc + num.weight * num.distance, 0) / weight_sum;
const sdv = arr.reduce((acc, num) => acc + Math.pow(num.distance - mean, 2) * num.weight, 0) / weight_sum;
return Math.sqrt(sdv);
""";

CREATE OR REPLACE TABLE `graphite-cell-472319-d2.amazon_esci.centroid_vectors`
AS

WITH

-- Assume the number of ATC events on each (query, product) based on the ESCI label
  product_query_actions AS (
    SELECT query, 
           product_id,
           (CASE
              WHEN esci_label = "E" THEN CAST(FLOOR(RAND() * (99999 - 10000 + 1)) + 10000 AS INT64)
              WHEN esci_label = "S" THEN CAST(FLOOR(RAND() * (9999 - 1000 + 1)) + 1000 AS INT64)
              WHEN esci_label = "C" THEN CAST(FLOOR(RAND() * (999 - 100 + 1)) + 100 AS INT64)
              WHEN esci_label = "I" THEN CAST(FLOOR(RAND() * (10 - 0 + 1)) + 0 AS INT64)
           END) AS atc,
    FROM `graphite-cell-472319-d2.amazon_esci.shopping_queries_dataset_examples`
    WHERE TRUE
      AND small_version = 1
      AND product_locale = "us"
  ),
  
-- Normalize number of ATC
  normalized_actions AS (
    SELECT query, 
           product_id, 
           100 * atc / SUM(atc) OVER (PARTITION BY query) AS normalized_atc
    FROM product_query_actions
    WHERE atc > 0
    GROUP BY query, product_id, atc
  ),
  
-- Calculate initial centroid of each query (include all products)
  query_centroid AS (
    SELECT
      na.query,
      GET_EMBEDDING_CENTROID(pe.ml_generate_embedding_result, na.normalized_atc) AS centroid
    FROM normalized_actions AS na
    JOIN `graphite-cell-472319-d2.amazon_esci.product_embeddings` AS pe ON na.product_id = pe.product_id
    GROUP BY query
    HAVING ARRAY_LENGTH(centroid) > 0
  ),

-- Check the distance of each product from the intial query centroid
  product_distance AS (
    SELECT
      na.query,
      na.product_id,
      na.normalized_atc,
      pe.ml_generate_embedding_result AS product_embedding,
      qc.centroid,
      ML.DISTANCE(qc.centroid, pe.ml_generate_embedding_result, "COSINE") / 2 AS distance
    FROM normalized_actions AS na
    JOIN `graphite-cell-472319-d2.amazon_esci.product_embeddings` AS pe ON na.product_id = pe.product_id
    JOIN query_centroid AS qc ON na.query = qc.query
  ),

-- Calculate standard deviation and mean of products distances from the initial centroid 
  query_aggregations AS (
    SELECT query, 
           StandardDeviation(ARRAY_AGG(STRUCT(distance, normalized_atc))) AS sdv,
           SUM(distance * normalized_atc) / SUM(normalized_atc) AS mean
    FROM product_distance
    GROUP BY ALL
    HAVING sdv > 0
  ),

-- Filter outliers products
  filtered_actions AS (
    SELECT query, 
           product_id, 
           normalized_atc, 
           product_embedding
    FROM product_distance
    JOIN query_aggregations USING (query)
    WHERE (distance - mean) / sdv <= 1.5
  ),

-- Calculate the final sanitized centroid for each query 
  sanitized_centroid AS (
    SELECT
      query,
      GET_EMBEDDING_CENTROID(product_embedding, normalized_atc) AS centroid
    FROM filtered_actions
    GROUP BY query
  )
  
SELECT query, 
       centroid
FROM sanitized_centroid
'''

# Run the query to create the table
job = client.query(centroid_query)
job.result()  # wait for completion

print("✅ centroid_vectors table created.")

# Preview the results
preview_df = client.query("""
    SELECT * 
    FROM `graphite-cell-472319-d2.amazon_esci.centroid_vectors`
    LIMIT 5
""").to_dataframe()

print(preview_df)

# Calculate NDCG with centroid vector search
Evaluate the new centroid vector using NDCG metric (Using same ESCI weights as mentioned above)

In [None]:
from google.cloud import bigquery

# Initialize BigQuery client
bqclient = bigquery.Client(project="graphite-cell-472319-d2")

# NDCG query for centroid vectors
query = """
-- Step 1: Aggregate labels
-- Weight ESCI labels for every (query, product) pair
WITH Examples_Aggregated AS (
  SELECT
    query,
    product_id,
    AVG(
      CASE
        WHEN esci_label = "E" THEN 1.0
        WHEN esci_label = "S" THEN 0.1
        WHEN esci_label = "C" THEN 0.01
        WHEN esci_label = "I" THEN 0.0
      END
    ) AS relevance
  FROM `graphite-cell-472319-d2.amazon_esci.shopping_queries_dataset_examples`
  WHERE product_locale = "us"
    AND small_version = 1
  GROUP BY query, product_id
),

-- Step 2: Join embeddings
-- Get query and product embeddings
EvaluationSet AS (
  SELECT
    ex.query,
    ex.relevance,
    qe.centroid AS query_embedding,
    pe.ml_generate_embedding_result AS product_embedding
  FROM Examples_Aggregated ex
  JOIN `graphite-cell-472319-d2.amazon_esci.centroid_vectors` qe
    ON ex.query = qe.query
  JOIN `graphite-cell-472319-d2.amazon_esci.product_embeddings` pe
    ON ex.product_id = pe.product_id
),

-- Step 3: Compute distances
-- Calculate cosine distance between query and product embeddings 
BaseResults AS (
  SELECT
    query,
    relevance,
    COSINE_DISTANCE(query_embedding, product_embedding) AS distance
  FROM EvaluationSet
),

-- Step 4: Compute DCG
DCG_Calc AS (
  SELECT
    query,
    SUM(relevance / LOG(1 + rank, 2)) AS dcg
  FROM (
    SELECT
      query,
      relevance,
      ROW_NUMBER() OVER (PARTITION BY query ORDER BY distance ASC) AS rank
    FROM BaseResults
  )
  GROUP BY query
),

-- Step 5: Compute IDCG (ideal DCG)
IDCG_Calc AS (
  SELECT
    query,
    SUM(relevance / LOG(1 + rank, 2)) AS idcg
  FROM (
    SELECT
      query,
      relevance,
      ROW_NUMBER() OVER (PARTITION BY query ORDER BY relevance DESC) AS rank
    FROM BaseResults
  )
  GROUP BY query
),

-- Step 6: Normalize DCG
FinalNDCG AS (
  SELECT
    d.query,
    SAFE_DIVIDE(d.dcg, i.idcg) AS ndcg
  FROM DCG_Calc d
  JOIN IDCG_Calc i
    ON d.query = i.query
)

-- Final Result
SELECT
  COUNT(query) AS num_queries_evaluated,
  AVG(ndcg) AS mean_ndcg
FROM FinalNDCG;
"""

# Run query
job = bqclient.query(query)
result = job.result().to_dataframe()

# Print results
print("Number of queries evaluated:", result["num_queries_evaluated"].iloc[0])
print("Mean NDCG:", result["mean_ndcg"].iloc[0])


# Generate Vector Index over product search space

In [None]:
from google.cloud import bigquery

client = bigquery.Client(project="graphite-cell-472319-d2")

create_index_query = """
CREATE VECTOR INDEX `graphite-cell-472319-d2.amazon_esci.product_embeddings_index`
ON `graphite-cell-472319-d2.amazon_esci.product_embeddings`(ml_generate_embedding_result)
OPTIONS(
  index_type = 'IVF',
  distance_type = 'COSINE',
  ivf_options = '{"num_lists": 100}'
)
"""

job = client.query(create_index_query)
job.result()
print("✅ Vector index created successfully with num_lists=100")



# Sample search run on baseline

In [None]:
from google.cloud import bigquery

client = bigquery.Client(project="graphite-cell-472319-d2")

# hardcoded query for testing purposes
search_text = "rc drone without camera"

# return top 20 search results
top_k = 20
baseline_sql = f"""
SELECT base.content AS query,
       query.product_id,
       LEFT(query.content, 70) AS product_details,
       distance AS cosine_distance,
        (CASE 
        WHEN query.esci_label = "E" THEN "Exact"
        WHEN query.esci_label = "S" THEN "Substitute"
        WHEN query.esci_label = "C" THEN "Complement"
        WHEN query.esci_label = "I" THEN "Irrelevant"
       END
       ) AS esci_label,
       (CASE 
        WHEN query.esci_label = "E" THEN 1
        WHEN query.esci_label = "S" THEN 0.1
        WHEN query.esci_label = "C" THEN 0.01
        WHEN query.esci_label = "I" THEN 0
       END
       ) AS relevance
FROM
VECTOR_SEARCH(
  (SELECT * 
   FROM `graphite-cell-472319-d2.amazon_esci.query_embeddings` 
   WHERE content = "{search_text}"),
  'ml_generate_embedding_result',
  (SELECT product_id, 
          content,
          ml_generate_embedding_result,
          esci_label
  FROM `graphite-cell-472319-d2.amazon_esci.product_embeddings`
  JOIN `graphite-cell-472319-d2.amazon_esci.shopping_queries_dataset_examples` USING (product_id)
  WHERE TRUE
    AND product_locale = "us"
    AND query = "{search_text}"
    AND small_version = 1),
  'ml_generate_embedding_result',
  top_k => {top_k}
)  
ORDER BY distance
LIMIT {top_k}

"""
query_job = client.query(baseline_sql)
baseline_results = query_job.result().to_dataframe()

print("🔹 Baseline Vector Search Results")
display(baseline_results)



# Sample search run on centroid vectors

In [None]:
from google.cloud import bigquery

client = bigquery.Client(project="graphite-cell-472319-d2")

# hardcoded query for testing purposes
search_text = "rc drone without camera"

# return top 20 search results
top_k = 20
centroid_sql = f"""
SELECT base.query,
       query.product_id,
       LEFT(query.content, 70) AS product_details,
       distance AS cosine_distance,
        (CASE 
        WHEN query.esci_label = "E" THEN "Exact"
        WHEN query.esci_label = "S" THEN "Substitute"
        WHEN query.esci_label = "C" THEN "Complement"
        WHEN query.esci_label = "I" THEN "Irrelevant"
       END
       ) AS esci_label,
       (CASE 
        WHEN query.esci_label = "E" THEN 1
        WHEN query.esci_label = "S" THEN 0.1
        WHEN query.esci_label = "C" THEN 0.01
        WHEN query.esci_label = "I" THEN 0
       END
       ) AS relevance
FROM
VECTOR_SEARCH(
  (SELECT * 
   FROM `graphite-cell-472319-d2.amazon_esci.centroid_vectors` 
   WHERE query = "{search_text}"),
  'centroid',
  (SELECT product_id, 
          content,
          ml_generate_embedding_result,
          esci_label
  FROM `graphite-cell-472319-d2.amazon_esci.product_embeddings`
  JOIN `graphite-cell-472319-d2.amazon_esci.shopping_queries_dataset_examples` USING (product_id)
  WHERE TRUE
    AND product_locale = "us"
    AND query = "{search_text}"
    AND small_version = 1),
  'ml_generate_embedding_result',
  top_k => {top_k}
)  
ORDER BY distance
LIMIT {top_k}


"""
query_job = client.query(centroid_sql)
centroid_results = query_job.result().to_dataframe()

print("🔹 Centroid Vector Search Results")
display(centroid_results)

