In [1]:
import re
from transformers import DonutProcessor, VisionEncoderDecoderModel
from datasets import load_dataset
import torch
device = torch.device("cuda")
print("dependencies loaded")

  from .autonotebook import tqdm as notebook_tqdm


dependencies loaded


In [2]:
from transformers import AutoProcessor, AutoModelForVision2Seq

# processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
# model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")

# quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
processor = AutoProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
model = AutoModelForVision2Seq.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")

model.eval()
model = model.to(device)
print("processor and model loaded")

# task_prompt = "<s_cord-v2>"
# task_prompt = "<s_docvqa><s_question>Give 3 sets of 12 digit numbers</s_question><s_answer>" # Gives 8800 and 0101 separately but not full number
# task_prompt = "<s_docvqa><s_question>Give a number in XXXX XXXX XXXX format</s_question><s_answer>" # Gives full phone number
# task_prompt = "<s_docvqa><s_question>14 characters with 12 numbers and 2 blank space</s_question><s_answer>"
task_prompt = "<s_docvqa><s_question>What are the main numbers in this?</s_question><s_answer>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device)
print(f"prompt added : {task_prompt}")



processor and model loaded
prompt added : <s_docvqa><s_question>What are the main numbers in this?</s_question><s_answer>


In [3]:
data_dir = r"Custom-Generated Synthetic Aadhar Card Dataset for Robust Identity Authentication Research\new_generated_aadharcard_images"
output_json = "aadhar_parsed_results.json"

In [None]:
import os
# Collect all image file paths
valid_exts = (".jpg", ".jpeg", ".png", ".tif", ".tiff")
image_paths = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.lower().endswith(valid_exts)]

# Full JSON extraction

In [None]:
import os
import json
from pathlib import Path
from PIL import Image
from tqdm.auto import tqdm
import torch

# Use DonutProcessor + VisionEncoderDecoderModel
from transformers import DonutProcessor, VisionEncoderDecoderModel

# --------------------------
# Config
# --------------------------
MODEL_ID = "sourinkarmakar/kyc_v1-donut-demo"
IMAGE_FOLDER = r"Custom-Generated Synthetic Aadhar Card Dataset for Robust Identity Authentication Research\new_generated_aadharcard_images"
NUM_TEST = 5   # process only first 5 images
MAX_LENGTH = 512

# --------------------------
# Device
# --------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

# --------------------------
# Load model + processor
# --------------------------
print(f"Loading processor & model from: {MODEL_ID} ...")
processor = DonutProcessor.from_pretrained(MODEL_ID)
model = VisionEncoderDecoderModel.from_pretrained(MODEL_ID)
model.to(device)
model.eval()
print("Model loaded.\n")

# --------------------------
# Prepare image list (first NUM_TEST)
# --------------------------
image_dir = Path(IMAGE_FOLDER)
if not image_dir.exists():
    raise FileNotFoundError(f"Image folder not found: {image_dir}")

image_files = sorted([p for p in image_dir.iterdir() if p.suffix.lower() in [".jpg", ".jpeg", ".png", ".tif", ".tiff"]])
if len(image_files) == 0:
    raise FileNotFoundError("No image files found in folder.")

image_files = image_files[:NUM_TEST]
print(f"Processing {len(image_files)} images (first {NUM_TEST}).\n")

# --------------------------
# Task prompt: try kyc-specific prompt, fallback to generic
# --------------------------
# This model was trained for KYC; its repo may use task-specific prompts.
# Try a KYC-style prompt first; if output doesn't parse, we will show raw text.
task_prompts_to_try = [ "<s_cord-v2>"]  # ordered guesses

# Precompute decoder input ids for each prompt (put on device)
decoder_input_map = {}
for tp in task_prompts_to_try:
    if tp == "":
        # empty prompt
        decoder_input_map[tp] = None
    else:
        decoder_input_map[tp] = processor.tokenizer(tp, add_special_tokens=False, return_tensors="pt").input_ids.to(device)

# --------------------------
# Inference loop
# --------------------------
for img_path in tqdm(image_files, desc="KYC Donut inference"):
    print("\n" + "=" * 80)
    print("Image:", img_path.name)
    print("=" * 80)
    try:
        image = Image.open(img_path).convert("RGB")
    except Exception as e:
        print(f"Failed to open image {img_path}: {e}")
        continue

    # Preprocess (pixel_values)
    pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)

    success = False
    for task_prompt in task_prompts_to_try:
        try:
            decoder_input_ids = decoder_input_map[task_prompt]

            # Generate
            gen_kwargs = dict(
                pixel_values=pixel_values,
                max_length=MAX_LENGTH,
                pad_token_id=processor.tokenizer.pad_token_id,
                eos_token_id=processor.tokenizer.eos_token_id,
                use_cache=True,
                return_dict_in_generate=True,
                output_scores=True,
            )
            if decoder_input_ids is not None:
                gen_kwargs["decoder_input_ids"] = decoder_input_ids

            outputs = model.generate(**gen_kwargs)

            # Raw sequence (includes special tokens)
            raw_seq = processor.batch_decode(outputs.sequences, skip_special_tokens=False)[0]
            # Cleaned text (skip special tokens)
            cleaned = processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0]

            print(f"\n--- Attempt with prompt: '{task_prompt or '<empty>'}' ---")
            print("\nRaw sequence (with special tokens):")
            print(raw_seq)
            print("\nCleaned text (skip special tokens):")
            print(cleaned)

            # Try the convenient token2json if available on the processor
            parsed = None
            try:
                # Some processors expose token2json helper to decode Donut-style output to JSON
                parsed = processor.token2json(cleaned)
            except Exception:
                # fallback: try to locate JSON substring and json.loads
                try:
                    # find first { and last }
                    first = cleaned.find("{")
                    last = cleaned.rfind("}")
                    if first != -1 and last != -1 and last > first:
                        json_str = cleaned[first:last + 1]
                        parsed = json.loads(json_str)
                except Exception:
                    parsed = None

            if parsed is not None:
                print("\n--- Parsed JSON ---")
                print(json.dumps(parsed, indent=4, ensure_ascii=False))
                success = True
                break
            else:
                print("\n Could not parse a JSON object from this prompt's output. Showing raw/cleaned output above.")
                # continue trying other prompts

        except Exception as gen_e:
            print(f"\nGeneration failed for prompt '{task_prompt}': {gen_e}")
            # try next prompt

    if not success:
        print("\n No JSON parsed from any prompt for this image. Save raw cleaned text for manual inspection.")
        print("Final cleaned text (last attempt):")
        try:
            print(cleaned)
        except NameError:
            print("<no cleaned text available>")

    # Optional: token-level scores / logits (if outputs.scores exists)
    try:
        if hasattr(outputs, "scores") and outputs.scores is not None:
            print("\n--- Token-level top token & confidence per generation step ---")
            for i, step_logits in enumerate(outputs.scores):
                # compute softmax to get probabilities
                probs = torch.nn.functional.softmax(step_logits, dim=-1)
                top_idx = torch.argmax(probs, dim=-1).item()
                top_token = processor.tokenizer.decode([top_idx])
                top_prob = probs[0, top_idx].item() if probs.dim() == 2 else probs[top_idx].item()
                print(f"Step {i:03d}: token='{top_token}'  prob={top_prob:.4f}")
    except Exception:
        pass

print("\nDone.")


# QA model

In [None]:
import os
import json
from tqdm.auto import tqdm
from PIL import Image

print(f"Found {len(image_paths)} images. Running test inference on first 10 images...")

# Limit to first 5 images
test_images = image_paths[:10]

for img_path in tqdm(test_images, desc="Processing Aadhaar test images"):
    try:
        image = Image.open(img_path).convert("RGB")

        # Preprocess image
        pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)

        # Generate output
        outputs = model.generate(
            pixel_values,
            decoder_input_ids=decoder_input_ids,
            max_length=model.decoder.config.max_position_embeddings,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id,
            use_cache=True,
            bad_words_ids=[[processor.tokenizer.unk_token_id]],
            return_dict_in_generate=True,
        )

        # Decode text sequence
        sequence = processor.batch_decode(outputs.sequences)[0]
        sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
        sequence = sequence.replace(task_prompt, "").strip()

        # Try to parse JSON output (Donut outputs structured text)
        try:
            parsed = json.loads(sequence)
        except json.JSONDecodeError:
            parsed = {"raw_text": sequence}

        print(f"\n {os.path.basename(img_path)}:")
        print(json.dumps(parsed, indent=4, ensure_ascii=False))

    except Exception as e:
        print(f"Error processing {img_path}: {e}")

print("\nTest inference completed for first 10 images.")


Found 1000 images. Running test inference on first 10 images...


Processing Aadhaar test images:  10%|█         | 1/10 [00:06<00:55,  6.17s/it]


 100backside_blurred.jpg:
{
    "raw_text": "<s_docvqa><s_question> What are the main numbers in this?</s_question><s_answer> 8800</s_answer>"
}


Processing Aadhaar test images:  20%|██        | 2/10 [00:09<00:34,  4.28s/it]


 100backside_contrast_adjusted.jpg:
{
    "raw_text": "<s_docvqa><s_question> What are the main numbers in this?</s_question><s_answer> 8800 6030 0101</s_answer>"
}


Processing Aadhaar test images:  30%|███       | 3/10 [00:12<00:27,  3.88s/it]


 100backside_hue_sat_adjusted.jpg:
{
    "raw_text": "<s_docvqa><s_question> What are the main numbers in this?</s_question><s_answer> 8800 6030 0101</s_answer>"
}


Processing Aadhaar test images:  40%|████      | 4/10 [00:15<00:22,  3.70s/it]


 100backside_scaled_down.jpg:
{
    "raw_text": "<s_docvqa><s_question> What are the main numbers in this?</s_question><s_answer> 8800 6030 0101</s_answer>"
}


Processing Aadhaar test images:  50%|█████     | 5/10 [00:19<00:17,  3.55s/it]


 100backside_scaled_up.jpg:
{
    "raw_text": "<s_docvqa><s_question> What are the main numbers in this?</s_question><s_answer> 8800 6030 0101</s_answer>"
}


Processing Aadhaar test images:  60%|██████    | 6/10 [00:22<00:14,  3.62s/it]


 100front_blurred.jpg:
{
    "raw_text": "<s_docvqa><s_question> What are the main numbers in this?</s_question><s_answer> 8800</s_answer>"
}


Processing Aadhaar test images:  70%|███████   | 7/10 [00:26<00:10,  3.58s/it]


 100front_contrast_adjusted.jpg:
{
    "raw_text": "<s_docvqa><s_question> What are the main numbers in this?</s_question><s_answer> 3Hetr</s_answer>"
}


Processing Aadhaar test images:  80%|████████  | 8/10 [00:29<00:07,  3.54s/it]


 100front_hue_sat_adjusted.jpg:
{
    "raw_text": "<s_docvqa><s_question> What are the main numbers in this?</s_question><s_answer> 3Hetr</s_answer>"
}


Processing Aadhaar test images:  90%|█████████ | 9/10 [00:33<00:03,  3.45s/it]


 100front_scaled_down.jpg:
{
    "raw_text": "<s_docvqa><s_question> What are the main numbers in this?</s_question><s_answer> 3Hetr</s_answer>"
}


Processing Aadhaar test images: 100%|██████████| 10/10 [00:36<00:00,  3.65s/it]


 100front_scaled_up.jpg:
{
    "raw_text": "<s_docvqa><s_question> What are the main numbers in this?</s_question><s_answer> 3Hetr</s_answer>"
}

✅ Test inference completed for first 10 images.





In [None]:
import json
from tqdm.autonotebook import tqdm
from PIL import Image


results = {}

print(f"Found {len(image_paths)} images. Starting inference...")

for img_path in tqdm(image_paths, desc="Processing Aadhaar images"):
    try:
        image = Image.open(img_path).convert("RGB")

        # Preprocess image
        pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)

        # Generate output
        outputs = model.generate(
            pixel_values,
            decoder_input_ids=decoder_input_ids,
            max_length=model.decoder.config.max_position_embeddings,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id,
            use_cache=True,
            bad_words_ids=[[processor.tokenizer.unk_token_id]],
            return_dict_in_generate=True,
        )

        # Decode text sequence
        sequence = processor.batch_decode(outputs.sequences)[0]
        sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
        sequence = sequence.replace(task_prompt, "").strip()

        # Try to parse JSON output (Donut outputs structured text)
        try:
            parsed = json.loads(sequence)
        except json.JSONDecodeError:
            parsed = {"raw_text": sequence}

        results[os.path.basename(img_path)] = parsed

    except Exception as e:
        print(f"Error processing {img_path}: {e}")

#  Save results
with open(output_json, "w", encoding="utf-8") as f:
    json.dump(results, f, indent=4, ensure_ascii=False)

print(f"\n Inference completed. Results saved to: {output_json}")