In [1]:
from dotenv import load_dotenv
import os
import getpass


# Load environment variables from .env file
load_dotenv()

# Access groq_key
groq_key = os.getenv("GROQ_API_KEY")
if "GROQ_API_KEY" not in os.environ:
    os.environ["GROQ_API_KEY"] = getpass.getpass(groq_key)

In [None]:
# 1. Imports
import gradio as gr
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import arabic_reshaper
from bidi.algorithm import get_display
from surya.recognition import RecognitionPredictor
from surya.detection import DetectionPredictor
from surya.layout import LayoutPredictor
from langchain_groq import ChatGroq
from langchain_core.prompts import PromptTemplate

# 2. Initialize predictors
det_predictor = DetectionPredictor()
rec_predictor = RecognitionPredictor()
layout_predictor = LayoutPredictor()

# 3. Initialize Groq LLaMA-3
llm = ChatGroq(
    model="llama-3.1-8b-instant",
    temperature=0,
    max_tokens=None,
    timeout=None,
    max_retries=2,
)

In [None]:
# 4. Global OCR context store
ocr_context = {"text_output": "", "layout": None, "detection": None}

# 5. Process image function
def process_image(image_pil):
    global ocr_context

    # Convert to OpenCV format
    image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
    original_image = image_cv.copy()

    ### Layout Predictions ---
    layout_predictions = layout_predictor([image_pil])

    layout_image = original_image.copy()
    for box in layout_predictions[0].bboxes:
        pts = np.array(box.polygon, np.int32).reshape((-1, 1, 2))
        cv2.polylines(layout_image, [pts], isClosed=True, color=(0, 255, 0), thickness=2)

        label_text = f"{box.label} ({box.confidence:.2f})"
        x, y = int(box.polygon[0][0]), int(box.polygon[0][1]) - 10
        (text_w, text_h), baseline = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
        cv2.rectangle(layout_image, (x, y - text_h - baseline), (x + text_w, y + baseline), (0, 255, 0), -1)
        cv2.putText(layout_image, label_text, (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)

    ### Detection + Recognition Predictions ---
    detection_image = original_image.copy()
    predictions = rec_predictor([image_pil], det_predictor=det_predictor)

    text_output = ""
    for line in predictions[0].text_lines:
        pts = np.array(line.polygon, np.int32).reshape((-1, 1, 2))
        cv2.polylines(detection_image, [pts], isClosed=True, color=(0, 0, 255), thickness=2)

        reshaped_text = arabic_reshaper.reshape(line.text)
        bidi_text = get_display(reshaped_text)

        image_for_pil = Image.fromarray(cv2.cvtColor(detection_image, cv2.COLOR_BGR2RGB))
        draw = ImageDraw.Draw(image_for_pil)
        font = ImageFont.truetype("arial.ttf", 16)
        x, y = int(line.polygon[0][0]), int(line.polygon[0][1])
        draw.text((x, y - 25), bidi_text, font=font, fill=(255, 0, 0))
        detection_image = cv2.cvtColor(np.array(image_for_pil), cv2.COLOR_RGB2BGR)

        text_output += f"{line.text}\n"

    ### Store OCR context for QA ---
    ocr_context["text_output"] = text_output
    ocr_context["layout"] = layout_predictions
    ocr_context["detection"] = predictions

    ### --- Convert images to RGB for Gradio display ---
    layout_rgb = cv2.cvtColor(layout_image, cv2.COLOR_BGR2RGB)
    detection_rgb = cv2.cvtColor(detection_image, cv2.COLOR_BGR2RGB)

    return layout_rgb, detection_rgb, text_output

# 6. QA function using LLaMA-3
def answer_question(question):
    global ocr_context

    template = """
You are a document QA assistant.

Here is the extracted OCR text from the uploaded document:

{ocr_text}

Question: {question}

Answer:
"""

    prompt = PromptTemplate(
        input_variables=["ocr_text", "question"],
        template=template,
    )

    final_prompt = prompt.format(
        ocr_text=ocr_context.get("text_output", ""),
        question=question,
    )

    response = llm.invoke(final_prompt)
    return response.content

# 7. Gradio Interface
with gr.Blocks() as iface:
    gr.Markdown("# 📄 OCR + LLaMA-3 QA Demo")
    gr.Markdown("Upload an image, view OCR outputs, and ask questions using LLaMA-3 via Groq.")

    with gr.Row():
        image_input = gr.Image(type="pil", label="Upload Document Image")
        layout_output = gr.Image(type="numpy", label="Layout Predictions")
        detection_output = gr.Image(type="numpy", label="Detection Predictions")

    text_output = gr.Textbox(label="OCR Text Output", lines=10)

    image_input.upload(process_image, inputs=image_input, outputs=[layout_output, detection_output, text_output])

    gr.Markdown("### ❓ Ask a Question about the document")
    question_input = gr.Textbox(label="Your Question")
    answer_output = gr.Textbox(label="LLaMA-3 Answer")

    question_input.submit(answer_question, inputs=question_input, outputs=answer_output)


iface.launch()


* Running on local URL:  http://127.0.0.1:7863
* To create a public link, set `share=True` in `launch()`.




Recognizing layout: 100%|██████████| 1/1 [00:05<00:00,  5.33s/it]
Detecting bboxes: 100%|██████████| 1/1 [00:05<00:00,  5.69s/it]
Recognizing Text: 100%|██████████| 10/10 [00:32<00:00,  3.23s/it]
Recognizing layout: 100%|██████████| 1/1 [00:03<00:00,  3.92s/it]
Detecting bboxes: 100%|██████████| 1/1 [00:03<00:00,  3.82s/it]
Recognizing Text: 100%|██████████| 10/10 [00:32<00:00,  3.26s/it]
Recognizing layout: 100%|██████████| 1/1 [00:05<00:00,  5.16s/it]
Detecting bboxes: 100%|██████████| 1/1 [00:09<00:00,  9.14s/it]
Recognizing Text: 100%|██████████| 35/35 [00:48<00:00,  1.38s/it]
Recognizing layout: 100%|██████████| 1/1 [00:03<00:00,  3.42s/it]
Detecting bboxes: 100%|██████████| 1/1 [00:04<00:00,  4.07s/it]
Recognizing Text: 100%|██████████| 102/102 [01:52<00:00,  1.10s/it]
Recognizing layout: 100%|██████████| 1/1 [00:02<00:00,  2.82s/it]
Detecting bboxes: 100%|██████████| 1/1 [00:03<00:00,  3.46s/it]
Recognizing Text: 100%|██████████| 172/172 [04:08<00:00,  1.45s/it]
