<a href="https://colab.research.google.com/github/Deep7285/Invoice-Extractor/blob/main/LayoutLMV3_invoice.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!apt-get install -y poppler-utils tesseract-ocr

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
tesseract-ocr is already the newest version (4.1.1-2.1build1).
poppler-utils is already the newest version (22.02.0-2ubuntu0.9).
0 upgraded, 0 newly installed, 0 to remove and 35 not upgraded.


In [7]:
!pip -q install transformers pdf2image pillow python-docx pandas openpyxl pytesseract python-dateutil


In [8]:
%%writefile extractor_colab.py
import os, re
from typing import Dict, List, Tuple
from PIL import Image
import pdf2image
import docx
import torch
import pytesseract
from pytesseract import Output
from dateutil import parser as dateparser
from transformers import AutoProcessor, LayoutLMv3ForTokenClassification

# -------------------- CONFIG --------------------
MODEL_CHECKPOINT = "oussama/layoutlmv3-finetuned-invoice"  # public HF model
TESS_CONFIG = "--oem 3 --psm 6"  # balanced OCR setting for invoices
LANG = "eng"                      # set to "eng" (add others if needed)
MAX_LEN = 512                     # LMv3 sequence cap
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# India-specific regexes
GSTIN_RE = re.compile(r"\b[0-9]{2}[A-Z]{5}[0-9]{4}[A-Z1-9]Z[0-9A-Z]\b")
DATE_RE  = re.compile(r"\b(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})\b")
AMT_RE   = re.compile(r"(?:₹|\bINR\b|Rs\.?)\s*([0-9]{1,3}(?:,[0-9]{3})*(?:\.[0-9]{2})|[0-9]+(?:\.[0-9]{2}))", re.I)
INVNO_NEAR_RE = re.compile(r"(?:Invoice\s*(?:No|#|Number)\s*[:\-]?\s*([A-Za-z0-9\-\/]+))", re.I)

# ----------------- GLOBAL SINGLETONS -----------------
_processor = None
_model = None

def _load():
    global _processor, _model
    if _processor is None or _model is None:
        _processor = AutoProcessor.from_pretrained(MODEL_CHECKPOINT)
        _model = LayoutLMv3ForTokenClassification.from_pretrained(MODEL_CHECKPOINT)
        _model.to(DEVICE).eval()
    return _processor, _model

# -------------------- OCR HELPERS --------------------
def _ocr_words_boxes(img: Image.Image) -> Tuple[List[str], List[List[int]]]:
    """
    Tesseract OCR for word-level boxes; return (words, boxes_norm0_1000)
    """
    data = pytesseract.image_to_data(img, lang=LANG, config=TESS_CONFIG, output_type=Output.DICT)
    W, H = img.size
    words, boxes = [], []
    n = len(data["text"])
    for i in range(n):
        text = data["text"][i]
        conf = data["conf"][i]
        if text is None:
            continue
        text = text.strip()
        try:
            conf_val = float(conf)
        except Exception:
            conf_val = -1
        if text and conf_val >= 0:
            x, y, w, h = data["left"][i], data["top"][i], data["width"][i], data["height"][i]
            # normalize to 0..1000 as expected by LayoutLM family
            x0 = int(1000 * x / max(1, W))
            y0 = int(1000 * y / max(1, H))
            x1 = int(1000 * (x + w) / max(1, W))
            y1 = int(1000 * (y + h) / max(1, H))
            words.append(text)
            boxes.append([x0, y0, x1, y1])
    # Safe fallback
    if not words:
        words = [" "]
        boxes = [[0, 0, 10, 10]]
    return words, boxes

def _ocr_full_text(img: Image.Image) -> str:
    """
    Full-page OCR text for robust regex fallbacks
    """
    txt = pytesseract.image_to_string(img, lang=LANG, config=TESS_CONFIG)
    # compress excessive whitespace
    return re.sub(r"[ \t]+", " ", txt).strip()

# --------------- MODEL INFERENCE HELPERS ---------------
def _extract_entities_on_image(img: Image.Image) -> Dict[str, str]:
    """
    Run LayoutLMv3 token classification and aggregate labels → strings.
    Also returns 'ocr_text' for downstream regex postprocessing.
    """
    processor, model = _load()
    words, boxes = _ocr_words_boxes(img)
    ocr_text = _ocr_full_text(img)

    # Encode for LMv3
    encoding = processor(
        images=img,
        text=words,
        boxes=boxes,
        truncation=True,
        padding="max_length",
        max_length=MAX_LEN,
        return_tensors="pt",
    )
    encoding = {k: v.to(DEVICE) for k, v in encoding.items()}

    with torch.no_grad():
        outputs = model(**encoding)

    # Token predictions
    preds = outputs.logits.argmax(-1).squeeze().tolist()
    id2label = model.config.id2label

    # try to map input_ids back to token strings (best-effort)
    try:
        tokens = processor.tokenizer.convert_ids_to_tokens(encoding["input_ids"].squeeze().cpu())
    except Exception:
        tokens = [""] * len(preds)

    # naive token aggregation
    raw = {}
    for tok, p in zip(tokens, preds):
        label = id2label[p]
        if label == "O":
            continue
        tok = str(tok).lstrip("Ġ")
        raw[label] = (raw.get(label, "") + " " + tok).strip()

    # normalize: collapse spaces in alphanumerics like GSTIN, trim repeated spaces
    def _collapse_alnum(s: str) -> str:
        # remove spaces inside alphanumeric runs: "33 AALC A..." -> "33AALCA..."
        return re.sub(r"(?<=\w)\s+(?=\w)", "", s)

    clean = {k: _collapse_alnum(v) for k, v in raw.items()}

    return {"model_fields": clean, "ocr_text": ocr_text}

# ----------------- REGEX / SCHEMA MAPPER -----------------
def _pick_total_from_text(text: str, existing: str = "") -> str:
    # choose the largest plausible currency-like number
    nums = []
    for m in AMT_RE.finditer(text):
        amt_str = m.group(1)
        try:
            nums.append(float(amt_str.replace(",", "")))
        except Exception:
            pass
    if nums:
        return f"{max(nums):.2f}"
    return existing

def _normalize_date(s: str) -> str:
    try:
        dt = dateparser.parse(s, dayfirst=True)
        return dt.strftime("%d-%m-%Y")
    except Exception:
        return s

def _map_to_target_schema(model_fields: Dict[str, str], ocr_text: str) -> Dict[str, str]:
    """
    Map model labels + OCR text → target schema:
      GSTIN, INVOICE_NUMBER, DATE, TOTAL
    Preference order:
      1) Regex on OCR text (robust)
      2) Heuristic from model_fields (fallback)
    """
    out: Dict[str, str] = {}

    # 1) GSTIN via regex (most reliable)
    m = GSTIN_RE.search(ocr_text)
    if m:
        out["GSTIN"] = m.group(0)

    # 2) DATE via regex (first parseable date)
    m = DATE_RE.search(ocr_text)
    if m:
        out["DATE"] = _normalize_date(m.group(1))

    # 3) TOTAL: pick largest amount
    out["TOTAL"] = _pick_total_from_text(ocr_text, existing=model_fields.get("B-TOTAL", ""))

    # 4) INVOICE_NUMBER: near-keyword or fallback to any plausible model field
    m = INVNO_NEAR_RE.search(ocr_text.replace("\n", " "))
    if m:
        out["INVOICE_NUMBER"] = m.group(1)
    else:
        # fallbacks: try common model keys the checkpoint might emit
        for k in ["B-INVOICE_NUMBER", "I-INVOICE_NUMBER", "INVOICE_NUMBER", "B-BILL_NUMBER"]:
            if k in model_fields and model_fields[k]:
                out["INVOICE_NUMBER"] = model_fields[k]
                break

    # If any are missing, try to salvage from model_fields generically
    if "GSTIN" not in out:
        # search any model value that matches GSTIN pattern
        for v in model_fields.values():
            mm = GSTIN_RE.search(v)
            if mm:
                out["GSTIN"] = mm.group(0); break

    if "DATE" not in out:
        # look for date-ish values in model fields
        for v in model_fields.values():
            mm = DATE_RE.search(v)
            if mm:
                out["DATE"] = _normalize_date(mm.group(1)); break

    if "TOTAL" not in out or not out["TOTAL"]:
        # look for numeric candidates in model fields
        out["TOTAL"] = _pick_total_from_text(" ".join(model_fields.values()))

    # final cleanup (remove stray punctuation)
    for k in list(out.keys()):
        out[k] = out[k].strip().strip(":,;")
    return out

# -------------------- PUBLIC API --------------------
def extract_invoice_fields(path: str) -> Dict[str, str]:
    """
    Accepts PDF/JPG/PNG/DOCX; returns a dict with standardized keys:
      GSTIN, INVOICE_NUMBER, DATE, TOTAL
    """
    ext = os.path.splitext(path)[1].lower()
    if ext in [".jpg", ".jpeg", ".png"]:
        img = Image.open(path).convert("RGB")
        res = _extract_entities_on_image(img)
        return _map_to_target_schema(res["model_fields"], res["ocr_text"])

    if ext == ".pdf":
        pages = pdf2image.convert_from_path(path, dpi=300)
        combined_model_fields = {}
        full_text = ""
        for img in pages:
            res = _extract_entities_on_image(img)
            # merge model fields across pages (concat strings)
            for k, v in res["model_fields"].items():
                combined_model_fields[k] = (combined_model_fields.get(k, "") + " " + v).strip()
            full_text += "\n" + res["ocr_text"]
        return _map_to_target_schema(combined_model_fields, full_text)

    if ext == ".docx":
        # use DOCX text + regex
        d = docx.Document(path)
        txt = "\n".join(p.text for p in d.paragraphs)
        # crude model_fields from docx (none) → rely on regex
        return _map_to_target_schema({}, txt)

    raise ValueError(f"Unsupported file type: {ext}")

def to_excel(data: Dict[str, str], outfile: str = "invoice_extracted.xlsx"):
    import pandas as pd
    df = pd.DataFrame([data], columns=["GSTIN", "INVOICE_NUMBER", "DATE", "TOTAL"])
    df.to_excel(outfile, index=False)
    return outfile

Overwriting extractor_colab.py


In [9]:
from google.colab import files
from extractor_colab import extract_invoice_fields, to_excel

uploaded = files.upload()  # choose your PDF/JPG/PNG/DOCX
assert uploaded, "No file uploaded."
local_path = list(uploaded.keys())[0]
print("Uploaded:", local_path)

fields = extract_invoice_fields(local_path)
print("Extracted fields (LayoutLMv3 + OCR + regex):\n", fields)

xlsx = to_excel(fields, "invoice_extracted.xlsx")
files.download(xlsx)



Saving Invoice.pdf to Invoice.pdf
Uploaded: Invoice.pdf
Extracted fields (LayoutLMv3 + OCR + regex):
 {'B-BILLER': 'am azon', 'I-BILLER_ADDRESS': '. in wee ]', 'B-ABN': 'CA 01 CA 305 49 75 - 2 . 02 . 20 12 ,', 'B-BILLER_POST_CODE': '71 01 71 E 1 Z 6 17 29', 'B-DUE_DATE': '05 05', 'B-INVOICE_NUMBER': 'AA 4 - 46 M AA 4 -', 'B-SUBTOTAL': '10', 'B-GST': '34 122 12 1 999 304 94 999', 'B-TOTAL': '% . | % 1 , 00'}




<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>