In [24]:
%pip install -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


In [25]:
import os
import json
import time
import pandas as pd
import google.genai as genai
from google.genai import types
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from dotenv import load_dotenv
from tqdm import tqdm
import typing_extensions as typing

In [26]:
# 1. SETUP & CONFIGURATION
# ------------------------
load_dotenv()  # Load API key from .env file
client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))

In [27]:
# Load the local embedding model (Free, runs on your laptop)
# This handles "Stage 3: Ranking"
print("Loading embedding model")
embedder = SentenceTransformer('all-MiniLM-L6-v2')

Loading embedding model


In [28]:
# Define the strict structure for the LLM output (Stage 1)
# This forces Gemini to return valid JSON every time.
class ConstraintSchema(typing.TypedDict):
    category: str | None
    temperature: str | None
    max_calories: int | None
    max_sugar: int | None
    max_price: float | None
    dairy_free: bool | None
    vegan: bool | None
    caffeine_level: str | None

In [29]:
# 2. LOAD DATA
# ------------
# We use pandas to read the CSV files
try:
    df_products = pd.read_csv('data/products.csv')
    # START WITH TRAINING DATA to test your logic. 
    # Once it works, switch this filename to 'queries_test.csv' for the final run.
    df_queries = pd.read_csv('data/queries_train.csv') 
    print("Data loaded successfully.")
except FileNotFoundError:
    print("Error: CSV files not found. Please check your 'data' folder.")
    exit()

Data loaded successfully.


In [30]:
# 3. PRE-COMPUTE PRODUCT EMBEDDINGS
# ---------------------------------
# We turn every product into a math vector now so we don't have to do it later.
print("Generating product embeddings...")
df_products['embedding_text'] = (
    df_products['name'].fillna('') + " " + 
    df_products['description'].fillna('') + " " + 
    df_products['category'].fillna('')
)
product_embeddings = embedder.encode(df_products['embedding_text'].tolist())
print("Product embeddings ready.")

Generating product embeddings...
Product embeddings ready.


In [31]:
# 4. STAGE 1: CONSTRAINT EXTRACTION (Using Gemini)
# ------------------------------------------------
from sklearn import exceptions

def extract_constraints_gemini(query_text):
    """
    Sends the user query to Gemini 2.5 Flash and asks for a JSON response.
    """
    # FIX 1: Add this loop so 'continue' has somewhere to go
    while True:
        prompt = f"""
        You are a Starbucks Data Assistant. Extract search constraints from this query:
        Query: "{query_text}"
        
        Return a JSON object with these exact keys. If a constraint is not mentioned, use null.
        - category: "espresso", "brewed", "cold_brew", "frappuccino", "refresher", "tea" or null
        - temperature: "hot", "iced", "blended" or null
        - max_calories: number or null
        - max_sugar: number (grams) or null
        - max_price: number or null
        - dairy_free: true (if "no milk", "dairy free") or null
        - vegan: true (if "vegan", "plant based") or null
        - caffeine_level: "none", "low", "medium", "high" or null
        """

        try:
            response = client.models.generate_content(
                model="gemini-2.5-flash",
                contents=prompt,
                config=types.GenerateContentConfig(
                    response_mime_type="application/json",
                    response_schema=ConstraintSchema 
                )
            )
            # SUCCESS: Parse and return
            parsed = json.loads(response.text)

            # Sleep 4s ensures we stay under the 15 RPM limit safely
            time.sleep(4)
            return parsed
        
        # FIX 2: Catch generic exceptions and check the message for "429" or "Quota"
        # This is safer than importing specific exception libraries that might conflict.
        except Exception as e:
            error_msg = str(e).lower()
            if "429" in error_msg or "quota" in error_msg or "resource_exhausted" in error_msg:
                print(f"‚ö†Ô∏è Quota hit! Sleeping 60s...")
                time.sleep(60)
                continue # Now this works because it's inside 'while True'
            else:
                # Real error? Stop and return empty.
                print(f"‚ùå Extraction Error: {e}")
                return {}

In [32]:
# 5. STAGE 2: FILTERING
# ---------------------
def filter_products(products_df, constraints):
    """
    Removes products that violate the specific constraints found by Gemini.
    """
    filtered = products_df.copy()
    
    # Text Filters
    if constraints.get('category'):
        filtered = filtered[filtered['category'] == constraints['category']]
    if constraints.get('temperature'):
        filtered = filtered[filtered['temperature'] == constraints['temperature']]
        
    # Number Filters (Using <= for max limits)
    if constraints.get('max_calories'):
        filtered = filtered[filtered['calories'] <= constraints['max_calories']]
    if constraints.get('max_sugar'):
        filtered = filtered[filtered['sugar_g'] <= constraints['max_sugar']]
    if constraints.get('max_price'):
        filtered = filtered[filtered['price'] <= constraints['max_price']]
        
    # Boolean Filters
    if constraints.get('dairy_free'):
        filtered = filtered[filtered['contains_dairy'] == False]
    if constraints.get('vegan'):
        filtered = filtered[filtered['is_vegan'] == True]
    
    # Caffeine Filter (Simple Mapping)
    if constraints.get('caffeine_level'):
        level = constraints['caffeine_level']
        if level == 'none':
            filtered = filtered[filtered['caffeine_mg'] < 5]
        elif level == 'high':
            filtered = filtered[filtered['caffeine_mg'] > 150]
            
    return filtered

In [33]:
# 6. STAGE 3: RANKING
# -------------------
def rank_products(query_text, candidate_df):
    """
    Sorts the remaining products by how similar they are to the user query.
    """
    if candidate_df.empty:
        return []
    
    # 1. Encode the user's query into a vector
    query_vec = embedder.encode([query_text])
    
    # 2. Get the vectors for ONLY the valid candidates
    # (We use the dataframe index to grab the correct pre-computed vectors)
    candidate_vectors = product_embeddings[candidate_df.index]
    
    # 3. Calculate similarity scores
    scores = cosine_similarity(query_vec, candidate_vectors)[0]
    
    # 4. Attach scores and sort
    candidate_df = candidate_df.copy()
    candidate_df['score'] = scores
    ranked = candidate_df.sort_values(by='score', ascending=False)
    
    return ranked['product_id'].tolist()

In [34]:
# 7. MAIN EXECUTION LOOP
# ----------------------
results = []
print(f"Processing {len(df_queries)} queries...")

# Loop through every query in the CSV
for i, row in tqdm(df_queries.iterrows(), total=len(df_queries)):
    q_id = row['query_id']
    q_text = row['query_text']
    
    # A. Extract (Wait 10 seconds to respect Free Tier limits)
    constraints = extract_constraints_gemini(q_text)
    
    # B. Filter
    # CRITICAL FIX: Check if constraints exist before filtering
    if constraints:
        candidates = filter_products(df_products, constraints)
    else:
        # If extraction failed, treat it as "no constraints" (search everything)
        candidates = df_products
    
    # Fallback: If filtering kills everything, ignore filters and rank everything
    if candidates.empty:
        candidates = df_products
        
    # C. Rank
    ranked_ids = rank_products(q_text, candidates)
    
    # D. Save Result
    results.append({
        "query_id": q_id,
        "products": ";".join(ranked_ids)  # Format: ID1;ID2;ID3
    })

    # Convert to DataFrame and view
df_results = pd.DataFrame(results)
print("‚úÖ Done!")
df_results.head()

Processing 100 queries...


  7%|‚ñã         | 7/100 [00:41<09:03,  5.84s/it]

‚ö†Ô∏è Quota hit! Sleeping 60s...


 17%|‚ñà‚ñã        | 17/100 [02:38<09:13,  6.66s/it]

‚ö†Ô∏è Quota hit! Sleeping 60s...


 18%|‚ñà‚ñä        | 18/100 [03:44<33:42, 24.67s/it]

‚ö†Ô∏è Quota hit! Sleeping 60s...
‚ö†Ô∏è Quota hit! Sleeping 60s...
‚ö†Ô∏è Quota hit! Sleeping 60s...


 18%|‚ñà‚ñä        | 18/100 [06:52<31:19, 22.92s/it]


KeyboardInterrupt: 

In [None]:
# 8. EXPORT FINAL CSV
# -------------------
submission = pd.DataFrame(results)
submission.to_csv('output/submission.csv', index=False)
print("Success! File saved to output/submission.csv")

Success! File saved to output/submission.csv


In [None]:
# 1. ENSURE LIBRARY IS INSTALLED
# If you get an import error, uncomment the line below and run it:
# %pip install -U google-genai

import os
from google import genai
from pydantic import BaseModel, TypeAdapter
from dotenv import load_dotenv

# 2. SETUP CLIENT
load_dotenv(override=True)
api_key = os.getenv("GEMINI_API_KEY")

if not api_key:
    print("‚ùå Error: .env file not found or empty. Please check your key.")
else:
    print(f"Key loaded: ...{api_key[-5:]}")
    client = genai.Client(api_key=api_key)

    # 3. DEFINE SCHEMA (The new way uses Pydantic or Dicts easily)
    # We will use a standard dict for simplicity here, similar to your TypedDict
    schema_config = {
        "response_mime_type": "application/json",
        "response_schema": {
            "type": "OBJECT",
            "properties": {
                "category": {"type": "STRING"},
                "price": {"type": "NUMBER"},
                "is_hot": {"type": "BOOLEAN"}
            }
        }
    }

    print("\n--- TEST 1: Basic Connection ---")
    try:
        response = client.models.generate_content(
            model="gemini-2.5-flash", 
            contents="Say 'Hello V2!'"
        )
        print(f"‚úÖ Success: {response.text}")
    except Exception as e:
        print(f"‚ùå Connection Failed: {e}")

    print("\n--- TEST 2: JSON Extraction ---")
    try:
        response = client.models.generate_content(
            model="gemini-2.5-flash",
            contents="Find me a hot latte for $4.50",
            config=schema_config
        )
        
        # In V2, response.parsed automatically converts it to a dict/object!
        print(f"‚úÖ Parsed Data: {response.parsed}")
        
        if response.parsed['price'] == 4.5:
            print("üöÄ SYSTEM READY: JSON parsing is working perfectly.")
            
    except Exception as e:
        print(f"‚ùå JSON Failed: {e}")

Key loaded: ...gIDY8

--- TEST 1: Basic Connection ---
‚úÖ Success: Hello V2!

--- TEST 2: JSON Extraction ---
‚úÖ Parsed Data: {'category': 'latte', 'price': 4.5, 'is_hot': True}
üöÄ SYSTEM READY: JSON parsing is working perfectly.
