In [10]:
import re
import os
from collections import defaultdict
from datetime import datetime
from typing import List, Dict, Tuple, Optional

import numpy as np
from dateutil.parser import parse
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util

# Optional numeric parsers (we try them in order if available)
try:
    from number_parser import parse_number as _np_parse_number  # pip install number-parser
except Exception:
    _np_parse_number = None

try:
    from word2number import w2n  # pip install word2number
except Exception:
    w2n = None

# -------------------- small utilities --------------------

def _strip_commas_int(s: str) -> Optional[int]:
    s = (s or "").strip()
    if not s:
        return None
    m = re.search(r"\d[\d,]*", s)
    if not m:
        return None
    return int(m.group(0).replace(",", ""))

def _normalize_wordnum(s: str) -> str:
    # normalize dash variants and whitespace
    return (s or "").replace("–", "-").replace("—", "-").replace("−", "-").strip()

def _int_from_strnum(s: str) -> Optional[int]:
    """
    Convert a string number to int.
    Supports digits with commas, 'Sixty-seven', 'twenty one', etc.
    Tries: direct digits -> number_parser -> word2number -> custom manual.
    """
    if not s or not isinstance(s, str):
        return None

    s = _normalize_wordnum(s)

    # 1) Digits like "67", "88,345"
    d = _strip_commas_int(s)
    if d is not None:
        return d

    # 2) number-parser (handles hyphenated/cased words)
    if _np_parse_number is not None:
        try:
            val = _np_parse_number(s)
            if isinstance(val, (int, float)):
                return int(val)
        except Exception:
            pass

    # 3) word2number (works better for lowercase tokens)
    if w2n is not None:
        try:
            val = w2n.word_to_num(s.lower().replace("-", " "))
            if isinstance(val, int):
                return val
        except Exception:
            pass

    # 4) last-ditch: small manual parser
    words = re.split(r"[-\s]+", s.lower())
    basic_map = {
        "zero":0,"one":1,"two":2,"three":3,"four":4,"five":5,"six":6,"seven":7,"eight":8,"nine":9,"ten":10,
        "eleven":11,"twelve":12,"thirteen":13,"fourteen":14,"fifteen":15,"sixteen":16,"seventeen":17,"eighteen":18,
        "nineteen":19,"twenty":20,"thirty":30,"forty":40,"fifty":50,"sixty":60,"seventy":70,"eighty":80,"ninety":90,
        "hundred":100,"thousand":1000,"million":1_000_000
    }
    total, cur = 0, 0
    had_any = False
    for w in words:
        if w in ("hundred","thousand","million"):
            if cur == 0:
                cur = 1
            cur *= basic_map[w]
            total += cur
            cur = 0
            had_any = True
        elif w in basic_map:
            cur += basic_map[w]
            had_any = True
    total += cur
    return total if had_any and total > 0 else None

def split_sentences(text: str) -> List[str]:
    parts = re.split(r'(?<=[\.\?!])\s+', text or "")
    return [p for p in parts if p and p.strip()]

def parse_latest_date(cands: List[str]) -> Optional[str]:
    best = None
    for x in set(cands or []):
        try:
            # strip ordinals like 1st/2nd/3rd/4th
            dt = parse(re.sub(r'(\d+)(st|nd|rd|th)\b', r'\1', x), fuzzy=True)
            if best is None or dt > best[0]:
                best = (dt, x)
        except Exception:
            continue
    return best[1] if best else None

# -------------------- main extractor --------------------

class FastContextExtractor:
    """
    Fast hybrid extractor:
      • Sentence embeddings (MiniLM) for context retrieval
      • Tiny QA head for dates/counts anchoring
      • Robust rule-based guards for 'included studies'
      • Country counts + sample sizes (n=) with global + contextual scan
      sent_model_name: str = "sentence-transformers/all-MiniLM-L6-v2","allenai/specter2_base"
      qa_model_name: str = "deepset/tinyroberta-squad2",
    """

    def __init__(
        self,
        sent_model_name: str = "allenai/specter2_base",
        qa_model_name: str = "deepset/tinyroberta-squad2",
        top_k: int = 8,
        device: str = "cpu"
    ):
        self.embedder = SentenceTransformer(sent_model_name, device=device)
        self.qa = pipeline("question-answering", model=qa_model_name, tokenizer=qa_model_name,
                           device=-1 if device == "cpu" else 0)
        self.top_k = top_k

        # Countries list (can be swapped for pycountry for full coverage)
        self.countries = set([
            "Afghanistan","Albania","Algeria","Andorra","Angola","Argentina","Armenia","Australia","Austria",
            "Azerbaijan","Bahamas","Bahrain","Bangladesh","Barbados","Belarus","Belgium","Belize","Benin","Bhutan",
            "Bolivia","Bosnia and Herzegovina","Botswana","Brazil","Brunei","Bulgaria","Burkina Faso","Burundi",
            "Cabo Verde","Cambodia","Cameroon","Canada","Central African Republic","Chad","Chile","China","Colombia",
            "Comoros","Congo","Costa Rica","Croatia","Cuba","Cyprus","Czech Republic","Denmark","Djibouti","Dominica",
            "Dominican Republic","Ecuador","Egypt","El Salvador","Equatorial Guinea","Eritrea","Estonia","Eswatini",
            "Ethiopia","Fiji","Finland","France","Gabon","Gambia","Georgia","Germany","Ghana","Greece","Grenada",
            "Guatemala","Guinea","Guinea-Bissau","Guyana","Haiti","Honduras","Hungary","Iceland","India","Indonesia",
            "Iran","Iraq","Ireland","Israel","Italy","Jamaica","Japan","Jordan","Kazakhstan","Kenya","Kiribati",
            "Kuwait","Kyrgyzstan","Laos","Latvia","Lebanon","Lesotho","Liberia","Libya","Liechtenstein","Lithuania",
            "Luxembourg","Madagascar","Malawi","Malaysia","Maldives","Mali","Malta","Mauritania","Mauritius","Mexico",
            "Moldova","Monaco","Mongolia","Montenegro","Morocco","Mozambique","Myanmar","Namibia","Nauru","Nepal",
            "Netherlands","New Zealand","Nicaragua","Niger","Nigeria","North Korea","North Macedonia","Norway","Oman",
            "Pakistan","Palau","Panama","Papua New Guinea","Paraguay","Peru","Philippines","Poland","Portugal","Qatar",
            "Romania","Russia","Rwanda","Saint Kitts and Nevis","Saint Lucia","Saint Vincent and the Grenadines","Samoa",
            "San Marino","Sao Tome and Principe","Saudi Arabia","Senegal","Serbia","Seychelles","Sierra Leone","Singapore",
            "Slovakia","Slovenia","Solomon Islands","Somalia","South Africa","South Korea","South Sudan","Spain","Sri Lanka",
            "Sudan","Suriname","Sweden","Switzerland","Syria","Taiwan","Tajikistan","Tanzania","Thailand","Timor-Leste",
            "Togo","Tonga","Trinidad and Tobago","Tunisia","Turkey","Turkmenistan","Tuvalu","Uganda","Ukraine",
            "United Arab Emirates","United Kingdom","United States","Uruguay","Uzbekistan","Vanuatu","Vatican City",
            "Venezuela","Vietnam","Yemen","Zambia","Zimbabwe"
        ])

        # QA prompts
        self.q_date = "What is the last literature search date?"
        self.q_total = "How many studies were included?"
        self.q_rct   = "How many randomized controlled trials?"
        self.q_coh   = "How many cohort studies?"
        self.q_cc    = "How many case-control studies?"
        self.q_cs    = "How many cross-sectional studies?"
        self.q_nrsi  = "How many non-randomized studies?"

        # Retrieval prompts
        self.p_date = "last search date; literature search date; search conducted on; last updated"
        self.p_stud = "included studies; eligible studies; PRISMA; randomized controlled trials; cohort; case-control; cross-sectional; non-randomized; study count"
        self.p_ctry = "country counts for included studies; numbers per country; Italy (1), Germany (10); sample sizes by country (n=); locations of studies; participants from"

        # Inclusion guards / decoy blockers
        self._include_verbs = re.compile(
            r"\b(included|were included|was included|retained|synthesi[sz]ed|meta-?analy[sz]ed|analy[sz]ed|used for analysis|for data extraction|in the review)\b",
            flags=re.I
        )
        self._decoy_block = re.compile(
            r"\b(excluded|missing data|dropped|non-eligible|not included|screened out|removed)\b",
            flags=re.I
        )

    # ---------------- embeddings / QA helpers ----------------

    def _top_k_sents(self, sentences: List[str], prompt: str, k: Optional[int] = None) -> List[Tuple[int, str]]:
        if not sentences:
            return []
        sentences = [s for s in sentences if s and s.strip()]
        if not sentences:
            return []
        k = k or self.top_k
        embs = self.embedder.encode(sentences, convert_to_tensor=True, normalize_embeddings=True)
        qemb = self.embedder.encode([prompt], convert_to_tensor=True, normalize_embeddings=True)[0]
        sims = util.cos_sim(embs, qemb).cpu().numpy().reshape(-1)
        idx = np.argsort(-sims)[:k]
        return [(int(i), sentences[int(i)]) for i in idx if sentences[int(i)].strip()]

    def _best_qa(self, question: str, contexts: List[str]) -> Dict:
        best = {"answer": "", "score": 0.0, "context": ""}
        for c in contexts:
            if not c or not c.strip():
                continue
            try:
                out = self.qa({"question": question, "context": c})
            except Exception:
                continue
            if out.get("score", 0.0) > best["score"]:
                best = {"answer": out.get("answer","") or "", "score": out["score"], "context": c}
        return best

    def _is_included_context(self, s: str) -> bool:
        # Sentence likely refers to included/synthesized material (avoid decoys)
        if self._decoy_block.search(s):
            return False
        return bool(self._include_verbs.search(s) or re.search(r"\b(included|eligible)\s+stud(?:y|ies)\b", s, flags=re.I))

    # ------------------- public APIs -------------------

    def extract_lit_search_date(self, text: str) -> Tuple[Optional[str], Dict]:
        sents = split_sentences(text)
        cand = self._top_k_sents(sents, self.p_date)
        contexts = [s for _, s in cand if s and s.strip()]
        if not contexts:
            # global fallback
            date_like = re.findall(
                r"(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|"
                r"Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2},?\s+\d{4}|"
                r"(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|"
                r"Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{4}|"
                r"\d{4}-\d{2}-\d{2}",
                text, flags=re.I
            )
            return parse_latest_date(date_like), {"evidence": ""}

        qa_res = self._best_qa(self.q_date, contexts)

        # dates found in contexts (for fallback / latest)
        date_like = []
        for s in contexts:
            date_like += re.findall(
                r"(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|"
                r"Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2},?\s+\d{4}|"
                r"(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|"
                r"Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{4}|"
                r"\d{4}-\d{2}-\d{2}",
                s, flags=re.I
            )

        val = qa_res["answer"].strip() if qa_res.get("answer") else ""
        if not val:
            val = parse_latest_date(date_like) or ""
        evidence = qa_res["context"] if qa_res.get("context") else (contexts[0] if contexts else "")
        return (val or None), {"evidence": evidence}

    def _extract_number_near_studies(self, s: str) -> Optional[int]:
        """
        Strong rule-based patterns targeting *included* study totals.
        Handles both 'Sixty-seven studies ... were included' and 'included 67 studies'.
        Avoids decoys with 'excluded', 'missing data', etc.
        """
        if not self._is_included_context(s):
            return None

        # Pattern A: "<number> studies/articles/trials ... (were) included/analyzed/synthesized"
        pat_a = re.compile(
            r"\b([A-Za-z][A-Za-z\- ]+|\d[\d,]*)\s+(?:stud(?:y|ies)|articles?|trials?)\b[^.]{0,120}?\b"
            r"(?:included|analy[sz]ed|synthesi[sz]ed|meta-?analy[sz]ed|used for analysis|for data extraction|in the review)\b",
            flags=re.I
        )

        # Pattern B: "(were) included/analyzed/... <number> studies"
        pat_b = re.compile(
            r"\b(?:included|analy[sz]ed|synthesi[sz]ed|meta-?analy[sz]ed|used for analysis|for data extraction|in the review)\b"
            r"[^.]{0,120}?\b([A-Za-z][A-Za-z\- ]+|\d[\d,]*)\s+(?:stud(?:y|ies)|articles?|trials?)\b",
            flags=re.I
        )

        for pat in (pat_a, pat_b):
            m = pat.search(s)
            if m:
                n = _int_from_strnum(m.group(1))
                if n is not None:
                    return n
        return None

    def _ask_int(self, question: str, contexts: List[str]) -> Tuple[Optional[int], Dict]:
        res = self._best_qa(question, contexts)
        # try to coerce to int (supports word numbers)
        m = re.search(r"[A-Za-z][A-Za-z\- ]+|\d[\d,]*", res.get("answer",""))
        val = _int_from_strnum(m.group(0)) if m else None
        return val, res

    def extract_study_counts(self, text: str) -> Tuple[Dict, List[Dict]]:
        sents = split_sentences(text)
        cand = self._top_k_sents(sents, self.p_stud, k=max(self.top_k, 14))
        contexts = [s for _, s in cand if s and s.strip()]
        evid = []

        # Rank contexts: prioritize those that look like "included studies"
        contexts_ranked = sorted(
            contexts,
            key=lambda s: (
                0 if self._is_included_context(s) else 1,
                -len(re.findall(r'\bstud(?:y|ies)\b', s, flags=re.I))
            )
        )

        # RULE FIRST for total (stronger precision)
        total_rule = None
        for s in contexts_ranked[:12]:
            n = self._extract_number_near_studies(s)
            if n is not None:
                total_rule = n
                evid.append({"field": "studies.total_rule", "text": s})
                break

        # QA backstop
        total_qa, ev_total = self._ask_int(self.q_total, contexts_ranked[:8] if contexts_ranked else [])
        if total_rule is None and ev_total.get("context"):
            evid.append({"field": "studies.total_qa", "text": ev_total["context"]})

        total = total_rule if total_rule is not None else (total_qa or 0)

        # Typed counts (regex + QA)
        def typed_from_regex(kind: str, labels: List[str]) -> Optional[int]:
            # e.g., "14 randomized controlled trials", "three cohort studies"
            pat = re.compile(
                rf"\b([A-Za-z][A-Za-z\- ]+|\d[\d,]*)\s+(?:{'|'.join(labels)})\b",
                flags=re.I
            )
            for s in contexts_ranked[:16]:
                if not self._is_included_context(s) and "stud" not in s.lower():
                    continue
                m = pat.search(s)
                if m:
                    n = _int_from_strnum(m.group(1))
                    if n is not None:
                        evid.append({"field": f"studies.{kind}", "text": s})
                        return n
            return None

        rct_regex = typed_from_regex("rct", ["randomi[sz]ed\\s+controlled\\s+trials?", "rcts?"])
        coh_regex = typed_from_regex("cohort", ["cohort\\s+stud(?:y|ies)"])
        cc_regex  = typed_from_regex("case_control", ["case[-\\s]?control\\s+stud(?:y|ies)"])
        cs_regex  = typed_from_regex("cross_sectional", ["cross[-\\s]?sectional\\s+stud(?:y|ies)"])
        nrsi_regex= typed_from_regex("nrsi", ["non[-\\s]?randomi[sz]ed\\s+stud(?:y|ies)", "observational\\s+stud(?:y|ies)"])

        # QA as backup
        rct_qa, ev_rct = (rct_regex, {}) if rct_regex is not None else self._ask_int(self.q_rct, contexts_ranked[:8])
        coh_qa, ev_coh = (coh_regex, {}) if coh_regex is not None else self._ask_int(self.q_coh, contexts_ranked[:8])
        cc_qa,  ev_cc  = (cc_regex,  {}) if cc_regex  is not None else self._ask_int(self.q_cc,  contexts_ranked[:8])
        cs_qa,  ev_cs  = (cs_regex,  {}) if cs_regex  is not None else self._ask_int(self.q_cs,  contexts_ranked[:8])
        nrsi_qa,ev_nrsi= (nrsi_regex,{}) if nrsi_regex is not None else self._ask_int(self.q_nrsi,contexts_ranked[:8])

        for k, ev in [("rct",ev_rct), ("cohort",ev_coh), ("case_control",ev_cc), ("cross_sectional",ev_cs), ("nrsi",ev_nrsi)]:
            if ev.get("context"):
                evid.append({"field": f"studies.{k}", "text": ev["context"]})

        parts = {
            "rct": rct_qa or 0,
            "cohort": coh_qa or 0,
            "case_control": cc_qa or 0,
            "cross_sectional": cs_qa or 0,
            "nrsi": nrsi_qa or 0
        }

        # Sanity reconcile: if typed sum > total, downscale proportionally
        sum_parts = sum(v for v in parts.values() if isinstance(v, int))
        if total and sum_parts > total:
            factor = total / sum_parts if sum_parts else 1.0
            for k in parts:
                parts[k] = int(round((parts[k] or 0) * factor))

        # Optional: try to capture number of countries mentioned (e.g., "from 25 countries were included")
        num_countries = None
        pat_c = re.compile(
            r"\bfrom\s+([A-Za-z][A-Za-z\- ]+|\d[\d,]*)\s+countries?\b[^.]{0,80}?\b(included|analy[sz]ed|synthesi[sz]ed|used)\b",
            flags=re.I
        )
        for s in contexts_ranked[:12]:
            m = pat_c.search(s)
            if m:
                nc = _int_from_strnum(m.group(1))
                if nc:
                    num_countries = nc
                    evid.append({"field": "studies.num_countries_included", "text": s})
                    break

        out = {"total": total, **parts}
        if num_countries:
            out["num_countries_included"] = num_countries
        return out, evid

    # ---------- Country extraction (global + contextual, aliases, multiple patterns) ----------

    def extract_country_counts(self, text: str) -> Tuple[Dict[str, Dict[str, int]], List[Dict]]:
        """
        Extract both:
          • study_counts: Italy (1), Germany (10) ...
          • sample_sizes: Country (n = 35,812) and variants
        """

        evid: List[Dict] = []

        # --- alias normalization ---
        def norm_country(name: str) -> Optional[str]:
            name = (name or "").strip()
            alias = {
                "US": "United States", "U.S.": "United States", "U.S": "United States", "USA": "United States",
                "UK": "United Kingdom", "U.K.": "United Kingdom", "U.K": "United Kingdom",
                "UAE": "United Arab Emirates",
                "Korea": "South Korea",  # often ambiguous in abstracts; adjust if needed
            }
            if name in alias:
                name = alias[name]
            if name.lower().startswith("the "):
                name = name[4:]
            # keep only well-formed proper name tokens
            if not re.match(r"^[A-Z][A-Za-z]+(?:\s[A-Z][A-Za-z]+)*$", name):
                return None
            return name if name in self.countries else name  # if a clean proper noun, keep; you can restrict to known set if preferred

        # --- collectors ---
        study_counts: Dict[str, int] = defaultdict(int)
        sample_sizes: Dict[str, int] = defaultdict(int)

        # --- patterns ---
        # (1) Explicit parentheses counts: "Italy (1), Germany (10)"
        pat_counts = re.compile(r"\b([A-Z][A-Za-z]*(?:\s[A-Z][A-Za-z]*)*)\s*\(\s*(\d{1,4})\s*\)")

        # (2) Parenthesized sample sizes: "Country (n = 35,812)" or "Country (N=7670)" with spaces
        pat_samples = re.compile(r"\b([A-Z][A-Za-z]*(?:\s[A-Z][A-Za-z]*)*)\s*\(\s*[nN]\s*=\s*([\d\s,]+)\s*\)")

        # (3) “<number> participants from Country” or “… in Country”
        pat_participants_from = re.compile(
            r"\b([\d][\d,\s]*)\s+(?:participants|subjects)\s+(?:from|in)\s+([A-Z][A-Za-z]*(?:\s[A-Z][A-Za-z]*)*)",
            flags=re.I
        )

        # (4) “in Country (n=...)” is covered by (2); also capture “Country had 35,812 participants”
        pat_ctry_had_participants = re.compile(
            r"\b([A-Z][A-Za-z]*(?:\s[A-Z][A-Za-z]*)*)\b[^.]{0,40}?\b(?:had|with)\s+([\d][\d,\s]*)\s+(?:participants|subjects)\b",
            flags=re.I
        )

        # (5) “X studies in/from Country”
        pat_studies_in_ctry = re.compile(
            r"\b([A-Za-z][A-Za-z\- ]+|\d[\d,]*)\s+stud(?:y|ies)\s+(?:in|from)\s+([A-Z][A-Za-z]*(?:\s[A-Z][A-Za-z]*)*)",
            flags=re.I
        )

        # --- 1) GLOBAL SCAN (high recall) ---
        for m in pat_samples.finditer(text):
            c_raw, n_raw = m.group(1), m.group(2)
            c = norm_country(c_raw)
            n = _strip_commas_int(n_raw.replace(" ", ""))
            if c and n is not None:
                sample_sizes[c] = max(sample_sizes[c], n)
                evid.append({"field": f"countries.sample_sizes.{c}", "text": m.group(0)})

        for m in pat_participants_from.finditer(text):
            n_raw, c_raw = m.group(1), m.group(2)
            c = norm_country(c_raw)
            n = _strip_commas_int(n_raw.replace(" ", ""))
            if c and n is not None:
                sample_sizes[c] = max(sample_sizes[c], n)
                evid.append({"field": f"countries.sample_sizes.{c}", "text": m.group(0)})

        for m in pat_ctry_had_participants.finditer(text):
            c_raw, n_raw = m.group(1), m.group(2)
            c = norm_country(c_raw)
            n = _strip_commas_int(n_raw.replace(" ", ""))
            if c and n is not None:
                sample_sizes[c] = max(sample_sizes[c], n)
                evid.append({"field": f"countries.sample_sizes.{c}", "text": m.group(0)})

        for m in pat_counts.finditer(text):
            c_raw, n_raw = m.group(1), m.group(2)
            c = norm_country(c_raw)
            n = _int_from_strnum(n_raw)
            if c and n is not None:
                study_counts[c] += n
                evid.append({"field": f"countries.study_counts.{c}", "text": m.group(0)})

        for m in pat_studies_in_ctry.finditer(text):
            n_raw, c_raw = m.group(1), m.group(2)
            c = norm_country(c_raw)
            n = _int_from_strnum(n_raw)
            if c and n is not None:
                study_counts[c] += n
                evid.append({"field": f"countries.study_counts.{c}", "text": m.group(0)})

        # --- 2) CONTEXTUAL SCAN (top-K sentences) for compact evidence / missed items ---
        sents = split_sentences(text)
        cand = self._top_k_sents(sents, self.p_ctry, k=max(12, self.top_k))
        contexts = [s for _, s in cand if s and s.strip()]

        for s in contexts:
            for m in pat_samples.finditer(s):
                c_raw, n_raw = m.group(1), m.group(2)
                c = norm_country(c_raw)
                n = _strip_commas_int(n_raw.replace(" ", ""))
                if c and n is not None:
                    sample_sizes[c] = max(sample_sizes[c], n)
                    evid.append({"field": f"countries.sample_sizes.{c}", "text": s})

            for m in pat_participants_from.finditer(s):
                n_raw, c_raw = m.group(1), m.group(2)
                c = norm_country(c_raw)
                n = _strip_commas_int(n_raw.replace(" ", ""))
                if c and n is not None:
                    sample_sizes[c] = max(sample_sizes[c], n)
                    evid.append({"field": f"countries.sample_sizes.{c}", "text": s})

            for m in pat_ctry_had_participants.finditer(s):
                c_raw, n_raw = m.group(1), m.group(2)
                c = norm_country(c_raw)
                n = _strip_commas_int(n_raw.replace(" ", ""))
                if c and n is not None:
                    sample_sizes[c] = max(sample_sizes[c], n)
                    evid.append({"field": f"countries.sample_sizes.{c}", "text": s})

            for m in pat_counts.finditer(s):
                c_raw, n_raw = m.group(1), m.group(2)
                c = norm_country(c_raw)
                n = _int_from_strnum(n_raw)
                if c and n is not None:
                    study_counts[c] += n
                    evid.append({"field": f"countries.study_counts.{c}", "text": s})

            for m in pat_studies_in_ctry.finditer(s):
                n_raw, c_raw = m.group(1), m.group(2)
                c = norm_country(c_raw)
                n = _int_from_strnum(n_raw)
                if c and n is not None:
                    study_counts[c] += n
                    evid.append({"field": f"countries.study_counts.{c}", "text": s})

        # sort outputs
        study_counts_sorted = dict(sorted(study_counts.items(), key=lambda x: (-x[1], x[0])))
        sample_sizes_sorted = dict(sorted(sample_sizes.items(), key=lambda x: (-x[1], x[0])))

        return {"study_counts": study_counts_sorted, "sample_sizes": sample_sizes_sorted}, evid

    def extract_all(self, text: str) -> Dict:
        lit_date, ev_date = self.extract_lit_search_date(text)
        studies, ev_stud  = self.extract_study_counts(text)
        countries, ev_cty = self.extract_country_counts(text)

        # Sort study_counts by desc
        if "study_counts" in countries:
            countries["study_counts"] = dict(sorted(countries["study_counts"].items(), key=lambda x: (-x[1], x[0])))
        # Sort sample_sizes by desc
        if "sample_sizes" in countries:
            countries["sample_sizes"] = dict(sorted(countries["sample_sizes"].items(), key=lambda x: (-x[1], x[0])))

        evidence = []
        if ev_date.get("evidence"):
            evidence.append({"field":"lit_search_date", "text": ev_date["evidence"]})
        evidence.extend(ev_stud)
        evidence.extend(ev_cty)
        # cap evidence list for readability
        evidence = evidence[:40]

        return {
            "lit_search_date": lit_date,
            "studies": studies,
            "countries": countries,
            "_evidence": evidence
        }

# -------------------- quick usage example --------------------
if __name__ == "__main__":
    text = """
    
    Abstract
Objectives
To assess the impact of restricting systematic reviews of conventional or alternative medical treatments or diagnostic tests to English-language publications.
Study design and setting
We systematically searched MEDLINE (Ovid), the Science Citation Index Expanded (Web of Science), and Current Contents Connect (Web of Science) up to April 24, 2020. Eligible methods studies assessed the impact of restricting systematic reviews to English-language publications on effect estimates and conclusions. Two reviewers independently screened the literature; one investigator performed the data extraction, a second investigator checked for completeness and accuracy. We synthesized the findings narratively.
Results
Eight methods studies (10 publications) met the inclusion criteria; none addressed language restrictions in diagnostic test accuracy reviews. The included studies analyzed nine to 147 meta-analyses and/or systematic reviews. The proportions of non-English-language publications ranged from 2% to 100%. Based on five methods studies, restricting literature searches or inclusion criteria to English-language publications led to a change in statistical significance in 23/259 meta-analyses (9%). Most commonly, the statistical significance was lost, but had no impact on the conclusions of systematic reviews.
    
    """

    print("Device set to use cpu")
    extractor = FastContextExtractor(device="mps")
    result = extractor.extract_all(text)
    print(result)


2025-11-12 13:26:09,115 - INFO - Load pretrained SentenceTransformer: allenai/specter2_base


Device set to use cpu


Device set to use mps:0
Batches: 100%|██████████| 1/1 [00:00<00:00, 40.55it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 38.55it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 100.10it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 61.21it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 104.16it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 28.17it/s]

{'lit_search_date': 'April 24, 2020', 'studies': {'total': 8, 'rct': 1, 'cohort': 3, 'case_control': 1, 'cross_sectional': 3, 'nrsi': 1}, 'countries': {'study_counts': {}, 'sample_sizes': {}}, '_evidence': [{'field': 'lit_search_date', 'text': 'Study design and setting\nWe systematically searched MEDLINE (Ovid), the Science Citation Index Expanded (Web of Science), and Current Contents Connect (Web of Science) up to April 24, 2020.'}, {'field': 'studies.total_qa', 'text': 'Results\nEight methods studies (10 publications) met the inclusion criteria; none addressed language restrictions in diagnostic test accuracy reviews.'}, {'field': 'studies.rct', 'text': 'Results\nEight methods studies (10 publications) met the inclusion criteria; none addressed language restrictions in diagnostic test accuracy reviews.'}, {'field': 'studies.cohort', 'text': 'Based on five methods studies, restricting literature searches or inclusion criteria to English-language publications led to a change in statis




In [1]:
# --------- quick usage ---------
from src.AIModels.Inference import SRPredictor


if __name__ == "__main__":
    text = """
    We searched MEDLINE and EMBASE through March 2024. Results: Sixty-seven studies (n = 88,345 participants)
    from 25 countries were included. The largest sample sizes were from China (n = 35,812), the United States (n = 10,437),
    and Germany (n = 7,670). Other notable sample sizes included Bangladesh (n = 2,909), Ethiopia (n = 2,033),
    Egypt (n = 3,204), and Vietnam (n = 2,945). In total, 14 randomized controlled trials, 10 cohort studies,
    and three case-control studies were analyzed; many cross-sectional studies were also included.
    Outcomes included infection, hospitalization, ICU admission, and death. We studied influenza and HPV vaccines,
    including quadrivalent (4vHPV) and bivalent (2vHPV) options; both live and non-live formulations were reported.
    Populations included adolescents, adults, and elderly; specific groups included nurses and pregnant women, with
    some immunocompromised patients. Studies were conducted in Europe, Asia, and Sub-Saharan Africa.
    We searched MEDLINE (Ovid), Embase, Web of Science, and Cochrane Library from inception to March 2024.
    Google Scholar and ClinicalTrials.gov were also screened; grey literature included OpenGrey.
    We included sixty-seven studies. The intervention was given as two-dose schedule at 0, 1, and 6 months,
    0.5 mL per dose (500 μg). Follow-up of 12 months. Participants were compared with placebo and usual care groups.
    We considered COVID-19 infection, hospitalization, ICU admission, and death as outcomes.

    """

    model_path = "sentence-transformers/all-MiniLM-L6-v2"  # or "allenai/specter2_base" or "sentence-transformers/all-MiniLM-L6-v2"
    print(f"Device set to use: auto")
    pred = SRPredictor(model_path=model_path, device=None, top_k=12)
    out = pred.predict_all(text)
    from pprint import pprint
    pprint(out)

  from .autonotebook import tqdm as notebook_tqdm


Device set to use: auto
{'age_groups': {'ado_10__17': 1, 'adu_18__64': 1, 'eld_65__10000': 1},
 'articles': {'articles_included': 67,
              'case_control': 3,
              'cohort': 10,
              'cross_sectional': 0,
              'nrsi': 0,
              'num_countries_included': 25,
              'rct': 14,
              'total': 67,
              'unique_studies': 0},
 'countries': {'sample_sizes': {'Bangladesh': 2909,
                                'China': 35812,
                                'Egypt': 3204,
                                'Ethiopia': 2033,
                                'Germany': 7670,
                                'United States': 10437,
                                'Vietnam': 2945},
               'study_counts': {}},
 'databases': {'database_list': ['ClinicalTrials.gov',
                                 'Cochrane Library',
                                 'Embase',
                                 'Google Scholar',
                      

In [2]:
topics_term_list = [f"{k}:{v}" for k,v in topics_terms.items()]


NameError: name 'topics_terms' is not defined

In [None]:
from src.AIModels.Training import run_training


if __name__ == "__main__":
    # choose device automatically if you want:
    # device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
    print("Device set via arg or default CPU")
    run_training(
        jsonl_path="sr_auto_train.jsonl",
        device="mps",      # change to "mps" or "cuda" if available
        epochs=3,
        batch_size=32,
        lr=2e-5
    )


  from .autonotebook import tqdm as notebook_tqdm
2025-11-12 17:16:39,128 - INFO - Loaded 1368 rows from sr_auto_train.jsonl
2025-11-12 17:16:39,128 - INFO - Train: 1162 | Eval: 206
2025-11-12 17:16:39,128 - INFO - Loading native SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2
2025-11-12 17:16:39,130 - INFO - Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2


Device set via arg or default CPU


2025-11-12 17:16:41,204 - INFO - Eval corpus size: 2206 docs | queries: 206
                                                                     

Step,Training Loss,Validation Loss,Sr Ir Eval Minilm Cosine Accuracy@1,Sr Ir Eval Minilm Cosine Accuracy@3,Sr Ir Eval Minilm Cosine Accuracy@5,Sr Ir Eval Minilm Cosine Accuracy@10,Sr Ir Eval Minilm Cosine Precision@1,Sr Ir Eval Minilm Cosine Precision@3,Sr Ir Eval Minilm Cosine Precision@5,Sr Ir Eval Minilm Cosine Precision@10,Sr Ir Eval Minilm Cosine Recall@1,Sr Ir Eval Minilm Cosine Recall@3,Sr Ir Eval Minilm Cosine Recall@5,Sr Ir Eval Minilm Cosine Recall@10,Sr Ir Eval Minilm Cosine Ndcg@10,Sr Ir Eval Minilm Cosine Mrr@10,Sr Ir Eval Minilm Cosine Map@1,Sr Ir Eval Minilm Cosine Map@3,Sr Ir Eval Minilm Cosine Map@5,Sr Ir Eval Minilm Cosine Map@10
37,No log,No log,0.019417,0.048544,0.053398,0.067961,0.019417,0.016181,0.01068,0.006796,0.019417,0.048544,0.053398,0.067961,0.044128,0.036494,0.019417,0.033172,0.034385,0.036494
74,No log,No log,0.009709,0.038835,0.053398,0.063107,0.009709,0.012945,0.01068,0.006311,0.009709,0.038835,0.053398,0.063107,0.03739,0.029045,0.009709,0.024272,0.027427,0.029045
111,No log,No log,0.014563,0.033981,0.048544,0.067961,0.014563,0.011327,0.009709,0.006796,0.014563,0.033981,0.048544,0.067961,0.039355,0.030432,0.014563,0.024272,0.027427,0.030432


2025-11-12 17:16:54,374 - INFO - Information Retrieval Evaluation of the model on the sr_ir_eval_minilm dataset in epoch 1.0 after 37 steps:
Batches: 100%|██████████| 7/7 [00:00<00:00,  8.19it/s]
Batches: 100%|██████████| 69/69 [00:02<00:00, 28.11it/s]
Corpus Chunks: 100%|██████████| 1/1 [00:03<00:00,  3.74s/it]
2025-11-12 17:16:58,975 - INFO - Queries: 206
2025-11-12 17:16:58,976 - INFO - Corpus: 2206

2025-11-12 17:16:58,979 - INFO - Score-Function: cosine
2025-11-12 17:16:58,980 - INFO - Accuracy@1: 1.94%
2025-11-12 17:16:58,980 - INFO - Accuracy@3: 4.85%
2025-11-12 17:16:58,980 - INFO - Accuracy@5: 5.34%
2025-11-12 17:16:58,980 - INFO - Accuracy@10: 6.80%
2025-11-12 17:16:58,981 - INFO - Precision@1: 1.94%
2025-11-12 17:16:58,981 - INFO - Precision@3: 1.62%
2025-11-12 17:16:58,981 - INFO - Precision@5: 1.07%
2025-11-12 17:16:58,981 - INFO - Precision@10: 0.68%
2025-11-12 17:16:58,981 - INFO - Recall@1: 1.94%
2025-11-12 17:16:58,981 - INFO - Recall@3: 4.85%
2025-11-12 17:16:58,982 -

Step,Training Loss,Validation Loss


RuntimeError: MPS backend out of memory (MPS allocated: 25.48 GB, other allocations: 1.55 GB, max allowed: 27.20 GB). Tried to allocate 384.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [None]:
from src.AIModels.DatasetGeneration import SRTrainingSetBuilder


if __name__ == "__main__":
    builder = SRTrainingSetBuilder(
        embedder_name="sentence-transformers/all-MiniLM-L6-v2",  # or None to skip semantic hard negatives
        device="cpu",                                           # "cuda", "mps", or "cpu"
        seed=13,
        use_pycountry=True
    )

    builder.build_from_dir(
        txt_dir="corpus_txt",          # your folder with cleaned *.txt fulltexts/sections
        out_jsonl="sr_auto_train.jsonl",
        per_doc_limit=None,            # or int to balance per document
        add_augmentations=True,
        semantic_negs_k=4
    )


2025-11-12 16:58:01,264 - INFO - Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2
Batches: 100%|██████████| 12/12 [00:00<00:00, 21.90it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 158.33it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 214.50it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 176.86it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 159.62it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 100.04it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 110.02it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 103.92it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 84.93it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 70.42it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 78.75it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 77.82it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 85.24it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 150.13it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 223.55it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 207.26it/s]

Wrote 1368 training rows to sr_auto_train.jsonl
Slot counts (before dedupe): {'date': 20, 'country': 20, 'region': 8, 'intervention_any': 344, 'intervention_covid': 186, 'design_any': 24, 'topic': 648, 'outcome_infection': 54, 'age_numeric': 18, 'age_groups': 21, 'specific_groups': 6, 'studies': 16, 'design_counts': 3}


In [None]:
from src.AIModels.DatasetGeneration import SRTrainingSetBuilder

builder = SRTrainingSetBuilder(
    embedder_name="allenai/specter2_base",  # <- Specter2
    device="mps",                           # or "cuda"/"cpu"
    seed=13,
    use_pycountry=True
)

builder.build_from_dir(
    txt_dir="corpus_txt",
    out_jsonl="sr_auto_train_allenai.jsonl",
    per_doc_limit=None,
    add_augmentations=True,
    semantic_negs_k=4
)


Batches: 100%|██████████| 12/12 [00:03<00:00,  3.61it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 19.19it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 101.21it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 97.04it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 103.44it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 18.34it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 100.50it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 107.10it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 103.45it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 13.13it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 102.23it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 103.34it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 90.83it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 14.63it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 88.64it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 100.84it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 109.22it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 22.02it/s]

Wrote 1368 training rows to sr_auto_train_allenai.jsonl
Slot counts (before dedupe): {'date': 20, 'country': 20, 'region': 8, 'intervention_any': 344, 'intervention_covid': 186, 'design_any': 24, 'topic': 648, 'outcome_infection': 54, 'age_numeric': 18, 'age_groups': 21, 'specific_groups': 6, 'studies': 16, 'design_counts': 3}
