the main difference is that this operates on normalized_enhanced.xml (by Ryan)

In [1]:
import os
import re
import json
import xml.etree.ElementTree as ET
from pathlib import Path
import requests

import keys

In [2]:
# ------------------------------
# CONFIGURATION
# ------------------------------
INPUT_XML = "normalized_enhanced.xml"
OUTPUT_DIR = "usc_cases_output"
os.makedirs(OUTPUT_DIR, exist_ok=True)

OPENROUTER_API_KEY = keys.OPENROUTER_KEY  # ← Insert key

GENERATION_MODEL = "google/gemini-2.5-pro"
FILTER_MODEL      = "google/gemini-2.5-flash"

TEMPERATURE = 0.1

In [3]:
# ------------------------------
# OpenRouter request wrapper
# ------------------------------
def openrouter_request(model, prompt, temperature=TEMPERATURE):
    url = "https://openrouter.ai/api/v1/chat/completions"

    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {OPENROUTER_API_KEY}",
    }

    payload = {
        "model": model,
        "messages": [
            {"role": "user", "content": prompt}
        ],
        "temperature": temperature,
    }

    r = requests.post(url, headers=headers, json=payload)
    r.raise_for_status()
    return r.json()["choices"][0]["message"]["content"]

In [4]:
# ------------------------------
# XML helpers
# ------------------------------
def pretty_xml(elem):
    from xml.dom import minidom
    xml_str = ET.tostring(elem, encoding="utf-8")
    return minidom.parseString(xml_str).toprettyxml(indent="  ")


def clean_json(raw):
    """Strip markdown/code fences/etc. to recover JSON."""
    raw = raw.strip().strip("`")
    start = raw.find("{")
    return raw[start:]

In [5]:
# ------------------------------
# CASE TEXT EXTRACTION
# ------------------------------
def extract_case_text(case_elem):
    """Concatenate all relevant text fields for grounding the LLM."""
    parts = []

    for tag in ["Metadata", "Facts", "Issues", "Holding", "Decision"]:
        section = case_elem.find(tag)
        if section is None:
            continue

        txt = " ".join(
            p.text for p in section.iter() if p.text
        )
        if txt.strip():
            parts.append(f"{tag.upper()}:\n{txt}")

    return "\n\n".join(parts)

In [6]:
# ------------------------------
# LLM PROMPTS
# ------------------------------

def generation_prompt(context_text):
    return f"""
You are an expert legal analyst.

Generate *exactly 10* high-quality, professional-grade,
difficult, varied Question/Answer pairs.

RULES:
- 100% grounded in the case text
- No external knowledge
- Lawyer-level realism
- Difficult questions requiring synthesis
- Varied: facts, reasoning, issues, holding, procedure, etc.
- Output ONLY valid JSON:
{{
  "qa_pairs": [
    {{"question": "...", "answer": "..."}}
  ]
}}

CASE TEXT:
=====================
{context_text}
=====================
"""

In [7]:
def filter_prompt(raw_json_text):
    return f"""
You are a legal question quality-control judge.

Your task is to FILTER and IMPROVE the Q&A pairs.

Rules:
1. Keep ONLY realistic, lawyer-level professional questions.
2. Must require reasoning/synthesis, not trivia.
3. Must cover varied aspects: facts, procedural posture,
   issues, holding, reasoning, consequences.
4. Enforce ZERO hallucination: 100% grounded in source.
5. Ensure exactly 10 Q&A pairs.
6. Output ONLY valid JSON in this structure:

{{
  "qa_pairs": [
    {{"question": "...", "answer": "..."}}
  ]
}}

INPUT QAs:
=====================
{raw_json_text}
=====================
"""

In [8]:
# ------------------------------
# GENERATION PIPELINE
# ------------------------------

def generate_qa_pairs(context_text):
    # Stage 1 → generate raw pairs
    raw = openrouter_request(
        GENERATION_MODEL,
        generation_prompt(context_text)
    )

    # Stage 2 → filter/refine/validate
    filtered = openrouter_request(
        FILTER_MODEL,
        filter_prompt(raw)
    )

    # Parse JSON result
    cleaned = clean_json(filtered)
    data = json.loads(cleaned)
    return data

In [9]:
# ------------------------------
# OUTPUT XML
# ------------------------------
def write_case_output(case_elem, qa_data, index):
    case_id   = case_elem.get("slug", "UNKNOWN")
    case_name = case_elem.get("name", "Unknown_Case")

    root = ET.Element("LegalDataEngineeringOutput")

    src = ET.SubElement(root, "SourceDataset", {"docId": case_id})
    # Embed original case XML node
    src.append(case_elem)

    qas = ET.SubElement(root, "QAPairs", {
        "caseId": case_id,
        "caseName": case_name
    })

    for pair in qa_data["qa_pairs"]:
        p = ET.SubElement(qas, "Pair")
        q = ET.SubElement(p, "Question")
        q.text = pair["question"]
        a = ET.SubElement(p, "Answer")
        a.text = pair["answer"]

    # Filename: numbered + safe
    safe_name = re.sub(r"[^a-zA-Z0-9_]+", "_", case_name)
    filename = f"{index:02d}_{safe_name}.xml"
    path = Path(OUTPUT_DIR) / filename

    with open(path, "w", encoding="utf-8") as f:
        f.write(pretty_xml(root))

    return str(path)

In [10]:
# ------------------------------
# MAIN PIPELINE
# ------------------------------

tree = ET.parse(INPUT_XML)
root = tree.getroot()
cases = root.findall("Case")

print(f"Found {len(cases)} cases")

output_paths = []

for idx, case_elem in enumerate(cases, start=1):
    case_name = case_elem.get("name")
    print(f"\nProcessing case {idx}: {case_name}")

    context = extract_case_text(case_elem)
    qa_data = generate_qa_pairs(context)

    path = write_case_output(case_elem, qa_data, idx)
    print(" → Saved:", path)

print("\nAll cases completed.")

Found 3303 cases

Processing case 1: Fletcher v. Peck
 → Saved: usc_cases_output/01_Fletcher_v_Peck.xml

Processing case 2: Martin v. Hunter's Lessee
 → Saved: usc_cases_output/02_Martin_v_Hunter_s_Lessee.xml

Processing case 3: McCulloch v. Maryland


KeyboardInterrupt: 