<a href="https://colab.research.google.com/github/Jessietbl/aviation-scsirisk-showcase/blob/main/01_llm_zephyr_7b_clean.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# LLM-based Trade Data Extraction (Zephyr-7B) — Showcase

This demo extracts **monthly trade statistics** (exports, imports, trade_balance, total_trade) from Malaysian trade bulletins.

**What’s inside**
1) Load PDFs from `inputs/`  
2) Extract & clean text  
3) Prompt Zephyr-7B for JSON  
4) Parse + (optionally) benchmark vs ground truth

> Uses **sample PDFs / sample GT** only. Full thesis data/code remain private.


In [None]:
# --- 0. Imports & Config (portable: no !pip, no Colab APIs) ---
from pathlib import Path
import os, re, json, glob, gc
import numpy as np
import pandas as pd
import pdfplumber
import pytesseract
from PIL import Image
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

from src.utils import (
    preprocess_text,
    build_llm_prompt,
    parse_llm_output,
    coerce_billion,
    metrics_table
)

INPUT_DIR = Path("inputs")          # put sample PDFs here
OUT_DIR   = Path("outputs"); OUT_DIR.mkdir(parents=True, exist_ok=True)
GT_CSV    = Path("data/sample_ground_truth.csv")  # optional benchmark
MODEL_ID  = "HuggingFaceH4/zephyr-7b-beta"  # keep as-is for clarity


In [None]:
!mkdir -p src


In [None]:
from __future__ import annotations
import re, json
import numpy as np
import pandas as pd

# ---------- text cleanup ----------
def preprocess_text(text: str) -> str:
    """Light, OCR-aware normalization for bulletin text."""
    t = re.sub(r"\s+", " ", (text or "").strip())

    # Common OCR & currency fixes
    t = re.sub(r"RlVl", "RM", t)
    t = re.sub(r"R[Mm]", "RM", t)

    # Spelling variants
    t = re.sub(r"\b[bil]{1,2}ion\b", "billion", t, flags=re.I)
    t = re.sub(r"\bmill[il]on\b", "million", t, flags=re.I)

    # 123,456 → 123456 (but leave decimals)
    t = re.sub(r"(?<=\d),(?=\d{3}\b)", "", t)

    # normalize currency wording
    t = re.sub(r"ringgit\s+malaysia", "RM", t, flags=re.I)
    t = re.sub(r"rm\s*(billion|million)", r"RM \1", t, flags=re.I)
    return t


# ---------- prompt ----------
def build_llm_prompt(text: str, filename: str) -> str:
    return f"""
<s>[INST] You are an assistant extracting Malaysian monthly trade values.

TASK:
- Extract ONLY the **latest month** mentioned.
- Units: **RM BILLIONS** (millions ÷ 1,000; trillions × 1,000).
- Return **JSON only** with keys: exports, imports, trade_balance, total_trade.
- Use null for missing/uncertain fields.

SOURCE_FILE: {filename}
BULLETIN_TEXT:
{text}
[/INST]
"""


# ---------- robust JSON parsing ----------
def parse_llm_output(raw: str) -> dict | None:
    """
    Extract the first valid JSON-looking object and coerce fields.
    """
    if not raw:
        return None

    # find the outermost {...}
    matches = re.findall(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", raw, flags=re.DOTALL)
    for m in matches:
        try:
            data = json.loads(m)
        except Exception:
            continue

        out = {}
        for k in ["exports", "imports", "trade_balance", "total_trade"]:
            v = data.get(k, None)
            try:
                out[k] = float(v) if v is not None else None
            except Exception:
                out[k] = None
        return out
    return None


# ---------- unit coercion ----------
def coerce_billion(x) -> float | None:
    """Normalize values to RM **billions**."""
    if x is None:
        return None
    try:
        v = float(str(x).replace(",", ""))
    except Exception:
        return None

    # if way too large, guess units and convert
    if v > 2_000_000_000:   # raw RM
        v /= 1_000_000_000
    elif v > 2_000_000:     # millions
        v /= 1_000
    return float(v)


# ---------- metrics table ----------
def _metrics(y, yhat):
    y = np.asarray(y, float)
    yhat = np.asarray(yhat, float)
    mask = np.isfinite(y) & np.isfinite(yhat)
    if mask.sum() == 0:
        return dict(MAE=np.nan, RMSE=np.nan, R2=np.nan)
    y, yhat = y[mask], yhat[mask]
    mae  = np.mean(np.abs(yhat - y))
    rmse = np.sqrt(np.mean((yhat - y) ** 2))
    denom = np.sum((y - y.mean()) ** 2)
    r2 = 1 - np.sum((yhat - y) ** 2) / denom if denom > 1e-12 else np.nan
    return dict(MAE=mae, RMSE=rmse, R2=r2)

def metrics_table(df: pd.DataFrame, keys: list[str]) -> pd.DataFrame:
    rows = []
    for k in keys:
        pred = df.get(k)
        true = df.get(f"{k}_true")
        if pred is None or true is None:
            continue

        # light outlier guard (3×IQR on abs error)
        e = (pred - true).astype(float)
        q1, q3 = np.nanpercentile(np.abs(e), [25, 75])
        thr = q3 + 3 * (q3 - q1)
        mask = (np.abs(e) <= thr) | ~np.isfinite(e)

        sc = _metrics(true[mask], pred[mask])
        rows.append({"Metric": k.title(), **{k: round(v, 3) if v == v else np.nan for k, v in sc.items()}})
    return pd.DataFrame(rows).set_index("Metric")


In [None]:
# --- 1. PDF → text (OCR fallback) ---
MONTH_RE = re.compile(
    r"\b(January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{4}\b",
    flags=re.I
)

def extract_text_enhanced(path: Path, max_pages: int = 6) -> str:
    """
    Extract text from first few content pages; fallback to OCR when needed.
    Filters to pages that likely mention the current month.
    """
    text = []
    with pdfplumber.open(path) as pdf:
        pages = pdf.pages[1:max_pages+1] if len(pdf.pages) > 1 else pdf.pages
        for idx, page in enumerate(pages, start=2):
            raw = page.extract_text() or ""
            if len(raw.strip()) < 60 or ("export" not in raw.lower() and "trade" not in raw.lower()):
                try:
                    image = page.to_image(resolution=300).original
                    raw = pytesseract.image_to_string(image, lang="eng", config="--psm 6 --oem 3")
                except Exception:
                    raw = ""
            if raw and MONTH_RE.search(raw):
                text.append(f"\n--- PAGE {idx} ---\n{raw}")

        # fallback: if nothing caught the month, just append raw text
        if not "".join(text).strip():
            for idx, page in enumerate(pages, start=2):
                raw = page.extract_text() or ""
                if raw:
                    text.append(f"\n--- PAGE {idx} ---\n{raw}")

    return preprocess_text("\n".join(text))[:8000]


In [None]:
# --- 2. Load Zephyr-7B (4-bit optional to keep VRAM small) ---
def load_llm(model_id: str = MODEL_ID):
    qconf = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype="bfloat16",
        bnb_4bit_use_double_quant=True,
    )
    tok = AutoTokenizer.from_pretrained(model_id)
    tok.pad_token = tok.eos_token
    mdl = AutoModelForCausalLM.from_pretrained(
        model_id, device_map="auto", quantization_config=qconf, trust_remote_code=True
    )
    return pipeline(
        "text-generation",
        model=mdl,
        tokenizer=tok,
        return_full_text=False,
        pad_token_id=tok.eos_token_id,
    )

pipe = load_llm()


In [None]:
# --- 3. Single-file extraction (PDF → text → LLM JSON → dict) ---
def extract_llm_only(pdf_path: Path) -> dict | None:
    fname = pdf_path.name
    txt = extract_text_enhanced(pdf_path)
    if not txt or len(txt) < 120:
        return None

    prompt = build_llm_prompt(txt, fname)
    raw = pipe(prompt, max_new_tokens=256, do_sample=False)[0]["generated_text"]
    parsed = parse_llm_output(raw)

    if not parsed:
        return None

    # normalize to RM billions where possible
    for k in ["exports", "imports", "trade_balance", "total_trade"]:
        parsed[k] = coerce_billion(parsed.get(k))
    parsed.update({"file": fname, "method": "llm_only", "text_len": len(txt)})
    return parsed


In [None]:
# --- 4. Batch run over inputs/ and save CSV ---
pdfs = sorted(INPUT_DIR.rglob("*.pdf"))
results = []
for p in pdfs:
    r = extract_llm_only(p)
    if r: results.append(r)
    gc.collect()

df_pred = pd.DataFrame(results)
pred_csv = OUT_DIR / "llm_zephyr7b_extraction.csv"
df_pred.to_csv(pred_csv, index=False)
pred_csv


In [None]:
# --- 5. Benchmark against sample ground truth ---
if GT_CSV.exists() and not df_pred.empty:
    gt = pd.read_csv(GT_CSV)
    # try to align on filename column
    key = next((c for c in ["file","filename","Source_File","PDF_Name","PDF"] if c in gt.columns), None)
    if key is None:
        raise ValueError("Ground truth CSV must include a filename column.")

    merged = pd.merge(df_pred, gt, left_on="file", right_on=key, how="left", suffixes=("","_true"))

    # compute metrics table if *_true exist
    keep = ["exports","imports","trade_balance","total_trade"]
    have_truth = [k for k in keep if f"{k}_true" in merged.columns]
    if have_truth:
        table = metrics_table(merged, have_truth)
        display(table)
        out_bench = OUT_DIR / "llm_benchmark_comparison.csv"
        merged.to_csv(out_bench, index=False)
        out_bench


## ✅ Notes
- This demo favors **clarity and portability** over maximum accuracy.
- Full thesis notebooks (data wrangling, advanced prompts, hybrid regex+LLM) remain private.
