In [None]:
import datetime
from pathlib import Path
import json
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from typing import Any, Union, Optional, Callable
from pydantic import SecretStr
from urllib.error import HTTPError
import time
from Bio import Entrez
from tqdm.auto import tqdm
import feedparser
import fitz
import requests
import hashlib
from newspaper import Article
from bs4 import BeautifulSoup
from langchain_core.documents import Document
from qdrant_client import QdrantClient, models
from qdrant_client.http.models import Distance, VectorParams, SparseVectorParams
from langchain_qdrant import QdrantVectorStore, FastEmbedSparse, RetrievalMode
# import os
import ast

In [None]:
class DataExtractionModule():
    """
    Fetches and processes recent articles from PubMed Central (PMC).

    This method retrieves metadata and full-text content for research papers indexed in PMC, 
    within the time window specified by the module's `lookup_window_size`. It handles batching, 
    rate limiting, and error retries. Articles that have already been ingested (based on the 
    provided `ingested_logs`) are skipped.

    For each fetched article, the following information is extracted and returned as a dictionary:
        - doc_id: Unique identifier for the document (PMC ID).
        - source: Source label ("PubMed Central").
        - doc_type: Document type ("research_paper").
        - title: Cleaned article title.
        - abstract: Cleaned abstract text.
        - body_text: Cleaned body text (or abstract if body is missing).
        - published_date: ISO 8601 formatted publication date, if available.
        - published_date_ts: Publication timestamp (float), or 0.0 if unavailable.
        - url: Direct URL to the article on PMC.

    Returns:
        List[dict[str, Any]]: A list of dictionaries, each representing a parsed PMC research article.

    Notes:
        - Articles outside the specified date window or already ingested are excluded.
        - Handles HTTP errors, rate limiting (429), and retries failed fetches up to a maximum.
        - Abstracts and bodies are recursively extracted and HTML-cleaned.
        - Uses the NCBI Entrez API with the provided API key.
        - Progress is displayed using tqdm.
    """
    
    def __init__(
        self,
        ncbi_api_key: Union[SecretStr, str],
        ingested_logs: set[str],
        lookup_window_size: int = 90 # (in days)
    ):
        if isinstance(ncbi_api_key, str):
            ncbi_api_key = SecretStr(ncbi_api_key)
        
        self.ncbi_api_key = ncbi_api_key
        
        self.lookup_window_size = lookup_window_size
        
        self.cutoff = self._days_ago(self.lookup_window_size)
        self.ingested: set[str] = ingested_logs
        
        self.session = requests.Session()
        self.session.headers.update({
            "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
                          "AppleWebKit/537.36 (KHTML, like Gecko) "
                          "Chrome/114.0.0.0 Safari/537.36"
        })
        
    def fetch_pmc(self) -> list[dict[str, Any]]:
        RETMAX: int = 2000 # currently small value for prototyping
        BATCH_SIZE: int = 400
        THROTTLE_SECONDS: float = 6.0
        MAX_RETRIES: int = 3
            
        def _get_pmcid(article_id_list: list[Any]) -> Optional[str]:
            for elem in article_id_list:
                id_str = str(elem)
                id_type = getattr(elem, "attributes", {}).get("pub-id-type", "")
                if id_type == 'pmcid':
                    return id_str
            return None

        def _extract_abstract(dict_elements: Union[dict, list[dict]], max_depth: Optional[int] = None, _depth: int = 0) -> str:
            if not isinstance(dict_elements, list):
                dict_elements = [dict_elements]

            paragraphs: list[str] = []

            for elem in dict_elements:
                if not isinstance(elem, dict):
                    continue

                p_texts = elem.get("p", [])
                if isinstance(p_texts, str):
                    paragraphs.append(p_texts.strip())
                elif isinstance(p_texts, list):
                    for p in p_texts:
                        if isinstance(p, str):
                            paragraphs.append(p.strip())

                if max_depth is None or _depth < max_depth:
                    for sec in elem.get("sec", []):
                        nested = _extract_abstract(sec, max_depth, _depth + 1)
                        if nested:
                            paragraphs.append(nested)

            raw_abstract = "\n\n".join(p for p in paragraphs if p)
            cleaned = BeautifulSoup(raw_abstract, "html.parser").get_text(separator=" ")
            return " ".join(cleaned.split())

        def _parse_pub_date(pub_date_elements, preferred_types: tuple = ('epub','ppub','collection')) -> Union[datetime.datetime, None]:
            best = None

            for elem in pub_date_elements or []:
                attrs = getattr(elem, 'attributes', None)
                pub_type = ''
                if isinstance(attrs, dict):
                    pub_type = attrs.get('pub-type','').lower()
                if pub_type not in preferred_types:
                    continue

                try:
                    parts = [int(x) for x in list(elem)]
                except Exception:
                    continue

                if len(parts) == 3:
                    day, month, year = parts
                elif len(parts) == 2:
                    day = 1
                    month, year = parts
                elif len(parts) == 1:
                    day = 1
                    month = 1
                    year = parts[0]
                else:
                    continue

                try:
                    dt = datetime.datetime(year, month, day)
                except ValueError:
                    continue

                idx = preferred_types.index(pub_type)
                precision = len(parts)

                if (best is None or idx < best[1] or (idx == best[1] and precision > best[2])):
                    best = (dt, idx, precision)

            return best[0] if best else None

        def _extract_body_text(elems: Union[dict, list[dict]], max_depth: Optional[int] = None, _depth: int = 0) -> str:
            if not isinstance(elems, list):
                elems = [elems]

            paragraphs: list[str] = []

            for elem in elems:
                if not isinstance(elem, dict):
                    continue

                p = elem.get("p", [])
                if isinstance(p, str):
                    paragraphs.append(p.strip())
                elif isinstance(p, list):
                    for text in p:
                        if isinstance(text, str):
                            paragraphs.append(text.strip())

                if max_depth is None or _depth < max_depth:
                    for child_sec in elem.get("sec", []):
                        extracted = _extract_body_text(child_sec, max_depth, _depth + 1)
                        if extracted:
                            paragraphs.append(extracted)

            return "\n\n".join([para for para in paragraphs if para])
        
        docs: list[dict[str, Any]] = []
        Entrez.email = 'vivalaraza234@gmail.com'
        Entrez.api_key = self.ncbi_api_key.get_secret_value()
        
        try:
            search_handle = Entrez.esearch(
                db="pmc",
                term="all[sb]",
                mindate=self.cutoff.date().isoformat(),
                retmax=RETMAX
            )
            record = Entrez.read(search_handle, validate=False)
            search_handle.close()
        except Exception as e:
            print(f"[ERROR] Entrez.esearch failed: {e}")
            return []

        id_list = record.get("IdList", [])
        total_ids = len(id_list)
        if total_ids == 0:
            return []

        pbar = tqdm(total=total_ids, desc="PMC docs processed", unit="id")
        
        for i in range(0, total_ids, BATCH_SIZE):
            batch_ids = id_list[i: i + BATCH_SIZE]
            articles = []

            for attempt in range(MAX_RETRIES):
                try:
                    fetch_handle = Entrez.efetch(
                        db="pmc",
                        id=",".join(batch_ids),
                        rettype="full",
                        retmode="xml"
                    )
                    articles = Entrez.read(fetch_handle, validate=False)
                    fetch_handle.close()
                    break
                except HTTPError as e:
                    if e.code == 429:
                        print("[WARN] Rate limit hit (429). Sleeping 60s before retry…")
                        time.sleep(60)
                    else:
                        print(f"[ERROR] Entrez.efetch HTTPError: {e}. Skipping this batch.")
                        break
                except Exception as e:
                    print(f"[WARN] Entrez.efetch failed (attempt {attempt+1}): {e}. Retrying in 3s…")
                    time.sleep(3)
            else:
                print("[ERROR] All retries for this batch failed. Skipping batch.")
                pbar.update(len(batch_ids))
                continue

            for art in articles:
                try:
                    pmc_id = _get_pmcid(art.get("front", {}).get("article-meta", {}).get("article-id", []))
                    if not pmc_id:
                        continue
                        
                    doc_id = f"PMC_{pmc_id}"
                    if doc_id in self.ingested:
                        continue
 
                    front = art.get("front", {}).get("article-meta", {})
                    pubdate  = _parse_pub_date(front.get("pub-date", []))
                    
                    raw_title    = front.get("title-group", {}).get("article-title", "")
                    raw_abstract = _extract_abstract(front.get("abstract", []), max_depth=1)
                    raw_body     = _extract_body_text(art.get("body", {}), max_depth=100)

                    title    = self._clean_html(raw_title)
                    abstract = self._clean_html(raw_abstract)
                    body = self._clean_html(raw_body) or abstract

                    docs.append({
                        "doc_id": doc_id,
                        "source": "PubMed Central",
                        "doc_type": "research_paper",
                        "title": title,
                        "abstract": abstract,
                        "body_text": body or abstract,
                        "published_date": None if not pubdate else pubdate.isoformat(),
                        "published_date_ts": 0.0 if not pubdate else float(pubdate.timestamp()),
                        "url": f"https://pmc.ncbi.nlm.nih.gov/articles/{pmc_id}/"
                    })
                    pbar.update(1)
                    
                except Exception as e:
                    print(f"[ERROR] Failed parsing article PMC_{pmc_id}: {e}. Skipping it.")
                    continue

            time.sleep(THROTTLE_SECONDS)

        pbar.close()
        return docs
    
    def fetch_arxiv(self, categories: list[str] = None, per_category_pdf_cap: int = 300) -> list[dict[str, Any]]:
        categories = categories or ["q-bio.*", "physics.bio-ph", "physics.med-ph"] # excluded {"cs.CV", "cs.LG", "cs.AI"} for prototype
        BASE = "http://export.arxiv.org/api/query?"
        ITEMS_PER_PAGE = 200
        MAX_RETRIES = 3
        THROTTLE_SECONDS = 3 # arXiv has 1 req/3s rate limit

        def _extract_pdf_text(pdf_url: str) -> str:
            try:
                pdf_bytes = self.session.get(pdf_url, timeout=20).content
                with fitz.open(stream=pdf_bytes, filetype="pdf") as doc:
                    return "\n\n".join(p.get_text() for p in doc)
            except Exception as e:
                print(f"[WARN] PDF parse failed ({pdf_url}): {e}")
                return ""

        all_docs: list[dict[str, Any]] = []

        for cat in categories:
            kept: list[dict[str, Any]] = []
            start = 0

            pbar = tqdm(total=per_category_pdf_cap, desc=f"arXiv {cat:<8} processed", unit="pdf")

            while len(kept) < per_category_pdf_cap:
                url = (
                    f"{BASE}search_query=cat:{cat}"
                    f"&sortBy=submittedDate&sortOrder=descending"
                    f"&start={start}&max_results={ITEMS_PER_PAGE}"
                )

                feed = None
                for attempt in range(MAX_RETRIES):
                    try:
                        raw = self.session.get(url, timeout=10).text
                        feed = feedparser.parse(raw)
                        break
                    except Exception as e:
                        print(f"[WARN] arXiv fetch fail (try {attempt+1}): {e}")
                        time.sleep(THROTTLE_SECONDS)
                if not feed or not feed.entries:
                    time.sleep(THROTTLE_SECONDS)
                    break

                stop_early = False
                for entry in feed.entries:
                    pub_dt = datetime.datetime(*entry.published_parsed[:6])
                    if pub_dt < self.cutoff:
                        stop_early = True
                        break

                    arxiv_id = entry.id.rsplit("/", 1)[-1]
                    doc_id   = f"ARXIV_{arxiv_id}"
                    if doc_id in self.ingested:
                        continue

                    pdf_url = next(
                        (lnk.href for lnk in entry.links if lnk.rel == "related" and lnk.type == "application/pdf"),
                        None
                    )
                    if not pdf_url:
                        print(f"[WARN] No PDF URL found for arXiv entry with id {arxiv_id}, skipping.")
                        continue

                    body = _extract_pdf_text(pdf_url)
                    if not body.strip():
                        continue

                    kept.append({
                        "doc_id": doc_id,
                        "source": "arXiv",
                        "doc_type": "research_paper",
                        "title": entry.title,
                        "abstract": entry.summary,
                        "body_text": body,
                        "published_date": None if not pub_dt else pub_dt.isoformat(),
                        "published_date_ts": 0.0 if not pub_dt else float(pub_dt.timestamp()),
                        "url": entry.link,
                        "metadata": {
                            "authors": [a.name for a in entry.authors],
                            "category": cat
                        },
                    })
                    pbar.update(1)

                    if len(kept) >= per_category_pdf_cap:
                        break
                    time.sleep(THROTTLE_SECONDS)

                if stop_early or len(feed.entries) < ITEMS_PER_PAGE:
                    break
                start += ITEMS_PER_PAGE

            pbar.close()
            all_docs.extend(kept)

        return all_docs
    
    def fetch_clinical_trials(self) -> list[dict[str, Any]]:
        BATCH_SIZE = 1000  # max allowed is 1000; currently small value for prototype
        THROTTLE_SECONDS = 1
        MAX_RETRIES = 3
        RETMAX = 10000

        def _parse_date(text: str) -> Optional[datetime.datetime]:
            try:
                return datetime.datetime.fromisoformat(text)
            except Exception:
                return None
            
        def _clean_html(text: str) -> str:
            soup = BeautifulSoup(text or "", "html.parser")
            return " ".join(soup.get_text(separator=" ").split())

        kept_docs = []
        next_page_token = None

        pbar = tqdm(total=RETMAX, desc="CT.gov trials processed", unit="trial")
        
        today = datetime.datetime.now(datetime.UTC).date()
        expr = f"AREA[LastUpdatePostDate]RANGE[{self.cutoff.date()},{today}]"

        while len(kept_docs) < RETMAX:
            url = (
                "https://clinicaltrials.gov/api/v2/studies?"
                f"query.term={expr}"
                f"&format=json"
                f"&pageSize={BATCH_SIZE}"
            )
            if next_page_token:
                url += f"&pageToken={next_page_token}"

            data = None
            for attempt in range(MAX_RETRIES):
                try:
                    response = self.session.get(url, timeout=10)
                    response.raise_for_status()
                    data = response.json()
                    break
                except Exception as e:
                    page_info = f"page {next_page_token}" if next_page_token else "first page"
                    print(f"[WARN] CT.gov fetch fail on {page_info} (try {attempt+1}): {e}")
                    time.sleep(THROTTLE_SECONDS * 2)
            if data is None:
                break

            studies = data.get("studies", [])
            if not studies:
                break

            for st in studies:
                if len(kept_docs) >= RETMAX:
                    break

                try:
                    protocol = st.get("protocolSection", {})
                    nct = protocol.get("identificationModule", {}).get("nctId")
                    if not nct:
                        print("[WARN] Skipping study with missing nctId")
                        continue
                    doc_id = f"CT_{nct}"
                    if doc_id in self.ingested:
                        continue

                    title = (
                        protocol.get("identificationModule", {}).get("officialTitle", "")
                        or protocol.get("identificationModule", {}).get("briefTitle", "")
                    )
                    summary = protocol.get("descriptionModule", {}).get("briefSummary", "")
                    details = protocol.get("descriptionModule", {}).get("detailedDescription", "")
                    summary = _clean_html(summary)
                    details = _clean_html(details)
                    body_txt = (details or summary).strip()
                    if not body_txt:
                        continue

                    start_date_str = protocol.get("statusModule", {}).get("startDateStruct", {}).get("date", "")
                    first_posted_str = protocol.get("statusModule", {}).get("studyFirstPostDateStruct", {}).get("date", "")
                    last_updated_str = protocol.get("statusModule", {}).get("lastUpdatePostDateStruct", {}).get("date", "")
                    
                    start_date = _parse_date(start_date_str)
                    first_posted = _parse_date(first_posted_str)
                    last_updated = _parse_date(last_updated_str)

                    if last_updated and last_updated.date() < self.cutoff.date():
                        continue

                    metadata = {
                        "status": protocol.get("statusModule", {}).get("overallStatus", ""),
                        "phase": protocol.get("designModule", {}).get("phases", []),
                        "study_type": protocol.get("designModule", {}).get("studyType", ""),
                        "enrollment": {
                            "count": protocol.get("designModule", {}).get("enrollmentInfo", {}).get("count", None),
                            "type": protocol.get("designModule", {}).get("enrollmentInfo", {}).get("type", ""),
                        },
                        "eligibility": {
                            "gender": protocol.get("eligibilityModule", {}).get("sex", ""),
                            "min_age": protocol.get("eligibilityModule", {}).get("minimumAge", ""),
                            "max_age": protocol.get("eligibilityModule", {}).get("maximumAge", ""),
                            "healthy_vols": protocol.get("eligibilityModule", {}).get("healthyVolunteers", False),
                            "criteria": protocol.get("eligibilityModule", {}).get("eligibilityCriteria", ""),
                        },
                        "conditions": protocol.get("conditionsModule", {}).get("conditions", []),
                        "interventions": [
                            f"{i.get('type', '')}:{i.get('name', '')}" for i in protocol.get("armsInterventionsModule", {}).get("interventions", [])
                        ],
                        "arms": [
                            {
                                "label": arm.get("label", ""),
                                "type": arm.get("type", ""),
                                "description": arm.get("description", "")
                            } for arm in protocol.get("armsInterventionsModule", {}).get("armGroups", [])
                        ],
                        "outcomes": {
                            "primary": [
                                {
                                    "measure": o.get("measure", ""),
                                    "description": o.get("description", "")
                                } for o in protocol.get("outcomesModule", {}).get("primaryOutcomes", [])
                            ],
                            "secondary": [
                                {
                                    "measure": o.get("measure", ""),
                                    "description": o.get("description", "")
                                } for o in protocol.get("outcomesModule", {}).get("secondaryOutcomes", [])
                            ]
                        },
                        "sponsors": [protocol.get("sponsorCollaboratorsModule", {}).get("leadSponsor", {}).get("name", "")],
                        "collaborators": protocol.get("sponsorCollaboratorsModule", {}).get("collaborators", []),
                        "locations": [
                            {
                                "facility": loc.get("facility", ""),
                                "city": loc.get("city", ""),
                                "state": loc.get("state", ""),
                                "country": loc.get("country", "")
                            } for loc in protocol.get("contactsLocationsModule", {}).get("locations", [])
                        ]
                    }
                    
                    published_date = None
                    if last_updated:
                        published_date = last_updated
                    elif first_posted:
                        published_date = first_posted
                    elif start_date:
                        published_date = start_date

                    kept_docs.append({
                        "doc_id": doc_id,
                        "source": "ClinicalTrials.gov",
                        "doc_type": "clinical_trial",
                        "title": title,
                        "abstract": summary,
                        "body_text": body_txt,
                        "published_date": published_date.isoformat(),
                        "published_date_ts": 0.0 if not published_date else float(published_date.timestamp()),
                        "url": f"https://clinicaltrials.gov/study/{nct}",
                        "metadata": metadata
                    })
                    pbar.update(1)

                except Exception as e:
                    print(f"[WARN] Failed processing study {st.get('nct_id', 'unknown')}: {e}")
                    continue

            next_page_token = data.get("nextPageToken")
            if not next_page_token:
                break
            time.sleep(THROTTLE_SECONDS)

        pbar.close()
        return kept_docs
    
    def fetch_news(self) -> list[dict[str, Any]]:
        RETMAX_PER_FEED = 1000
        THROTTLE_SECONDS = 1
        MAX_RETRIES = 3
        
        FEEDS = {
            "Labiotech": "https://www.labiotech.eu/feed/",
            "BioPharma Dive": "https://www.biopharmadive.com/feeds/news/",
            "STAT News": "https://www.statnews.com/feed/",
            "GEN News": "https://www.genengnews.com/feed/",
            "Nature Biotechnology": "https://www.nature.com/nbt.rss"
        }

        def _hash(text: str) -> str:
            return hashlib.sha1(text.encode()).hexdigest()[:20]

        def _get_entry_datetime(entry):
            dt_struct = getattr(entry, "published_parsed", None) or getattr(entry, "updated_parsed", None)
            if dt_struct:
                return datetime.datetime(*dt_struct[:6])
            return None

        def _clean_html(html: str) -> str:
            soup = BeautifulSoup(html, "html.parser")
            text = soup.get_text(separator=" ")
            return " ".join(text.split())

        def _looks_like_html_error(content: str) -> bool:
            lowered = content.lower()
            error_signals = [
                "attention required",
                "cloudflare",
                "sorry, you have been blocked",
                "enable cookies",
                "please turn javascript on",
                "<html",
                "<body"
            ]
            if lowered.strip().startswith("<html") and any(sig in lowered for sig in error_signals):
                return True
            return False

        kept_docs: list[dict[str, Any]] = []

        headers = {
            "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
                          "AppleWebKit/537.36 (KHTML, like Gecko) "
                          "Chrome/114.0.0.0 Safari/537.36"
        }

        for name, url in FEEDS.items():
            successes = 0
            feed = None

            for attempt in range(1, MAX_RETRIES + 1):
                try:
                    resp = self.session.get(url, timeout=20)
                    resp.raise_for_status()
                    raw = resp.text

                    if _looks_like_html_error(raw):
                        print(f"[WARN] Blocked or invalid content detected for feed {name} on attempt {attempt}")
                        time.sleep(THROTTLE_SECONDS * 3)
                        continue

                    feed = feedparser.parse(raw)
                    break
                except Exception as e:
                    print(f"[WARN] RSS fetch {name} (try {attempt}/{MAX_RETRIES}): {e}")
                    time.sleep(THROTTLE_SECONDS * 2)

            if not feed or not feed.entries:
                print(f"[WARN] No valid feed entries for {name}, skipping.")
                continue

            pbar = tqdm(total=RETMAX_PER_FEED, desc=f"NEWS - {name:<15} processed", unit="article")

            for entry in feed.entries:
                if successes >= RETMAX_PER_FEED:
                    break

                pub_dt = _get_entry_datetime(entry)
                if pub_dt is None or pub_dt.date() < self.cutoff.date():
                    continue

                link = entry.get("link")
                if not link:
                    continue

                guid = entry.get("id") or link or entry.title
                doc_id = f"NEWS_{_hash(guid)}"
                if doc_id in self.ingested:
                    continue

                full_text = None
                summary = None
                for n in range(1, MAX_RETRIES + 1):
                    try:
                        art = Article(url=link, browser_user_agent=headers["User-Agent"])
                        art.download()
                        art.parse()
                        full_text = art.text or ""
                        summary = art.meta_description or entry.get("summary", "")
                        break
                    except Exception as e:
                        print(f"[WARN] Extraction failed for {link} (try {n}/{MAX_RETRIES}): {e}")
                        time.sleep(THROTTLE_SECONDS * 2)

                summary = _clean_html(summary)

                if not full_text:
                    full_text = entry.get("summary", "")
                    summary = full_text

                kept_docs.append({
                    "doc_id": doc_id,
                    "source": name,
                    "doc_type": "news_article",
                    "title": entry.get("title", "").strip(),
                    "abstract": summary.strip() if summary else "",
                    "body_text": full_text.strip() if full_text else "",
                    "published_date": pub_dt.date().isoformat(),
                    "published_date_ts": 0.0 if not pub_dt else float(pub_dt.timestamp()),
                    "url": link,
                    "metadata": {}
                })
                pbar.update(1)

                successes += 1
                time.sleep(THROTTLE_SECONDS)

            pbar.close()
            time.sleep(THROTTLE_SECONDS)

        return kept_docs
    
    def run(self) -> list[dict[str, Any]]:
        collected_docs = []
        
        try:
            collected_docs.extend(self.fetch_pmc())
        except Exception as e:
            print(f"[ERROR] fetch_pmc failed: {e}")

        try:
            collected_docs.extend(self.fetch_arxiv())
        except Exception as e:
            print(f"[ERROR] fetch_arxiv failed: {e}")
            
        try:
            collected_docs.extend(self.fetch_clinical_trials())
        except Exception as e:
            print(f"[ERROR] fetch_clinical_trials failed: {e}")
            
        try:
            collected_docs.extend(self.fetch_news())
        except Exception as e:
            print(f"[ERROR] fetch_news failed: {e}")
        
        return collected_docs
        
    def _days_ago(self, n: int) -> datetime.datetime:
        today = datetime.datetime.now(datetime.UTC).date()
        return datetime.datetime.combine(today - datetime.timedelta(days=n), datetime.datetime.min.time())
    
    @staticmethod
    def _clean_html(text: str) -> str:
        soup = BeautifulSoup(text or "", "html.parser")
        cleaned = soup.get_text(separator=" ")
        return " ".join(cleaned.split())

In [None]:
class DataIngestionModule():
    def run(self, extracted_docs: list[dict[str, Any]]) -> set[str]:
    """
    Ingests a list of extracted biomedical documents into the Qdrant vector database with hybrid search capabilities.

    This method orchestrates the following workflow:
        1. **Client and Collection Initialization:**  
           Ensures a live connection to Qdrant and that the target collection and payload indexes exist. Handles connection retries.
        2. **Document Preparation:**  
           For each extracted document:
               - Abstracts are wrapped into `Document` objects with relevant metadata.
               - Full body text is chunked into overlapping segments (using the configured splitter). Each chunk is paired with metadata.
        3. **Summarization and Keyword Extraction:**  
           Each body chunk is summarized and annotated with up to 10 keywords using an LLM-based summarization chain.  
           The output strictly follows a Python dictionary format suitable for further downstream parsing.
           If the LLM fails or returns malformed output, the original chunk text is used as a fallback summary.
        4. **Document Creation for Embedding:**  
           All summary and abstract `Document` objects are prepared for embedding and vector storage.
        5. **Batch Embedding and Storage:**  
           All documents are embedded (dense and sparse) and persisted in Qdrant in batches, with error handling and exponential backoff on failures.
           Source document IDs for all successfully stored docs are tracked.
        6. **Callback Invocation:**  
           Optionally, calls the provided `on_successful_ingest` callback with a set of source document IDs for each successful batch.

    Args:
        extracted_docs (list[dict[str, Any]]):  
            List of documents as dictionaries, each representing extracted biomedical articles, clinical trials, or news.

    Returns:
        set[str]:  
            Set of unique document IDs (`source_doc_id`) that were successfully stored in the vector database.

    Notes:
        - Handles batching for both summarization and embedding to optimize throughput and reliability.
        - Uses retry with exponential backoff for LLM summarization and Qdrant storage to handle transient failures.
        - Enforces all output summaries and keywords conform to a strict, parseable Python dictionary format for downstream processing.
        - Provides progress updates via `tqdm` and logs key events, errors, and warnings to the console.
        - Supports hybrid search by storing both dense and sparse embeddings.
        - Designed to support millions of documents with scalable chunking and batching.

    Raises:
        ConnectionError: If unable to connect to Qdrant after all retries.

    Example:
        >>> ingestion = DataIngestionModule(qdrant_url, qdrant_api_key, google_api_key)
        >>> stored_ids = ingestion.run(extracted_docs)
        >>> print(f"{len(stored_ids)} documents ingested.")
    """

    SUMMARY_PROMPT_TEMPLATE = """You are a highly specialized text processing AI. Your sole task is to analyze the provided bio-medical text and generate a Python dictionary string.

Follow these instructions METICULOUSLY:

1.  **Output Format:** Your response MUST be a single, valid Python dictionary string.
    The dictionary must conform EXACTLY to this structure:
    `{{ 'summary': "<summary_string>", 'keywords': ["<keyword1>", "<keyword2>", ...] }}`

2.  **Content for 'summary' key:**
    * The value must be a concise and comprehensive summary of the input text.
    * Focus on the most critical bio-medical information, key findings, and essential details.
    * The summary MUST be strictly 200 words or less.
    * It must be a single string.

3.  **Content for 'keywords' key:**
    * The value must be a Python list of strings.
    * Extract atmost 10 of the most relevant and specific bio-medical keywords or keyphrases from the text.
    * These keywords should be suitable for use in a hybrid search system (meaning they should be significant terms, entities, or concepts).
    * Each item in the list must be a string.

4.  **Example of PERFECT output format:**
    `{{ 'summary': 'The study investigates the effect of Compound X on murine models of Alzheimer's disease. Results indicate a significant reduction in amyloid-beta plaques and improved cognitive function compared to placebo. No adverse effects were reported at the tested dosage.', 'keywords': ['Compound X', 'Alzheimer's disease', 'murine models', 'amyloid-beta plaques', 'cognitive function', 'placebo'] }}`

5.  **CRITICAL - Do NOT:**
    * Do NOT include any text, explanation, apologies, or conversational filler before or after the Python dictionary string.
    * Do NOT use markdown (e.g., ```python ... ``` or ```json ... ```) to wrap your output.
    * Do NOT deviate from the specified dictionary structure or key names.
    * Ensure all strings within the dictionary (keys and values) are properly quoted using single quotes for the outer dictionary and keys, and single or double quotes for string values as per valid Python syntax. Double check your quote usage to ensure the output is a valid Python dictionary string.

Text to process:
{text}

Your Python dictionary string output:
"""
    
    def __init__(
        self,
        qdrant_url: str,
        qdrant_api_key: Union[SecretStr, str],
        google_api_key: Union[SecretStr, str],
        summary_model: str = "gemini-1.5-flash-8b",
        embed_model: str = "models/embedding-001",
        chunk_size: int = 2000,
        chunk_overlap: int = 200,
        base_dir: Path = Path("./project_asclepius"),
        qdrant_dir: Path = Path("./qdrant_db"),
        on_successful_ingest: Optional[Callable[[set[str]], None]] = None
    ):
        base_dir.mkdir(exist_ok=True)
        # qdrant_dir = base_dir / qdrant_dir if not qdrant_dir.is_absolute() else qdrant_dir
        # qdrant_dir.mkdir(exist_ok=True)
        
        if isinstance(qdrant_api_key, SecretStr):
            qdrant_api_key = qdrant_api_key.get_secret_value()
            
        self.qdrant_api_key = qdrant_api_key
        self.qdrant_url = qdrant_url
        
        if isinstance(google_api_key, str):
            google_api_key = SecretStr(google_api_key)
            
        self.google_api_key = google_api_key
        
        self.on_successful_ingest = on_successful_ingest
        
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size = chunk_size, chunk_overlap = chunk_overlap
        )
        self.embedding_model = GoogleGenerativeAIEmbeddings(model=embed_model, google_api_key=google_api_key)
        self.sparse_model = FastEmbedSparse(model_name="Qdrant/bm25")
        
        self.qdrant_dir = qdrant_dir
        self.collection_name = "project_asclepius"
        
        self.summary_generation_agent = ChatGoogleGenerativeAI(
            model=summary_model,
            api_key=google_api_key,
            temperature=0,
            max_tokens=2048 # crucial because we are dealing with lots of data and there is a good chance that without this restriction -> LLM can generate really long summaries for atleast few of the data
        )
        self.summary_generation_prompt_template = PromptTemplate(
            input_variables=["text"],
            template=self.SUMMARY_PROMPT_TEMPLATE
        )
        self.summary_generation_chain = (
            self.summary_generation_prompt_template
            | self.summary_generation_agent
            | StrOutputParser()
        )
        
    def run(self, extracted_docs: list[dict[str, Any]]) -> set[str]:
        SUMMARY_BATCH_SIZE = 100
        EMBEDDING_BATCH_SIZE = 100
        MAX_RETRIES = 5
        RETRY_BACKOFF_SECS = 5
        
        if not extracted_docs:
            return set()
        
        try:
            client = self._get_client(self.qdrant_url, self.qdrant_api_key)
            self._ensure_collection_exists(self.qdrant_url, self.qdrant_api_key)
        except ConnectionError as e:
            print(f"[ERROR] Cannot connect to Qdrant at {self.qdrant_url}: {e}")
            return set()
        
        self.vectorstore = QdrantVectorStore(
            client=client,
            collection_name=self.collection_name,
            embedding=self.embedding_model,
            sparse_embedding=self.sparse_model,
            retrieval_mode=RetrievalMode.HYBRID,
            vector_name="dense",
            sparse_vector_name="langchain-sparse"
        )
        
        abstract_docs: list[Document] = []
        chunk_records: list[dict[str, Any]] = []
            
        for doc in extracted_docs:
            abstract = doc.get("abstract", doc.get("title", "")).strip()
            # if there is no abstract -> based on `title`/`body_text` you can use an agent to generate abstracts
            if abstract:
                meta = {
                    "published_date": doc.get("published_date") or "",
                    "published_date_ts": doc.get("published_date_ts") or 0.0,
                    "content_type": "abstract",
                    "source": doc.get("source") or "",
                    "source_doc_id": doc.get("doc_id") or "",
                    "url": doc.get("url") or "",
                    "title": doc.get("title") or "",
                    "original_doc_type": doc.get("doc_type") or ""
                }
                abstract_docs.append(Document(page_content=abstract, metadata=meta))

            body = doc.get("body_text", "").strip()
            if not body:
                continue

            chunks = self.text_splitter.split_text(body)
            for idx, chunk in enumerate(chunks):
                meta = {
                    "published_date": doc.get("published_date") or "",
                    "published_date_ts": doc.get("published_date_ts") or 0.0,
                    "content_type": "chunk",
                    "chunk_index": idx,
                    "source": doc.get("source") or "",
                    "source_doc_id": doc.get("doc_id") or "",
                    "url": doc.get("url") or "",
                    "title": doc.get("title") or "",
                    "original_doc_type": doc.get("doc_type") or ""
                }
                chunk_records.append({"text": chunk, "meta": meta})
                
        if chunk_records:
            print(f"[INFO] Summarizing {len(chunk_records):,} chunks…")
                
        chunk_summaries: list[dict[str, Any]] = [{} for _ in chunk_records]
        for start in tqdm(range(0, len(chunk_records), SUMMARY_BATCH_SIZE), unit="batch", desc="Batches summarized"):
            end = start + SUMMARY_BATCH_SIZE
            batch = chunk_records[start:end]
            batch_inputs = [{"text": rec["text"]} for rec in batch]
            
            for attempt in range(1, MAX_RETRIES + 1):
                try:
                    outputs: list[str] = self.summary_generation_chain.batch(batch_inputs)
                    if len(outputs) != len(batch):
                        raise RuntimeError("LLM returned unexpected number of summaries.")
                    
                    parsed_outputs: list[dict[str, Any]] = []
                    for i, output_str in enumerate(outputs):
                        try:
                            data = ast.literal_eval(output_str)
                            if not (
                                isinstance(data, dict) and
                                isinstance(data.get("summary"), str) and
                                isinstance(data.get("keywords"), list) and
                                all(isinstance(kw, str) for kw in data["keywords"])
                            ):
                                raise ValueError()
                            parsed_outputs.append(data)
                        except Exception:
                            parsed_outputs.append({"summary": batch_inputs[i]["text"], "keywords": []})
                    
                    chunk_summaries[start:end] = parsed_outputs
                    time.sleep(2) # for avoding LangChain internal server errors
                    break
                except Exception as e:
                    if attempt == MAX_RETRIES:
                        print(f"[WARN] Final failure summarizing batch {start}-{end}: {e}")
                    else:
                        wait = RETRY_BACKOFF_SECS * (2 ** (attempt - 1))
                        print(
                            f"[WARN] Summary batch {start}-{end} failed ({e}). "
                            f"Retrying in {wait}s…"
                        )
                        time.sleep(wait)
                        
            # Fallback to storing original chunk text if summary generation failed
            for i in range(start, min(end, len(chunk_summaries))):
                if not chunk_summaries[i]:
                    chunk_summaries[i] = {"summary": chunk_records[i]["text"], "keywords": []}
                        
        chunk_docs: list[Document] = []
        for rec, parsed in zip(chunk_records, chunk_summaries):
            summary_text = parsed["summary"]
            keywords = parsed["keywords"]
            meta = rec["meta"]
            meta["original_content"] = rec["text"]
            meta["keywords"] = keywords
            chunk_docs.append(Document(page_content=summary_text, metadata=meta))
            
        all_docs = abstract_docs + chunk_docs
        stored_source_ids: set[str] = set()
        stored_doc_count = 0
        
        if all_docs:
            print(f"[INFO] Persisting {len(all_docs):,} docs to Qdrant…")
            
        for start in tqdm(range(0, len(all_docs), EMBEDDING_BATCH_SIZE), unit="batch", desc="Batches stored"):
            end = start + EMBEDDING_BATCH_SIZE
            batch_docs = all_docs[start:end]

            for attempt in range(1, MAX_RETRIES + 1):
                try:
                    self.vectorstore.add_documents(batch_docs)
                    batch_ids: set[str] = set()
                    for d in batch_docs:
                        stored_source_ids.add(d.metadata["source_doc_id"])
                        batch_ids.add(d.metadata["source_doc_id"])
                    stored_doc_count += len(batch_docs)
                    if self.on_successful_ingest:
                        self.on_successful_ingest(batch_ids)
                    time.sleep(5) # for avoiding 429 ResourceExhaust erors
                    break
                except Exception as e:
                    if attempt == MAX_RETRIES:
                        print(f"[WARN] Final failure embedding batch {start}-{end}: {e}")
                    else:
                        wait = RETRY_BACKOFF_SECS * (2 ** (attempt - 1))
                        print(
                            f"[WARN] Embedding batch {start}-{end} failed ({e}). "
                            f"Retrying in {wait}s…"
                        )
                        time.sleep(wait)
                        
        print(f"[INFO] Ingestion complete · {stored_doc_count:,} items belonging to {len(stored_source_ids):,} source docs stored.")
        return stored_source_ids
    
    def _ensure_collection_exists(self, qdrant_url: str, qdrant_api_key: str):
        client = self._get_client(qdrant_url=qdrant_url, qdrant_api_key=qdrant_api_key)
        try:
            client.get_collection(self.collection_name)
        except Exception:
            print(f"[INFO] Creating collection '{self.collection_name}'...")
            sample_embedding = self.embedding_model.embed_query("test")
            embedding_dim = len(sample_embedding)
            
            client.create_collection(
                collection_name=self.collection_name,
                vectors_config={
                    "dense": VectorParams(
                        size=embedding_dim,
                        distance=Distance.COSINE
                    )
                },
                sparse_vectors_config={
                    "langchain-sparse": SparseVectorParams(
                        index=models.SparseIndexParams(on_disk=False)
                    )
                }
            )
            print(f"[SUCCESS] Collection created with dense dimension {embedding_dim} and sparse vectors.")
            
        self._ensure_payload_indexes(client)
            
    def _ensure_payload_indexes(self, client: QdrantClient):
        required_indexes = [
            ("metadata.content_type", "keyword"),
            ("metadata.published_date_ts", "float"),
            ("metadata.original_doc_type", "keyword"), 
            ("metadata.source", "keyword"),
            ("metadata.source_doc_id", "keyword"),
            ("metadata.url", "keyword"),
            ("metadata.chunk_index", "integer"),
            ("metadata.title", "keyword"),
            ("metadata.keywords", "keyword"),
            ("metadata.original_content", "text")
        ]
        
        try:
            collection_info = client.get_collection(self.collection_name)
            existing_indexes = set(collection_info.payload_schema.keys())
        except:
            existing_indexes = set()
        
        for field_name, field_type in required_indexes:
            if field_name not in existing_indexes:
                try:
                    client.create_payload_index(
                        collection_name=self.collection_name,
                        field_name=field_name,
                        field_schema=field_type
                    )
                    print(f"[DataIngestPipeline] Successfully created index for {field_name}")
                except Exception as e:
                    if "already exists" in str(e).lower():
                        print(f"[DataIngestPipeline] Index for {field_name} already exists")
                    else:
                        print(f"[DataIngestPipeline] Warning: Could not create index for {field_name}: {e}")
            
    def _get_client(self, qdrant_url: str, qdrant_api_key: str):
#         if not hasattr(self, '_client') or self._client is None:
#             lock_file = os.path.join(str(self.qdrant_dir), ".lock")
#             if os.path.exists(lock_file):
#                 try:
#                     os.remove(lock_file)
#                 except Exception as e:
#                     print(f"[WARN] Could not remove lock file: {e}")
            
#             self._client = QdrantClient(path=str(self.qdrant_dir))
        MAX_RETRIES = 3
        BACKOFF = 4
        for attempt in range(1, MAX_RETRIES + 1):
            try:
                client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
                # quick health check
                client.get_collections()
                self._client = client
                return client
            except Exception as exc:
                print(f"[WARN] Qdrant connection attempt {attempt} failed: {exc}")
                if attempt >= MAX_RETRIES:
                    raise ConnectionError("All Qdrant connection attempts failed.") from exc
                time.sleep(BACKOFF)
                BACKOFF *= 2
            
    def close(self):
        if hasattr(self, '_client') and self._client is not None:
            try:
                self._client.close()
                print("[INFO] QdrantClient closed successfully.")
            except Exception as e:
                print(f"[WARN] Error closing client: {e}")
            finally:
                self._client = None

In [None]:
class DataIngestPipeline():
    """
    Executes the full biomedical data extraction and ingestion pipeline.

    This method coordinates the end-to-end workflow for collecting, processing, 
    summarizing, and storing biomedical documents in a Qdrant vector database. 
    It performs the following sequence:

        1. **Extraction:**  
           Uses the `DataExtractionModule` to fetch new documents from various biomedical sources 
           (e.g., PubMed Central, arXiv, ClinicalTrials.gov, news feeds), 
           excluding any documents already present in the ingest log.

        2. **Ingestion:**  
           Uses the `DataIngestionModule` to:
               - Summarize and keyword-tag each document (and its body text chunks) via LLM.
               - Embed the resulting texts (dense + sparse) for hybrid search.
               - Persist all items into Qdrant, while updating the log of ingested document IDs.

        3. **Progress Reporting:**  
           Prints progress and completion messages to the console.

        4. **Resource Cleanup:**  
           Ensures the underlying Qdrant client connection is closed, regardless of success or errors.

    Notes:
        - New document IDs are immediately added to the ingest log after successful storage,
          preventing redundant re-processing in future pipeline runs.
        - This pipeline is designed for periodic or scheduled operation in a production setting.

    Raises:
        Any exceptions from extraction or ingestion modules are allowed to propagate after 
        cleanup and logging.

    Example:
        >>> pipeline = DataIngestPipeline(qdrant_url, qdrant_api_key, ncbi_api_key, google_api_key)
        >>> pipeline.run()
        [INFO] Starting data extraction...
        [INFO] Extraction complete. 3,500 docs fetched for ingestion.
        [INFO] Starting ingestion to Qdrant...
        [INFO] Pipeline finished · 2,770 new source docs added.
    """
    
    def __init__(
        self,
        qdrant_url: str,
        qdrant_api_key: Union[SecretStr, str],
        ncbi_api_key: Union[SecretStr, str],
        google_api_key: Union[SecretStr, str],
        lookup_window_size: int = 365 * 3, # (in days)
        ingest_log_path: Path = Path("./ingested.json"),
        summary_model: str = "gemini-1.5-flash-8b",
        embed_model: str = "models/embedding-001",
        chunk_size: int = 2000,
        chunk_overlap: int = 200,
        base_dir: Path = Path("./project_asclepius"),
        qdrant_dir: Path = Path("./qdrant_db")
    ):
        base_dir.mkdir(exist_ok=True)
        ingest_log_path = base_dir / ingest_log_path if not ingest_log_path.is_absolute() else ingest_log_path
        
        if not qdrant_url:
            print("[ERROR] Qdrant URL is required but was not provided.")
            raise ValueError("Qdrant URL is required.")
        if not qdrant_api_key:
            print("[ERROR] Qdrant API key is required but was not provided.")
            raise ValueError("Qdrant API key is required.")
        
        if isinstance(qdrant_api_key, SecretStr):
            qdrant_api_key = qdrant_api_key.get_secret_value()
        
        self.qdrant_url = qdrant_url
        self.qdrant_api_key = qdrant_api_key
        
        self.ingest_log_path = ingest_log_path
        self.ingested: set[str] = self._load_ingest_log()
        
        self.data_extraction_module = DataExtractionModule(
            ncbi_api_key=ncbi_api_key,
            ingested_logs=self.ingested,
            lookup_window_size=lookup_window_size
        )
        
        self.data_ingestion_module = DataIngestionModule(
            qdrant_url=self.qdrant_url,
            qdrant_api_key=self.qdrant_api_key,
            google_api_key=google_api_key,
            summary_model=summary_model,
            embed_model=embed_model,
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            base_dir=base_dir,
            qdrant_dir=qdrant_dir,
            on_successful_ingest=self._eagerly_update_log
        )
        
    def run(self):
        try:
            print("[INFO] Starting data extraction...")
            extracted_docs = self.data_extraction_module.run()
            print(f"[INFO] Extraction complete. {len(extracted_docs):,} docs fetched for ingestion.")
            print("[INFO] Starting ingestion to Qdrant...")
            stored_ids = self.data_ingestion_module.run(extracted_docs)
            print(f"[INFO] Pipeline finished · {len(stored_ids):,} new source docs added.")
        finally:
            self.data_ingestion_module.close()
        
    def _load_ingest_log(self) -> set[str]:
        if self.ingest_log_path.exists():
            try:
                text = self.ingest_log_path.read_text(encoding="utf-8")
                data = json.loads(text)
                if isinstance(data, list):
                    return set(str(x) for x in data)
                else:
                    print("[WARN] Ingest log is not a list. Starting with empty log.")
                    return set()
            except (json.JSONDecodeError, UnicodeDecodeError) as e:
                print(f"[WARN] Ingest log corrupted or not valid JSON: {e}. Starting with empty log.")
                return set()
            except Exception as e:
                print(f"[WARN] Could not read ingest log: {e}. Starting with empty log.")
                return set()
        else:
            return set()
        
    def _save_ingest_log(self):
        try:
            self.ingest_log_path.write_text(json.dumps(sorted(self.ingested)), encoding="utf-8")
        except Exception as e:
            print(f"[WARN] Failed to save ingest log: {e}")
            
    def _eagerly_update_log(self, new_ids: set[str]):
        self.ingested.update(new_ids)
        self._save_ingest_log()

In [None]:
# run data ingestion pipeline

# use your own keys
qdrant_url=""
qdrant_api_key=""
ncbi_api_key=""
google_api_key=""

obj = DataIngestPipeline(
    qdrant_url=qdrant_url,
    qdrant_api_key=qdrant_api_key,
    ncbi_api_key=ncbi_api_key,
    google_api_key=google_api_key
)

res = obj.run()