In [1]:
from __future__ import annotations
import sys, subprocess, json, time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Dict, Tuple, Optional, Any
from collections import defaultdict

OUT_DIR = "/content/faiss_per_number"
GEN_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

In [2]:
def ensure(pkgs: List[str]):
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q"] + pkgs)

try:
    import torch
    import faiss
    from sentence_transformers import SentenceTransformer
except Exception:
    ensure(["torch", "faiss-cpu", "sentence-transformers"])
    import torch, faiss
    from sentence_transformers import SentenceTransformer

try:
    from transformers import AutoTokenizer, AutoModelForCausalLM
except Exception:
    ensure(["transformers", "accelerate", "sentencepiece"])
    from transformers import AutoTokenizer, AutoModelForCausalLM

import numpy as np
import re


In [3]:
@dataclass
class EncoderConfig:
    model_name: str
    max_seq_length: int = 384
    batch_size: int = 64
    device: Optional[str] = None

In [4]:
@dataclass
class GeneratorConfig:
    model_name: str = GEN_MODEL
    temperature: float = 0.1
    max_new_tokens: int = 500
    device: Optional[str] = None
    use_bfloat16: bool = True
    do_sample: bool = False
    top_p: float = 0.9
    top_k: int = 50

In [5]:
@dataclass
class BuildPaths:
    out_dir: str
    chunk_faiss: str
    chunk_docstore: str
    manifest: str
    numbers_faiss: str
    numbers_meta: str

In [6]:
def make_paths(out_dir: str) -> BuildPaths:
    p = Path(out_dir); p.mkdir(parents=True, exist_ok=True)
    return BuildPaths(
        out_dir=str(p),
        chunk_faiss=str(p / "chunks.faiss"),
        chunk_docstore=str(p / "docstore.json"),
        manifest=str(p / "manifest.json"),
        numbers_faiss=str(p / "numbers.faiss"),
        numbers_meta=str(p / "numbers_meta.json"),
    )

In [7]:
class STEncoder:
    def __init__(self, cfg: EncoderConfig):
        dev = cfg.device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model = SentenceTransformer(cfg.model_name, device=dev)
        self.model.max_seq_length = cfg.max_seq_length
        self.batch_size = cfg.batch_size
    def encode(self, texts: List[str]) -> np.ndarray:
        try:
            X = self.model.encode(texts, batch_size=self.batch_size, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False).astype(np.float32)
        except TypeError:
            X = self.model.encode(texts, batch_size=self.batch_size, convert_to_numpy=True, show_progress_bar=False).astype(np.float32)
            faiss.normalize_L2(X)
        return X

In [8]:
class ChunkFaissIndex:
    def __init__(self):
        self.index: faiss.Index = None
        self.ids: List[str] = []
        self.texts: List[str] = []
        self.numbers: List[str] = []
        self.sections: List[str] = []
    @classmethod
    def load(cls, paths: BuildPaths) -> "ChunkFaissIndex":
        obj = cls()
        obj.index = faiss.read_index(paths.chunk_faiss)
        doc = json.loads(Path(paths.chunk_docstore).read_text(encoding="utf-8"))
        obj.ids = doc["ids"]; obj.texts = doc["texts"]; obj.numbers = doc["numbers"]; obj.sections = doc["sections"]
        return obj
    def search(self, Q: np.ndarray, top_k: int) -> Tuple[np.ndarray, np.ndarray]:
        faiss.normalize_L2(Q)
        return self.index.search(Q.astype(np.float32), top_k)

In [9]:
class NumberFaissIndex:
    def __init__(self):
        self.index: faiss.Index = None
        self.numbers: List[str] = []
        self.groups: Dict[str, List[int]] = {}
    @classmethod
    def load(cls, paths: BuildPaths) -> Optional["NumberFaissIndex"]:
        if not Path(paths.numbers_faiss).exists() or not Path(paths.numbers_meta).exists():
            return None
        obj = cls()
        obj.index = faiss.read_index(paths.numbers_faiss)
        meta = json.loads(Path(paths.numbers_meta).read_text(encoding="utf-8"))
        obj.numbers = meta["numbers"]; obj.groups = {k: list(v) for k, v in meta["groups"].items()}
        return obj
    def search(self, Q: np.ndarray, top_k: int) -> Tuple[np.ndarray, np.ndarray]:
        faiss.normalize_L2(Q)
        return self.index.search(Q.astype(np.float32), top_k)

In [10]:
SECTION_WEIGHT = {
    "enunciado": 1.10, "enunciado_mini": 1.10,
    "referencias_legislativas": 1.05,
    "excertos_precedentes": 1.00,
    "orgao_data_fonte": 0.95,
    "header": 0.90,
}

In [11]:
def aggregate_by_number(js: List[int], sims: np.ndarray, numbers: List[str], sections: List[str], pool: str = "sum_sqrt"):
    if pool == "max":
        tmp = defaultdict(float)
        for j, s in zip(js, sims):
            w = SECTION_WEIGHT.get(sections[j], 1.0)
            num = numbers[j]
            tmp[num] = max(tmp[num], float(s) * w)
        return sorted(tmp.items(), key=lambda x: x[1], reverse=True)
    agg = defaultdict(float); counts = defaultdict(int)
    for j, s in zip(js, sims):
        w = SECTION_WEIGHT.get(sections[j], 1.0)
        agg[numbers[j]] += float(s) * w
        counts[numbers[j]] += 1
    for n in list(agg.keys()):
        agg[n] = agg[n] / max(1.0, np.sqrt(counts[n]))
    return sorted(agg.items(), key=lambda x: x[1], reverse=True)

In [12]:
class ChunkIdxView:
    def __init__(self, index: ChunkFaissIndex):
        self.index = index.index
        self.ids = index.ids
        self.texts = index.texts
        self.numbers = index.numbers
        self.sections = index.sections
    def search(self, Q: np.ndarray, top_k: int):
        return self.index.search(Q.astype(np.float32), top_k)

In [13]:
class PerNumberRetriever:
    def __init__(self, out_dir: str):
        self.paths = make_paths(out_dir)
        self.manifest = json.loads(Path(self.paths.manifest).read_text(encoding="utf-8"))
        self.chunk_idx = ChunkIdxView(ChunkFaissIndex.load(self.paths))
        self.number_idx = NumberFaissIndex.load(self.paths)
    def encode_queries(self, texts: List[str], enc_cfg: Optional[EncoderConfig] = None) -> np.ndarray:
        cfg = enc_cfg or EncoderConfig(model_name=self.manifest["model_name"])
        cfg = EncoderConfig(model_name=self.manifest.get("model_name", cfg.model_name), max_seq_length=int(self.manifest.get("max_seq_length", 384)), batch_size=cfg.batch_size, device=cfg.device)
        enc = STEncoder(cfg)
        return enc.encode(texts).astype(np.float32)
    def search_two_stage(self, Q: np.ndarray, top_numbers: int = 20, top_chunks_per_query: int = 200, pool: str = "sum_sqrt") -> List[List[Tuple[str, float]]]:
        sims, idxs = self.chunk_idx.search(Q, top_chunks_per_query)
        out = []
        for i in range(Q.shape[0]):
            js = idxs[i].tolist()
            sc = sims[i].tolist()
            pairs = aggregate_by_number(js, sc, self.chunk_idx.numbers, self.chunk_idx.sections, pool=pool)
            out.append(pairs[:top_numbers])
        return out

In [14]:
def pick_evidence_from_idxs(retriever: PerNumberRetriever, number: str, top_chunks_global_idxs, max_chunks: int = 3):
    prefer = ["enunciado", "enunciado_mini", "excertos_precedentes", "referencias_legislativas", "orgao_data_fonte", "header"]
    selected, seen = [], set()
    for sec in prefer:
        for j in top_chunks_global_idxs:
            if j == -1 or j in seen:
                continue
            if retriever.chunk_idx.numbers[j] == number and retriever.chunk_idx.sections[j] == sec:
                selected.append({"id": retriever.chunk_idx.ids[j], "section": retriever.chunk_idx.sections[j], "text": retriever.chunk_idx.texts[j]})
                seen.add(j)
                if len(selected) >= max_chunks:
                    return selected
    return selected

In [15]:
def gather_hits_for_query(query: str, retriever: PerNumberRetriever, enc_cfg: Optional[EncoderConfig] = None, top_numbers: int = 5, top_chunks_per_query: int = 200, max_evidence_per_number: int = 3) -> List[Dict[str, Any]]:
    Q = retriever.encode_queries([query], enc_cfg=enc_cfg)
    ranked = retriever.search_two_stage(Q, top_numbers=top_numbers, top_chunks_per_query=top_chunks_per_query, pool="sum_sqrt")[0]
    sims, idxs = retriever.chunk_idx.search(Q, top_k=top_chunks_per_query)
    results = []
    for number, score in ranked[:top_numbers]:
        enun = next((retriever.chunk_idx.texts[i] for i, num in enumerate(retriever.chunk_idx.numbers) if num == number and retriever.chunk_idx.sections[i] == "enunciado"), "")
        ev = pick_evidence_from_idxs(retriever, number=number, top_chunks_global_idxs=idxs[0], max_chunks=max_evidence_per_number)
        results.append({"number": number, "score": float(score), "enunciado": enun, "evidences": ev})
    return results

In [16]:
def token_truncate(text: str, tokenizer, max_tokens: int) -> str:
    ids = tokenizer.encode(text, add_special_tokens=False)
    if len(ids) <= max_tokens:
        return text
    return tokenizer.decode(ids[:max_tokens])

In [17]:
def compact_hits(hits: List[Dict[str, Any]], tokenizer, max_numbers: int = 3, max_evidence_per_number: int = 2, max_enunciado_tokens: int = 120, max_evidence_tokens: int = 100) -> List[Dict[str, Any]]:
    compact = []
    for h in hits[:max_numbers]:
        enun = h.get("enunciado") or ""
        enun_snip = token_truncate(enun, tokenizer, max_enunciado_tokens) if enun else ""
        ev_snips = []
        for ev in h.get("evidences", [])[:max_evidence_per_number]:
            txt = ev.get("text") or ""
            ev_snips.append({**ev, "text": token_truncate(txt, tokenizer, max_evidence_tokens) if txt else ""})
        compact.append({"number": h["number"], "score": float(h.get("score", 0.0)), "enunciado": enun_snip, "evidences": ev_snips})
    return compact

In [18]:
_SECTION_PT = {
    "enunciado": "Enunciado",
    "enunciado_mini": "Enunciado (mini)",
    "excertos_precedentes": "Excerto dos Precedentes",
    "referencias_legislativas": "Referências Legislativas",
    "orgao_data_fonte": "Órgão Julgador / Data / Fonte",
    "header": "Cabeçalho",
}

In [19]:
def _sec_label(sec: str) -> str:
    return _SECTION_PT.get(sec, sec)

_CASE_RE = re.compile(r'\b(REsp|AgRg|HC|EDcl|AgInt|RMS|AREsp)\b[^)\]\n]{0,120}')

In [20]:
def extract_case_citation(text: str) -> str:
    m = _CASE_RE.search(text or "")
    if not m:
        return ""
    span_start = m.start()
    cut = text[span_start: span_start + 160]
    cut = re.split(r'[\)\n]', cut, maxsplit=1)[0]
    return cut.strip()

In [21]:
def build_prompt_from_compact_narrative(query: str, compact_hits_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
    system = (
        "Você é um assistente jurídico especializado em Súmulas do STJ.\n"
        "Identifique a(s) Súmula(s) aplicável(is) ao trecho, cite o número e explique brevemente. "
        "Ao citar evidências, indique a seção e o [chunk_id]; quando houver, mencione a referência do precedente."
    )
    ctx_lines = []
    for h in compact_hits_list:
        num = h["number"]
        enun = (h.get("enunciado") or "").strip()
        ctx_lines.append(f"SÚMULA {num} — ENUNCIADO:\n{(enun if enun else '(sem enunciado)')}")
        for ev in h.get("evidences", []):
            sec = _sec_label(ev.get("section", ""))
            text = (ev.get("text") or "").strip()
            case = extract_case_citation(text)
            bits = [f"seção: {sec}", f"[{ev['id']}"] + ([f"originado de {case}"] if case else [])
            ctx_lines.append(f"Evidência (Súmula {num}) — {'; '.join(bits)}:\n{text}")
    context_block = "\n\n".join(ctx_lines)
    user = (
        f"Trecho do usuário:\n---\n{query.strip()}\n---\n\n"
        f"Contexto:\n{context_block}\n\n"
        "Responda:\n"
        "1) Diga qual(is) Súmula(s) se alinham ao trecho, no formato textual: "
        "\"Este é um trecho da Súmula NNN (seção: ...) que veio do [chunk_id]\".\n"
        "2) Explique em 2–4 frases, citando trechos entre aspas com [chunk_id].\n"
        "3) Se houver incerteza, reporte top-3 candidatas.\n"
        "4) Final: 'Súmula aplicável: NNN'."
    )
    return [{"role": "system", "content": system}, {"role": "user", "content": user}]

In [22]:
_LLM_CACHE: Dict[str, Any] = {}

In [23]:
class LocalHFChat:
    def __init__(self, cfg: GeneratorConfig):
        self.cfg = cfg
        self.device = cfg.device or ("cuda" if torch.cuda.is_available() else "cpu")
        dtype = torch.bfloat16 if (self.device == "cuda" and cfg.use_bfloat16) else (torch.float16 if self.device == "cuda" else torch.float32)
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, use_fast=True)
        self.model = AutoModelForCausalLM.from_pretrained(cfg.model_name, torch_dtype=dtype, low_cpu_mem_usage=True, device_map="auto" if self.device == "cuda" else None)
        self.model.eval()
    def _to_prompt(self, messages: List[Dict[str, str]]) -> str:
        if hasattr(self.tokenizer, "apply_chat_template"):
            try:
                return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            except Exception:
                pass
        system = next((m["content"] for m in messages if m["role"] == "system"), "")
        user = "\n\n".join([m["content"] for m in messages if m["role"] == "user"])
        return f"<|system|>\n{system}\n<|user|>\n{user}\n<|assistant|>\n"
    def chat(self, messages: List[Dict[str, str]]) -> str:
        inputs = self.tokenizer(self._to_prompt(messages), return_tensors="pt").to(self.model.device)
        with torch.no_grad():
            out = self.model.generate(**inputs, do_sample=self.cfg.do_sample and (self.cfg.temperature > 0), temperature=self.cfg.temperature, top_p=self.cfg.top_p, top_k=self.cfg.top_k, max_new_tokens=self.cfg.max_new_tokens, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id)
        text = self.tokenizer.decode(out[0], skip_special_tokens=True)
        if "<|assistant|>" in text:
            text = text.split("<|assistant|>")[-1]
        return text.strip()

In [24]:
def get_llm(gen_cfg: GeneratorConfig) -> LocalHFChat:
    key = f"{gen_cfg.model_name}|{gen_cfg.device}"
    if key not in _LLM_CACHE:
        _LLM_CACHE[key] = LocalHFChat(gen_cfg)
    return _LLM_CACHE[key]

In [25]:
def rag_answer(query: str, retriever: PerNumberRetriever, enc_cfg: Optional[EncoderConfig] = None, gen_cfg: GeneratorConfig = GeneratorConfig(), top_numbers: int = 8, top_chunks_per_query: int = 160, max_evidence_per_number: int = 3, prompt_max_numbers: int = 3, prompt_evs_per_number: int = 2, enunciado_tok_cap: int = 100, evidence_tok_cap: int = 90) -> Dict[str, Any]:
    hits = gather_hits_for_query(query, retriever, enc_cfg, top_numbers, top_chunks_per_query, max_evidence_per_number)
    llm = get_llm(gen_cfg)
    hits_compact = compact_hits(hits, tokenizer=llm.tokenizer, max_numbers=prompt_max_numbers, max_evidence_per_number=prompt_evs_per_number, max_enunciado_tokens=enunciado_tok_cap, max_evidence_tokens=evidence_tok_cap)
    messages = build_prompt_from_compact_narrative(query, hits_compact)
    answer = llm.chat(messages)
    return {"answer": answer, "hits": hits_compact, "messages": messages}

In [26]:
def load_retriever(out_dir: str = OUT_DIR) -> Tuple[PerNumberRetriever, EncoderConfig]:
    paths = make_paths(out_dir)
    mani = json.loads(Path(paths.manifest).read_text(encoding="utf-8"))
    enc_cfg = EncoderConfig(model_name=mani["model_name"], max_seq_length=int(mani.get("max_seq_length", 384)))
    retriever = PerNumberRetriever(out_dir)
    return retriever, enc_cfg

In [27]:
QUERY = """PRISÃO PREVENTIVA. PORTE ILEGAL DE ARMA DE FOGO DE USO PERMITIDO, TRÁFICO
DE DROGAS E ASSOCIAÇÃO PARA O TRÁFICO. GARANTIA DA ORDEM PÚBLICA. CONVERSÃO EX
OFFICIO DA PRISÃO EM FLAGRANTE EM PREVENTIVA. IMPOSSIBILIDADE. NECESSIDADE DE
REQUERIMENTO PRÉVIO OU PELO MINISTÉRIO PÚBLICO OU PELO QUERELANTE, OU PELO
ASSISTENTE OU, POR FIM, MEDIANTE REPRESENTAÇÃO DA AUTORIDADE POLICIAL. [...] No
caso, a decisão agravada deve ser mantida, uma vez que não é possível a
decretação da prisão preventiva de ofício em face do que dispõe a Lei n.
13.964/2019, mesmo se decorrente de prisão em flagrante e se não tiver ocorrido
audiência de custódia. Isso porque não existe diferença entre a conversão da
prisão em flagrante em preventiva e a decretação da prisão preventiva como uma
primeira prisão (EDcl no AgRg no HC n. 653.425/MG, de minha relatoria, Sexta
Turma, DJe 19/11/2021)"""
QUERY = " ".join(QUERY.split())


In [28]:
retriever, enc_cfg = load_retriever(OUT_DIR)

In [29]:
gen_cfg = GeneratorConfig(
    model_name=GEN_MODEL,   # e.g. "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    temperature=0.0,
    max_new_tokens=400,
    do_sample=False
)

In [30]:
res = rag_answer(
    QUERY,
    retriever, enc_cfg, gen_cfg,
    top_numbers=8,
    top_chunks_per_query=160,
    prompt_max_numbers=3,
    prompt_evs_per_number=2,
    enunciado_tok_cap=100,
    evidence_tok_cap=90
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/723 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/402 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (2048). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.


In [31]:
print("=== RAG Answer ===")
print(res["answer"])

print("\n=== Top candidates (number, score) + evidence ids ===")
for h in res["hits"]:
    print(f"- Súmula {h['number']} | score={h['score']:.3f}")
    for ev in h["evidences"]:
        print(f"    • {ev['id']} | {ev['section']}")

=== RAG Answer ===
Trecho do usuário:
---
PRISÃO PREVENTIVA. PORTE ILEGAL DE ARMA DE FOGO DE USO PERMITIDO, TRÁFICO DE DROGAS E ASSOCIAÇÃO PARA O TRÁFICO. GARANTIA DA ORDEM PÚBLICA. CONVERSÃO EX OFFICIO DA PRISÃO EM FLAGRANTE EM PREVENTIVA. IMPOSSIBILIDADE. NECESSIDADE DE REQUERIMENTO PRÉVIO OU PELO MINISTÉRIO PÚBLICO OU PELO QUERELANTE, OU PELO ASSISTENTE OU, POR FIM, MEDIANTE REPRESENTAÇÃO DA AUTORIDADE POLICIAL.

Evidência (Súmula 676) — seção: Excerto dos Precedentes; [676#excertos_precedentes@544-864:
SÚMULA 676 — EXCERTOS DOS PRECEDENTES
Enunciado:
Fixada a pena-base no mínimo legal, é vedado o estabelecimento de regime prisional mais gravoso do que o cabível em razão da sanção imposta, com base apenas na gravidade abstrata do delito.
Evidência (Súmula 440) — seção: Excerto dos Precedentes; [440#excertos_precedentes@544-864:
SÚMULA 4

=== Top candidates (number, score) + evidence ids ===
- Súmula 676 | score=1.691
    • 676#excertos_precedentes@272-592 | excertos_precedentes
    