In [1]:
!pip install gradio sentence-transformers faiss-cpu requests rank-bm25



In [2]:
import os
import zipfile
import tempfile
from typing import List, Dict, Callable, Tuple, Optional
from pathlib import Path
from datetime import date, datetime
import calendar
import json
import re
from dataclasses import dataclass

import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
import requests
import gradio as gr

from rank_bm25 import BM25Okapi


# ============================================================
# BM25 tokenizer (noņem diakritiku NEVAJAG; saglabā LV burtus)
# ============================================================
_TOKEN_RE = re.compile(r"[0-9A-Za-zĀāČčĒēĢģĪīĶķĻļŅņŠšŪūŽž]+", re.UNICODE)

def bm25_tokenize(text: str) -> List[str]:
    return _TOKEN_RE.findall((text or "").lower())


AUTO_MODE = "AUTO"
MANUAL_MODE = "MANUAL"

def sem_percent_to_alpha(sem_percent: float) -> float:
    p = float(sem_percent)
    if p < 0:
        p = 0.0
    if p > 100:
        p = 100.0
    return p / 100.0


# ===========================
# OpenRouter
# ===========================

DEFAULT_MODEL_ID = "meta-llama/llama-3.3-70b-instruct:free" # "anthropic/claude-3.5-sonnet"
OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions"
LOG_PATH = "lkp_rag_log.json"


def ask(
    prompt: str,
    api_key: str,
    model_id: Optional[str] = None,
    temperature: float = 0.0,
) -> str:
    if not api_key.strip():
        raise RuntimeError("OpenRouter API key nav norādīts.")

    model_id = (model_id or DEFAULT_MODEL_ID).strip()

    headers = {
        "Authorization": f"Bearer {api_key.strip()}",
        "Content-Type": "application/json",
    }

    data = {
        "model": model_id,
        "messages": [
            {
                "role": "system",
                "content": (
                    "Tu esi precīzs, kritisks vēsturnieks, kas analizē "
                    "Latvijas komunistisko pagrīdes organizāciju skrejlapas (1934–1940)."
                ),
            },
            {"role": "user", "content": prompt},
        ],
        "temperature": float(temperature),
    }

    resp = requests.post(OPENROUTER_URL, headers=headers, json=data, timeout=120)
    resp.raise_for_status()
    return resp.json()["choices"][0]["message"]["content"]


# ===========================
# JSON LOG (NEW FORMAT)
# ===========================
def _serialize_chunk_for_log(chunk: Dict) -> Dict:
    # STRICTLY pēc tavas BM25 programmas: pilns teksts + score_semantic/score_bm25
    return {
        "doc_id": chunk.get("doc_id"),
        "file_name": chunk.get("file_name"),
        "title": chunk.get("title"),
        "chunk_id": chunk.get("chunk_id"),
        "score": chunk.get("score"),
        "score_semantic": chunk.get("score_semantic"),
        "score_bm25": chunk.get("score_bm25"),
        "text": chunk.get("text", "") or "",
    }


def log_qa_event(
    question: str,
    answer: str,
    retrieved_chunks: List[Dict],
    model_id: str,
    top_k: int,
    temperature: float,
    query_profile: Optional["QueryProfile"] = None,
    weight_info: Optional[Dict] = None,
) -> None:
    event = {
        "timestamp_utc": datetime.utcnow().isoformat() + "Z",
        "question": question,
        "answer": answer,
        "model_id": model_id,
        "top_k": int(top_k),
        "temperature": float(temperature),

        "query_profile": None if query_profile is None else {
            "qtype": query_profile.qtype,
            "alpha_auto_suggested": float(query_profile.alpha),
            "alpha_effective": None if weight_info is None else float(weight_info["alpha_semantic"]),
            "sem_candidates": int(query_profile.sem_candidates),
            "bm25_candidates": int(query_profile.bm25_candidates),
        },

        "retrieval_weights": weight_info,

        "retrieved_chunks": [
            _serialize_chunk_for_log(c) for c in (retrieved_chunks or [])
        ],
    }

    try:
        if os.path.exists(LOG_PATH):
            try:
                with open(LOG_PATH, "r", encoding="utf-8") as f:
                    content = f.read().strip()
                    data = json.loads(content) if content else []
            except Exception:
                data = []
        else:
            data = []

        if not isinstance(data, list):
            data = [data]

        data.append(event)

        with open(LOG_PATH, "w", encoding="utf-8") as f:
            json.dump(data, f, ensure_ascii=False, indent=2)

    except Exception as e:
        print(f"Warning: could not write log: {e}")


def get_log_file_for_gui():
    if not os.path.exists(LOG_PATH):
        with open(LOG_PATH, "w", encoding="utf-8") as f:
            json.dump([], f, ensure_ascii=False, indent=2)
    return LOG_PATH


# ============================================================
# 1) Datumi / tirāža (no leaflet programmas)
# ============================================================
def _parse_plain_date_to_range(s: str) -> Tuple[Optional[date], Optional[date]]:
    s = (s or "").strip()
    if not s:
        return None, None

    parts = s.split("-")
    try:
        if len(parts) == 3:
            y, m, d = map(int, parts)
            dt = date(y, m, d)
            return dt, dt
        elif len(parts) == 2:
            y, m = map(int, parts)
            last_day = calendar.monthrange(y, m)[1]
            return date(y, m, 1), date(y, m, last_day)
        else:
            y = int(parts[0])
            return date(y, 1, 1), date(y, 12, 31)
    except Exception:
        return None, None


def parse_metadata_date_to_range(date_str: str) -> Tuple[Optional[date], Optional[date]]:
    if not date_str:
        return None, None

    date_str = date_str.strip()

    if not (date_str.startswith("[") and date_str.endswith("]")):
        return _parse_plain_date_to_range(date_str)

    inner = date_str[1:-1].strip()

    if "..." in inner:
        if inner.endswith("..."):
            left = inner[:-3].strip()
            start, end = _parse_plain_date_to_range(left)
            return start, None
        elif inner.startswith("..."):
            right = inner[3:].strip()
            start, end = _parse_plain_date_to_range(right)
            return None, end

    if ".." in inner:
        left, right = inner.split("..", 1)
        left = left.strip()
        right = right.strip()
        s1, e1 = _parse_plain_date_to_range(left)
        s2, e2 = _parse_plain_date_to_range(right)
        start = s1
        end = e2
        return start, end

    return _parse_plain_date_to_range(inner)


def parse_print_run_value(v: str) -> Optional[int]:
    if not v:
        return None
    v_low = v.strip().lower()
    if v_low == "unk":
        return None
    digits = "".join(ch for ch in v_low if ch.isdigit())
    if not digits:
        return None
    try:
        return int(digits)
    except ValueError:
        return None


def parse_user_date_box(s: str, is_start: bool) -> Optional[date]:
    s = (s or "").strip()
    if not s:
        return None
    parts = s.split("-")
    if len(parts) == 3:
        y, m, d = map(int, parts)
        return date(y, m, d)
    elif len(parts) == 2:
        y, m = map(int, parts)
        if is_start:
            return date(y, m, 1)
        else:
            last_day = calendar.monthrange(y, m)[1]
            return date(y, m, last_day)
    else:
        y = int(parts[0])
        return date(y, 1, 1) if is_start else date(y, 12, 31)


# ============================================================
# 1b) Loading leaflets from ZIP (no leaflet programmas)
# ============================================================
def parse_metadata(content: str) -> Dict:
    parts = content.split("text:", 1)
    metadata_text = parts[0]

    metadata: Dict = {
        "id": None,
        "file_name": "",
        "title": "",
        "author": "",
        "date": "",
        "print_run": "",
        "typography_name": "",
        "source": "",
        "text": "",
    }

    for line in metadata_text.split("\n"):
        line = line.strip()
        if not line:
            continue

        if ":" in line:
            key, value = line.split(":", 1)
            key = key.strip()
            value = value.strip()

            if key == "id":
                try:
                    metadata[key] = int(value)
                except ValueError:
                    metadata[key] = None
            elif key in metadata:
                metadata[key] = value

    if len(parts) > 1:
        metadata["text"] = parts[1].strip()

    raw_date = metadata.get("date", "")
    try:
        d_start, d_end = parse_metadata_date_to_range(raw_date)
    except Exception:
        d_start, d_end = None, None

    metadata["date_start"] = d_start
    metadata["date_end"] = d_end
    metadata["print_run_value"] = parse_print_run_value(metadata.get("print_run", ""))

    return metadata


def load_leaflets_from_zip(zip_path: str) -> List[Dict]:
    results: List[Dict] = []

    with tempfile.TemporaryDirectory() as temp_dir:
        with zipfile.ZipFile(zip_path, "r") as zip_ref:
            zip_ref.extractall(temp_dir)

        corpus_dir: Optional[str] = None
        if any(Path(temp_dir).glob("*.txt")):
            corpus_dir = temp_dir
        else:
            for item in os.listdir(temp_dir):
                potential_corpus_dir = os.path.join(temp_dir, item)
                if os.path.isdir(potential_corpus_dir) and any(
                    Path(potential_corpus_dir).glob("*.txt")
                ):
                    corpus_dir = potential_corpus_dir
                    break

        if not corpus_dir:
            raise ValueError("Cannot find corpus directory with .txt files in ZIP file")

        for file_path in Path(corpus_dir).glob("*.txt"):
            try:
                with open(file_path, "r", encoding="utf-8") as f:
                    content = f.read()

                leaflet_data = parse_metadata(content)
                leaflet_data["path"] = str(file_path)
                if not leaflet_data.get("file_name"):
                    leaflet_data["file_name"] = file_path.name
                if not leaflet_data.get("title"):
                    leaflet_data["title"] = file_path.stem

                results.append(leaflet_data)

            except Exception as e:
                print(f"Error processing {file_path}: {e}")

    return results


# ============================================================
# Chunking
# ============================================================
def chunk_text(text: str, max_words: int = 300, overlap_words: int = 60) -> List[str]:
    words = text.split()
    if not words:
        return []

    if overlap_words < 0:
        overlap_words = 0
    if overlap_words >= max_words:
        overlap_words = max_words // 3

    chunks: List[str] = []
    step = max_words - overlap_words

    start = 0
    n = len(words)
    while start < n:
        end = min(start + max_words, n)
        chunk = " ".join(words[start:end])
        chunks.append(chunk)
        if end == n:
            break
        start += step

    return chunks


# ============================================================
# Filters (no leaflet programmas)
# ============================================================
def chunk_matches_filters(chunk: Dict, filters: Dict) -> bool:
    # Date filter
    df = filters.get("date_from")
    dt = filters.get("date_to")
    if df or dt:
        s = chunk.get("date_start")
        e = chunk.get("date_end")
        if s is None and e is None:
            return False
        if df and e is not None and e < df:
            return False
        if dt and s is not None and s > dt:
            return False

    # Print run filter
    pr_min = filters.get("print_run_min")
    pr_max = filters.get("print_run_max")
    include_unk = bool(filters.get("include_unk_print_run", False))

    if pr_min is not None or pr_max is not None:
        pr = chunk.get("print_run_value")
        if pr is None:
            if not include_unk:
                return False
        else:
            if pr_min is not None and pr < pr_min:
                return False
            if pr_max is not None and pr > pr_max:
                return False

    # Org substring filter (author+source+title+file_name)
    org_subs = filters.get("org_substrings") or []
    if org_subs:
        org_meta = (
            (chunk.get("author", "") + " " +
             chunk.get("source", "") + " " +
             chunk.get("title", "") + " " +
             chunk.get("file_name", ""))
            .lower()
        )
        if not any(sub in org_meta for sub in org_subs):
            return False

    return True


# ============================================================
# QueryProfile (no BM25 programmas)
# ============================================================
@dataclass
class QueryProfile:
    qtype: str
    alpha: float
    sem_candidates: int
    bm25_candidates: int


_FACTOID_STARTS = {
    "kas", "kur", "kad", "cik",
    "kāds", "kāda", "kādi",
    "kurš", "kura", "kuru", "kuras",
}

def _has_year_or_number(q: str) -> bool:
    return bool(re.search(r"\b(\d{3,4}|\d+)\b", q))

def _has_quotes(q: str) -> bool:
    return any(ch in q for ch in ['"', "“", "”", "«", "»"])


def classify_query_rule_based(q: str, top_k: int) -> QueryProfile:
    q0 = (q or "").strip().lower()
    toks = bm25_tokenize(q0)
    if not toks:
        return QueryProfile(
            qtype="general",
            alpha=0.75,
            sem_candidates=max(top_k * 8, top_k),
            bm25_candidates=max(top_k * 3, top_k),
        )

    first = toks[0]
    n_tok = len(toks)

    fact_signals = 0
    if first in _FACTOID_STARTS:
        fact_signals += 2
    if _has_year_or_number(q0):
        fact_signals += 2
    if _has_quotes(q0):
        fact_signals += 1
    if n_tok <= 6:
        fact_signals += 1

    general_signals = 0
    if any(w in toks for w in [
        "kāpēc", "kādēļ", "paskaidro", "apraksti", "analizē", "novērtē",
        "raksturo", "salīdzini", "nozīme", "loma", "sekas", "iemesli",
    ]):
        general_signals += 2
    if n_tok >= 10:
        general_signals += 1

    if fact_signals >= general_signals + 1:
        return QueryProfile(
            qtype="factoid",
            alpha=0.40,
            sem_candidates=max(top_k * 3, top_k),
            bm25_candidates=max(top_k * 8, top_k),
        )

    return QueryProfile(
        qtype="general",
        alpha=0.70,
        sem_candidates=max(top_k * 8, top_k),
        bm25_candidates=max(top_k * 3, top_k),
    )


# ============================================================
# LeafletRAG HYBRID (SEM + BM25) + filters
# ============================================================
class LeafletRAGHybrid:
    def __init__(self, model_name: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"):
        self.model = SentenceTransformer(model_name)
        self.index: Optional[faiss.IndexFlatIP] = None
        self.chunks: List[Dict] = []
        self.embedding_dim: Optional[int] = None
        self.bm25: Optional[BM25Okapi] = None
        self.bm25_tokens: List[List[str]] = []

    def build_index(
        self,
        leaflets: List[Dict],
        max_words_per_chunk: int = 260,
        overlap_words: int = 60,
        normalize_embeddings: bool = True,
    ) -> None:
        all_texts: List[str] = []
        self.chunks = []

        for leaflet in leaflets:
            full_text = leaflet.get("text", "") or ""
            if not full_text.strip():
                continue

            leaflet_id = leaflet.get("id")
            file_name = leaflet.get("file_name", "")
            title = leaflet.get("title", "")
            date_str = leaflet.get("date", "")
            date_start = leaflet.get("date_start")
            date_end = leaflet.get("date_end")
            print_run = leaflet.get("print_run", "")
            print_run_value = leaflet.get("print_run_value")
            author = leaflet.get("author", "")
            source = leaflet.get("source", "")

            chunk_list = chunk_text(full_text, max_words=max_words_per_chunk, overlap_words=overlap_words)
            for i, chunk in enumerate(chunk_list):
                # doc_id = leaflet_id (lai JSON laukos būtu doc_id, kā tu gribi)
                self.chunks.append(
                    {
                        "doc_id": leaflet_id,
                        "leaflet_id": leaflet_id,  # saglabājam arī oriģinālo (var noderēt)
                        "file_name": file_name,
                        "title": title,
                        "date": date_str,
                        "date_start": date_start,
                        "date_end": date_end,
                        "print_run": print_run,
                        "print_run_value": print_run_value,
                        "author": author,
                        "source": source,
                        "chunk_id": i,
                        "text": chunk,
                    }
                )
                all_texts.append(chunk)

        if not all_texts:
            raise ValueError("No text chunks found. Cannot build index.")

        # BM25
        self.bm25_tokens = [bm25_tokenize(t) for t in all_texts]
        self.bm25 = BM25Okapi(self.bm25_tokens)

        # SEM embeddings + FAISS
        embeddings = self.model.encode(all_texts, convert_to_numpy=True)
        if normalize_embeddings:
            norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12
            embeddings = embeddings / norms

        self.embedding_dim = embeddings.shape[1]
        self.index = faiss.IndexFlatIP(self.embedding_dim)
        self.index.add(embeddings)

    def retrieve(
        self,
        query: str,
        top_k: int = 5,
        filters: Optional[Dict] = None,
        normalize_embeddings: bool = True,
        sem_candidates: Optional[int] = None,
        bm25_candidates: Optional[int] = None,
        weight_mode: str = AUTO_MODE,
        sem_percent: float = 40.0,
    ) -> Tuple[List[Dict], QueryProfile, Dict]:
        if self.index is None or not self.chunks:
            raise RuntimeError("Index is not built. Call build_index() first.")
        if self.bm25 is None or not self.bm25_tokens:
            raise RuntimeError("BM25 is not built. Check build_index().")

        if filters is None:
            filters = {}

        # AUTO profile vienmēr rēķinām (logam)
        profile = classify_query_rule_based(query, top_k=top_k)

        mode = (weight_mode or AUTO_MODE).upper().strip()
        if mode == MANUAL_MODE:
            alpha = sem_percent_to_alpha(sem_percent)
        else:
            alpha = profile.alpha

        weight_info = {
            "mode": "manual" if mode == MANUAL_MODE else "auto",
            "alpha_semantic": float(alpha),
            "alpha_bm25": float(1.0 - alpha),
            "sem_percent": float(alpha * 100.0),
            "bm25_percent": float((1.0 - alpha) * 100.0),
        }

        if sem_candidates is None:
            sem_candidates = profile.sem_candidates
        if bm25_candidates is None:
            bm25_candidates = profile.bm25_candidates

        # ---- SEM candidates (FAISS)
        query_emb = self.model.encode([query], convert_to_numpy=True)
        if normalize_embeddings:
            norms = np.linalg.norm(query_emb, axis=1, keepdims=True) + 1e-12
            query_emb = query_emb / norms

        sem_scores, sem_indices = self.index.search(query_emb, min(sem_candidates, len(self.chunks)))
        sem_scores = sem_scores[0]
        sem_indices = sem_indices[0]

        sem_map = {}
        for s, idx in zip(sem_scores, sem_indices):
            if 0 <= idx < len(self.chunks):
                sem_map[int(idx)] = float(s)

        # ---- BM25 candidates
        q_tokens = bm25_tokenize(query)
        bm25_all = self.bm25.get_scores(q_tokens)

        if bm25_candidates >= len(bm25_all):
            bm25_top_idx = np.argsort(-bm25_all)
        else:
            bm25_top_idx = np.argpartition(-bm25_all, bm25_candidates)[:bm25_candidates]
            bm25_top_idx = bm25_top_idx[np.argsort(-bm25_all[bm25_top_idx])]

        bm25_map = {int(i): float(bm25_all[i]) for i in bm25_top_idx}

        # ---- candidates union
        cand_idx = set(sem_map.keys()) | set(bm25_map.keys())
        if not cand_idx:
            return [], profile, weight_info

        # ---- normalize within candidate set
        sem_vals = [sem_map.get(i, 0.0) for i in cand_idx]
        bm_vals = [bm25_map.get(i, 0.0) for i in cand_idx]
        sem_min, sem_max = min(sem_vals), max(sem_vals)
        bm_min, bm_max = min(bm_vals), max(bm_vals)

        def norm(x, a, b):
            if b - a < 1e-12:
                return 0.0
            return (x - a) / (b - a)

        scored = []
        for i in cand_idx:
            s_sem = sem_map.get(i, 0.0)
            s_bm = bm25_map.get(i, 0.0)
            s_sem_n = norm(s_sem, sem_min, sem_max)
            s_bm_n = norm(s_bm, bm_min, bm_max)
            combo = alpha * s_sem_n + (1.0 - alpha) * s_bm_n
            scored.append((combo, i, s_sem, s_bm))

        scored.sort(reverse=True, key=lambda x: x[0])

        results: List[Dict] = []
        for combo, idx, s_sem, s_bm in scored:
            if idx < 0 or idx >= len(self.chunks):
                continue

            chunk_info = self.chunks[idx].copy()
            chunk_info["score"] = float(combo)
            chunk_info["score_semantic"] = float(s_sem)
            chunk_info["score_bm25"] = float(s_bm)

            # IMPORTANT: leaflet filters tieši šeit (tātad ietekmē retrieval rezultātu)
            if not chunk_matches_filters(chunk_info, filters):
                continue

            results.append(chunk_info)
            if len(results) >= top_k:
                break

        return results, profile, weight_info


# ============================================================
# Prompt builder (leaflet versija)
# ============================================================
def simple_llm_prompt_builder(query: str, chunks: List[Dict]) -> str:
    context_blocks = []
    for i, c in enumerate(chunks, start=1):
        meta = (
            f"title={c.get('title', '')}, "
            f"date={c.get('date', '')}, "
            f"file={c.get('file_name', '')}, "
            f"chunk_id={c.get('chunk_id', '')}"
        )
        block = f"[{i}] ({meta})\n{c.get('text', '')}"
        context_blocks.append(block)

    context_str = "\n\n---\n\n".join(context_blocks)

    prompt = f"""
Tu esi vēsturnieks, kurš analizē Latvijas komunistisko pagrīdes organizāciju skrejlapas (1934–1940).
Tev ir pieejami tikai zemāk dotie skrejlapu fragmenti.

Tavi metodoloģiskie principi:

1. Atbildi uz jautājumu, balstoties TIKAI uz dotajiem fragmentiem. Nekādu ārēju zināšanu.
2. NEIZDOMĀT faktus. Ja avotos nav tiešas norādes, tas jāklasificē kā nezināms.

I. **Droši fakti (tieši avotos)**
– Iekļauj tikai informāciju, kas skaidri minēta tekstā.
– Katram faktam pievieno atsauci uz fragmentu (piem., “[3]”).
– Ja iespējams, pievieno īsu citātu no avota.

II. **Piesardzīgie secinājumi (netieši, bet atļauti)**
– Atļauts tikai tad, ja secinājums loģiski izriet no fragmentu formulējumiem.
– Vienmēr norādi, ka tas ir NETIEŠS secinājums.

III. **Nezināmais**
– Skaidri norādi visu, ko no avotiem noteikt NAV iespējams.
– Šo sadaļu vienmēr iekļauj.
– Ja atbilde nav nosakāma, skaidri uzraksti:
  **"To nav iespējams noteikt, balstoties tikai uz šeit dotajiem avotiem."**

4. Stilam jābūt akadēmiski precīzam, konspektīvam, bez retorikas un vispārinājumiem.
5. Atsaucies tikai uz informāciju, kas patiešām ir fragmentos.

---

Skrejlapu fragmenti:
{context_str}

Jautājums:
{query}

Tagad sniedz īsu, stingri strukturētu atbildi LATVIEŠU valodā tieši šādā formā:

**I. Droši fakti (tieši avotos)**
– ...

**II. Piesardzīgie secinājumi (no fragmentiem izrietoši)**
– ...

**III. Nezināmais**
– ...
"""
    return prompt.strip()


# ============================================================
# GUI helpers (leaflet filters)
# ============================================================
def build_filters_from_inputs(
    date_from_str: str,
    date_to_str: str,
    print_run_min,
    print_run_max,
    org_custom: str,
    include_unk_print_run: bool,
) -> Dict:
    df = parse_user_date_box(date_from_str, is_start=True) if date_from_str else None
    dt = parse_user_date_box(date_to_str, is_start=False) if date_to_str else None

    pr_min = None
    if print_run_min is not None and str(print_run_min).strip() != "":
        try:
            v = int(print_run_min)
            if v > 0:
                pr_min = v
        except ValueError:
            pr_min = None

    pr_max = None
    if print_run_max is not None and str(print_run_max).strip() != "":
        try:
            v = int(print_run_max)
            if v > 0:
                pr_max = v
        except ValueError:
            pr_max = None

    org_substrings: List[str] = []
    custom = (org_custom or "").strip().lower()
    if custom:
        org_substrings.append(custom)

    org_substrings = sorted({s for s in org_substrings if s})

    return {
        "date_from": df,
        "date_to": dt,
        "print_run_min": pr_min,
        "print_run_max": pr_max,
        "include_unk_print_run": bool(include_unk_print_run),
        "org_substrings": org_substrings,
    }


# ============================================================
# Global state
# ============================================================
global_rag = None
global_leaflets = None


def build_rag_from_zip_gui(zip_file):
    global global_rag, global_leaflets

    if zip_file is None:
        return "Nav augšupielādēts ZIP fails.", ""

    zip_path = zip_file.name

    try:
        leaflets = load_leaflets_from_zip(zip_path)
    except Exception as e:
        return f"Kļūda, lasot ZIP: {e}", ""

    if not leaflets:
        return "Neizdevās nolasīt nevienu skrejlapu no ZIP.", ""

    rag = LeafletRAGHybrid()
    rag.build_index(leaflets, max_words_per_chunk=260, overlap_words=60)

    global_rag = rag
    global_leaflets = leaflets

    info = (
        f"Indekss uzbūvēts. Ielādētas {len(leaflets)} skrejlapas. "
        f"Kopējais fragmentu skaits: {len(rag.chunks)}."
    )
    return info, ""


def qa_on_corpus_gui(
    api_key: str,
    question: str,
    top_k: int,
    preview_chars: int,
    min_score: float,
    model_choice: str,
    date_from_str: str,
    date_to_str: str,
    print_run_min,
    print_run_max,
    org_custom: str,
    include_unk_print_run: bool,
    show_full_chunks: bool,
    temperature: float,
    weight_mode: str,
    sem_percent: float,
):
    global global_rag

    if global_rag is None:
        return (
            "Indekss vēl nav uzbūvēts. Lūdzu augšupielādē ZIP un nospied 'Izveidot indeksu'.",
            "",
        )

    if not api_key.strip():
        return "Nav norādīts OpenRouter API key. Lūdzu ievadi savu API key.", ""

    if not question.strip():
        return "Lūdzu ievadi jautājumu.", ""

    effective_model_id = (model_choice or DEFAULT_MODEL_ID).strip()

    filters = build_filters_from_inputs(
        date_from_str,
        date_to_str,
        print_run_min,
        print_run_max,
        org_custom,
        include_unk_print_run,
    )

    # (A) retrieve (SEM+BM25) + leaflet filters
    retrieved, profile, weight_info = global_rag.retrieve(
        query=question,
        top_k=top_k,
        filters=filters,
        weight_mode=weight_mode,
        sem_percent=sem_percent,
    )

    # (B) min_score filtrs PIRMS LLM (kā tev BM25 programmā)
    effective_min_score = float(min_score) if min_score is not None else 0.0
    if effective_min_score < 0.0:
        effective_min_score = 0.0

    if effective_min_score > 0.0:
        retrieved = [c for c in retrieved if c.get("score", 0.0) >= effective_min_score]

    # Ja nav fragmentu — LLM nesaucam; logā rakstām tukšu retrieved_chunks
    if not retrieved:
        answer_text = (
            f"Netika atrasts neviens fragments ar līdzības score "
            f"≥ {effective_min_score:.2f}."
        )
        log_qa_event(
            question=question,
            answer=answer_text,
            retrieved_chunks=[],
            model_id=effective_model_id,
            top_k=top_k,
            temperature=temperature,
            query_profile=profile,
            weight_info=weight_info,
        )
        return answer_text, ""

    # (C) LLM prompt (redz tikai šos retrieved)
    prompt = simple_llm_prompt_builder(question, retrieved)
    answer_text = ask(prompt, api_key, model_id=effective_model_id, temperature=temperature)

    # (D) normalize answer logam
    if answer_text is None:
        answer_for_log = "[NONE_ANSWER_FROM_MODEL]"
    else:
        stripped = answer_text.strip()
        answer_for_log = "[EMPTY_OR_WHITESPACE_ANSWER]" if stripped == "" else answer_text

    # (E) LOG: tikai promptā ieliktie chunk (=retrieved)
    log_qa_event(
        question=question,
        answer=answer_for_log,
        retrieved_chunks=retrieved,
        model_id=effective_model_id,
        top_k=top_k,
        temperature=temperature,
        query_profile=profile,
        weight_info=weight_info,
    )

    # (F) GUI preview
    preview_lines = []
    preview_lines.append(
        f"[WEIGHTS] mode={weight_info['mode']} | sem={weight_info['sem_percent']:.0f}% | bm25={weight_info['bm25_percent']:.0f}%"
    )
    preview_lines.append(
        f"[QUERY_TYPE] {profile.qtype} | alpha_auto={profile.alpha:.2f} | "
        f"semCand={profile.sem_candidates} | bm25Cand={profile.bm25_candidates}"
    )

    if effective_min_score > 0.0:
        preview_lines.append(
            f"[INFO] Pēc min score {effective_min_score:.2f} filtrēšanas izmantoti "
            f"{len(retrieved)} fragmenti (no top_k={top_k})."
        )

    for i, c in enumerate(retrieved, start=1):
        text = c.get("text", "") or ""
        if not show_full_chunks and len(text) > preview_chars:
            text = text[:preview_chars] + "..."

        meta = (
            f"[{i}] score={c.get('score', 0):.4f} | "
            f"sem={c.get('score_semantic', 0):.4f} | "
            f"bm25={c.get('score_bm25', 0):.4f} | "
            f"title={c.get('title','')} "
            f"date={c.get('date','')} "
            f"print_run={c.get('print_run','')} "
            f"author={c.get('author','')} "
            f"file={c.get('file_name','')} "
            f"chunk_id={c.get('chunk_id','')}"
        )
        preview_lines.append(meta + "\n" + text)

    preview_block = "\n\n---\n\n".join(preview_lines)
    return answer_text, preview_block


def retrieve_only_gui(
    question: str,
    top_k: int,
    preview_chars: int,
    min_score: float,
    date_from_str: str,
    date_to_str: str,
    print_run_min,
    print_run_max,
    org_custom: str,
    include_unk_print_run: bool,
    show_full_chunks: bool,
    weight_mode: str,
    sem_percent: float,
):
    global global_rag

    if global_rag is None:
        return "Indekss vēl nav uzbūvēts. Lūdzu vispirms uzbūvē indeksu."

    if not question.strip():
        return "Lūdzu ievadi jautājumu."

    filters = build_filters_from_inputs(
        date_from_str,
        date_to_str,
        print_run_min,
        print_run_max,
        org_custom,
        include_unk_print_run,
    )

    retrieved, profile, weight_info = global_rag.retrieve(
        query=question,
        top_k=top_k,
        filters=filters,
        weight_mode=weight_mode,
        sem_percent=sem_percent,
    )

    effective_min_score = float(min_score) if min_score is not None else 0.0
    if effective_min_score < 0.0:
        effective_min_score = 0.0
    if effective_min_score > 0.0:
        retrieved = [c for c in retrieved if c.get("score", 0.0) >= effective_min_score]

    if not retrieved:
        if effective_min_score > 0.0:
            return (
                f"Nav atrasts neviens fragments ar līdzības score "
                f"≥ {effective_min_score:.2f} (no top_k={top_k})."
            )
        return "Nav atrasts neviens atbilstošs fragments."

    lines = []
    lines.append(
        f"[WEIGHTS] mode={weight_info['mode']} | sem={weight_info['sem_percent']:.0f}% | bm25={weight_info['bm25_percent']:.0f}%"
    )
    lines.append(
        f"[QUERY_TYPE] {profile.qtype} | alpha_auto={profile.alpha:.2f} | "
        f"semCand={profile.sem_candidates} | bm25Cand={profile.bm25_candidates}"
    )
    if effective_min_score > 0.0:
        lines.append(
            f"[INFO] Pēc min score {effective_min_score:.2f} filtrēšanas izmantoti "
            f"{len(retrieved)} fragmenti (no top_k={top_k})."
        )

    for i, c in enumerate(retrieved, start=1):
        text = c.get("text", "") or ""
        if not show_full_chunks and len(text) > preview_chars:
            text = text[:preview_chars] + "..."

        meta = (
            f"[{i}] score={c.get('score', 0):.4f} | "
            f"sem={c.get('score_semantic', 0):.4f} | "
            f"bm25={c.get('score_bm25', 0):.4f} | "
            f"title={c.get('title','')} "
            f"date={c.get('date','')} "
            f"print_run={c.get('print_run','')} "
            f"author={c.get('author','')} "
            f"file={c.get('file_name','')} "
            f"chunk_id={c.get('chunk_id','')}"
        )
        lines.append(meta + "\n" + text)

    return "\n\n---\n\n".join(lines)


# ============================================================
# GRADIO UI
# ============================================================
with gr.Blocks() as gui:
    gr.Markdown(
        "## Latvijas komunistisko organizāciju skrejlapu RAG asistents (1934–1940)\n"
        "ZIP ar skrejlapu .txt (ar metadatiem un `text:`) + jautājumi.\n"
        "Retrieval: **SEM + BM25** (AUTO/MANUAL). Atbildes balstītas TIKAI uz fragmentiem."
    )

    with gr.Row():
        with gr.Column():
            api_key_box = gr.Textbox(
                label="OpenRouter API key",
                type="password",
                placeholder="ievadi savu OpenRouter API key šeit",
            )

            model_choice_box = gr.Dropdown(
                label="OpenRouter modelis (vari izvēlēties vai ierakstīt pats)",
                choices=[
                    DEFAULT_MODEL_ID,
                    "anthropic/claude-3.5-sonnet",
                    "anthropic/claude-3.5-haiku",

                    "openai/gpt-4.1",
                    "openai/gpt-4.1-mini",
                    "openai/gpt-4o",
                    "openai/gpt-4o-mini",

                    "qwen/qwen-2.5-7b-instruct",

                    "deepseek/deepseek-chat",

                    "mistralai/mistral-large-2512",
                    "mistralai/mistral-small-3.2-24b-instruct",
                    "mistralai/mistral-nemo",

                    "meta-llama/llama-3.1-70b-instruct",
                    "meta-llama/llama-3.1-8b-instruct",

                    "google/gemini-2.5-flash",
                    "google/gemini-2.5-flash-lite",
                    "google/gemini-2.5-pro",

                    "amazon/nova-2-lite-v1:free",
                    "mistralai/mistral-7b-instruct:free",
                    "kwaipilot/kat-coder-pro:free",
                    "tngtech/deepseek-r1t2-chimera:free",
                ],
                value=DEFAULT_MODEL_ID,
                allow_custom_value=True,
            )

            temperature_inp = gr.Slider(
                label="Temperature (0.0 = mazāka variācija, 1.0 = lielāka variācija)",
                minimum=0.0,
                maximum=1.0,
                value=0.0,
                step=0.05,
            )

            zip_input = gr.File(label="ZIP ar LKP skrejlapu .txt failiem")
            build_btn = gr.Button("Izveidot indeksu")
            build_status = gr.Textbox(label="Status", interactive=False)

            top_k_inp = gr.Slider(
                label="Cik fragmentus izmantot (top_k)?",
                minimum=1,
                maximum=30,
                value=8,
                step=1,
            )

            # WEIGHTS (NEW)
            weight_mode_box = gr.Radio(
                label="Svara režīms (AUTO vai MANUAL)",
                choices=[AUTO_MODE, MANUAL_MODE],
                value=AUTO_MODE,
            )

            sem_weight_slider = gr.Slider(
                label="Semantika (%) [MANUAL] (BM25 = 100 - Semantika)",
                minimum=0,
                maximum=100,
                value=40,
                step=1,
            )

            preview_chars_inp = gr.Slider(
                label="Cik simbolus rādīt katrā fragmenta preview?",
                minimum=50,
                maximum=1000,
                value=300,
                step=50,
            )

            min_score_inp = gr.Slider(
                label="Minimālais līdzības score (0 = izslēgts)",
                minimum=0.0,
                maximum=1.0,
                value=0.0,
                step=0.01,
            )

            show_full_chunks_box = gr.Checkbox(
                label="Rādīt pilnus fragmentus (nevis tikai preview)",
                value=True,
            )

            # Leaflet filters (paliek)
            date_from_box = gr.Textbox(
                label="Datums no (YYYY, YYYY-MM vai YYYY-MM-DD, tukšs – nav filtra)",
                placeholder="piem., 1934-01",
            )
            date_to_box = gr.Textbox(
                label="Datums līdz (YYYY, YYYY-MM vai YYYY-MM-DD, tukšs – nav filtra)",
                placeholder="piem., 1936-12",
            )

            print_run_min_box = gr.Number(
                label="Tirāža no (>=, tukšs – nav filtra)",
                value=None,
                precision=0,
            )
            print_run_max_box = gr.Number(
                label="Tirāža līdz (<=, tukšs – nav filtra)",
                value=None,
                precision=0,
            )

            include_unk_print_run_box = gr.Checkbox(
                label="Iekļaut skrejlapas ar nezināmu tirāžu (unk), ja ir tirāžas filtrs",
                value=True,
            )

            org_custom_box = gr.Textbox(
                label="Papildu organizācijas filtrs (brīvs teksts, pēc apakšvirknes)",
                placeholder="piem., LKP CK, Rīgas komiteja, Sarkanā palīdzība, VEF, Daugavpils",
            )

        with gr.Column():
            question_box = gr.Textbox(
                label="Jautājums par Latvijas komunistisko organizāciju skrejlapu korpusu",
                lines=3,
                placeholder="Piemēram: Nosauc, kuriem komunistiem piesprieda nāvessodu Ulmaņa režīma laikā!",
            )
            ask_btn = gr.Button("Uzdot jautājumu")
            retrieve_btn = gr.Button("Rādīt tikai fragmentus (bez LLM)")

            answer_out = gr.Markdown(label="Atbilde")
            chunks_out = gr.Textbox(
                label="Izmantotie fragmenti (preview vai pilni)",
                lines=20,
            )

            with gr.Row():
                log_btn = gr.Button("Izveidot un lejupielādēt žurnālu (JSON)")
                log_file_out = gr.File(
                    label=f"Žurnāla fails ({LOG_PATH})",
                    interactive=False,
                )

    build_btn.click(
        fn=build_rag_from_zip_gui,
        inputs=[zip_input],
        outputs=[build_status, chunks_out],
    )

    ask_btn.click(
        fn=qa_on_corpus_gui,
        inputs=[
            api_key_box,
            question_box,
            top_k_inp,
            preview_chars_inp,
            min_score_inp,
            model_choice_box,
            date_from_box,
            date_to_box,
            print_run_min_box,
            print_run_max_box,
            org_custom_box,
            include_unk_print_run_box,
            show_full_chunks_box,
            temperature_inp,
            weight_mode_box,
            sem_weight_slider,
        ],
        outputs=[answer_out, chunks_out],
    )

    retrieve_btn.click(
        fn=retrieve_only_gui,
        inputs=[
            question_box,
            top_k_inp,
            preview_chars_inp,
            min_score_inp,
            date_from_box,
            date_to_box,
            print_run_min_box,
            print_run_max_box,
            org_custom_box,
            include_unk_print_run_box,
            show_full_chunks_box,
            weight_mode_box,
            sem_weight_slider,
        ],
        outputs=[chunks_out],
    )

    log_btn.click(
        fn=get_log_file_for_gui,
        inputs=[],
        outputs=[log_file_out],
    )


if __name__ == "__main__":
    gui.launch(inbrowser=True)



It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://62f3448933d042e20a.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
