### Installing requried libraries

In [292]:
%pip install -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


In [293]:
import os
import json
import time
import pandas as pd
import google.genai as genai
from google.genai import types
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from dotenv import load_dotenv
from tqdm import tqdm
import typing_extensions as typing

### Loading local embedding model

In [307]:
print("Embedding model loaded: all-mpnet-base-v2")
embedder = SentenceTransformer('all-mpnet-base-v2')

Embedding model loaded: all-mpnet-base-v2


### Constraint JSON schema

In [296]:
constraint_schema = {
    "type": "OBJECT",
    "properties": {
        "category": {
            "type": "STRING",
            "nullable": True
        },
        "temperature": {
            "type": "STRING",
            "nullable": True
        },
        "max_calories": {
            "type": "NUMBER",
            "nullable": True
        },
        "max_sugar": {
            "type": "NUMBER",
            "nullable": True
        },
        "max_price": {
            "type": "NUMBER",
            "nullable": True
        },
        "dairy_free": {
            "type": "BOOLEAN",
            "nullable": True
        },
        "vegan": {
            "type": "BOOLEAN",
            "nullable": True
        },
        "caffeine_level": {
            "type": "STRING",
            "nullable": True
        }
    },
    "required": ["category", "temperature", "max_price"]
}

### Loading product data and training data

In [308]:
try:
    df_products = pd.read_csv('data/products.csv')
    df_queries = pd.read_csv('data/queries_train.csv') 
    print("Data loaded successfully.")
except FileNotFoundError:
    print("Error: CSV files not found. Please check your 'data' folder.")
    exit()

Data loaded successfully.


### Pre-computing product embeddings

In [309]:
print("Generating product embeddings...")
cols_to_use = ['name','category','subcategory','temperature','caffeine_mg',
               'calories','sugar_g','protein_g','contains_dairy','contains_nuts',
               'contains_gluten','is_vegan','description','price']

df_products['embedding_text'] = (
    df_products[cols_to_use]
        .fillna('')
        .astype(str)
        .agg(' '.join, axis=1)
)

product_embeddings = embedder.encode(df_products['embedding_text'].tolist())
print("Product embeddings ready.")

Generating product embeddings...
Product embeddings ready.


### Vertex AI initialization/authentication and constraint extraction using prompt

In [314]:
import time
import json
import os
import vertexai
import warnings
from vertexai.generative_models import GenerativeModel, GenerationConfig

warnings.filterwarnings("ignore")

# Authentication
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "auth.json"

# Initialization with Project ID
with open("auth.json") as f:
    creds = json.load(f)
    PROJECT_ID = creds["project_id"]

# Initializing Vertex AI with the new project
vertexai.init(project=PROJECT_ID, location="us-central1")
print(f"Connected to Project: {PROJECT_ID}")

def extract_constraints_gemini(query_text):
    try:
        model = GenerativeModel("gemini-2.5-flash")
        
        prompt = f"""
        You are a Starbucks Data Assistant. Extract search constraints from this query:
        Query: "{query_text}"
        
        Return a valid JSON object. Use null for missing values.
        - category: "espresso", "brewed", "cold_brew", "frappuccino", "refresher", "tea" or null
        - temperature: "hot", "iced", "blended" or null
        - max_calories: number or null
        - max_sugar: number (grams) or null
        - max_price: number or null
        - dairy_free: true or null
        - vegan: true or null
        - caffeine_level: "none", "low", "medium", "high" or null
        Return strictly valid JSON.
        Use null (without quotes) for missing values.
        Do NOT use 0 unless explicitly stated in the query.
        Do NOT return "null" as a string.
        """

        response = model.generate_content(
            prompt,
            generation_config=GenerationConfig(
                response_mime_type="application/json",
                response_schema=constraint_schema
            )
        )
        
        parsed = json.loads(response.text)
        time.sleep(1)
        return parsed

    except Exception as e:
        print(f"Error: {e}")
        time.sleep(2)
        return {}

Connected to Project: starbucks-barista-486010


### Filtering products based on constraints

In [315]:
def filter_products(products_df, constraints):
    """
    Removes products that violate the specific constraints found by Gemini.
    """
    filtered = products_df.copy()
    
    if constraints.get('category'):
        filtered = filtered[filtered['category'] == constraints['category']]
    if constraints.get('temperature'):
        filtered = filtered[filtered['temperature'] == constraints['temperature']]
        
    if constraints.get('max_calories'):
        filtered = filtered[filtered['calories'] <= constraints['max_calories']]
    if constraints.get('max_sugar'):
        filtered = filtered[filtered['sugar_g'] <= constraints['max_sugar']]
    if constraints.get('max_price'):
        filtered = filtered[filtered['price'] <= constraints['max_price']]
        
    if constraints.get('dairy_free'):
        filtered = filtered[filtered['contains_dairy'] == False]
    if constraints.get('vegan'):
        filtered = filtered[filtered['is_vegan'] == True]
    
    if constraints.get('caffeine_level'):
        level = constraints['caffeine_level']
        if level == "none":
            filtered = filtered[filtered['caffeine_mg'] == 0]
        elif level == 'low':
            filtered = filtered[(filtered['caffeine_mg'] > 0) & (filtered['caffeine_mg'] <= 70)]
        elif level == 'medium':
            filtered = filtered[(filtered['caffeine_mg'] > 70) & (filtered['caffeine_mg'] <= 150)]
        elif level == 'high':
            filtered = filtered[filtered['caffeine_mg'] > 150]
            
    return filtered

### Ranking products on Cosine similarity

In [316]:
def rank_products(query_text, candidate_df):
    """
    Sorts the remaining products by how similar they are to the user query.
    """
    if candidate_df.empty:
        return []
    
    # User query embedding
    query_vec = embedder.encode([query_text])
    
    # Embedded vectors for valid candidates
    candidate_vectors = product_embeddings[candidate_df.index]
    
    # Similarity score calculation
    scores = cosine_similarity(query_vec, candidate_vectors)[0]
    
    # Attaching scores and sorting
    candidate_df = candidate_df.copy()
    candidate_df['score'] = scores
    ranked = candidate_df.sort_values(by='score', ascending=False)
    
    return ranked['product_id'].tolist()

### Main execution loop

In [317]:
results = []
print(f"Processing {len(df_queries)} queries...")

for i, row in tqdm(df_queries.iterrows(), total=len(df_queries)):
    q_id = row['query_id']
    q_text = row['query_text']
    
    constraints = extract_constraints_gemini(q_text)
    
    # Filtering
    if constraints:
        candidates = filter_products(df_products, constraints)
    else:
        candidates = df_products
    
    if candidates.empty:
        candidates = df_products
        
    # Ranking
    ranked_ids = rank_products(q_text, candidates)
    
    # Saving results
    results.append({
        "query_id": q_id,
        "products": ";".join(ranked_ids)
    })

df_results = pd.DataFrame(results)
print("Results Extracted!")
df_results.head()

Processing 100 queries...


100%|██████████| 100/100 [06:16<00:00,  3.77s/it]

Results Extracted!





Unnamed: 0,query_id,products
0,TRAIN_001,ICE_007;ESP_013;ICE_009;ICE_008;ICE_015;ESP_00...
1,TRAIN_002,TEA_011;TEA_005;TEA_010;TEA_008;ICT_010;ICT_00...
2,TRAIN_003,BRW_004;BRW_005;BRW_003;BRW_002;BRW_001
3,TRAIN_004,TEA_005;TEA_012;TEA_008;ICT_002;ICT_004;ICT_00...
4,TRAIN_005,CBR_012;CBR_011;CBR_002;CBR_001;CBR_009


### Evaluation of results (Recall & NDCG)

In [320]:
import ast
import numpy as np

if 'relevant_products' in df_queries.columns:
    print("True product list detected. Calculating NDCG, Recall, and Accuracy.")
    
    df_eval = pd.merge(df_queries[['query_id', 'relevant_products']], df_results, on='query_id')

    # NDCG calculation function
    def get_ndcg(row):
        pred_str = row['products']
        truth_str = row['relevant_products']
        
        if pd.isna(pred_str) or not pred_str: return 0.0
        
        predicted_list = pred_str.split(';')
    
        try:
            truth_list = ast.literal_eval(truth_str)
        except:
            return 0.0
     
        dcg = sum(1.0 / np.log2(i + 2) for i, p in enumerate(predicted_list) if p in truth_list)
        
        idcg = sum(1.0 / np.log2(i + 2) for i in range(len(truth_list)))
        
        return dcg / idcg if idcg > 0 else 0.0

    # Recall & Precision at rank function
    def get_recall_and_accuracy(row):
        pred_str = row['products']
        truth_str = row['relevant_products']
        
        if pd.isna(pred_str) or not pred_str:
            return pd.Series({'recall': 0.0, 'top_1_accuracy': 0.0})
            
        predicted_list = pred_str.split(';')
        
        try:
            truth_list = ast.literal_eval(truth_str)
        except:
            truth_list = []
            
        set_pred = set(predicted_list)
        set_truth = set(truth_list)
        
        intersection = set_pred.intersection(set_truth)
       
        recall = len(intersection) / len(set_truth) if set_truth else 0.0
      
        top_1 = 1.0 if (predicted_list and predicted_list[0] in set_truth) else 0.0
        
        return pd.Series({'recall': recall, 'top_1_accuracy': top_1})

    df_eval['ndcg'] = df_eval.apply(get_ndcg, axis=1)
    df_eval[['recall', 'top_1_accuracy']] = df_eval.apply(get_recall_and_accuracy, axis=1)

    print("\n" + "="*40)
    print(" Pipeline Performance Report")
    print("="*40)
    print(f"Average NDCG:      {df_eval['ndcg'].mean():.4f}")
    print(f"Average Recall:    {df_eval['recall'].mean():.4f}")
    print(f"Top-1 Accuracy:    {df_eval['top_1_accuracy'].mean():.4f}")
    print("="*40 + "\n")

    df_eval.to_csv('output/training_evaluation.csv', index=False)
    print("Detailed evaluation saved to output/training_evaluation.csv")

else:
    # If no true labels, saving predictions (for test set)
    df_results.to_csv('output/submission.csv', index=False)
    print("Test submission saved to output/submission.csv. Ready to upload!")

True product list detected. Calculating NDCG, Recall, and Accuracy.

 Pipeline Performance Report
Average NDCG:      0.9650
Average Recall:    0.9770
Top-1 Accuracy:    0.9700

Detailed evaluation saved to output/training_evaluation.csv


### Saving final results

In [321]:
submission = pd.DataFrame(results)
submission.to_csv('output/submission.csv', index=False)
print("File saved to output/submission.csv")

File saved to output/submission.csv


## Tuning Code & Testing On Single Query

In [322]:
print("Generating product embeddings...")
df_products_copy = df_products.copy()
cols_to_use = ['name','category','subcategory','temperature','caffeine_mg',
               'calories','sugar_g','protein_g','contains_dairy','contains_nuts',
               'contains_gluten','is_vegan','description','price']

df_products_copy['embedding_text'] = (
    df_products_copy[cols_to_use]
        .fillna('')
        .astype(str)
        .agg(' '.join, axis=1)
)

product_embeddings_1 = embedder.encode(df_products_copy['embedding_text'].tolist())
print("Product embeddings ready.")

Generating product embeddings...
Product embeddings ready.


In [323]:
results_1 = []


q_id_1 = 'TRAIN_019'
q_text_1 = "running late but i need something refreshing and fruity that's no more than 150 cal and no more than $4.5"

constraints_1 = extract_constraints_gemini(q_text_1)

if constraints_1:
    candidates_1 = filter_products(df_products, constraints_1)
else:
    candidates_1 = df_products

if candidates_1.empty:
    candidates_1 = df_products

ranked_ids_1 = rank_products(q_text_1, candidates_1)

results_1.append({
    "query_id": q_id_1,
    "products": ";".join(ranked_ids_1)
})

df_results_1 = pd.DataFrame(results_1)
print("Done!")
df_results_1.head()

Done!


Unnamed: 0,query_id,products
0,TRAIN_019,REF_001;REF_008;REF_002


In [324]:
candidates_1

Unnamed: 0,product_id,name,category,subcategory,temperature,caffeine_mg,calories,sugar_g,protein_g,contains_dairy,contains_nuts,contains_gluten,is_vegan,description,price,embedding_text
80,REF_001,Strawberry Acai Refresher,refresher,water,iced,45,90,20,0,False,False,False,True,Sweet strawberry and acai flavors with real fr...,4.45,Strawberry Acai Refresher refresher water iced...
81,REF_002,Mango Dragonfruit Refresher,refresher,water,iced,45,90,19,1,False,False,False,True,Mango and dragonfruit flavors with real fruit ...,4.45,Mango Dragonfruit Refresher refresher water ic...
87,REF_008,Pineapple Passionfruit Refresher,refresher,water,iced,45,100,22,0,False,False,False,True,Tropical pineapple and passionfruit flavors,4.45,Pineapple Passionfruit Refresher refresher wat...


In [325]:
len(candidates_1)

3

In [328]:
candidates_1.index

Index([80, 81, 87], dtype='int64')

In [329]:
print(constraints_1)

{'category': 'refresher', 'temperature': 'iced', 'max_calories': 150, 'max_sugar': None, 'max_price': 4.5, 'dairy_free': None, 'vegan': None, 'caffeine_level': None}


In [330]:
qvec = embedder.encode([q_text_1])

In [331]:
qvec.shape

(1, 768)

In [332]:
q_text_1

"running late but i need something refreshing and fruity that's no more than 150 cal and no more than $4.5"

In [333]:
cand_vectors = product_embeddings_1[candidates_1.index]

In [334]:
cand_vectors.shape

(3, 768)

In [335]:
score1 = cosine_similarity(qvec, cand_vectors)[0]
score1

array([0.5859856 , 0.51825416, 0.545011  ], dtype=float32)

In [336]:
candidates_1_copy = candidates_1.copy()
candidates_1_copy['score'] = score1
ranked1 = candidates_1_copy.sort_values(by='score', ascending=False)
ranked1

Unnamed: 0,product_id,name,category,subcategory,temperature,caffeine_mg,calories,sugar_g,protein_g,contains_dairy,contains_nuts,contains_gluten,is_vegan,description,price,embedding_text,score
80,REF_001,Strawberry Acai Refresher,refresher,water,iced,45,90,20,0,False,False,False,True,Sweet strawberry and acai flavors with real fr...,4.45,Strawberry Acai Refresher refresher water iced...,0.585986
87,REF_008,Pineapple Passionfruit Refresher,refresher,water,iced,45,100,22,0,False,False,False,True,Tropical pineapple and passionfruit flavors,4.45,Pineapple Passionfruit Refresher refresher wat...,0.545011
81,REF_002,Mango Dragonfruit Refresher,refresher,water,iced,45,90,19,1,False,False,False,True,Mango and dragonfruit flavors with real fruit ...,4.45,Mango Dragonfruit Refresher refresher water ic...,0.518254
