<a href="https://colab.research.google.com/github/abhimshr08/OCR/blob/main/OCR_Main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import openai
import pandas as pd
import re
import base64
import requests
!pip install pdf2image
!apt-get install -y poppler-utils
from pdf2image import convert_from_path
from PIL import Image
from IPython.display import display, HTML, Audio
from google.colab import files
from getpass import getpass
from tqdm.notebook import tqdm
from termcolor import colored
import os
import traceback
import json
import ast

def cprint(msg, state='info'):
    icons = {'info': 'üî∑', 'success': '‚úÖ', 'error': '‚ùå', 'debug': 'üëÄ', 'warn': '‚ö†Ô∏è'}
    colors = {'info': 'cyan', 'success': 'green', 'error': 'red', 'debug': 'magenta', 'warn': 'yellow'}
    print(colored(f"{icons.get(state, '‚ÑπÔ∏è')} {msg}", colors.get(state, 'cyan')))

cprint("Please upload your CSV (1st column should have Google Drive links):", 'info')
uploaded = files.upload()
csv_path = next(iter(uploaded))
cprint("Now, please enter your OpenAI API Key (input box below is HIDDEN):", 'info')
openai.api_key = getpass()

def extract_drive_file_id(url):
    m = re.search(r'/d/([\w-]+)', url)
    if m: return m.group(1)
    m = re.search(r'id=([\w-]+)', url)
    if m: return m.group(1)
    m = re.search(r'/open\?id=([\w-]+)', url)
    if m: return m.group(1)
    return None

def download_from_drive(file_id, dest):
    cprint(f"Downloading Google Drive file_id {file_id} ...", 'info')
    download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
    session = requests.Session()
    response = session.get(download_url, stream=True)
    if 'Content-Disposition' in response.headers:
        with open(dest, "wb") as f:
            for chunk in response.iter_content(1024):
                f.write(chunk)
        cprint(f"File saved as {dest}", 'debug')
        return dest
    for key, value in response.cookies.items():
        if key.startswith('download_warning'):
            params = {'id': file_id, 'confirm': value}
            response = session.get(download_url, params=params, stream=True)
            with open(dest, "wb") as f:
                for chunk in response.iter_content(1024):
                    f.write(chunk)
            cprint(f"File saved as {dest}", 'debug')
            return dest
    raise Exception("Failed to download from Google Drive.")

def pdf_to_images(pdf_path):
    cprint(f"Converting PDF '{os.path.basename(pdf_path)}' to images...", 'info')
    try:
        images = convert_from_path(pdf_path)
        image_paths = []
        for i, image in enumerate(images):
            img_path = f'/content/page_{i}_{os.path.basename(pdf_path)}.png'
            image.save(img_path)
            cprint(f"Page {i+1}: {img_path}", 'debug')
            image_paths.append(img_path)
        return image_paths
    except Exception as e:
        cprint(f"PDF conversion failed: {e}", 'error')
        traceback.print_exc()
        raise

def image_to_base64(image_path):
    try:
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode()
    except Exception as e:
        cprint(f"Couldn't base64 encode image {image_path}: {e}", 'error')
        traceback.print_exc()
        return None

def robust_json_extract(text):
    """
    Try to extract a JSON object from model output, handling single/double quotes and some common syntax errors.
    """
    try:
        m = re.search(r"\{[\s\S]+\}", text)
        if m:
            chunk = m.group(0)
            # Try parsing with json.loads() directly first
            try:
                return json.loads(chunk)
            except Exception:
                # Replace single quotes with double quotes for JSON
                fixed = re.sub(r"'", '"', chunk)
                try:
                    return json.loads(fixed)
                except Exception:
                    # As a last resort, use ast.literal_eval (safe for dict-like strings)
                    try:
                        return ast.literal_eval(chunk)
                    except Exception:
                        try:
                            return ast.literal_eval(fixed)
                        except Exception:
                            pass
        # If nothing works, fall back to empty dict
        return {}
    except Exception as exc:
        print(f"robust_json_extract failed: {exc}")
        return {}

def classify_document(image_path):
    """Identify the document type using GPT vision."""
    image_b64 = image_to_base64(image_path)
    if not image_b64:
        return None
    prompt = (
        "This is an Indian legal/official/scanned document image. "
        "Classify the document type as one of the following: "
        "['Aadhar Card', 'PAN Card', 'Driving License', 'Cancelled Cheque', "
        "'Education Certificate', 'Degree Certificate', 'Marksheet', 'GST Certificate', "
        "'Passport', 'Bank Statement', 'Passbook', 'Other']. "
        "If it has a large table with transactions, bank name/logo and account info, call it 'Bank Statement' or 'Passbook'. "
        "If it's a government tax certificate with GSTIN, it's a 'GST Certificate'. "
        "If it has marks, grades, university/school info, or says 'degree', it's an 'Education Certificate', 'Degree Certificate', or 'Marksheet'. "
        "Return strict JSON: {'type': '...'}."
    )
    try:
        response = openai.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "user", "content": [
                    {"type": "text", "text": prompt},
                    {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}}
                ]}],
            max_tokens=100
        )
        text_resp = response.choices[0].message.content
        cprint(f"Classify raw: {text_resp}", "debug")
        data = robust_json_extract(text_resp)
        doc_type = data.get('type', None)
        return doc_type.strip() if doc_type else None
    except Exception as e:
        cprint(f"Classification error: {e}", 'error')
        traceback.print_exc()
        return None

def query_gpt_with_template(image_path, extraction_prompt):
    """Calls OpenAI GPT with a given prompt template for the image; expects robust JSON response"""
    image_b64 = image_to_base64(image_path)
    if not image_b64:
        return {}
    try:
        response = openai.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "user", "content": [
                    {"type": "text", "text": extraction_prompt},
                    {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}}
                ]}
            ],
            max_tokens=400
        )
        text_resp = response.choices[0].message.content
        cprint(f"OCR raw: {text_resp}", "debug")
        return robust_json_extract(text_resp)
    except Exception as e:
        cprint(f"OCR error: {e}", 'error')
        traceback.print_exc()
        return {}

PROMPT_TEMPLATES = {
    "Aadhar Card": (
        "From this Indian Aadhaar card, extract as JSON: "
        "{'name': ..., 'dob': ..., 'aadhar_number': ..., 'permanent_address': ...} "
        "(Aadhaar number is a 12-digit number, address is under 'Address'. If a field is not present, leave blank. Use dd-mm-yyyy date format.)"
    ),
    "PAN Card": (
        "From this Indian PAN card, extract as JSON: {'name': ..., 'dob': ..., 'pan_number': ...} "
        "(PAN is a 10-character alphanumeric code, use dd-mm-yyyy for dob)."
    ),
    "Driving License": (
        "From this Indian Driving License, extract as JSON: "
        "{'dl_number': ..., 'name': ..., 'dob': ..., 'address': ...} "
        "(Use dd-mm-yyyy for DOB. DL number is the driving license number.)"
    ),
    "Cancelled Cheque": (
        "From this Indian cancelled cheque or bank image, extract as JSON: "
        "{'bank_name': ..., 'account_holder_name': ..., 'account_number': ..., 'ifsc_code': ...} "
        "Use blank for any missing fields."
    ),
    "GST Certificate": (
        "From this Indian GST Certificate, extract as JSON: "
        "{'registration_number': ..., 'legal_name': ..., 'trade_name': ..., 'gst_validity': ...} "
        "(Registration number is GSTIN; trade name may be blank. GST validity is the registration date or period.)"
    ),
    "Passport": (
        "From this Indian Passport, extract as JSON: {'name': ..., 'dob': ...} (Use dd-mm-yyyy for dob.)"
    ),
    "Education Certificate": (
        "From this Indian Education Certificate, extract as JSON: "
        "{'degree_name': ..., 'learner_name': ..., 'institute_name': ..., 'year_of_passing': ...} "
        "(The degree name/type, learner's name, university/board/institute name, and passing year. Leave blank if not found.)"
    ),
    "Degree Certificate": (
        "From this Degree or Diploma certificate, extract as JSON: {'degree_name': ..., 'learner_name': ..., 'institute_name': ..., 'year_of_passing': ...} "
        "(Degree/qualification name, student's name, university/institute, year of passing. Leave blank if not found)."
    ),
    "Marksheet": (
        "From this Indian Marksheet, extract as JSON: {'learner_name': ..., 'institute_name': ..., 'exam_name': ..., 'year': ...} "
        "(Exam name can be 'CBSE', 'ICSE', subject etc; year is year of examination. Leave blank if not found.)"
    ),
    "Bank Statement": (
        "From this bank statement (may be first page), extract as JSON: "
        "{'account_holder_name': ..., 'account_number': ..., 'bank_name': ..., 'ifsc_code': ...} "
        "(These fields are usually in the header. Leave blank if not found.)"
    ),
    "Passbook": (
        "From this Indian bank passbook, extract as JSON: "
        "{'account_holder_name': ..., 'account_number': ..., 'bank_name': ..., 'ifsc_code': ...} "
        "(Top section or table header. Leave blank if not found.)"
    ),
    "Other": (
        "Extract all possible important details as JSON: "
        "{'name': ..., 'dob': ..., 'aadhar_number': ..., 'pan_number': ..., 'account_number': ..., 'registration_number': ..., 'degree_name': ..., 'learner_name': ...} "
        "If a field is not found, use blank."
    )
}

# --- Main processing ---
df = pd.read_csv(csv_path)
output = []

result_columns = [
    'link', 'document_type', 'error',
    'name', 'dob', 'aadhar_number', 'pan_number', 'permanent_address',
    'degree_name', 'learner_name', 'institute_name', 'year_of_passing',
    'exam_name', 'year',
    'dl_number', 'address',
    'bank_name', 'account_holder_name', 'account_number', 'ifsc_code',
    'registration_number', 'legal_name', 'trade_name', 'gst_validity'
]

cprint(f"Processing {len(df)} files...", 'info')

for idx, row in tqdm(df.iterrows(), total=len(df), desc="‚è≥ Records"):
    cprint(f"\n==============\nRecord {idx+1}/{len(df)}", 'info')
    drive_url = row.iloc[0] if 'link' not in row.index else row['link']
    cprint(f"Google Drive link: {drive_url}", 'debug')

    file_id = extract_drive_file_id(str(drive_url))
    if not file_id:
        cprint("No file ID found in link!", 'error')
        empty_row = {col: None for col in result_columns}
        empty_row['link'] = drive_url
        empty_row['error'] = 'No file id found'
        output.append(empty_row)
        continue

    try:
        local_path = f"/content/file_{file_id}"
        download_from_drive(file_id, local_path)
        with open(local_path, 'rb') as f:
            header = f.read(4)
        if header == b'%PDF':
            images = pdf_to_images(local_path)
        else:
            img_ext = os.path.splitext(drive_url)[1] or '.png'
            image_path = local_path + img_ext
            os.rename(local_path, image_path) if not os.path.exists(image_path) else None
            images = [image_path]

        extracted = {}
        doc_type_guess = None
        for page_idx, img in enumerate(images):
            cprint(f"Processing image/page {page_idx+1} of {len(images)}...", 'info')
            doc_type_guess = classify_document(img)
            cprint(f"Classified as: {doc_type_guess}", 'debug')
            use_type = doc_type_guess if doc_type_guess in PROMPT_TEMPLATES else "Other"
            prompt = PROMPT_TEMPLATES[use_type]
            result = query_gpt_with_template(img, prompt)
            if result and any(str(v).strip() for v in result.values()):
                extracted = result
                break

        # If nothing meaningful was extracted, fallback to generic extraction
        if not extracted or not any(str(v).strip() for v in extracted.values()):
            generic_prompt = (
                "From this document, extract as JSON: "
                "{'name': ..., 'dob': ..., 'aadhar_number': ..., 'pan_number': ..., 'account_number': ..., 'degree_name': ..., 'learner_name': ..., 'registration_number': ...}. "
                "If not found, use blank fields."
            )
            fallback = query_gpt_with_template(images[0], generic_prompt)
            extracted = fallback if any(str(v).strip() for v in fallback.values()) else {}

        row_data = {col: '' for col in result_columns}
        row_data['link'] = drive_url
        row_data['document_type'] = doc_type_guess
        row_data['error'] = None

        # Populate from extracted
        for k in row_data.keys():
            if k in extracted and extracted[k] is not None:
                row_data[k] = extracted[k]
        output.append(row_data)
        cprint(f"RESULT ‚Üí {json.dumps(row_data, indent=2)}", 'success')
    except Exception as e:
        cprint(f"File processing failed: {e}", 'error')
        traceback.print_exc()
        empty_row = {col: None for col in result_columns}
        empty_row['link'] = drive_url
        empty_row['error'] = str(e)
        output.append(empty_row)

result_df = pd.DataFrame(output, columns=result_columns)
result_path = '/content/extracted_combined_info.csv'
result_df.to_csv(result_path, index=False)

cprint("\n========== SUMMARY ==========", 'info')

def color_fail(val):
    color = '#fbe7e7' if pd.isnull(val) or str(val).lower() in ('none', '') else '#bdf7b7'
    return f'background-color: {color}'

display(HTML(result_df.style.applymap(color_fail, subset=result_df.columns[2:]).set_caption("Extraction Results").to_html()))
files.download(result_path)
cprint("Extraction completed üéâ. File is ready for download.", 'success')

try:
    display(Audio(url="https://actions.google.com/sounds/v1/alarms/beep_short.ogg", autoplay=True))
except:
    pass