<a href="https://colab.research.google.com/github/anandbarman45/krux_finance_ocr/blob/main/complete_OCR_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ==============================================================================
# 0Ô∏è‚É£ Setup & Dependencies (Run First)
# ==============================================================================
import os
import sys

def install(package):
    os.system(f'pip install {package}')

# Install critical libraries
install('transformers')
install('datasets')
install('torch')
install('torchvision')
install('pillow')
install('pytesseract')
install('pdf2image')
install('accelerate') # Required for Trainer

# Install System Dependencies (Tesseract & Poppler)
if not os.path.exists("/usr/bin/pdftoppm"):
    print("‚¨áÔ∏è Installing System Tools (Tesseract + Poppler)...")
    os.system('apt-get update && apt-get install -y poppler-utils tesseract-ocr > /dev/null')

print("‚úÖ Dependencies Installed.")

# ==============================================================================
# üõ†Ô∏è STEP 2: Master Synthetic Data Generator (12 Classes)
# ==============================================================================
import random
import numpy as np
from PIL import Image, ImageDraw, ImageFont

# Define all 12 Classes
classes = [
    "GST", "COI", "GUMASTA", "UDYAM", "FSSAI",
    "EKARMIKA", "DRUG_LICENSE", "IEC", "PTEC", "TAN",
    "TRADE_LICENSE_WB", "PARTNERSHIP_DEED"
]

for c in classes:
    os.makedirs(f"dataset/{c}", exist_ok=True)

def get_font(size, bold=False):
    try:
        font_name = "LiberationSans-Bold.ttf" if bold else "LiberationSans-Regular.ttf"
        return ImageFont.truetype(font_name, size)
    except:
        return ImageFont.load_default()

def create_mock_qr(size=100):
    qr = np.random.randint(0, 2, (size, size)) * 255
    return Image.fromarray(qr.astype('uint8')).convert('RGB')

# --- Generators for Key Docs ---

def generate_gst(filename):
    img = Image.new('RGB', (1000, 1400), (255, 255, 255))
    draw = ImageDraw.Draw(img)
    draw.text((400, 50), "Government of India", fill="black", font=get_font(30, True))
    draw.text((380, 90), "Form GST REG-06", fill="black", font=get_font(24, True))
    gstin = f"27{random.randint(10000,99999)}A1Z5"
    draw.text((50, 220), f"Registration Number : {gstin}", fill="black", font=get_font(22, True))
    draw.text((600, 930), "DS GOODS AND SERVICES", fill="black", font=get_font(18))
    img.save(filename)

def generate_coi(filename):
    img = Image.new('RGB', (1000, 1400), (245, 245, 245))
    draw = ImageDraw.Draw(img)
    draw.text((280, 100), "MINISTRY OF CORPORATE AFFAIRS", fill="black", font=get_font(24, True))
    cin = f"U72900MH{random.randint(2015,2025)}PTC{random.randint(100000,999999)}"
    draw.text((50, 500), f"Corporate Identity Number: {cin}", fill="black", font=get_font(20, True))
    draw.rectangle([600, 900, 850, 1000], fill="yellow", outline="black")
    draw.text((620, 920), "DS MINISTRY OF CORPORATE", fill="black", font=get_font(18))
    img.save(filename)

def generate_gumasta(filename):
    img = Image.new('RGB', (1400, 1000), (255, 250, 240))
    draw = ImageDraw.Draw(img)
    draw.text((600, 50), "FORM 'F'", fill="black", font=get_font(36, True))
    draw.text((450, 100), "Maharashtra Shops and Establishments Act", fill="black", font=get_font(28, True))
    lic_no = f"MH/MUM/{random.randint(1000,9999)}"
    draw.text((100, 250), f"Registration No: {lic_no}", fill="red", font=get_font(28, True))
    draw.ellipse([1000, 600, 1200, 800], outline="blue", width=5)
    img.save(filename)

def generate_udyam(filename):
    img = Image.new('RGB', (1000, 1400), (255, 255, 255))
    draw = ImageDraw.Draw(img)
    draw.text((400, 180), "UDYAM", fill="darkblue", font=get_font(48, True))
    draw.text((300, 240), "REGISTRATION CERTIFICATE", fill="black", font=get_font(28, True))
    qr = create_mock_qr(150)
    img.paste(qr, (750, 50))
    udyam_no = f"UDYAM-MH-03-{random.randint(1000000,9999999)}"
    draw.text((250, 350), f"UDYAM REGISTRATION NUMBER: {udyam_no}", fill="black", font=get_font(24, True))
    img.save(filename)

def generate_fssai(filename):
    img = Image.new('RGB', (1000, 1400), (255, 255, 255))
    draw = ImageDraw.Draw(img)
    draw.text((300, 100), "Food Safety and Standards Authority of India", fill="darkblue", font=get_font(24, True))
    fssai_no = f"1{random.randint(10,20)}{random.randint(10000000000,99999999999)}"
    draw.text((350, 200), f"Lic No: {fssai_no}", fill="black", font=get_font(36, True))
    img.save(filename)

# --- Batch 2 Generators ---
def generate_ekarmika(filename):
    img = Image.new('RGB', (1000, 1400), (255, 255, 255))
    draw = ImageDraw.Draw(img)
    draw.text((350, 50), "GOVERNMENT OF KARNATAKA", fill="black", font=get_font(28, True))
    draw.text((400, 90), "DEPARTMENT OF LABOUR", fill="black", font=get_font(24, True))
    reg_no = f"KA/BNG/{random.randint(10000,99999)}"
    draw.text((100, 250), f"Registration No: {reg_no}", fill="red", font=get_font(26, True))
    img.save(filename)

def generate_drug_license(filename):
    img = Image.new('RGB', (1000, 1400), (240, 255, 240))
    draw = ImageDraw.Draw(img)
    draw.text((400, 50), "FORM 20", fill="black", font=get_font(32, True))
    draw.text((200, 150), "LICENCE TO SELL, STOCK OR EXHIBIT DRUGS", fill="black", font=get_font(24, True))
    draw.text((750, 1080), "DRUG CONTROL", fill="purple", font=get_font(18, True))
    img.save(filename)

def generate_iec(filename):
    img = Image.new('RGB', (1000, 1400), (255, 255, 255))
    draw = ImageDraw.Draw(img)
    draw.text((300, 250), "IMPORTER-EXPORTER CODE CERTIFICATE", fill="darkblue", font=get_font(26, True))
    iec_code = f"{random.randint(1000000000, 9999999999)}"
    draw.text((100, 350), f"IEC Number: {iec_code}", fill="black", font=get_font(30, True))
    img.save(filename)

def generate_ptec(filename):
    img = Image.new('RGB', (1000, 1400), (255, 255, 255))
    draw = ImageDraw.Draw(img)
    draw.text((350, 50), "FORM II", fill="black", font=get_font(32, True))
    ptec_no = f"99{random.randint(10000000,99999999)}P"
    draw.text((100, 250), f"Enrollment Certificate No: {ptec_no}", fill="black", font=get_font(24, True))
    img.save(filename)

def generate_tan(filename):
    img = Image.new('RGB', (1000, 1400), (255, 255, 255))
    draw = ImageDraw.Draw(img)
    draw.text((400, 100), "TAN ALLOTMENT LETTER", fill="black", font=get_font(24, True))
    tan_no = f"MUM{random.choice(['A','B'])}{random.randint(10000,99999)}C"
    draw.rectangle([300, 200, 700, 300], outline="black", width=2)
    draw.text((350, 240), f"TAN: {tan_no}", fill="black", font=get_font(36, True))
    img.save(filename)

# --- Batch 3 Generators ---
def generate_trade_license_wb(filename):
    img = Image.new('RGB', (1000, 1400), (255, 255, 255))
    draw = ImageDraw.Draw(img)
    draw.text((300, 50), "THE KOLKATA MUNICIPAL CORPORATION", fill="darkblue", font=get_font(26, True))
    draw.text((350, 150), "CERTIFICATE OF ENLISTMENT", fill="red", font=get_font(28, True))
    ce_no = f"{random.randint(1000000000, 9999999999)}"
    draw.text((600, 270), ce_no, fill="black", font=get_font(24, True))
    img.save(filename)

def generate_partnership_deed(filename):
    img = Image.new('RGB', (1000, 1400), (240, 255, 240))
    draw = ImageDraw.Draw(img)
    draw.rectangle([50, 50, 950, 300], outline="green", width=5)
    draw.text((400, 80), "INDIA NON JUDICIAL", fill="black", font=get_font(30, True))
    draw.text((350, 350), "DEED OF PARTNERSHIP", fill="black", font=get_font(32, True))
    img.save(filename)

print("üöÄ Generating Dataset for 12 Classes...")
for i in range(25):
    generate_gst(f"dataset/GST/gst_{i}.jpg")
    generate_coi(f"dataset/COI/coi_{i}.jpg")
    generate_gumasta(f"dataset/GUMASTA/gumasta_{i}.jpg")
    generate_udyam(f"dataset/UDYAM/udyam_{i}.jpg")
    generate_fssai(f"dataset/FSSAI/fssai_{i}.jpg")
    generate_ekarmika(f"dataset/EKARMIKA/eka_{i}.jpg")
    generate_drug_license(f"dataset/DRUG_LICENSE/dl_{i}.jpg")
    generate_iec(f"dataset/IEC/iec_{i}.jpg")
    generate_ptec(f"dataset/PTEC/ptec_{i}.jpg")
    generate_tan(f"dataset/TAN/tan_{i}.jpg")
    generate_trade_license_wb(f"dataset/TRADE_LICENSE_WB/tl_{i}.jpg")
    generate_partnership_deed(f"dataset/PARTNERSHIP_DEED/deed_{i}.jpg")
print("‚úÖ Dataset Generation Complete.")

# ==============================================================================
# üîÑ STEP 3 & 4: Processing & Training (Simplified)
# ==============================================================================
import torch
import pytesseract
from transformers import LayoutLMv3Processor, LayoutLMv3ForSequenceClassification, TrainingArguments, Trainer, default_data_collator
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split

# 1. Prepare Data
label2id = {label: i for i, label in enumerate(classes)}
processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)

def get_ocr(image):
    w, h = image.size
    df = pytesseract.image_to_data(image, output_type=pytesseract.Output.DATAFRAME).dropna()
    df = df[df.text.str.strip().astype(bool)]
    words = df.text.astype(str).tolist()
    boxes = [[max(0,min(1000,int(r['left']/w*1000))), max(0,min(1000,int(r['top']/h*1000))),
              max(0,min(1000,int((r['left']+r['width'])/w*1000))), max(0,min(1000,int((r['top']+r['height'])/h*1000)))]
             for _, r in df.iterrows()]
    if not words: words, boxes = ["empty"], [[0,0,1000,1000]]
    return words, boxes

class DocDataset(Dataset):
    def __init__(self, paths, labels): self.paths, self.labels = paths, labels
    def __len__(self): return len(self.paths)
    def __getitem__(self, i):
        img = Image.open(self.paths[i]).convert("RGB")
        words, boxes = get_ocr(img)
        enc = processor(img, words, boxes=boxes, truncation=True, padding="max_length", max_length=512, return_tensors="pt")
        enc = {k: v.squeeze() for k, v in enc.items()}
        enc['labels'] = torch.tensor(self.labels[i], dtype=torch.long)
        return enc

all_files, all_labels = [], []
for c in classes:
    path = f"dataset/{c}"
    fs = [os.path.join(path, f) for f in os.listdir(path)]
    all_files.extend(fs)
    all_labels.extend([label2id[c]] * len(fs))

train_f, test_f, train_l, test_l = train_test_split(all_files, all_labels, test_size=0.2)
train_ds = DocDataset(train_f, train_l)
test_ds = DocDataset(test_f, test_l)

# 2. Train Model
os.environ["WANDB_DISABLED"] = "true"
model = LayoutLMv3ForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base", num_labels=12) # 12 Labels
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

training_args = TrainingArguments(
    output_dir="./results_12class", max_steps=150, per_device_train_batch_size=4,
    learning_rate=5e-5, remove_unused_columns=False, report_to="none"
)

trainer = Trainer(
    model=model, args=training_args, train_dataset=train_ds,
    processing_class=processor, data_collator=default_data_collator
)

print("üöÄ Starting Training (12 Classes)...")
trainer.train()
model.save_pretrained("./saved_12class_model")
processor.save_pretrained("./saved_12class_model")
print("‚úÖ Training Complete & Model Saved.")



‚¨áÔ∏è Installing System Tools (Tesseract + Poppler)...
‚úÖ Dependencies Installed.
üöÄ Generating Dataset for 12 Classes...
‚úÖ Dataset Generation Complete.


preprocessor_config.json:   0%|          | 0.00/275 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/856 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/501M [00:00<?, ?B/s]

Some weights of LayoutLMv3ForSequenceClassification were not initialized from the model checkpoint at microsoft/layoutlmv3-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


üöÄ Starting Training (12 Classes)...




Step,Training Loss




‚úÖ Training Complete & Model Saved.


In [None]:
# ==============================================================================
# üß† STEP 6: Unified Pipeline (Strict Validation)
# ==============================================================================
import re
from google.colab import files
from pdf2image import convert_from_path

class DocumentAI:
    def __init__(self, model_path="./saved_12class_model"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = LayoutLMv3ForSequenceClassification.from_pretrained(model_path).to(self.device).eval()
        self.processor = LayoutLMv3Processor.from_pretrained(model_path, apply_ocr=False)
        self.id2label = {i: c for i, c in enumerate(classes)}

    def _heuristic_check(self, text):
        t = text.upper()
        if "CORPORATE IDENTITY NUMBER" in t or "CIN" in t: return "COI" if "GST" not in t else None
        if "GST" in t and "REG-06" in t: return "GST"
        if "UDYAM" in t and "REGISTRATION" in t: return "UDYAM"
        if "FSSAI" in t: return "FSSAI"
        if "GUMASTA" in t or ("FORM F" in t and "ESTABLISHMENTS" in t): return "GUMASTA"
        if "KARNATAKA" in t and "FORM C" in t: return "EKARMIKA"
        if "FORM 20" in t or "DRUG" in t: return "DRUG_LICENSE"
        if "IMPORTER-EXPORTER" in t or "IEC" in t: return "IEC"
        if "PROFESSION TAX" in t or "FORM II" in t: return "PTEC"
        if "TAN" in t and "DEDUCTION" in t: return "TAN"
        if "KOLKATA MUNICIPAL" in t and "ENLISTMENT" in t: return "TRADE_LICENSE_WB"
        if "DEED OF PARTNERSHIP" in t: return "PARTNERSHIP_DEED"
        return None

    def _extract(self, doc_type, text):
        data = {"type": doc_type, "id_number": "Not Found"}

        if doc_type == "GST":
            # Strict GSTIN: 2 digits + 5 chars + 4 digits + 1 char + 1 char + Z + 1 char
            match = re.search(r"(?:Number|GSTIN)[\s:\-\.]*([0-9]{2}[A-Z]{5}[0-9]{4}[A-Z]{1}[1-9A-Z]{1}Z[0-9A-Z]{1})", text, re.IGNORECASE)
            if not match: match = re.search(r"\b[0-9]{2}[A-Z]{5}[0-9]{4}[A-Z]{1}[1-9A-Z]{1}Z[0-9A-Z]{1}\b", text)
            if match: data["id_number"] = match.group(1)

        elif doc_type == "COI":
            match = re.search(r"(?:CIN|Identity\s*Number).*?([LU][0-9]{5}[A-Z]{2}[0-9]{4}[A-Z]{3}[0-9]{6})", text, re.IGNORECASE | re.DOTALL)
            if match: data["id_number"] = match.group(1)

        elif doc_type == "UDYAM":
            match = re.search(r"UDYAM-[A-Z]{2}-\d{2}-\d{7}", text, re.IGNORECASE)
            if match: data["id_number"] = match.group(0)

        elif doc_type == "FSSAI":
            match = re.search(r"(?:License|Lic).*?([0-9]{14})", text, re.IGNORECASE)
            if match: data["id_number"] = match.group(1)

        elif doc_type == "GUMASTA":
            match = re.search(r"(?:Registration\s*No)[\s:\-\.]*([A-Z0-9/]{5,25})", text, re.IGNORECASE)
            if match: data["id_number"] = match.group(1)

        elif doc_type == "IEC":
            match = re.search(r"(?:IEC\s*Number|Code).*?([0-9]{10})", text, re.IGNORECASE)
            if match: data["id_number"] = match.group(1)

        elif doc_type == "TAN":
            match = re.search(r"[A-Z]{4}[0-9]{5}[A-Z]{1}", text)
            if match: data["id_number"] = match.group(0)

        elif doc_type == "TRADE_LICENSE_WB":
            match = re.search(r"(?:CE\s*No|Enlistment).*?([0-9]{10,15})", text, re.IGNORECASE)
            if match: data["id_number"] = match.group(1)

        return data

    def analyze(self, image_path):
        img = Image.open(image_path).convert("RGB")
        words, boxes = get_ocr(img)
        full_text = " ".join(words)

        # 1. Heuristics
        doc_type = self._heuristic_check(full_text)
        conf = "100% (Rule-Based)"

        # 2. AI Model (Fallback)
        if not doc_type:
            enc = processor(img, words, boxes=boxes, truncation=True, padding="max_length", max_length=512, return_tensors="pt")
            with torch.no_grad():
                logits = self.model(input_ids=enc.input_ids.to(self.device), bbox=enc.bbox.to(self.device), pixel_values=enc.pixel_values.to(self.device), attention_mask=enc.attention_mask.to(self.device)).logits
            doc_type = self.id2label[logits.argmax(-1).item()]
            conf = f"{torch.softmax(logits, dim=1).max().item():.2%} (AI)"

        # 3. Extraction
        data = self._extract(doc_type, full_text)
        status = "VALID" if data["id_number"] != "Not Found" else "REVIEW_REQUIRED"

        return {"Type": doc_type, "Confidence": conf, "Status": status, "Data": data}

# ==============================================================================
# üöÄ TEST EXECUTION
# ==============================================================================
pipeline = DocumentAI()

print("\nUpload a document to test (Any of the 12 types):")
uploaded = files.upload()

if uploaded:
    fname = list(uploaded.keys())[0]
    print(f"Processing {fname}...")
    if fname.lower().endswith(".pdf"):
        convert_from_path(fname)[0].save("test.jpg")
        res = pipeline.analyze("test.jpg")
    else:
        res = pipeline.analyze(fname)

    print("\n" + "="*40)
    print(f"üìÑ Type:   {res['Type']}")
    print(f"‚ö†Ô∏è Status: {res['Status']}")
    print(f"üìÇ Data:   {res['Data']}")
    print("="*40)


Upload a document to test (Any of the 12 types):


Saving SPICE + Part B_Approval Letter_AB7922695.pdf to SPICE + Part B_Approval Letter_AB7922695.pdf
Processing SPICE + Part B_Approval Letter_AB7922695.pdf...

üìÑ Type:   COI
‚ö†Ô∏è Status: VALID
üìÇ Data:   {'type': 'COI', 'id_number': 'U63111RJ2025PTC107336'}


In [None]:
import shutil
from google.colab import files

# 1. Zip the dataset folder
# Syntax: shutil.make_archive(output_filename, 'format', source_dir)
print("üóúÔ∏è Zipping dataset...")
shutil.make_archive('OCRdataset', 'zip', 'dataset')

# 2. Download the zip file
print("‚¨áÔ∏è Downloading gst_coi_dataset.zip...")
files.download('OCRdataset.zip')

üóúÔ∏è Zipping dataset...
‚¨áÔ∏è Downloading gst_coi_dataset.zip...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>