In [None]:
import os
import uuid
import pandas as pd
from dotenv import load_dotenv
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct, VectorParams, Distance
from groq import Groq

# =========================
# SETUP
# =========================
load_dotenv()
encoder = SentenceTransformer("all-MiniLM-L6-v2")
qdrant = QdrantClient(":memory:")

llm = Groq(api_key=os.getenv("GROQ_API_KEY"))

TEXT_COLLECTION = "study_text"
IMAGE_COLLECTION = "study_images"

# =========================
# COLLECTIONS
# =========================

existing = [c.name for c in qdrant.get_collections().collections]

if TEXT_COLLECTION not in existing:
    qdrant.create_collection(
        collection_name=TEXT_COLLECTION,
        vectors_config=VectorParams(size=384, distance=Distance.COSINE)
    )

if IMAGE_COLLECTION not in existing:
    qdrant.create_collection(
        collection_name=IMAGE_COLLECTION,
        vectors_config=VectorParams(size=384, distance=Distance.COSINE)
    )

print("‚úÖ All collections ready")

# =========================
# IMAGE INGESTION
# =========================

df_img = pd.read_csv("query.csv")
img_points = []

for _, row in df_img.iterrows():
    vec = encoder.encode(
        f"chemical structure of {row['compoundLabel']}"
    ).tolist()

    img_points.append(
        PointStruct(
            id=str(uuid.uuid4()),
            vector=vec,
            payload={
                "compound_name": row["compoundLabel"].lower(),
                "image": row["image"]
            }
        )
    )

qdrant.upsert(IMAGE_COLLECTION, img_points)
print(f"‚úÖ {len(img_points)} images ingested")

# =========================
# TEXT INGESTION
# =========================

df_text = pd.read_csv("chemistry3.csv")
txt_points = []

for _, row in df_text.iterrows():
    if not row["compoundLabel"] or not row["article"]:
        continue

    vec = encoder.encode(row["article"]).tolist()

    txt_points.append(
        PointStruct(
            id=str(uuid.uuid4()),
            vector=vec,
            payload={
                "compound_name": row["compoundLabel"].lower(),
                "features": row["article"]
            }
        )
    )

qdrant.upsert(TEXT_COLLECTION, txt_points)
print(f"‚úÖ {len(txt_points)} text entries ingested")

# =========================
# MEMORY
# =========================

USER_MEMORY = []
MAX_MEMORY = 5

def store_user_memory(q):
    q = q.lower().strip()
    if USER_MEMORY and USER_MEMORY[-1] == q:
        return
    USER_MEMORY.append(q)
    if len(USER_MEMORY) > MAX_MEMORY:
        USER_MEMORY.pop(0)

def get_user_memory():
    return USER_MEMORY[::-1]

# =========================
# SEARCH FUNCTIONS
# =========================

def search_by_text(query, top_k=1):
    hits = qdrant.query_points(
        TEXT_COLLECTION,
        query=encoder.encode(query).tolist(),
        limit=top_k,
        with_payload=True
    ).points

    return [{
        "compound": h.payload["compound_name"],
        "text": h.payload["features"],
        "score": h.score
    } for h in hits]


def retrieve_images(query, top_k=3):
    hits = qdrant.query_points(
        IMAGE_COLLECTION,
        query=encoder.encode(query).tolist(),
        limit=top_k * 3,
        with_payload=True
    ).points

    seen = {}
    for h in hits:
        name = h.payload["compound_name"]
        if name not in seen or h.score > seen[name]["score"]:
            seen[name] = {
                "compound": name,
                "image": h.payload["image"],
                "score": round(h.score, 3)
            }

    return list(seen.values())[:top_k]


def filter_image_text_intersection(images):
    valid = []
    for img in images:
        hits = qdrant.query_points(
            TEXT_COLLECTION,
            query=encoder.encode(img["compound"]).tolist(),
            limit=1,
            with_payload=True
        ).points

        if hits:
            valid.append({
                "compound": img["compound"],
                "image": img["image"],
                "image_score": img["score"],
                "text": hits[0].payload["features"],
                "text_score": hits[0].score
            })
    return valid


def choose_best_compound(valid):
    return max(valid, key=lambda x: x["image_score"] + x["text_score"])


def recommend_related_compounds(compound, top_k=3):
    hits = qdrant.query_points(
        TEXT_COLLECTION,
        query=encoder.encode(compound).tolist(),
        limit=top_k + 3,
        with_payload=True
    ).points

    out, seen = [], set()
    for h in hits:
        name = h.payload["compound_name"]
        if name != compound and name not in seen:
            seen.add(name)
            out.append(name)
        if len(out) == top_k:
            break
    return out

# =========================
# RAG
# =========================

def rag_answer(query, chosen, memory):
    prompt = f"""
You are a chemistry assistant.

Answer ONLY using the information below.

Compound: {chosen['compound']}
Description:
{chosen['text']}

User History:
{', '.join(memory)}

Question:
{query}
"""
    res = llm.chat.completions.create(
        model="llama-3.3-70b-versatile",
        messages=[{"role": "user", "content": prompt}]
    )
    return res.choices[0].message.content

# =========================
# RUN QUERY (FIXED)
# =========================

query = "what is ibuprofen?"

store_user_memory(f"user asked about: {query}")
memory = get_user_memory()

print("üß† MEMORY USED:", memory)

# üîë CRITICAL FIX (THIS LINE)
text_hits = search_by_text(query, top_k=1)

if text_hits:
    resolved_compound = text_hits[0]["compound"]
else:
    resolved_compound = query

images = retrieve_images(resolved_compound, top_k=3)
valid = filter_image_text_intersection(images)

if valid:
    chosen = choose_best_compound(valid)
    recs = recommend_related_compounds(chosen["compound"])
    answer = rag_answer(query, chosen, memory)

    answer += "\n\nüìå Recommended Next Topics:\n"
    for r in recs:
        answer += f"- {r}\n"
else:
    answer = llm.chat.completions.create(
        model="llama-3.3-70b-versatile",
        messages=[{"role": "user", "content": f"Explain {query} in general chemistry terms."}]
    ).choices[0].message.content

print("\nüß† FINAL ANSWER:\n", answer)

print("\nüñºÔ∏è TOP IMAGE MATCHES:")
for img in images:
    print(f"- {img['compound']} | score={img['score']}")
    print(f"  Image: {img['image']}")


‚úÖ All collections ready
‚úÖ 986 images ingested
‚úÖ 818 text entries ingested
üß† MEMORY USED: ['user asked about: what are alcohols?']

üß† FINAL ANSWER:
 Alcohols are a class of organic compounds in which the hydroxyl (-OH) functional group is bonded to a carbon atom. They can be categorized into different types, such as monohydric, dihydric, and polyhydric alcohols, based on the number of hydroxyl groups present.

üìå Recommended Next Topics:
- (2e)-geranic acid
- nervonic acid
- isonicotinic acid


üñºÔ∏è TOP IMAGE MATCHES:
- neranic acid | score=0.779
  Image: http://commons.wikimedia.org/wiki/Special:FilePath/Nerolic%20acid.svg
- (2e)-geranic acid | score=0.633
  Image: http://commons.wikimedia.org/wiki/Special:FilePath/Geranic%20acid.png
- nervonic acid | score=0.493
  Image: http://commons.wikimedia.org/wiki/Special:FilePath/Nervonic%20acid%20v2.svg
