In [None]:
# | default_exp semantic_search


In [None]:

# | export
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, PointStruct, Distance
from datetime import datetime
from typing import Dict, List, Optional
from uuid import uuid4



In [None]:

# | export
class KeywordVectorDB:
    def __init__(
        self,
        model_name: str = "Omartificial-Intelligence-Space/Arabic-Triplet-Matryoshka-V2",
    ):
        self.model = SentenceTransformer(model_name, device="cuda")
        self.client = QdrantClient("localhost", port=6333)

    def _get_collection_name(self, page_url: str) -> str:
        """Convert URL to valid collection name"""
        return page_url.replace("https://", "").replace("/", "_").replace(".", "_")

    def get_or_create_collection(self, page_url: str) -> str:
        """Get existing collection or create if doesn't exist"""
        collection_name = self._get_collection_name(page_url)
        collections = self.client.get_collections().collections
        exists = any(c.name == collection_name for c in collections)

        if not exists:
            self.client.create_collection(
                collection_name=collection_name,
                vectors_config=VectorParams(
                    size=self.model.get_sentence_embedding_dimension(),
                    distance=Distance.COSINE,
                ),
            )
        return collection_name

    def store_keywords(self, page_url: str, keywords_data: List[Dict]):
        """Store keywords with performance data"""
        collection_name = self.get_or_create_collection(page_url)
        points = []

        for kw_data in keywords_data:
            keyword = kw_data["keyword"]
            vector = self.model.encode(keyword)

            point = PointStruct(
                id=str(uuid4()),
                vector=vector.tolist(),
                payload={
                    "keyword": keyword,
                    "clicks": kw_data.get("clicks", 0),
                    "impressions": kw_data.get("impressions", 0),
                    "position": kw_data.get("position", 0),
                    "ctr": kw_data.get("ctr", 0),
                    "in_content": False,
                    "is_important": False,
                    "last_updated": datetime.now().isoformat(),
                },
            )
            points.append(point)

        if points:
            self.client.upsert(collection_name=collection_name, points=points)

    def search_keywords(
        self, page_url: str, query_text: str, limit: int = 10
    ) -> List[Dict]:
        """Search similar keywords"""
        collection_name = self._get_collection_name(page_url)
        query_vector = self.model.encode(query_text)

        results = self.client.search(
            collection_name=collection_name,
            query_vector=query_vector.tolist(),
            limit=limit,
        )

        return [
            {"keyword": r.payload["keyword"], "score": r.score, **r.payload}
            for r in results
        ]

    def get_keywords(self, page_url: str, min_clicks: int = None) -> List[Dict]:
        """Get all keywords for a page"""
        collection_name = self._get_collection_name(page_url)
        results = self.client.scroll(collection_name=collection_name, limit=1000)[0]

        keywords = [{"keyword": r.payload["keyword"], **r.payload} for r in results]

        if min_clicks:
            keywords = [k for k in keywords if k.get("clicks", 0) >= min_clicks]

        return keywords

    def update_content_status(self, page_url: str, content: str):
        """Mark which keywords appear in content"""
        collection_name = self._get_collection_name(page_url)
        results = self.client.scroll(collection_name=collection_name, limit=1000)[0]

        for point in results:
            keyword = point.payload["keyword"]
            in_content = keyword.lower() in content.lower()

            self.client.set_payload(
                collection_name=collection_name,
                payload={
                    "in_content": in_content,
                    "last_updated": datetime.now().isoformat(),
                },
                points=[point.id],
            )
    def set_keyword_importance(self, page_url: str, keyword: str, is_important: bool):
        """Mark keyword as important or not"""
        collection_name = self._get_collection_name(page_url)

        results = self.client.scroll(
            collection_name=collection_name,
            scroll_filter={"must": [{"key": "keyword", "match": {"value": keyword}}]},
            limit=1,
        )[0]

        if results:
            point = results[0]
            self.client.set_payload(
                collection_name=collection_name,
                payload={
                    "is_important": is_important,
                    "last_updated": datetime.now().isoformat(),
                },
                points=[point.id],
            )
            return True
        return False


    def get_keyword_importance(self, page_url: str, keyword: str) -> Optional[bool]:
        """Check if keyword is marked important"""
        collection_name = self._get_collection_name(page_url)

        results = self.client.scroll(
            collection_name=collection_name,
            scroll_filter={"must": [{"key": "keyword", "match": {"value": keyword}}]},
            limit=1,
        )[0]

        return results[0].payload.get("is_important") if results else None
