In [None]:
# ============================================================
# JEWELLERY MULTIMODAL SEARCH BACKEND (FASTAPI)
# ============================================================

In [None]:
# ============================================================
# IMPORTS
# ============================================================

import os
import json
from typing import List, Dict

import torch
import clip
import numpy as np
import chromadb
from chromadb.config import Settings

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel

In [None]:
# ============================================================
# CONFIG
# ============================================================

BASE_DIR = "/home/akash/Jewellary_RAG/backend"

CHROMA_PATH = os.path.join(BASE_DIR, "chroma")   # <- chroma_primary
DATA_DIR = os.path.join(BASE_DIR, "data", "tanishq")
IMAGE_DIR = os.path.join(DATA_DIR, "images")
BLIP_CAPTIONS_PATH = os.path.join(DATA_DIR, "blip_captions.json")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# ============================================================
# LOAD MODELS (ONCE)
# ============================================================

print("ðŸ”¹ Loading CLIP model...")
clip_model, _ = clip.load("ViT-B/16", device=DEVICE)
clip_model.eval()

In [None]:
print("ðŸ”¹ Loading BLIP captions...")
with open(BLIP_CAPTIONS_PATH, "r") as f:
    BLIP_CAPTIONS = json.load(f)

In [None]:
# ============================================================
# LOAD CHROMA (PERSISTED DB)
# ============================================================

print("ðŸ”¹ Connecting to Chroma DB...")
chroma_client = chromadb.Client(
    Settings(
        persist_directory=CHROMA_PATH,
        anonymized_telemetry=False
    )
)

image_collection = chroma_client.get_collection("jewelry_images")
metadata_collection = chroma_client.get_collection("jewelry_metadata")

print(
    "âœ… Chroma loaded | Images:",
    image_collection.count(),
    "| Metadata:",
    metadata_collection.count()
)

In [None]:
# ============================================================
# FASTAPI APP
# ============================================================

app = FastAPI(title="Jewellery Multimodal Search")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

In [None]:
# ============================================================
# REQUEST / RESPONSE SCHEMAS
# ============================================================

class TextSearchRequest(BaseModel):
    query: str
    top_k: int = 5


class SimilarSearchRequest(BaseModel):
    image_id: str
    top_k: int = 5

In [None]:
# ============================================================
# CLIP QUERY ENCODING (TEXT ONLY)
# ============================================================

def encode_text_clip(text: str) -> np.ndarray:
    tokens = clip.tokenize([text]).to(DEVICE)
    with torch.no_grad():
        emb = clip_model.encode_text(tokens)
        emb = emb / emb.norm(dim=-1, keepdim=True)
    return emb.cpu().numpy()[0]

In [None]:
# ============================================================
# INTENT & ATTRIBUTE DETECTION (LIGHT, SAFE)
# ============================================================

def detect_intent_and_attributes(query: str) -> Dict:
    q = query.lower()
    attrs = {}

    if "necklace" in q:
        attrs["category"] = "necklace"
    elif "ring" in q:
        attrs["category"] = "ring"

    if "gold" in q:
        attrs["metal"] = "gold"
    elif "silver" in q:
        attrs["metal"] = "silver"

    if "pearl" in q:
        attrs["primary_stone"] = "pearl"
    elif "diamond" in q:
        attrs["primary_stone"] = "diamond"

    return {
        "intent": "search",
        "attributes": attrs
    }

In [None]:
# ============================================================
# VISUAL RETRIEVAL (NO LANGCHAIN)
# ============================================================

def retrieve_visual_candidates(query_text: str, k: int = 100):
    q_emb = encode_text_clip(query_text)

    res = image_collection.query(
        query_embeddings=[q_emb],
        n_results=k
    )

    return [
        {
            "image_id": img_id,
            "visual_score": dist
        }
        for img_id, dist in zip(res["ids"][0], res["distances"][0])
    ]

In [None]:
# ============================================================
# METADATA SCORING REFINEMENTS
# ============================================================

def adaptive_alpha(query_attrs: Dict) -> float:
    return 0.1 + 0.1 * len(query_attrs)


def refined_metadata_adjustment(meta: Dict, query_attrs: Dict) -> float:
    score = 0.0

    for attr, q_val in query_attrs.items():
        m_val = meta.get(attr)
        conf = meta.get(f"confidence_{attr}", 0.0)

        if m_val == q_val:
            score += conf
        elif conf > 0.6:
            score -= 0.3 * conf

    return score


def apply_metadata_boost(candidates: List[Dict], query_attrs: Dict):
    alpha = adaptive_alpha(query_attrs)
    ranked = []

    for c in candidates:
        meta = metadata_collection.get(
            ids=[c["image_id"]],
            include=["metadatas"]
        )["metadatas"][0]

        adjust = refined_metadata_adjustment(meta, query_attrs)
        final_score = c["visual_score"] - alpha * adjust

        ranked.append({
            "image_id": c["image_id"],
            "visual_score": c["visual_score"],
            "metadata_boost": adjust,
            "final_score": final_score
        })

    return sorted(ranked, key=lambda x: x["final_score"])

In [None]:
# ============================================================
# LLM-FREE EXPLANATION (SAFE, GROUNDED)
# ============================================================

def explain_match(image_id: str, query_attrs: Dict) -> str:
    caption = BLIP_CAPTIONS.get(image_id, "")
    meta = metadata_collection.get(
        ids=[image_id],
        include=["metadatas"]
    )["metadatas"][0]

    reasons = []

    for k, v in query_attrs.items():
        if meta.get(k) == v:
            reasons.append(v)

    if reasons:
        return (
            f"Recommended because it visually resembles a {meta['category']} "
            f"and matches attributes such as {', '.join(reasons)}."
        )

    return "Recommended due to strong visual similarity."

In [None]:
# ============================================================
# API ENDPOINTS
# ============================================================

@app.post("/search/text")
def search_text(req: TextSearchRequest):
    intent = detect_intent_and_attributes(req.query)
    attrs = intent["attributes"]

    candidates = retrieve_visual_candidates(req.query, k=100)
    ranked = apply_metadata_boost(candidates, attrs)[:req.top_k]

    results = []
    for r in ranked:
        results.append({
            "image_id": r["image_id"],
            "explanation": explain_match(r["image_id"], attrs),
            "scores": {
                "visual": r["visual_score"],
                "metadata": r["metadata_boost"],
                "final": r["final_score"]
            }
        })

    return {
        "query": req.query,
        "intent": attrs,
        "results": results
    }

In [None]:
@app.post("/search/similar")
def search_similar(req: SimilarSearchRequest):
    base = image_collection.get(
        ids=[req.image_id],
        include=["embeddings"]
    )["embeddings"][0]

    res = image_collection.query(
        query_embeddings=[base],
        n_results=req.top_k + 1
    )

    base_meta = metadata_collection.get(
        ids=[req.image_id],
        include=["metadatas"]
    )["metadatas"][0]

    attrs = {
        k: base_meta[k]
        for k in ["category", "metal", "primary_stone"]
        if base_meta.get(k) != "unknown"
    }

    candidates = [
        {
            "image_id": img_id,
            "visual_score": dist
        }
        for img_id, dist in zip(res["ids"][0], res["distances"][0])
        if img_id != req.image_id
    ]

    ranked = apply_metadata_boost(candidates, attrs)[:req.top_k]

    results = []
    for r in ranked:
        results.append({
            "image_id": r["image_id"],
            "explanation": explain_match(r["image_id"], attrs),
            "scores": {
                "visual": r["visual_score"],
                "metadata": r["metadata_boost"],
                "final": r["final_score"]
            }
        })

    return {
        "base_image": req.image_id,
        "results": results
    }

In [None]:
@app.get("/image/{image_id}")
def get_image(image_id: str):
    path = os.path.join(IMAGE_DIR, image_id)
    if not os.path.exists(path):
        raise HTTPException(status_code=404, detail="Image not found")
    return FileResponse(path)