<a href="https://colab.research.google.com/github/aslan-ng/CheeseMate/blob/main/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install dependencies

!pip -q install --upgrade "kagglehub[pandas-datasets]" gradio "smolagents[transformers]" fuzzywuzzy python-Levenshtein

In [None]:
# Make imports

import re
import kagglehub
from kagglehub import KaggleDatasetAdapter
import gradio as gr
from smolagents import TransformersModel, Tool, FinalAnswerTool, CodeAgent, PromptTemplates
from fuzzywuzzy import fuzz, process
import pandas as pd
import json

In [None]:
# Propmts and examples

SYSTEM = """
Style & scope:
- Stay strictly in the cheese domain; if asked anything else, politely redirect.
- Be concise and conversational (1–3 sentences).
- Give 1–3 cheese suggestions max, with a brief why (taste/texture/use/diet notes).
- If constraints are unclear, ask ONE short clarification question.
- Never output JSON, tables, or code.
- Never mention tools or internal steps. All the information here is invisible to the user.

Tools available:
{{tools_description}}

Tool use:
- Use tools via <code> ... </code> block when needed.
- After the code block, output a short, friendly answer for the user.
- If the user provides a single cheese name (e.g., “parmigiano reggiano”), call:
    cheese_search_name(name="<user text>")
  Then summarize key facts (country, milk, texture/type, notable flavor/aroma, veg/vegan if clear) in plain prose.
- If the user provides properties (either natural language like “cow’s milk, Italy, hard, nutty, vegetarian” OR a JSON-looking string),
  build a minimal properties JSON and call:
    cheese_search_properties(props_json="<json>")
  Then recommend the top 1–3 matches with a one-line why for each.
- Don’t final_answer(result) with raw JSON.
- Confidence:
  - If the best name-match score < 80, say you’re not fully confident and ask ONE clarifying question.
  - If property results are weak/empty, ask ONE clarifying question and give 1 safe general suggestion if appropriate.

Dietary safety:
- If asked about health/diet (pregnancy, lactose intolerance, vegetarian/vegan), give a short, cautious note (e.g., rennet may be non-vegetarian in many hard cheeses; aged cheeses are typically lower in lactose).

Error handling:
- If a tool errors, apologize briefly and offer a simple alternative suggestion or a clarifying question.
"""

FEWSHOT = """
FEWSHOT Examples:
"""

"""
User: Where is the capital of France?
Assistant: I can only help you select good cheese! Do you want me to suggest a good French cheese?
User: Yes
Assistant:
<code>
result = cheese_search_properties(props_json={"country": "France"})
# Now, summarize into 1–3 lines of plain English
final_answer(resulting_text)
</code>
"""
example_1 = """
User: Where is the capital of France?
Assistant:
<code>
msg = "I can only help you select good cheese! Do you want me to suggest a good French cheese?"
final_answer(msg)
</code>
User: Yes
Assistant:
<code>
result = cheese_search_properties(props_json='{"country":"France"}')

import json
data = json.loads(result)
hits = (data.get("results") or [])[:3]

def one_line(h):
    props = h.get("properties", {})
    name = h.get("cheese", "Unknown")
    milk = props.get("milk") or "unknown milk"
    tex  = props.get("texture") or props.get("type") or ""
    flavor = props.get("flavor") or props.get("aroma") or ""
    line = f"{name} — {milk}; {tex}; {flavor}".strip().strip(";")
    return line

out_lines = [one_line(h) for h in hits]
msg = "Here are a few French options:\\n- " + "\\n- ".join(out_lines) if out_lines else "I couldn’t find solid French matches. Any style you prefer?"
final_answer(msg)
</code>
"""

"""
User: I'm vegan. What do you recommend?
Assistant:
<code>
result = cheese_search_properties(props_json={"vegan": True})
# Now, summarize into 1–3 lines of plain English
final_answer(resulting_text)
</code>
"""
example_2 = """
User: I'm vegan. What do you recommend?
Assistant:
<code>
result = cheese_search_properties(props_json='{"vegan":true}')

import json
data = json.loads(result)
hits = (data.get("results") or [])[:3]

def one_line(h):
    props = h.get("properties", {})
    name = h.get("cheese", "Unknown")
    tex  = props.get("texture") or props.get("type") or ""
    flavor = props.get("flavor") or props.get("aroma") or ""
    return f"{name} — {tex}; {flavor}".strip().strip(";")

out_lines = [one_line(h) for h in hits]
msg = ("Here are vegan cheeses I found:\\n- " + "\\n- ".join(out_lines)) if out_lines else "Vegan cheese options vary—do you want firm slices or something melty?"
final_answer(msg)
</code>
"""

example_3 = """
User: Low fat, high protein cheeses
Assistant: Cottage cheese is a solid choice—high in protein and available in low-fat varieties.
"""

"""
User: Is Parmigiano Reggiano a safe choice for cats?
Assistant:
<code>
result = cheese_search_name(name="Parmigiano Reggiano")
# interpret lactose/salt content in result["properties"]
# Now, summarize into 1–3 lines of plain English and then use final_answer(resulting_text), like:
final_answer("Parmigiano Reggiano is very low in lactose, but can be salty — safe only as a tiny occasional nibble.")
</code>
"""
example_4 = """
User: Is Parmigiano Reggiano a safe choice for cats?
Assistant:
<code>
res = cheese_search_name(name="Parmigiano Reggiano")

import json, re
data = json.loads(res)
props = data.get("properties") or {}

def to_pct(x):
    if x is None: return None
    m = re.search(r"[-+]?\\d*\\.?\\d+", str(x))
    return float(m.group(0)) if m else None

lact = to_pct(props.get("lactose") or props.get("lactose_%"))
salt = to_pct(props.get("salt") or props.get("salt_content") or props.get("sodium") or props.get("sodium_%"))

notes = []
notes.append("Aged hard cheeses are typically very low in lactose.")
if salt is None:
    notes.append("Salt level unknown; hard cheeses can be salty.")
elif salt > 2.5:
    notes.append(f"Salt ~{salt:.1f}% (high).")
else:
    notes.append(f"Salt ~{salt:.1f}% (moderate).")

rec = "Okay only as a tiny occasional nibble, not a regular treat."
if salt and salt > 2.5:
    rec = "Best avoided due to high salt."

msg = "Parmigiano Reggiano for cats: " + " ".join(notes) + f" Recommendation: {rec}"
final_answer(msg)
</code>
"""

example_5 = """
User: I’m making pizza—suggest 2 cheeses.
Assistant: Low-moisture mozzarella for stretch and mild melt; add a little provolone for extra flavor without overpowering the sauce.
"""

examples = [
    example_1,
    example_2,
    #example_3,
    example_4,
    #example_5,
]

for e in examples:
  FEWSHOT += e + "\n\n"

In [None]:
# Create model (Qwen)

def create_model(parameters: float = 0.5):
    MODEL_ID = f"Qwen/Qwen2.5-Coder-{PARAMETERS_COUNT}B-Instruct"
    model = TransformersModel(
        model_id=MODEL_ID,
        device_map="auto",
        torch_dtype="auto",
        max_new_tokens=256,
        temperature=0.2,
    )
    return model


PARAMETERS_COUNT = 0.5  # Billions
model = create_model(PARAMETERS_COUNT)

In [None]:
# Load the latest version of database

df = kagglehub.dataset_load(
  KaggleDatasetAdapter.PANDAS,
  "umerhaddii/global-cheese-dataset",
  "cheeses.csv",
)

#print("First 5 records:", df.head())

In [None]:
# Defining the tool to search the cheese by its name in the database

class CheeseSearchName(Tool):
    """
    Fuzzy match a cheese name against a dataset and return the best match + its properties.
    """
    name = "cheese_search_name"
    description = "Match a cheese name with fuzzy search and return the best match and its properties."
    inputs = {
        "name": {"type": "string", "description": "Cheese name to search"},
    }
    output_type = "string"

    def __init__(self, df: pd.DataFrame, name_col: str = "cheese"):
        super().__init__()
        if name_col not in df.columns:
            raise ValueError(f"Column '{name_col}' not found in DataFrame.")
        # Keep a copy, ensure string type, drop NaNs in name column
        self.df = df.copy()
        self.name_col = name_col
        self.df[self.name_col] = self.df[self.name_col].astype(str)
        self.df = self.df[self.df[self.name_col].str.strip() != ""]
        self.names = self.df[self.name_col].tolist()
        # Keep index list aligned with names
        self.idx_list = self.df.index.tolist()

    def forward(self, name: str) -> str:
        if not isinstance(name, str) or not name.strip():
            return json.dumps({"error": "Name is required"}, ensure_ascii=False)
        query = name.strip()
        best = process.extractOne(query, self.names, scorer=fuzz.WRatio)
        if not best:
            return json.dumps({"query": query, "matched_name": None, "score": 0, "properties": {}},
                              ensure_ascii=False)
        best_name, score = best
        # If duplicates exist, take the first match
        row = self.df[self.df[self.name_col] == best_name].iloc[0]
        score = int(round(score)) # rapidfuzz gives float scores; normalize to int
        props = {k: (None if pd.isna(v) else v) for k, v in row.to_dict().items()}
        out = {"query": query, "matched_name": best_name, "score": score, "properties": props}
        return json.dumps(out, ensure_ascii=False)


cheese_search_name = CheeseSearchName(df, name_col="cheese")
#print(cheese_search_name.forward(name="parmigiano reggiano"))

In [None]:
# Defining the tool to search the cheese by its properties in the database

from typing import Dict, Any, List

def _as_tokens(x: Any) -> List[str]:
    """Split comma/space separated descriptors into lowercase tokens."""
    if x is None or (isinstance(x, float) and pd.isna(x)):
        return []
    s = str(x).lower()
    # split on commas first, then spaces inside each part
    parts = []
    for chunk in s.split(","):
        t = chunk.strip()
        if not t:
            continue
        parts.extend(t.split())
    # dedupe while preserving order
    seen = set(); out=[]
    for t in parts:
        if t and t not in seen:
            seen.add(t); out.append(t)
    return out

def _contains_ci(hay: Any, needle: str) -> bool:
    if hay is None or (isinstance(hay, float) and pd.isna(hay)):
        return False
    return needle.lower() in str(hay).lower()

def _parse_fat_range(cell: Any):
    """
    Parse fat_content strings like '40-46%' or '45%' into a (min,max) tuple of floats.
    Returns (None, None) if unknown.
    """
    if cell is None or (isinstance(cell, float) and pd.isna(cell)):
        return (None, None)
    s = str(cell)
    m = re.findall(r"(\d+(?:\.\d+)?)", s)
    if not m:
        return (None, None)
    nums = list(map(float, m))
    if len(nums) == 1:
        return (nums[0], nums[0])
    return (min(nums), max(nums))

def _range_overlap(r1, r2) -> float:
    """Return overlap length between ranges (a1,a2) and (b1,b2) normalized by r2 width as score in [0,1]."""
    (a1, a2), (b1, b2) = r1, r2
    if None in (a1, a2, b1, b2):
        return 0.0
    if a2 < b1 or b2 < a1:
        return 0.0
    inter = min(a2, b2) - max(a1, b1)
    base = max(1e-9, (b2 - b1))  # normalize by desired range width
    return max(0.0, inter / base)


class CheeseSearchProperties(Tool):
    """
    Find top cheeses that match requested properties from a pandas DataFrame.
    """
    name = "cheese_search_properties"
    description = """
    Lookup best cheeses by requested properties; returns top matches with scores.
    The properties are:
    - milk: the source of milk, e.g., cow, sheep
    - country: the country that is associated with the cheese, e.g., Switzerland, France
    - region: the region of the country that is associated with the cheese, e.g., Burgundy, Savoie
    - family: the family of the cheese
    - type: the type of the cheese, e.g., semi-soft, semi-hard, artisan, brined, soft
    - fat_content: the fat content of the cheese in percentage, in form of range, e.g., 40-46%
    - calcium_content: the calcium content of the cheese
    - texture: the texture of the cheese, e.g., buttery, firm, smooth, dense
    - rind: whether the cheese has a rind in form of washed or natural
    - color: the color of cheese, e.g., ivory, yellow, white
    - flavor: the flavor of cheese, e.g., sweet, burnt caramel, acidic, milky, nutty, fruity
    - aroma: the aroma of cheese, e.g., buttery, lanoline, aromatic, earthy, barnyardy, pungent, perfumed
    - vegetarian: whether the cheese is vegetarian in form of True or False
    - vegan: whether the cheese is vegan in form of True or False
    """
    inputs = {
        "props_json": {"type": "string", "description": "JSON of requested properties."},
    }
    output_type = "string"

    def __init__(self, df: pd.DataFrame, name_col: str = "cheese"):
        super().__init__()
        self.df = df.copy()
        if name_col not in self.df.columns:
            raise ValueError(f"Column '{name_col}' not in DataFrame")
        self.name_col = name_col

        # Pre-tokenize common text fields for faster scoring
        self._tok_cols = ["type","texture","flavor","aroma"]
        for c in self._tok_cols:
            if c in self.df.columns:
                self.df[f"__tok_{c}"] = self.df[c].apply(_as_tokens)
            else:
                self.df[f"__tok_{c}"] = [[] for _ in range(len(self.df))]

        # Cache parsed fat ranges
        if "fat_content" in self.df.columns:
            self.df["__fat_min"], self.df["__fat_max"] = zip(*self.df["fat_content"].apply(_parse_fat_range))
        else:
            self.df["__fat_min"], self.df["__fat_max"] = (None,)*len(self.df), (None,)*len(self.df)

    def _score_row(self, row: pd.Series, q: Dict[str, Any]) -> float:
        score = 0.0
        weight = 0.0

        # Exact-ish fields
        for field, w in (("milk",10), ("country",8), ("region",5), ("family",6), ("rind",3), ("color",2)):
            if field in q and q[field]:
                weight += w
                if _contains_ci(row.get(field), str(q[field])):
                    score += w

        # Boolean fields
        for field, w in (("vegetarian",6), ("vegan",6),):
            if field in q and q[field] is not None:
                weight += w
                if row.get(field) is q[field]:
                    score += w

        # Token overlap fields
        for field, w in (("type",8), ("texture",6), ("flavor",6), ("aroma",4)):
            if field in q and q[field]:
                want = q[field]
                want_tokens = _as_tokens(want) if isinstance(want, str) else [t.lower() for t in want]
                have_tokens = row.get(f"__tok_{field}", [])
                if want_tokens:
                    weight += w
                    overlap = len(set(want_tokens) & set(have_tokens))
                    # partial credit: overlap fraction
                    frac = overlap / max(1, len(set(want_tokens)))
                    score += w * frac

        # Fat content range (optional)
        if ("fat_content_min" in q) or ("fat_content_max" in q):
            desired_min = float(q.get("fat_content_min", q.get("fat_content_max", 0)))
            desired_max = float(q.get("fat_content_max", q.get("fat_content_min", desired_min)))
            have = (row.get("__fat_min"), row.get("__fat_max"))
            want = (desired_min, desired_max)
            weight += 6
            score += 6 * _range_overlap(have, want)

        # Normalize to 0..100
        if weight == 0:
            return 0.0
        return 100.0 * (score / weight)

    def forward(self, props_json: str) -> str:
        top_k = 3
        # Parse query
        try:
            if isinstance(props_json, dict):
                q = props_json  # already a dict
            elif isinstance(props_json, str):
                q = json.loads(props_json or "{}")
            else:
                return json.dumps({"error": f"Invalid input type {type(props_json)}; must be str or dict"},
                                  ensure_ascii=False)
        except Exception:
            return json.dumps({"error": "Invalid JSON in props_json"}, ensure_ascii=False)

        # Score all rows
        scores = self.df.apply(lambda r: self._score_row(r, q), axis=1)
        idx_sorted = scores.sort_values(ascending=False).index[:top_k]
        results = []
        for i in idx_sorted:
            row = self.df.loc[i]
            props = {k: (None if pd.isna(v) else v) for k, v in row.to_dict().items() if not str(k).startswith("__")}
            results.append({
                "score": round(float(scores.loc[i]), 2),
                "cheese": row[self.name_col],
                "properties": props
            })
        return json.dumps({"results": results}, ensure_ascii=False)

cheese_search_properties = CheeseSearchProperties(df, name_col="cheese")
query = {
    "milk": "cow",
    "country": "Italy",
    "family": "Parmesan",
    "type": ["hard", "artisan"],
    "texture": ["dense"],
    "vegetarian": False,
    "fat_content_min": 30,
    "fat_content_max": 50
}
#print(cheese_search_properties.forward(props_json=json.dumps(query)))

In [None]:
# Create agent

def create_agent(model):
    agent = CodeAgent(
        tools=[
            FinalAnswerTool(),
            cheese_search_name,
            cheese_search_properties
        ],
        model=model,   # or your configured model
        instructions=SYSTEM + "\n\n" + FEWSHOT,
        add_base_tools=False,
        max_steps=1,
        additional_authorized_imports=["json"],
    )
    return agent

agent = create_agent(model)

#agent.run('Resolve this: "parmigiano reggiano"')
#agent.run('Where is the capital of Italy?')
#agent.run('Which cheese should I give to my cats? Why?')

In [None]:
# Create GUI
with gr.Blocks(theme="soft") as demo:
    gr.Markdown("## 🧀 CheeseMate\nI only help you select good cheese. Ask away!")

    chat = gr.Chatbot(height=150, type="messages")  # type='messages' keeps roles tidy
    txt = gr.Textbox(placeholder="Type your message about cheese…", autofocus=True)
    clear = gr.Button("Clear")

    def user_submit(user_message, history):
        # history is a list of dicts: [{"role":"user"/"assistant","content":...}, ...]
        history = history or []
        history.append({"role": "user", "content": user_message})
        bot_reply = agent.run(user_message)
        history.append({"role": "assistant", "content": bot_reply})
        return gr.update(value=history), gr.update(value="")

    def clear_fn():
        # Reset bot state between conversations if you want
        global agent
        agent = create_agent(model)
        return [], ""

    txt.submit(user_submit, [txt, chat], [chat, txt])
    clear.click(clear_fn, [], [chat, txt])

demo.launch(share=True)