In [69]:
from __future__ import annotations
import os, sys, subprocess, json, re, time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Dict, Tuple, Optional, Any
from collections import defaultdict
from sentence_transformers import CrossEncoder

In [70]:
OUT_DIR = "/content/faiss_per_number"
PROMPT_TOKENIZER_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

def ensure(pkgs: List[str]):
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q"] + pkgs)

try:
    import torch, faiss
    from sentence_transformers import SentenceTransformer
except Exception:
    ensure(["torch", "faiss-cpu", "sentence-transformers"])
    import torch, faiss
    from sentence_transformers import SentenceTransformer

try:
    from transformers import AutoTokenizer
except Exception:
    ensure(["transformers", "accelerate", "sentencepiece"])
    from transformers import AutoTokenizer

try:
    import groq
except Exception:
    ensure(["groq"])
    import groq
from groq import Groq

import numpy as np

In [71]:
@dataclass
class EncoderConfig:
    model_name: str
    max_seq_length: int = 384
    batch_size: int = 64
    device: Optional[str] = None

In [4]:
@dataclass
class GroqConfig:
    model: str = "llama-3.1-8b-instant"
    temperature: float = 0.0
    top_p: float = 1.0
    max_completion_tokens: int = 700
    service_tier: Optional[str] = "auto"
    timeout: Optional[float] = 60.0

In [5]:
@dataclass
class BuildPaths:
    out_dir: str
    chunk_faiss: str
    chunk_docstore: str
    manifest: str
    numbers_faiss: str
    numbers_meta: str

In [6]:
def make_paths(out_dir: str) -> BuildPaths:
    p = Path(out_dir); p.mkdir(parents=True, exist_ok=True)
    return BuildPaths(
        out_dir=str(p),
        chunk_faiss=str(p / "chunks.faiss"),
        chunk_docstore=str(p / "docstore.json"),
        manifest=str(p / "manifest.json"),
        numbers_faiss=str(p / "numbers.faiss"),
        numbers_meta=str(p / "numbers_meta.json"),
    )

In [7]:
class STEncoder:
    def __init__(self, cfg: EncoderConfig):
        dev = cfg.device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model = SentenceTransformer(cfg.model_name, device=dev)
        self.model.max_seq_length = cfg.max_seq_length
        self.batch_size = cfg.batch_size
    def encode(self, texts: List[str]) -> np.ndarray:
        try:
            X = self.model.encode(texts, batch_size=self.batch_size, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False).astype(np.float32)
        except TypeError:
            X = self.model.encode(texts, batch_size=self.batch_size, convert_to_numpy=True, show_progress_bar=False).astype(np.float32)
            faiss.normalize_L2(X)
        return X

In [8]:
class ChunkFaissIndex:
    def __init__(self):
        self.index: faiss.Index = None
        self.ids: List[str] = []
        self.texts: List[str] = []
        self.numbers: List[str] = []
        self.sections: List[str] = []
    @classmethod
    def load(cls, paths: BuildPaths) -> "ChunkFaissIndex":
        obj = cls()
        obj.index = faiss.read_index(paths.chunk_faiss)
        doc = json.loads(Path(paths.chunk_docstore).read_text(encoding="utf-8"))
        obj.ids = doc["ids"]; obj.texts = doc["texts"]; obj.numbers = doc["numbers"]; obj.sections = doc["sections"]
        return obj
    def search(self, Q: np.ndarray, top_k: int) -> Tuple[np.ndarray, np.ndarray]:
        faiss.normalize_L2(Q)
        return self.index.search(Q.astype(np.float32), top_k)

In [9]:
class NumberFaissIndex:
    def __init__(self):
        self.index: faiss.Index = None
        self.numbers: List[str] = []
        self.groups: Dict[str, List[int]] = {}
    @classmethod
    def load(cls, paths: BuildPaths) -> Optional["NumberFaissIndex"]:
        if not Path(paths.numbers_faiss).exists() or not Path(paths.numbers_meta).exists():
            return None
        obj = cls()
        obj.index = faiss.read_index(paths.numbers_faiss)
        meta = json.loads(Path(paths.numbers_meta).read_text(encoding="utf-8"))
        obj.numbers = meta["numbers"]; obj.groups = {k: list(v) for k, v in meta["groups"].items()}
        return obj
    def search(self, Q: np.ndarray, top_k: int) -> Tuple[np.ndarray, np.ndarray]:
        faiss.normalize_L2(Q)
        return self.index.search(Q.astype(np.float32), top_k)

In [62]:
SECTION_WEIGHT = {
    "enunciado": 1.40,
    "enunciado_mini": 1.30,
    "referencias_legislativas": 1.05,
    "excertos_precedentes": 1.00,
    "orgao_data_fonte": 0.90,
    "header": 0.85,
}

In [11]:
def aggregate_by_number(js: List[int], sims: np.ndarray, numbers: List[str], sections: List[str], pool: str = "sum_sqrt"):
    if pool == "max":
        tmp = defaultdict(float)
        for j, s in zip(js, sims):
            w = SECTION_WEIGHT.get(sections[j], 1.0)
            tmp[numbers[j]] = max(tmp[numbers[j]], float(s) * w)
        return sorted(tmp.items(), key=lambda x: x[1], reverse=True)
    agg = defaultdict(float); counts = defaultdict(int)
    for j, s in zip(js, sims):
        w = SECTION_WEIGHT.get(sections[j], 1.0)
        agg[numbers[j]] += float(s) * w
        counts[numbers[j]] += 1
    for n in list(agg.keys()):
        agg[n] = agg[n] / max(1.0, np.sqrt(counts[n]))
    return sorted(agg.items(), key=lambda x: x[1], reverse=True)

In [12]:
class ChunkIdxView:
    def __init__(self, index: ChunkFaissIndex):
        self.index = index.index
        self.ids = index.ids
        self.texts = index.texts
        self.numbers = index.numbers
        self.sections = index.sections
    def search(self, Q: np.ndarray, top_k: int):
        return self.index.search(Q.astype(np.float32), top_k)

In [72]:
def _restrict_idxs_to_numbers(idxs: np.ndarray, numbers_list: List[str], allow_set: set) -> np.ndarray:
    out = []
    for row in idxs:
        keep = [j for j in row if j != -1 and numbers_list[j] in allow_set]
        if not keep:
            keep = [row[0]] if row.size > 0 else [-1]
        out.append(np.array(keep, dtype=np.int64))
    return np.stack(out)

In [73]:
class PerNumberRetriever:
    def __init__(self, out_dir: str):
        self.paths = make_paths(out_dir)
        self.manifest = json.loads(Path(self.paths.manifest).read_text(encoding="utf-8"))
        self.chunk_idx = ChunkIdxView(ChunkFaissIndex.load(self.paths))
        self.number_idx = NumberFaissIndex.load(self.paths)
    def encode_queries(self, texts: List[str], enc_cfg: Optional[EncoderConfig] = None) -> np.ndarray:
        cfg = enc_cfg or EncoderConfig(model_name=self.manifest["model_name"])
        cfg = EncoderConfig(model_name=self.manifest.get("model_name", cfg.model_name), max_seq_length=int(self.manifest.get("max_seq_length", 384)), batch_size=cfg.batch_size, device=cfg.device)
        enc = STEncoder(cfg)
        return enc.encode(texts).astype(np.float32)
    def search_two_stage(self, Q: np.ndarray, top_numbers: int = 20, top_chunks_per_query: int = 200, pool: str = "sum_sqrt") -> List[List[Tuple[str, float]]]:
        sims, idxs = self.chunk_idx.search(Q, top_chunks_per_query)
        out = []
        for i in range(Q.shape[0]):
            js = idxs[i].tolist()
            sc = sims[i].tolist()
            pairs = aggregate_by_number(js, sc, self.chunk_idx.numbers, self.chunk_idx.sections, pool=pool)
            out.append(pairs[:top_numbers])
        return out
    def search_two_stage_hybrid(
        self,
        Q: np.ndarray,
        coarse_top_numbers: int = 50,
        refine_top_chunks: int = 300,
        final_top_numbers: int = 20,
        pool: str = "sum_sqrt"
    ) -> List[List[Tuple[str, float]]]:
        if self.number_idx is None:
            return self.search_two_stage(Q, top_numbers=final_top_numbers, top_chunks_per_query=refine_top_chunks, pool=pool)

        sims_num, idxs_num = self.number_idx.search(Q, coarse_top_numbers)
        out = []
        for i in range(Q.shape[0]):
            allow = { self.number_idx.numbers[j] for j in idxs_num[i] if j != -1 }
            sims_ck, idxs_ck = self.chunk_idx.search(Q[i:i+1], refine_top_chunks)
            idxs_filt = _restrict_idxs_to_numbers(idxs_ck, self.chunk_idx.numbers, allow)
            sc = sims_ck[0].tolist()
            js = idxs_filt[0].tolist()
            pairs = aggregate_by_number(js, sc, self.chunk_idx.numbers, self.chunk_idx.sections, pool=pool)
            out.append(pairs[:final_top_numbers])
        return out

In [14]:
def pick_evidence_from_idxs(retriever: PerNumberRetriever, number: str, top_chunks_global_idxs, max_chunks: int = 3):
    prefer = ["enunciado", "enunciado_mini", "excertos_precedentes", "referencias_legislativas", "orgao_data_fonte", "header"]
    selected, seen = [], set()
    for sec in prefer:
        for j in top_chunks_global_idxs:
            if j == -1 or j in seen:
                continue
            if retriever.chunk_idx.numbers[j] == number and retriever.chunk_idx.sections[j] == sec:
                selected.append({"id": retriever.chunk_idx.ids[j], "section": retriever.chunk_idx.sections[j], "text": retriever.chunk_idx.texts[j]})
                seen.add(j)
                if len(selected) >= max_chunks:
                    return selected
    return selected

In [74]:
class Reranker:
    def __init__(self, model_id: str = "mixedbread-ai/mxbai-rerank-xsmall-v1", device: Optional[str] = None, max_length: int = 512):
        self.model = CrossEncoder(model_id, device=device or ("cuda" if torch.cuda.is_available() else "cpu"), max_length=max_length)
    def rerank_numbers(self, query: str, hits: List[Dict[str, Any]], k: int = 5) -> List[Dict[str, Any]]:
        if not hits: return hits
        pairs = []
        for h in hits:
            text = (h.get("enunciado") or "").strip()
            if not text and h.get("evidences"):
                text = h["evidences"][0].get("text","")
            pairs.append([query, text])
        scores = self.model.predict(pairs, convert_to_numpy=True)
        order = np.argsort(-scores)[:k]
        return [hits[i] | {"rerank_score": float(scores[i])} for i in order]

In [67]:
def gather_hits_for_query(query: str, retriever: PerNumberRetriever, enc_cfg: Optional[EncoderConfig] = None, top_numbers: int = 5, top_chunks_per_query: int = 200, max_evidence_per_number: int = 3) -> List[Dict[str, Any]]:
    Q = retriever.encode_queries([query], enc_cfg=enc_cfg)
    ranked = retriever.search_two_stage_hybrid(
        Q,
        coarse_top_numbers=60,
        refine_top_chunks=300,
        final_top_numbers=top_numbers,
        pool="sum_sqrt"
    )[0]
    sims, idxs = retriever.chunk_idx.search(Q, top_k=top_chunks_per_query)
    results = []
    for number, score in ranked[:top_numbers]:
        enun = next((retriever.chunk_idx.texts[i] for i, num in enumerate(retriever.chunk_idx.numbers) if num == number and retriever.chunk_idx.sections[i] == "enunciado"), "")
        ev = pick_evidence_from_idxs(retriever, number=number, top_chunks_global_idxs=idxs[0], max_chunks=max_evidence_per_number)
        results.append({"number": number, "score": float(score), "enunciado": enun, "evidences": ev})
    return results

In [16]:
def count_tokens(text: str, tokenizer) -> int:
    return len(tokenizer.encode(text or "", add_special_tokens=False))

In [17]:

def compact_hits_autobudget(hits: List[Dict[str, Any]], tokenizer, token_budget: int = 12000, system_overhead: int = 300, query_overhead: int = 600, max_numbers_cap: int = 12, max_evidence_per_number_cap: int = 6) -> List[Dict[str, Any]]:
    used = system_overhead + query_overhead
    out: List[Dict[str, Any]] = []
    for h in hits[:max_numbers_cap]:
        enun = (h.get("enunciado") or "").strip()
        enun_cost = count_tokens(f"SÚMULA {h['number']} — ENUNCIADO:\n{enun}", tokenizer)
        evs = h.get("evidences", [])[:max_evidence_per_number_cap]
        ev_costs = []
        for ev in evs:
            sec = ev.get("section", "")
            txt = (ev.get("text") or "").strip()
            ev_costs.append((ev, count_tokens(f"[{ev['id']} | {sec}]\n{txt}", tokenizer)))
        need = enun_cost + sum(c for _, c in ev_costs)
        if used + need <= token_budget:
            out.append({"number": h["number"], "score": float(h.get("score", 0.0)), "enunciado": enun, "evidences": evs})
            used += need
            continue
        if used + enun_cost > token_budget:
            break
        acc_evs, acc_cost = [], enun_cost
        for ev, c in ev_costs:
            if used + acc_cost + c <= token_budget:
                acc_evs.append(ev); acc_cost += c
            else:
                room = token_budget - (used + acc_cost)
                if room > 30:
                    ids = tokenizer.encode(ev.get("text") or "", add_special_tokens=False)[:room]
                    txt_trunc = tokenizer.decode(ids)
                    if txt_trunc.strip():
                        ev_trunc = {**ev, "text": txt_trunc}
                        acc_evs.append(ev_trunc)
                        acc_cost += room
                break
        out.append({"number": h["number"], "score": float(h.get("score", 0.0)), "enunciado": enun, "evidences": acc_evs})
        used += acc_cost
        if used >= token_budget:
            break
    return out

In [18]:
_SECTION_PT = {
    "enunciado": "Enunciado",
    "enunciado_mini": "Enunciado (mini)",
    "excertos_precedentes": "Excerto dos Precedentes",
    "referencias_legislativas": "Referências Legislativas",
    "orgao_data_fonte": "Órgão Julgador / Data / Fonte",
    "header": "Cabeçalho",
}

In [19]:
def _sec_label(sec: str) -> str:
    return _SECTION_PT.get(sec, sec)

In [20]:
_CASE_RE = re.compile(r'\b(REsp|AgRg|HC|EDcl|AgInt|RMS|AREsp)\b[^)\]\n]{0,120}')
def extract_case_citation(text: str) -> str:
    m = _CASE_RE.search(text or "");
    if not m: return ""
    span_start = m.start(); cut = text[span_start: span_start + 160]
    cut = re.split(r'[\)\n]', cut, maxsplit=1)[0]
    return cut.strip()

In [21]:
def build_prompt_from_compact_narrative(query: str, compact_hits_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
    system = (
        "Você é um assistente jurídico especializado em Súmulas do STJ. "
        "Identifique a(s) Súmula(s) aplicável(is) ao trecho, cite o número e explique brevemente. "
        "Ao citar evidências, indique a seção e o [chunk_id]; quando houver, mencione a referência do precedente."
    )
    ctx_lines = []
    for h in compact_hits_list:
        num = h["number"]; enun = (h.get("enunciado") or "").strip()
        ctx_lines.append(f"SÚMULA {num} — ENUNCIADO:\n{(enun if enun else '(sem enunciado)')}")
        for ev in h.get("evidences", []):
            sec = _sec_label(ev.get("section", "")); text = (ev.get("text") or "").strip()
            case = extract_case_citation(text)
            bits = [f"seção: {sec}", f"[{ev['id']}"] + ([f"originado de {case}"] if case else [])
            ctx_lines.append(f"Evidência (Súmula {num}) — {'; '.join(bits)}:\n{text}")
    context_block = "\n\n".join(ctx_lines)
    user = (
        f"Trecho do usuário:\n---\n{query.strip()}\n---\n\n"
        f"Contexto:\n{context_block}\n\n"
        "Responda:\n"
        "1) Diga qual(is) Súmula(s) se alinham ao trecho, no formato textual: "
        "\"Este é um trecho da Súmula NNN (seção: ...) que veio do [chunk_id]\".\n"
        "2) Explique em 2–4 frases, citando trechos entre aspas com [chunk_id].\n"
        "3) Se houver incerteza, reporte top-3 candidatas.\n"
        "4) Final: 'Súmula aplicável: NNN'."
    )
    return [{"role": "system", "content": system}, {"role": "user", "content": user}]

In [None]:
import os
print("ok" if os.environ.get("GROQ_API_KEY") else "missing")

ok


In [85]:
class GroqChat:
    def __init__(self, cfg: GroqConfig):
        api_key = os.environ.get("GROQ_API_KEY")
        if not api_key:
            raise RuntimeError("Set GROQ_API_KEY in environment.")
        self.client = Groq(api_key=api_key, timeout=cfg.timeout)
        self.cfg = cfg
        self.tokenizer = AutoTokenizer.from_pretrained(PROMPT_TOKENIZER_ID, use_fast=True)

    def chat(self, messages):
        kwargs = dict(
            model=self.cfg.model,
            messages=messages,
            temperature=self.cfg.temperature,
            top_p=self.cfg.top_p,
            max_completion_tokens=self.cfg.max_completion_tokens,
            stream=False,
        )
        if self.cfg.service_tier:
            kwargs["service_tier"] = self.cfg.service_tier
        resp = self.client.chat.completions.create(**kwargs)
        return resp.choices[0].message.content.strip()

In [86]:
_LLM_CACHE: Dict[str, Any] = {}
def get_llm(gen_cfg: GroqConfig) -> GroqChat:
    key = f"{gen_cfg.model}"
    if key not in _LLM_CACHE:
        _LLM_CACHE[key] = GroqChat(gen_cfg)
    return _LLM_CACHE[key]

In [88]:
def count_messages_tokens(messages, tokenizer) -> int:
    text = ""
    for m in messages:
        role = m.get("role","")
        content = m.get("content","")
        text += f"<|{role}|>\n{content}\n"
    return len(tokenizer.encode(text, add_special_tokens=False))

In [89]:
def load_retriever(out_dir: str = OUT_DIR) -> Tuple[PerNumberRetriever, EncoderConfig]:
    paths = make_paths(out_dir)
    mani = json.loads(Path(paths.manifest).read_text(encoding="utf-8"))
    enc_cfg = EncoderConfig(model_name=mani["model_name"], max_seq_length=int(mani.get("max_seq_length", 384)))
    retriever = PerNumberRetriever(out_dir)
    return retriever, enc_cfg

In [90]:
retriever, enc_cfg = load_retriever(OUT_DIR)

In [91]:
def rag_answer(
    query: str,
    retriever: PerNumberRetriever,
    enc_cfg: Optional[EncoderConfig] = None,
    gen_cfg: GroqConfig = GroqConfig(),
    top_numbers: int = 8,
    top_chunks_per_query: int = 160,
    max_evidence_per_number: int = 3,
    token_budget: Optional[int] = None,
    tpm_limit: int = 6000,
    safety_tokens: int = 200,
    use_reranker: bool = True,
    rerank_keep: int = 5
) -> Dict[str, Any]:
    hits_all = gather_hits_for_query(query, retriever, enc_cfg, top_numbers, top_chunks_per_query, max_evidence_per_number)
    if use_reranker:
        rr = Reranker()
        hits_all = rr.rerank_numbers(query, hits_all, k=min(rerank_keep, len(hits_all)))

    llm = get_llm(gen_cfg)
    prompt_limit = token_budget if token_budget is not None else max(512, tpm_limit - int(gen_cfg.max_completion_tokens) - safety_tokens)
    hits_compact = compact_hits_autobudget(hits_all, tokenizer=llm.tokenizer, token_budget=prompt_limit)
    messages = build_prompt_from_compact_narrative(query, hits_compact)
    if count_messages_tokens(messages, llm.tokenizer) > prompt_limit:
        hits_compact = compact_hits_autobudget(hits_all, tokenizer=llm.tokenizer, token_budget=max(512, int(prompt_limit*0.9)))
        messages = build_prompt_from_compact_narrative(query, hits_compact)
    answer = llm.chat(messages)
    return {"answer": answer, "hits": hits_compact, "messages": messages}

In [92]:
def load_retriever(out_dir: str = OUT_DIR) -> Tuple[PerNumberRetriever, EncoderConfig]:
    paths = make_paths(out_dir)
    mani = json.loads(Path(paths.manifest).read_text(encoding="utf-8"))
    enc_cfg = EncoderConfig(model_name=mani["model_name"], max_seq_length=int(mani.get("max_seq_length", 384)))
    retriever = PerNumberRetriever(out_dir)
    return retriever, enc_cfg

In [99]:
QUERY = """"[...] REFIS. LEGITIMIDADE DA EXCLUSÃO POR MEIO DO DIÁRIO OFICIAL E DA INTERNET.
AFASTAMENTO DA LEGISLAÇÃO SUBSIDIÁRIA (LEI 9.784/99). [...] Nos termos do art. 69 da Lei 9.784/99,
'os processos administrativos específicos continuarão a reger-se por lei própria, aplicando-se-lhes
apenas subsidiariamente os preceitos desta Lei'. Considerando que o REFIS é regido especificamente
pela Lei 9.964/2000, a sua incidência afasta a aplicação da norma subsidiária (Lei 9.784/99). 2. Não há
ilegalidade na exclusão do REFIS sem a intimação pessoal do contribuinte, efetuando-se a notificação
por meio do Diário Oficial e da Internet, nos termos do art. 9º, III, da Lei 9.964/2000, c/c o art. 5º da
Resolução 20/2001 do Comitê Gestor do Programa. [...]" (AgRg no Ag 902614 PR, Rel. Ministra DENISE
ARRUDA, PRIMEIRA TURMA, julgado em 13/11/2007, DJ 12/12/2007, p. 397)"""
QUERY = " ".join(QUERY.split())

In [100]:
retriever, enc_cfg = load_retriever(OUT_DIR)
gen_cfg = GroqConfig(model="llama-3.1-8b-instant", temperature=0.0, max_completion_tokens=400, service_tier="on_demand")

res = rag_answer(
    QUERY, retriever, enc_cfg, gen_cfg,
    top_numbers=10,
    top_chunks_per_query=200,
    max_evidence_per_number=3,
    tpm_limit=6000,
    use_reranker=True,
    rerank_keep=5
)

print(res["answer"])
print("\nCandidates:")
for h in res["hits"]:
    print(f"Súmula {h['number']} | score={h.get('score',0):.3f} | rerank={h.get('rerank_score',0):.3f}")
    for ev in h["evidences"][:2]:
        print("  -", ev["id"], "|", ev["section"])


1. Este é um trecho da Súmula 355 (seção: Excerto dos Precedentes; [355#excertos_precedentes@272-592; originado de AgRg no Ag 902614 PR, Rel. Ministra DENISE ARRUDA, PRIMEIRA TURMA, julgado em 13/11/2007, DJ 12/12/2007, p. 397]) que veio do [272-592].
Este é um trecho da Súmula 355 (seção: Excerto dos Precedentes; [355#excertos_precedentes@0-320; originado de REsp 842906 DF, Rel. Ministra ELIANA CALMON, SEGUNDA TURMA, julgado em 06/05/2008, DJe: 19/05/2008]) que veio do [0-320].
Este é um trecho da Súmula 355 (seção: Excerto dos Precedentes; [355#excertos_precedentes@816-1136; originado de REsp 761128 RS, Rel. Ministro CASTRO MEIRA, SEGUNDA TURMA, julgado em 17/05/2007, DJ 29/05/2007, p. 274]) que veio do [816-1136].

2. A Súmula 355 se alinha ao trecho, pois aborda a validade da notificação do ato de exclusão do programa de recuperação fiscal do Refis pelo Diário Oficial ou pela Internet. O trecho cita a jurisprudência pacífica da Primeira e da Segunda Turma do STJ, que entende que a 