# POC — International Agreements Database Mining

This notebook implements a **hybrid extraction pipeline** for noisy OCR legal agreements for the following project tasks:
- **8: Conditions for extending** (automatic vs mutual decision vs optional)
- **11: Evaluation of implementation** (review/audit/reporting) + basic attributes

POC Pipeline:

1. **Candidate retrieval** (regex triggers + neighbor expansion for recall under OCR noise)  
2. **Neural clause detection** (finetuned legal model if available, else MNLI zero-shot baseline)  
3. **NLI verification** (ContractNLI-style hypothesis test)  
4. **Normalization & structured output** (dates/durations, renewal type + notice, evaluation attributes)  



## 1) Setup & Utilities

### 1.1 Configuration and memory‑safe model loading
- Set device/CPU defaults, thresholds, and model names.
- Define shared helpers (random seed, caps, lightweight data structures).

In [None]:
# =========================
# 0) CONFIG + MEMORY-SAFE MODEL LOADING 
# =========================
import platform

# Device: GPU if available, else CPU
try:
    import torch
    DEVICE = 0 if torch.cuda.is_available() else -1
except Exception:
    DEVICE = -1

# --- Switches (set these as needed) ---
USE_FINETUNED_SEQCLS = False
FINETUNED_SEQCLS_MODEL = None  # e.g. "nlpaueb/legal-bert-base-uncased" or our finetuned checkpoint

USE_ZEROSHOT_FALLBACK = True
# Default to SMALL on Windows/CPU to avoid paging-file OSError 1455
ZEROSHOT_MODEL = "typeform/distilbert-base-uncased-mnli"

USE_NLI_VERIFIER = True
NLI_MODEL = "typeform/distilbert-base-uncased-mnli"

# Thresholds (tune on a dev set)
THRESH_RENEWAL = 0.60
THRESH_EVAL = 0.60

# Retrieval
NEIGHBOR_K = 2
MAX_CANDIDATES = 50

# Optional HeidelTime hook (off by default)
USE_HEIDELTIME = False
HEIDELTIME_JAR = None
HEIDELTIME_CONFIG = None

# Keep backward-compatible names used later in the notebook
CLAUSE_MODEL = ZEROSHOT_MODEL

# -------------------------
# Memory-safe pipeline loader
# -------------------------
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification

def make_zsc(model_name: str, device: int):
    """
    Create zero-shot-classification pipeline with safe fallbacks.
    Avoids Windows paging-file OSError by falling back to smaller MNLI models.
    """
    fallbacks = [
        model_name,
        "typeform/distilbert-base-uncased-mnli",
        "valhalla/distilbart-mnli-12-1",
    ]
    last_err = None
    for name in fallbacks:
        try:
            print(f"Loading ZSC model: {name}")
            return pipeline(
                "zero-shot-classification",
                model=name,
                device=device,
                model_kwargs={"low_cpu_mem_usage": True},
            )
        except Exception as e:
            last_err = e
            print(f"⚠️ Failed loading {name}: {e}")
    raise RuntimeError(f"Failed to load any ZSC model. Last error: {last_err}")

print("✅ Config loaded | Device:", "GPU" if DEVICE==0 else "CPU")


import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import re
import json
import random
from dataclasses import dataclass, asdict
from typing import List, Dict, Optional, Tuple, Any

import pandas as pd
from tqdm import tqdm
import dateparser
from datetime import datetime
from dateutil.relativedelta import relativedelta

from transformers import pipeline


  from .autonotebook import tqdm as notebook_tqdm


✅ Config loaded | Device: CPU


In [None]:
# -----------------------------
# CONFIG (CPU)
# -----------------------------
DEVICE = -1
RANDOM_SEED = 42

# Retrieval caps
MAX_BLOCKS = 24
MAX_SENTS  = 80
NEIGHBOR_K = 2

# Thresholds
THRESH_RENEWAL  = 0.60   # renewal type
THRESH_EVAL     = 0.60   # eval present

DERIVE_END_DATE = True

# ------------- MODEL CHOICES -------------
# 1) Clause classifier (LegalBERT fine-tuned token/sequence classification)
#    For CPU POC, use a smaller sequence classifier OR keep ZSC as fallback.
#
# Recommended: swap this to a LegalBERT-like finetune when we have one.
CLAUSE_MODEL = "typeform/distilbert-base-uncased-mnli"  # fallback verifier/classifier

# NLI verifier (ContractNLI-like verification layer)
USE_NLI_VERIFIER = True
NLI_MODEL = "typeform/distilbert-base-uncased-mnli"  # swap to stronger MNLI later (DeBERTa MNLI)

# If we later have a real finetuned model for "temporal/renewal/eval", plug it here:
# SEQCLS_MODEL = "our-finetuned-legalbert-clause-classifier"
USE_SEQCLS = False
SEQCLS_MODEL = None


### 1.2 Data I/O
- Load OCR text (plain text or JSON) into a single normalized string.
- Preserve any page markers used later for evidence and page‑level outputs.

In [3]:
def load_txt(path: str) -> str:
    with open(path, "r", encoding="utf-8", errors="ignore") as f:
        return f.read()

def load_ocr_json_with_pages(path: str) -> str:
    """
    Reconstruct text with explicit page markers to preserve provenance:
    [[PAGE=1]] ... [[PAGE=2]] ...
    """
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)

    out_lines = []
    pages = data.get("pages", [])
    for p_idx, page in enumerate(pages, start=1):
        out_lines.append(f"[[PAGE={p_idx}]]")
        for block in page.get("blocks", []):
            for line in block.get("lines", []):
                words = [w.get("value", "") for w in line.get("words", []) if w.get("value")]
                if words:
                    out_lines.append(" ".join(words))
        out_lines.append("")  # blank line between pages

    return "\n".join(out_lines)

def normalize_text(text: str) -> str:
    text = text.replace("\x0c", "\n")
    # Keep page markers intact
    text = re.sub(r"[ \t]+", " ", text)
    text = re.sub(r"\n{3,}", "\n\n", text)
    return text.strip()


### 1.3 Page‑aware segmentation
- Split OCR text into pages, then into blocks and sentences.
- Keep `(page, block_id, sent_id)` so every extracted field can cite evidence.

In [4]:
PAGE_MARKER = re.compile(r"\[\[PAGE=(\d+)\]\]")

def split_into_pages(text: str) -> List[Tuple[int, str]]:
    """
    Returns [(page_num, page_text), ...]
    If no markers exist, treat as page 1.
    """
    chunks = []
    matches = list(PAGE_MARKER.finditer(text))
    if not matches:
        return [(1, text)]

    for i, m in enumerate(matches):
        page_num = int(m.group(1))
        start = m.end()
        end = matches[i+1].start() if i+1 < len(matches) else len(text)
        page_text = text[start:end].strip()
        chunks.append((page_num, page_text))
    return chunks

# OCR-friendly sentence split: don't rely on capitalization
_SENT_SPLIT = re.compile(r"(?<=[\.\?\!])\s+|\n+")

@dataclass
class SentItem:
    sid: int
    page: int
    text: str

@dataclass
class BlockItem:
    bid: int
    page: int
    text: str

def split_sentences_with_meta(text: str, max_len: int = 1200) -> List[SentItem]:
    pages = split_into_pages(text)
    sents: List[SentItem] = []
    sid = 0
    for page_num, page_text in pages:
        raw = [s.strip() for s in _SENT_SPLIT.split(page_text) if s and s.strip()]
        # Light merge for very short OCR fragments
        merged, buf = [], ""
        for s in raw:
            if not buf:
                buf = s
            elif len(buf) < 120 and len(s) < 260:
                buf = buf + " " + s
            else:
                merged.append(buf.strip())
                buf = s
        if buf.strip():
            merged.append(buf.strip())

        for m in merged:
            sents.append(SentItem(sid=sid, page=page_num, text=m[:max_len]))
            sid += 1
    return sents

def split_blocks_with_meta(text: str, max_len: int = 2500) -> List[BlockItem]:
    pages = split_into_pages(text)
    blocks: List[BlockItem] = []
    bid = 0
    for page_num, page_text in pages:
        paras = [b.strip() for b in re.split(r"\n\s*\n+", page_text) if b and b.strip()]
        for b in paras:
            b = re.sub(r"\s+", " ", b).strip()
            if len(b) <= max_len:
                blocks.append(BlockItem(bid=bid, page=page_num, text=b))
                bid += 1
            else:
                for i in range(0, len(b), 2000):
                    blocks.append(BlockItem(bid=bid, page=page_num, text=b[i:i+2000]))
                    bid += 1
    return blocks


## 2) Candidate retrieval (high‑recall)
- Use regex patterns to over‑generate candidate clauses for **renewal**, and **evaluation**.
- Retrieve the top candidate sentences/blocks + neighboring context for downstream scoring.

In [None]:
RENEWAL_PATTERNS = [
    r"\brenew(al|ed|s|ing)?\b",
    r"\bextend(ed|s|ing)?\b",
    r"\bextension\b",
    r"\bautomatic(ally)?\s+renew\b",
    r"\bshall\s+be\s+renewed\b",
    r"\bmay\s+be\s+renewed\b",
    r"\bnon-?renewal\b",
    r"\bunless\s+terminated\b",
    r"\bnotice\b",
    r"\botherwise\s+agreed\s+upon\b",
    r"\bby\s+mutual\s+agreement\b",
    r"\bmutually\s+agreed\b",
]

EVAL_PATTERNS = [
    r"\bevaluat(e|ion|ing)\b",
    r"\breview(s|ed|ing)?\b",
    r"\bassess(ment|es|ed|ing)?\b",
    r"\bmonitor(ing|ed|s)?\b",
    r"\baudit(s|ed|ing)?\b",
    r"\breport(s|ed|ing)?\b",
    r"\bprogress\s+report\b",
    r"\bimplementation\b.*\b(review|evaluation|assessment|monitor|audit|report)\b",
]

def any_match(text: str, patterns: List[str]) -> bool:
    lt = text.lower()
    return any(re.search(p, lt) for p in patterns)

# HeidelTime-inspired temporal extraction (rules baseline)
MONTHY_DATE = re.compile(
    r"\b(?:jan(?:uary)?|feb(?:ruary)?|mar(?:ch)?|apr(?:il)?|may|jun(?:e)?|"
    r"jul(?:y)?|aug(?:ust)?|sep(?:tember)?|oct(?:ober)?|nov(?:ember)?|dec(?:ember)?)"
    r"\s+\d{1,2},?\s+\d{4}\b", re.IGNORECASE
)
ORDINAL_MONTHY_DATE = re.compile(
    r"\b(?:jan(?:uary)?|feb(?:ruary)?|mar(?:ch)?|apr(?:il)?|may|jun(?:e)?|"
    r"jul(?:y)?|aug(?:ust)?|sep(?:tember)?|oct(?:ober)?|nov(?:ember)?|dec(?:ember)?)"
    r"\s+\d{1,2}(?:st|nd|rd|th)?\s*,?\s*\d{4}\b", re.IGNORECASE
)
DAY_OF_MONTH_DATE = re.compile(
    r"\b\d{1,2}(?:st|nd|rd|th)?\s+of\s+"
    r"(?:jan(?:uary)?|feb(?:ruary)?|mar(?:ch)?|apr(?:il)?|may|jun(?:e)?|"
    r"jul(?:y)?|aug(?:ust)?|sep(?:tember)?|oct(?:ober)?|nov(?:ember)?|dec(?:ember)?)"
    r"(?:\s+in\s+the\s+year\s+of)?\s+\d{4}\b", re.IGNORECASE
)
NUM_DATE = re.compile(r"\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b")

DURATION_NUMERIC = re.compile(r"\b(\d+)\s*(?:\(\s*\d+\s*\)\s*)?(years?|months?|days?)\b", re.IGNORECASE)
NOTICE_PERIOD = re.compile(r"\b(\d+)\s*(?:\(\s*\d+\s*\)\s*)?(days?|months?|years?)\s+(?:prior|before)\b", re.IGNORECASE)

NUMBER_WORDS = {
    "one": 1, "two": 2, "three": 3, "four": 4, "five": 5,
    "six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10,
    "eleven": 11, "twelve": 12
}
WORD_DURATION = re.compile(r"\b(" + "|".join(NUMBER_WORDS.keys()) + r")\s+(years?|months?|days?)\b", re.IGNORECASE)
WORD_PAREN_DURATION = re.compile(r"\b(" + "|".join(NUMBER_WORDS.keys()) + r")\s*\(\s*(\d+)\s*\)\s*(years?|months?|days?)\b", re.IGNORECASE)

# Anchors for safe derivation
ANCHOR_START = re.compile(r"\beffective\s+date\b|\beffective\s+upon\b|\benter\s+into\s+(force|effect)\b|\bupon\s+signature\b", re.IGNORECASE)
ANCHOR_TERM  = re.compile(r"\bterm\b|\bfor\s+a\s+period\s+of\b|\bshall\s+remain\s+in\s+(force|effect)\b|\buntil\b|\bexpires?\b|\bexpiration\b", re.IGNORECASE)

def parse_dates_from_text(text: str) -> List[str]:
    found = []
    for rgx in [MONTHY_DATE, ORDINAL_MONTHY_DATE, DAY_OF_MONTH_DATE, NUM_DATE]:
        for m in rgx.findall(text):
            # Try MDY then DMY (OCR/intl ambiguity)
            dt = dateparser.parse(m, settings={"DATE_ORDER": "MDY"})
            if not dt:
                dt = dateparser.parse(m, settings={"DATE_ORDER": "DMY"})
            if dt:
                found.append(dt.date().isoformat())
    # de-dup stable order
    return sorted(set(found))

def parse_duration(text: str) -> Optional[str]:
    m = DURATION_NUMERIC.search(text)
    if m:
        return f"{m.group(1)} {m.group(2).lower()}"
    m2 = WORD_PAREN_DURATION.search(text)
    if m2:
        return f"{int(m2.group(2))} {m2.group(3).lower()}"
    m3 = WORD_DURATION.search(text)
    if m3:
        return f"{NUMBER_WORDS[m3.group(1).lower()]} {m3.group(2).lower()}"
    return None

def parse_notice_period(text: str) -> Optional[str]:
    m = NOTICE_PERIOD.search(text)
    if m:
        return f"{m.group(1)} {m.group(2).lower()}"
    return None

def derive_end_date_safe(effective_date: Optional[str], duration: Optional[str], evidence_text: str) -> Optional[str]:
    """
    Only derive end_date if:
    - we have effective_date + duration
    - and the evidence has both start anchor AND term anchor (avoid mixing notice/training periods)
    """
    if not effective_date or not duration:
        return None
    if not (ANCHOR_START.search(evidence_text) and ANCHOR_TERM.search(evidence_text)):
        return None

    m = re.match(r"^\s*(\d+)\s+(years?|months?|days?)\s*$", duration.strip(), re.IGNORECASE)
    if not m:
        return None
    n = int(m.group(1))
    unit = m.group(2).lower()

    try:
        start = datetime.fromisoformat(effective_date).date()
    except Exception:
        return None

    if unit.startswith("year"):
        end = start + relativedelta(years=n)
    elif unit.startswith("month"):
        end = start + relativedelta(months=n)
    else:
        end = start + relativedelta(days=n)

    return end.isoformat()


In [6]:
def retrieve_candidates_with_meta(
    text: str,
    patterns: List[str],
    max_blocks: int = MAX_BLOCKS,
    max_sents: int = MAX_SENTS,
    neighbor_k: int = NEIGHBOR_K
) -> Dict[str, Any]:
    blocks = split_blocks_with_meta(text)
    sents  = split_sentences_with_meta(text)

    block_hits: List[BlockItem] = []
    for b in blocks:
        if any_match(b.text, patterns):
            block_hits.append(b)
            if len(block_hits) >= max_blocks:
                break

    hit_sids = []
    for s in sents:
        if any_match(s.text, patterns):
            hit_sids.append(s.sid)
            if len(hit_sids) >= max_sents:
                break

    expanded = set()
    for sid in hit_sids:
        for j in range(max(0, sid-neighbor_k), min(len(sents), sid+neighbor_k+1)):
            expanded.add(j)

    sent_hits = [sents[i] for i in sorted(expanded)]
    return {"blocks": block_hits, "sentences": sent_hits, "all_sentences": sents}


## 2.2 Semantic filtering
- Load a zero‑shot **NLI/MNLI** model as a proxy for fine‑tuned **LegalBERT/Longformer** clause classifiers.
- Run a second NLI verification pass (ContractNLI‑style hypothesis check).

In [7]:
# Fallback: use MNLI zero-shot as "clause classifier" (POC)
zsc = make_zsc(CLAUSE_MODEL, DEVICE)

if USE_NLI_VERIFIER:
    nli = make_zsc(NLI_MODEL, DEVICE)
else:
    nli = None

def zsc_best(texts: List[str], labels: List[str]) -> List[Tuple[str, float]]:
    if not texts:
        return []
    res = zsc(texts, candidate_labels=labels, multi_label=False)
    return [(r["labels"][0], float(r["scores"][0])) for r in res]

def nli_verify(texts: List[str], hypothesis_pos: str, hypothesis_neg: str) -> List[Tuple[bool, float]]:
    """
    ContractNLI-like verification: "entailed or not".
    Using ZSC MNLI as proxy:
    - candidate_labels = [pos, neg]
    - choose pos => verified True
    """
    if not USE_NLI_VERIFIER or nli is None:
        return [(True, 1.0) for _ in texts]  # no-op verifier for ablation

    res = nli(texts, candidate_labels=[hypothesis_pos, hypothesis_neg], multi_label=False)
    out = []
    for r in res:
        lab = r["labels"][0]
        sc  = float(r["scores"][0])
        out.append((lab == hypothesis_pos, sc))
    return out


Loading ZSC model: typeform/distilbert-base-uncased-mnli


Device set to use cpu


Loading ZSC model: typeform/distilbert-base-uncased-mnli


Device set to use cpu


### 2.3 Clause labels and scoring
- Define label sets for each task (renewal and evaluation).
- Pick the best label per candidate and keep calibrated confidence scores.

In [None]:
# Stage 1: clause type classification (cheap)
RENEWAL_LABELS = [
    "Automatic renewal unless terminated or notice is given",
    "Renewal/extension requires mutual agreement",
    "Unilateral renewal/extension option",
    "Other"
]
EVAL_LABELS = [
    "Evaluation/monitoring/reporting/audit/review obligation",
    "Other"
]

# Stage 2: verification (ContractNLI-like)
H_EVAL_POS  = "This text requires evaluation, monitoring, reporting, auditing, or review of implementation."
H_EVAL_NEG  = "This text is not about evaluation, monitoring, or reporting obligations."

H_REN_POS   = "This text describes how the agreement is renewed or extended (automatic, mutual, or unilateral)."
H_REN_NEG   = "This text is not about renewal or extension."


### 2.4 Evidence schema
- Standardize evidence items: text span, page number, source (sentence/block), label, score.
- Store whether the span was NLI‑verified and the verification score.

In [9]:
@dataclass
class EvidenceItem:
    text: str
    page: int
    sid: Optional[int]     # sentence id if sentence evidence
    bid: Optional[int]     # block id if block evidence
    source: str            # "sentence" | "block"
    label: str
    score: float
    verified: bool
    verify_score: float

def build_evidence(
    sent_items: List[SentItem],
    block_items: List[BlockItem],
    labels: List[str],
    hyp_pos: str,
    hyp_neg: str,
    threshold: float
) -> List[EvidenceItem]:
    # Stage 1 classification (type)
    sent_texts  = [s.text for s in sent_items]
    block_texts = [b.text for b in block_items]

    sent_preds  = zsc_best(sent_texts, labels)
    block_preds = zsc_best(block_texts, labels)

    # Stage 2 verification
    sent_ver = nli_verify(sent_texts, hyp_pos, hyp_neg)
    block_ver = nli_verify(block_texts, hyp_pos, hyp_neg)

    out: List[EvidenceItem] = []
    for s, (lab, sc), (ok, vsc) in zip(sent_items, sent_preds, sent_ver):
        out.append(EvidenceItem(
            text=s.text, page=s.page, sid=s.sid, bid=None, source="sentence",
            label=lab, score=sc, verified=ok and vsc >= threshold, verify_score=vsc
        ))
    for b, (lab, sc), (ok, vsc) in zip(block_items, block_preds, block_ver):
        out.append(EvidenceItem(
            text=b.text, page=b.page, sid=None, bid=b.bid, source="block",
            label=lab, score=sc, verified=ok and vsc >= threshold, verify_score=vsc
        ))

    out.sort(key=lambda x: (x.verified, x.verify_score, x.score), reverse=True)
    return out


In [None]:
def extract_renewal(text: str) -> Dict[str, Any]:
    seg = retrieve_candidates_with_meta(text, RENEWAL_PATTERNS)
    sents = seg["sentences"]
    blocks = seg["blocks"]

    if not (sents or blocks):
        return {
            "renewal_type": "absent",
            "notice_period": None,
            "renewal_status": "absent",
            "renewal_evidence": []
        }

    evidence = build_evidence(
        sent_items=sents,
        block_items=blocks,
        labels=RENEWAL_LABELS,
        hyp_pos=H_REN_POS,
        hyp_neg=H_REN_NEG,
        threshold=THRESH_RENEWAL
    )
    verified = [e for e in evidence if e.verified]
    top = (verified[:10] if verified else evidence[:6])

    # Document-level decision: any verified evidence triggers type
    # Priority: automatic > mutual > unilateral
    renewal_type = "uncertain"
    if any(e.label == RENEWAL_LABELS[0] for e in verified):
        renewal_type = "automatic"
    elif any(e.label == RENEWAL_LABELS[1] for e in verified):
        renewal_type = "by_mutual_agreement"
    elif any(e.label == RENEWAL_LABELS[2] for e in verified):
        renewal_type = "unilateral_option"

    # Notice period extraction: scan verified evidence, then top evidence
    notice = None
    for e in (verified[:12] + top[:8]):
        notice = notice or parse_notice_period(e.text)
        if notice:
            break

    # If patterns hit but no verified => uncertain
    status = "found" if (renewal_type != "uncertain") else "uncertain"

    return {
        "renewal_type": renewal_type,
        "notice_period": notice,
        "renewal_status": status,
        "renewal_evidence": [asdict(e) for e in top]
    }


In [11]:
def extract_evaluation(text: str) -> Dict[str, Any]:
    seg = retrieve_candidates_with_meta(text, EVAL_PATTERNS)
    sents = seg["sentences"]
    blocks = seg["blocks"]

    if not (sents or blocks):
        return {
            "evaluation": "absent",
            "evaluation_status": "absent",
            "evaluation_evidence": []
        }

    evidence = build_evidence(
        sent_items=sents,
        block_items=blocks,
        labels=EVAL_LABELS,
        hyp_pos=H_EVAL_POS,
        hyp_neg=H_EVAL_NEG,
        threshold=THRESH_EVAL
    )
    verified = [e for e in evidence if e.verified]
    top = (verified[:8] if verified else evidence[:6])

    if verified:
        return {
            "evaluation": "present",
            "evaluation_status": "found",
            "evaluation_evidence": [asdict(e) for e in top]
        }
    else:
        # candidates existed but verifier didn't confirm -> uncertain (not absent)
        return {
            "evaluation": "uncertain",
            "evaluation_status": "uncertain",
            "evaluation_evidence": [asdict(e) for e in top]
        }


In [12]:
def baseline_keyword(text: str, patterns: List[str]) -> bool:
    return any_match(text, patterns)

In [None]:

# ============================================================
# 4) Processing — Tasks 8 (Renewal/Extension) and 11 (Evaluation)
# ============================================================

def process_document_8_11(doc_id: str, source_path: str, raw_text: str) -> Dict[str, Any]:
    """Run ONLY Task 8 (renewal/extension conditions) and Task 11 (evaluation)."""
    text = normalize_text(raw_text)

    # Baselines (keyword presence)
    b_kw_ren  = baseline_keyword(text, RENEWAL_PATTERNS)
    b_kw_eval = baseline_keyword(text, EVAL_PATTERNS)

    # Hybrid: ZSC + NLI verification
    renewal  = extract_renewal(text)     # Task 8
    evalcl   = extract_evaluation(text)  # Task 11

    return {
        "doc_id": doc_id,
        "source_path": source_path,

        # Task 8 outputs
        **renewal,

        # Task 11 outputs
        **evalcl,

        # Baselines
        "baseline_keyword_renewal": b_kw_ren,
        "baseline_keyword_eval": b_kw_eval,
    }


def list_agreements(root: str) -> List[Tuple[str, str, str]]:
    """List agreement files under `root`.

    Returns tuples (doc_uid, path, state) where:
    - doc_uid is UNIQUE across the whole corpus (derived from relative path without extension)
    - path is the chosen content source (.txt preferred, otherwise .json)
    - state is the immediate subfolder under OCR_root when possible

    This avoids collisions when multiple states contain files with the same filename.
    """
    found: Dict[str, Dict[str, str]] = {}

    root_abs = os.path.abspath(root)
    for dp, _, fnames in os.walk(root_abs):
        for fn in fnames:
            low = fn.lower()
            if not (low.endswith(".txt") or low.endswith(".json")):
                continue
            path = os.path.join(dp, fn)
            rel = os.path.relpath(path, root_abs)
            base_rel, ext = os.path.splitext(rel)
            ext = ext.lower()

            found.setdefault(base_rel, {})
            found[base_rel][ext] = path

    out: List[Tuple[str, str, str]] = []
    for base_rel, paths in found.items():
        chosen = paths.get(".txt") or paths.get(".json")
        # Derive state as the first path component under OCR_root
        state = base_rel.split(os.sep, 1)[0] if os.sep in base_rel else ""
        doc_uid = base_rel.replace(os.sep, "/")  # stable uid across OS
        out.append((doc_uid, chosen, state))

    out.sort(key=lambda x: x[0])
    return out


def run_poc_8_11(root: str, n: int = None, strategy: str = "all", seed: int = RANDOM_SEED, profiler=None) -> pd.DataFrame:
    """Run Tasks 8 & 11 over a corpus root.

    - If n is None or strategy == 'all', process ALL agreements.
    - If n is provided, process either first-n or random-n depending on strategy.

    Note: doc_id is kept as the *original filename stem* for readability,
    while doc_uid is the unique identifier across the whole dataset.
    """
    docs = list_agreements(root)
    if not docs:
        raise ValueError(f"No .txt/.json files found under: {root}")

    # Choose subset or all
    if n is None or strategy == "all" or len(docs) <= (n or len(docs)):
        chosen = docs
    else:
        if strategy == "random":
            random.seed(seed)
            chosen = random.sample(docs, n)
        else:
            chosen = docs[:n]

    rows = []
    for doc_uid, path, state in tqdm(chosen, desc=f"Tasks 8+11 ({len(chosen)} agreements)"):
        # Keep a human-readable doc_id (stem only)
        doc_id = os.path.splitext(os.path.basename(path))[0]

        if path.lower().endswith(".txt"):
            raw_text = load_txt(path)
        else:
            raw_text = load_ocr_json_with_pages(path)

        row = process_document_8_11(doc_id, path, raw_text)
        row["doc_uid"] = doc_uid
        row["state"] = state
        rows.append(row)
        if profiler is not None:
            profiler.tick()

    # Put key identifiers first
    df = pd.DataFrame(rows)
    front = [c for c in ["doc_uid", "state", "doc_id", "source_path"] if c in df.columns]
    rest = [c for c in df.columns if c not in front]
    return df[front + rest]


In [14]:
# ============================================================
# 5) Run on FULL dataset (all states) ONCE + shorten evidence + save results
#    Includes runtime & memory metrics captured during the same run
# ============================================================

import os, json, time
from pathlib import Path

# --- Output folders
out_dir = Path("tables")
out_dir.mkdir(parents=True, exist_ok=True)

OCR_ROOT = r"OCR_output"  # root folder containing state subfolders

# --- Minimal profiler (wall time + CPU + RSS peak + tracemalloc peak)
try:
    import psutil
except ImportError:
    !pip -q install psutil
    import psutil

import tracemalloc

PROC = psutil.Process(os.getpid())

def _rss_mb() -> float:
    return PROC.memory_info().rss / (1024 * 1024)

class RunProfile:
    def __init__(self, label: str):
        self.label = label
        self.t0 = None
        self.cpu0 = None
        self.rss0 = None
        self.rss_peak = None
        self.wall_s = None
        self.cpu_s = None
        self.rss_end = None
        self.rss_delta_mb = None
        self.py_heap_peak_mb_tracemalloc = None

    def __enter__(self):
        self.t0 = time.perf_counter()
        self.cpu0 = time.process_time()
        self.rss0 = _rss_mb()
        self.rss_peak = self.rss0
        tracemalloc.start()
        return self

    def tick(self):
        self.rss_peak = max(self.rss_peak, _rss_mb())

    def __exit__(self, exc_type, exc, tb):
        self.wall_s = time.perf_counter() - self.t0
        self.cpu_s = time.process_time() - self.cpu0
        self.rss_end = _rss_mb()
        self.rss_delta_mb = self.rss_end - self.rss0
        _, peak = tracemalloc.get_traced_memory()
        self.py_heap_peak_mb_tracemalloc = peak / (1024 * 1024)
        tracemalloc.stop()

    def to_dict(self):
        return {
            "label": self.label,
            "n_docs": None,  # filled after run
            "wall_s": self.wall_s,
            "cpu_s": self.cpu_s,
            "rss_start_mb": self.rss0,
            "rss_end_mb": self.rss_end,
            "rss_peak_mb": self.rss_peak,
            "rss_delta_mb": self.rss_delta_mb,
            "py_heap_peak_mb_tracemalloc": self.py_heap_peak_mb_tracemalloc,
            "root": str(OCR_ROOT),
        }

# --- Evidence shortening (Excel-friendly)
def _shorten_text(s: str, max_chars: int = 240) -> str:
    if s is None:
        return ""
    s = " ".join(str(s).split())  # collapse whitespace
    if len(s) <= max_chars:
        return s
    return s[: max_chars - 1] + "…"

def evidence_to_short_string(evidence, max_items: int = 2, max_chars: int = 240) -> str:
    """
    Convert evidence (list[dict] or list[str] or dict or str) into a short string:
    keep up to max_items snippets, each trimmed to max_chars.
    """
    if evidence is None:
        return ""
    # If it's already a string
    if isinstance(evidence, str):
        return _shorten_text(evidence, max_chars=max_chars)

    items = []
    if isinstance(evidence, dict):
        # common schema: {"text": "..."} or similar
        txt = evidence.get("text", str(evidence))
        items = [txt]
    elif isinstance(evidence, list):
        for e in evidence:
            if isinstance(e, dict):
                items.append(e.get("text", str(e)))
            else:
                items.append(str(e))
    else:
        items = [str(evidence)]

    items = [t for t in items if t and t.strip()]
    items = items[:max_items]
    items = [_shorten_text(t, max_chars=max_chars) for t in items]
    return " || ".join(items)

# ============================================================
# RUN (single pass over full dataset)
# ============================================================

with RunProfile(label="Tasks8_11_full_dataset") as prof:
    final_df = run_poc_8_11(OCR_ROOT, n=None, strategy="all", profiler=prof)

# --- Save FULL evidence separately (JSONL), so Excel stays small
full_evidence_path = out_dir / "tasks8_11_all_agreements_evidence_full.jsonl"
with open(full_evidence_path, "w", encoding="utf-8") as f:
    for _, row in final_df.iterrows():
        rec = {
            "doc_uid": row.get("doc_uid"),
            "state": row.get("state"),
            "doc_id": row.get("doc_id"),
            "source_path": row.get("source_path"),
        }
        # Store only evidence-like columns in the JSONL file
        for col in final_df.columns:
            if col.endswith("_evidence"):
                rec[col] = row.get(col)
        f.write(json.dumps(rec, ensure_ascii=False) + "\n")

# --- Shorten evidence columns in the main table (replace with short strings)
for col in [c for c in final_df.columns if c.endswith("_evidence")]:
    final_df[col] = final_df[col].apply(lambda x: evidence_to_short_string(x, max_items=2, max_chars=240))

# ============================================================
# SAVE RESULTS (CSV + Excel)
# ============================================================

csv_path = out_dir / "tasks8_11_all_agreements.csv"
xlsx_path = out_dir / "tasks8_11_all_agreements.xlsx"
final_df.to_csv(csv_path, index=False, encoding="utf-8")
final_df.to_excel(xlsx_path, index=False)

# ============================================================
# SAVE PROFILING METRICS (captured during the SAME run)
# ============================================================

profile = prof.to_dict()
profile["n_docs"] = len(final_df)
profile_df = pd.DataFrame([profile])

profile_csv = out_dir / "profiling_tasks8_11_full_dataset.csv"
profile_xlsx = out_dir / "profiling_tasks8_11_full_dataset.xlsx"
profile_df.to_csv(profile_csv, index=False, encoding="utf-8")
profile_df.to_excel(profile_xlsx, index=False)

print("Saved results to:")
print(" -", csv_path)
print(" -", xlsx_path)
print("Saved full evidence to:")
print(" -", full_evidence_path)
print("Saved profiling to:")
print(" -", profile_csv)
print(" -", profile_xlsx)

display(final_df.head(5))
display(profile_df)


Tasks 8+11 (298 agreements): 100%|██████████| 298/298 [34:34<00:00,  6.96s/it] 


Saved results to:
 - tables\tasks8_11_all_agreements.csv
 - tables\tasks8_11_all_agreements.xlsx
Saved full evidence to:
 - tables\tasks8_11_all_agreements_evidence_full.jsonl
Saved profiling to:
 - tables\profiling_tasks8_11_full_dataset.csv
 - tables\profiling_tasks8_11_full_dataset.xlsx


Unnamed: 0,doc_uid,state,doc_id,source_path,renewal_type,notice_period,renewal_status,renewal_evidence,evaluation,evaluation_status,evaluation_evidence,baseline_keyword_renewal,baseline_keyword_eval
0,Alabama/Alabama_1,Alabama,Alabama_1,d:\NLP_Project_tasks_6_8_11\OCR_output\Alabama...,absent,,absent,,absent,absent,,False,False
1,Alabama/Alabama_10,Alabama,Alabama_10,d:\NLP_Project_tasks_6_8_11\OCR_output\Alabama...,uncertain,,uncertain,1. Hotel Selection and Assignment - The host s...,present,found,and Provincial Coordinators through the Steeri...,True,True
2,Alabama/Alabama_2,Alabama,Alabama_2,d:\NLP_Project_tasks_6_8_11\OCR_output\Alabama...,absent,,absent,,absent,absent,,False,False
3,Alabama/Alabama_3,Alabama,Alabama_3,d:\NLP_Project_tasks_6_8_11\OCR_output\Alabama...,uncertain,,uncertain,7. To promote the exchange of visits of compan...,absent,absent,,True,False
4,Alabama/Alabama_4,Alabama,Alabama_4,d:\NLP_Project_tasks_6_8_11\OCR_output\Alabama...,by_mutual_agreement,,found,"4, Sponsor and promote exchanges of visits by ...",absent,absent,,True,False


Unnamed: 0,label,n_docs,wall_s,cpu_s,rss_start_mb,rss_end_mb,rss_peak_mb,rss_delta_mb,py_heap_peak_mb_tracemalloc,root
0,Tasks8_11_full_dataset,298,2075.120893,13605.25,444.109375,533.722656,856.804688,89.613281,2.924397,OCR_output
