In [None]:
!pip install qwen_vl_utils bitsandbytes accelerate

In [None]:
!pip install flash-attn --no-build-isolation

In [None]:
# 1. Install aria2 (a multi-source download utility)
!apt-get install -y -qq aria2

# 2. Download AZ.zip using 16 connections
# -x 16: use 16 connections
# -s 16: split file into 16 parts
# -o: output filename
print("Starting optimized download...")
!aria2c -x 16 -s 16 -o AZ.zip "https://zenodo.org/records/13852757/files/AZ.zip?download=1"

# 3. Unzip as before
print("Download complete. Extracting...")
!unzip -q AZ.zip -d ./IdNet_Data
print("Done!")

In [None]:
import os, shutil, random

# Create 'images' directory
os.makedirs('images', exist_ok=True)
print("Created 'images' directory.")

# Create 'labels' directory
os.makedirs('labels', exist_ok=True)
print("Created 'labels' directory.")

source_dir = '/content/IdNet_Data/AZ/positive'
dest_dir = 'images'

# Ensure the destination directory exists (already created in previous step, but good practice)
os.makedirs(dest_dir, exist_ok=True)

# Get a list of all files in the source directory
files_to_copy = [f for f in os.listdir(source_dir) if os.path.isfile(os.path.join(source_dir, f))]

print(f"Copying {len(files_to_copy)} files from '{source_dir}' to '{dest_dir}'...")

for filename in files_to_copy:
    source_path = os.path.join(source_dir, filename)
    destination_path = os.path.join(dest_dir, filename)
    shutil.copy(source_path, destination_path)

print("Finished copying files to 'images' directory.")


source_dir = '/content/IdNet_Data/AZ/meta/basic'
dest_dir = 'labels'

# Ensure the destination directory exists
os.makedirs(dest_dir, exist_ok=True)

# Get a list of all files in the source directory
files_to_copy = [f for f in os.listdir(source_dir) if os.path.isfile(os.path.join(source_dir, f))]

print(f"Copying {len(files_to_copy)} files from '{source_dir}' to '{dest_dir}'...")

for filename in files_to_copy:
    source_path = os.path.join(source_dir, filename)
    destination_path = os.path.join(dest_dir, filename)
    shutil.copy(source_path, destination_path)

print("Finished copying files to 'labels' directory.")




import os
import json

# Directory to process
target_dir = 'labels'

# Mapping from Source Key -> Target Key
key_mapping = {
    "first_name": "first_name",
    "last_name": "last_name",
    "address": "address",
    "birthday": "DOB",
    "gender": "SEX",
    "class": "CLASS",
    "issue_date": "ISS",
    "expire_date": "EXP",
    "height": "HGT",
    "weight": "WGT",
    "eye_color": "EYES",
    "license_number": "DLN"
}

print(f"Processing files in '{target_dir}' to rename fields and match template...")

files = [f for f in os.listdir(target_dir) if f.endswith('.json')]
count = 0

for filename in files:
    path = os.path.join(target_dir, filename)
    try:
        with open(path, 'r') as f:
            data = json.load(f)

        # 1. Handle Name splitting if source has 'name' but not 'first_name'
        if "first_name" not in data and "name" in data:
            parts = data['name'].split(' ', 1)
            data['first_name'] = parts[0]
            data['last_name'] = parts[1] if len(parts) > 1 else ""

        # 2. Create new dictionary with only the target keys
        new_data = {}
        for source_key, target_key in key_mapping.items():
            # Use the mapped source key if it exists, otherwise mapping might mean
            # the key is the same (e.g. first_name -> first_name)
            # Ideally we look for the source_key in data.

            # Check if we need to look up by the 'old' name corresponding to the new name
            # The mapping above is explicit: source_key_name -> target_key_name
            # So we look up 'source_key' in 'data'.

            if source_key in data:
                new_data[target_key] = data[source_key]
            elif target_key in data:
                # Fallback: if the data already has the target key (e.g. first_name)
                new_data[target_key] = data[target_key]
            else:
                new_data[target_key] = None # or "" or skip

        # 3. Overwrite the file
        with open(path, 'w') as f:
            json.dump(new_data, f, indent=2)

        count += 1

    except Exception as e:
        print(f"Error processing {filename}: {e}")

print(f"Successfully updated {count} label files.")






# Create 'ho_img' directory for holdout images
os.makedirs('ho_img', exist_ok=True)
print("Created 'ho_img' directory.")

# Create 'ho_label' directory for holdout labels
os.makedirs('ho_label', exist_ok=True)
print("Created 'ho_label' directory.")

image_dir = 'images'
label_dir = 'labels'
ho_img_dir = 'ho_img'
ho_label_dir = 'ho_label'

# Get all image files
all_image_files = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]

# Calculate 10% for holdout
holdout_percentage = 0.10
num_holdout_files = int(len(all_image_files) * holdout_percentage)

# Randomly select files for holdout
random.seed(42) # for reproducibility
holdout_image_files = random.sample(all_image_files, num_holdout_files)

print(f"Total images: {len(all_image_files)}")
print(f"Number of holdout files to move: {num_holdout_files}")

# Move selected files and their corresponding labels
moved_count = 0
for img_filename in holdout_image_files:
    # Determine corresponding label filename (assuming same base name, different extension)
    base_name, _ = os.path.splitext(img_filename)
    label_filename = f"{base_name}.json" # Assuming labels are JSON files

    source_img_path = os.path.join(image_dir, img_filename)
    dest_img_path = os.path.join(ho_img_dir, img_filename)

    source_label_path = os.path.join(label_dir, label_filename)
    dest_label_path = os.path.join(ho_label_dir, label_filename)

    # Move image file
    if os.path.exists(source_img_path):
        shutil.move(source_img_path, dest_img_path)
        moved_count += 1
    else:
        print(f"Warning: Image file not found: {source_img_path}")

    # Move corresponding label file
    if os.path.exists(source_label_path):
        shutil.move(source_label_path, dest_label_path)
    else:
        print(f"Warning: Label file not found: {source_label_path} for image {img_filename}")

print(f"Moved {moved_count} image-label pairs to holdout directories.")

In [None]:
import torch
from transformers import AutoProcessor, AutoConfig

# Robust import for the model class handling potential version discrepancies
try:
    from transformers import AutoModelForVision2Seq
except ImportError:
    print("AutoModelForVision2Seq not found. Falling back to AutoModel.")
    from transformers import AutoModel as AutoModelForVision2Seq

model_name = "numind/NuExtract-2.0-4B"

# Check for GPU availability
if torch.cuda.is_available():
    print("GPU detected. Loading model in bfloat16 (optimized for L4)...")
    # L4 supports bfloat16 and Flash Attention 2
    # 4B params in bfloat16 = ~8GB VRAM. L4 has 22.5GB, so this fits comfortably.
    model = AutoModelForVision2Seq.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        #device_map="auto",
        trust_remote_code=True,
        attn_implementation="flash_attention_2" # Enable Flash Attention 2 for L4
    )
else:
    print("No GPU detected. Loading model in CPU mode (float32).")



print(f"Loading processor for {model_name}...")
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)

print("Model and Processor loaded successfully.")

In [None]:
import sys
import os
import json
import re
import torch
import warnings
import subprocess
from tqdm import tqdm
from PIL import Image
from dateutil import parser  # For smart date comparison

# Suppress specific warnings
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")

# Ensure qwen-vl-utils is installed
try:
    from qwen_vl_utils import process_vision_info
except ImportError:
    print("Installing qwen-vl-utils...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "qwen-vl-utils"])
    from qwen_vl_utils import process_vision_info

from transformers import AutoProcessor, AutoModelForCausalLM

# Robust import logic for the Model Class
try:
    from transformers import Qwen2_5_VLForConditionalGeneration
    ModelClass = Qwen2_5_VLForConditionalGeneration
except ImportError:
    try:
        from transformers import AutoModelForVision2Seq
        ModelClass = AutoModelForVision2Seq
    except ImportError:
        ModelClass = AutoModelForCausalLM

model_name = "numind/NuExtract-2.0-4B"

# Load Model
print(f"Loading model: {model_name}...")

# FORCE GPU LOADING
if torch.cuda.is_available():
    print("GPU detected. Loading model explicitly on CUDA device...")
    model = ModelClass.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        attn_implementation="flash_attention_2"
    ).to("cuda")
else:
    model = ModelClass.from_pretrained(
        model_name,
        device_map="cpu",
        torch_dtype=torch.float32,
        trust_remote_code=True
    )

processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
print("Model loaded.")

# --- UPDATED TEMPLATE: Matching Ground Truth Schema ---
template = json.dumps({
    "first_name": "verbatim-string",
    "last_name": "verbatim-string",
    "address": "verbatim-string",
    "DOB": "verbatim-string",
    "SEX": "verbatim-string",
    "CLASS": "verbatim-string",
    "ISS": "verbatim-string",
    "EXP": "verbatim-string",
    "HGT": "verbatim-string",
    "WGT": "verbatim-string",
    "EYES": "verbatim-string",
    "DLN": "verbatim-string"
})

# Inference Function
def predict_nu_extract(image_path, model, processor, template):
    prompt_text = f"# Template:\n{template}\n# Context:\n<|vision_start|><|image_pad|><|vision_end|>"
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image_path},
                {"type": "text", "text": prompt_text},
            ],
        }
    ]
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    ).to(model.device)

    # Generate
    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_new_tokens=1024)

    generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
    output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    return output_text

def normalize_text(text):
    """Lowercases and removes punctuation for fairer comparison."""
    if not text: return ""
    # Remove non-alphanumeric characters (keep spaces)
    text = re.sub(r'[^\w\s]', '', text)
    return text.lower().strip()

def normalize_date(text):
    """Parses dates to allow YYYY-MM-DD to match MM/DD/YYYY."""
    if not text: return ""
    try:
        # Parse string to datetime (handles most formats automatically)
        dt = parser.parse(str(text), fuzzy=True)
        # Return as standard YYYY-MM-DD string
        return dt.strftime("%Y-%m-%d")
    except:
        # If parsing fails, fall back to text normalization
        return normalize_text(str(text))

# Evaluation Logic
ho_img_dir = "ho_img"
ho_label_dir = "ho_label"
log_filename = "evaluation_log_exp.txt"

# Initialize Metrics
template_keys = list(json.loads(template).keys())
field_stats = {key: {'correct': 0, 'total': 0} for key in template_keys}

image_files = [f for f in os.listdir(ho_img_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]

# --- Run on full holdout set ---
subset_size = 597
image_files_subset = image_files[:subset_size]

print(f"Starting evaluation on {len(image_files_subset)} images...\n")

with open(log_filename, "w") as log_file:
    log_file.write("EVALUATION DETAILED LOG (Smart Date & Text Comparison)\n")
    log_file.write("="*60 + "\n\n")

    for img_filename in tqdm(image_files_subset):
        img_path = os.path.join(ho_img_dir, img_filename)
        label_filename = os.path.splitext(img_filename)[0] + ".json"
        label_path = os.path.join(ho_label_dir, label_filename)

        if not os.path.exists(label_path):
            continue

        # Load Ground Truth
        with open(label_path, 'r') as f:
            gt_data = json.load(f)

        # Run Prediction
        try:
            prediction_str = predict_nu_extract(img_path, model, processor, template)
            pred_data = json.loads(prediction_str)
        except Exception as e:
            pred_data = {}

        # Write details to log
        log_file.write(f"--- Image: {img_filename} ---\n")
        log_file.write("PREDICTED:\n")
        log_file.write(json.dumps(pred_data, indent=2) + "\n")
        log_file.write("GROUND TRUTH:\n")
        log_file.write(json.dumps(gt_data, indent=2) + "\n")

        # Compare Fields
        mismatches = []
        for key in template_keys:
            gt_val = gt_data.get(key, None)
            pred_val = pred_data.get(key, None)

            # Convert to strings if they exist, else empty string
            gt_str = str(gt_val) if gt_val is not None else ""
            pred_str = str(pred_val) if pred_val is not None else ""

            # Use Date Normalization for date fields, Text for others
            if key in ["DOB", "ISS", "EXP"]:
                gt_norm = normalize_date(gt_str)
                pred_norm = normalize_date(pred_str)
            else:
                gt_norm = normalize_text(gt_str)
                pred_norm = normalize_text(pred_str)

            # Update Stats
            if gt_norm == pred_norm:
                field_stats[key]['correct'] += 1
            else:
                mismatches.append(f"{key}: '{gt_str}' != '{pred_str}'")

            field_stats[key]['total'] += 1

        if mismatches:
            log_file.write("MISMATCHES:\n" + "\n".join(mismatches) + "\n")
        else:
            log_file.write("RESULT: Perfect Match\n")

        log_file.write("\n" + "-"*40 + "\n\n")

    # --- Generate Report ---
    report_lines = []
    report_lines.append("\n" + "="*40)
    report_lines.append(f"EVALUATION RESULTS (Subset: {subset_size} images)")
    report_lines.append("="*40)

    report_lines.append(f"{ 'FIELD':<20} | { 'ACCURACY':<10} | { 'CORRECT':<8} / { 'TOTAL'}")
    report_lines.append("-"*50)

    for key in template_keys:
        stats = field_stats[key]
        acc = (stats['correct'] / stats['total']) * 100 if stats['total'] > 0 else 0.0
        report_lines.append(f"{key:<20} | {acc:>9.2f}% | {stats['correct']:<8} / {stats['total']}")

    report_lines.append("-"*50)
    final_report = "\n".join(report_lines)

    print(final_report)
    log_file.write(final_report)

print(f"\nLog saved to {log_filename}")

In [None]:
import re

log_file_path = '/content/evaluation_log_exp.txt'

total_correct = 0
total_count = 0

try:
    with open(log_file_path, 'r') as f:
        lines = f.readlines()

    print("Parsing evaluation log for total accuracy...")

    # We look for lines in the summary table that contain the counts
    # Format: Field | Accuracy | Correct / Total
    # Regex to capture the "Correct / Total" part at the end of the line
    # e.g. "| 551      / 597"
    pattern = re.compile(r'\|\s+(\d+)\s+/\s+(\d+)\s*$')

    for line in lines:
        match = pattern.search(line)
        if match:
            correct = int(match.group(1))
            total = int(match.group(2))
            total_correct += correct
            total_count += total

    if total_count > 0:
        overall_accuracy = (total_correct / total_count) * 100
        print("=" * 40)
        print(f"TOTAL AGGREGATE ACCURACY")
        print("=" * 40)
        print(f"Total Fields:  {total_count}")
        print(f"Total Correct: {total_correct}")
        print(f"Accuracy:      {overall_accuracy:.2f}%")
        print("=" * 40)
    else:
        print("No accuracy data found in the log file.")

except FileNotFoundError:
    print(f"Error: File '{log_file_path}' not found. Please ensure the evaluation step completed.")

In [None]:
!pip install -U trl peft

In [None]:
# --- VERSION CHECK & ROBUST IMPORT ---
try:
    from transformers import AutoModelForVision2Seq
except ImportError:
    print("⚠️ AutoModelForVision2Seq not found. Falling back to AutoModelForImageTextToText...")
    from transformers import AutoModelForImageTextToText as AutoModelForVision2Seq

import os
import json
import torch
from PIL import Image
from transformers import AutoProcessor
from qwen_vl_utils import process_vision_info
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# --- Configuration ---
MODEL_ID = "numind/NuExtract-2.0-4B"
DATA_DIR = "./labels"
IMAGE_DIR = "./images"
OUTPUT_DIR = "./nuextract_id_finetune"

# --- 1. Load Model & Processor ---
print("Loading Model...")
model = AutoModelForVision2Seq.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="auto",
    # use_cache=False,  <-- REMOVED: Caused the TypeError
    quantization_config={"load_in_4bit": True, "bnb_4bit_compute_dtype": torch.bfloat16}
)

# Disable cache in config instead
model.config.use_cache = False

processor = AutoProcessor.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    padding_side='right',
    use_fast=True,
)

# Apply LoRA Adapter
peft_config = LoraConfig(
    r=16, lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
    bias="none"
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# --- 2. Message Formatting ---
def construct_messages(image_path, template, label_str):
    image = Image.open(image_path).convert("RGB")

    # NuExtract 2.0 specific placeholder format
    image_placeholder = "<|vision_start|><|image_pad|><|vision_end|>"
    text_content = f"# Template:\n{template}\n# Context:\n{image_placeholder}"

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": text_content},
                {"type": "image", "image": image}
            ],
        },
        {
            "role": "assistant",
            "content": [{"type": "text", "text": label_str}],
        }
    ]
    return messages

# --- 3. Prepare Dataset ---
print("Scanning Data...")
id_card_template = json.dumps({
    "first_name": "verbatim-string", "last_name": "verbatim-string",
    "address": "verbatim-string", "DOB": "verbatim-string",
    "SEX": "verbatim-string", "CLASS": "verbatim-string",
    "ISS": "verbatim-string", "EXP": "verbatim-string",
    "HGT": "verbatim-string", "WGT": "verbatim-string",
    "EYES": "verbatim-string", "DLN": "verbatim-string"
}, indent=None)

formatted_dataset = []
files = [f for f in os.listdir(DATA_DIR) if f.endswith(".json")]

for json_file in files:
    base_name = os.path.splitext(json_file)[0]
    json_path = os.path.join(DATA_DIR, json_file)
    # Find Image
    for ext in [".png", ".jpg", ".jpeg"]:
        possible = os.path.join(IMAGE_DIR, base_name + ext)
        if os.path.exists(possible):
            try:
                with open(json_path, 'r') as f:
                    target_str = json.dumps(json.load(f), indent=None)
                msg = construct_messages(os.path.abspath(possible), id_card_template, target_str)
                formatted_dataset.append(msg)
            except: pass
            break

print(f"Loaded {len(formatted_dataset)} valid samples.")

# --- 4. Collator ---
def collate_fn(examples):
    user_texts = [processor.apply_chat_template(x[:1], tokenize=False) for x in examples]
    full_texts = [processor.apply_chat_template(x, tokenize=False) for x in examples]
    image_inputs = process_vision_info(examples)[0]

    # Batch Tokenization
    user_batch = processor(text=user_texts, images=image_inputs, return_tensors="pt", padding=True)
    full_batch = processor(text=full_texts, images=image_inputs, return_tensors="pt", padding=True)

    # Create Labels with Masking
    labels = full_batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100
    for i in range(len(examples)):
        user_len = user_batch["attention_mask"][i].sum().item()
        labels[i, :user_len - 1] = -100 # Mask user prompt

    full_batch["labels"] = labels
    return full_batch

# --- 5. Start Training ---
print("Starting Training...")
model.gradient_checkpointing_enable()

training_args = SFTConfig(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=2, # L4 Optimized
    gradient_accumulation_steps=4,
    learning_rate=1e-5,
    lr_scheduler_type="constant",
    logging_steps=10,
    save_strategy="epoch",
    bf16=True,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataset_kwargs={"skip_prepare_dataset": True},
    report_to="none"
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=formatted_dataset,
    processing_class=processor.tokenizer,
)

trainer.train()
trainer.save_model(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)
print("Training Complete. Model saved.")

In [None]:
import torch
import gc

# 1. Delete the specific objects holding memory
# (Add any other variable names you defined in your previous run)
try:
    del model
    del trainer
    del inputs
    del processor
except NameError:
    pass # Variables might not exist, which is fine

# 2. Force Python's Garbage Collector to release memory
gc.collect()

# 3. Clear PyTorch's internal cache
torch.cuda.empty_cache()

# 4. Verify memory is cleared
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"Reserved:  {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

In [None]:
import sys
import os
import json
import re
import torch
import warnings
from tqdm import tqdm
from PIL import Image
from dateutil import parser

# --- 1. Imports (Robust) ---
try:
    from transformers import AutoModelForVision2Seq
except ImportError:
    from transformers import AutoModelForImageTextToText as AutoModelForVision2Seq

from transformers import AutoProcessor
from peft import PeftModel
from qwen_vl_utils import process_vision_info

# --- 2. Configuration ---
# Only need to reload paths if variables were lost, but safe to redefine
BASE_MODEL_ID = "numind/NuExtract-2.0-4B"
ADAPTER_PATH = "./nuextract_id_finetune"
HO_IMG_DIR = "ho_img"
HO_LABEL_DIR = "ho_label"
BATCH_SIZE = 16  # FAST BATCH PROCESSING

# --- 3. Load Model (Only if not already loaded) ---
if 'model' not in globals():
    print("Loading Model & Adapter...")
    model = AutoModelForVision2Seq.from_pretrained(
        BASE_MODEL_ID,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        attn_implementation="flash_attention_2",
        device_map="auto",
        quantization_config={"load_in_4bit": True, "bnb_4bit_compute_dtype": torch.bfloat16}
    )
    model = PeftModel.from_pretrained(model, ADAPTER_PATH)
    model.eval()
    processor = AutoProcessor.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
else:
    print("Model already loaded. Skipping reload.")

print(f"Using Batch Size: {BATCH_SIZE}")

# --- 4. Batched Inference Function ---
def batch_predict(image_paths, model, processor, template):
    batch_messages = []
    image_placeholder = "<|vision_start|><|image_pad|><|vision_end|>"
    prompt_text = f"# Template:\n{template}\n# Context:\n{image_placeholder}"

    # Load batch of images
    loaded_images = [Image.open(p).convert("RGB") for p in image_paths]

    for img in loaded_images:
        msg = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt_text},
                    {"type": "image", "image": img},
                ],
            }
        ]
        batch_messages.append(msg)

    # Process inputs
    texts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
    image_inputs, video_inputs = process_vision_info(batch_messages)

    inputs = processor(
        text=texts,
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    ).to(model.device)

    # Generate
    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_new_tokens=1024)

    # Decode
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_texts = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    return output_texts

# --- 5. Helpers ---
def normalize_text(text):
    if not text: return ""
    text = re.sub(r'[^\w\s]', '', text)
    return text.lower().strip()

def normalize_date(text):
    if not text: return ""
    try:
        dt = parser.parse(str(text), fuzzy=True)
        return dt.strftime("%Y-%m-%d")
    except:
        return normalize_text(str(text))

# --- 6. Main Evaluation Loop (Detailed Logging) ---
template = json.dumps({
    "first_name": "verbatim-string", "last_name": "verbatim-string",
    "address": "verbatim-string", "DOB": "verbatim-string",
    "SEX": "verbatim-string", "CLASS": "verbatim-string",
    "ISS": "verbatim-string", "EXP": "verbatim-string",
    "HGT": "verbatim-string", "WGT": "verbatim-string",
    "EYES": "verbatim-string", "DLN": "verbatim-string"
})

template_keys = list(json.loads(template).keys())
field_stats = {key: {'correct': 0, 'total': 0} for key in template_keys}
image_files = [f for f in os.listdir(HO_IMG_DIR) if f.endswith(('.png', '.jpg', '.jpeg'))]

# Create output log
log_filename = "finetuned_eval_log_detailed.txt"
with open(log_filename, "w") as log_file:
    log_file.write(f"EVALUATION DETAILED LOG (Batch Size {BATCH_SIZE})\n" + "="*60 + "\n\n")

# Process ALL images
subset = image_files
batches = [subset[i:i + BATCH_SIZE] for i in range(0, len(subset), BATCH_SIZE)]

print(f"Starting evaluation on {len(subset)} images ({len(batches)} batches)...")

for batch_files in tqdm(batches):
    batch_img_paths = []
    batch_gt_data = []
    valid_files = []

    # Pre-load labels to ensure we only process valid pairs
    for img_filename in batch_files:
        label_filename = os.path.splitext(img_filename)[0] + ".json"
        label_path = os.path.join(HO_LABEL_DIR, label_filename)

        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                batch_gt_data.append(json.load(f))
            batch_img_paths.append(os.path.join(HO_IMG_DIR, img_filename))
            valid_files.append(img_filename)

    if not batch_img_paths:
        continue

    # Run Batch
    try:
        batch_predictions = batch_predict(batch_img_paths, model, processor, template)
    except Exception as e:
        print(f"Batch Failed: {e}")
        continue

    # Process & Log Results
    for i, pred_str in enumerate(batch_predictions):
        img_name = valid_files[i]
        gt_data = batch_gt_data[i]

        try:
            pred_data = json.loads(pred_str)
        except:
            pred_data = {}

        # --- LOGGING TO MATCH YOUR PREFERRED FORMAT ---
        with open(log_filename, "a") as log_file:
            log_file.write(f"--- Image: {img_name} ---\n")
            log_file.write("PREDICTED:\n")
            log_file.write(json.dumps(pred_data, indent=2) + "\n")
            log_file.write("GROUND TRUTH:\n")
            log_file.write(json.dumps(gt_data, indent=2) + "\n")

            mismatches = []
            for key in template_keys:
                gt_val = str(gt_data.get(key, ""))
                pred_val = str(pred_data.get(key, ""))

                # Smart Comparison
                if key in ["DOB", "ISS", "EXP"]:
                    match = normalize_date(gt_val) == normalize_date(pred_val)
                else:
                    match = normalize_text(gt_val) == normalize_text(pred_val)

                if match:
                    field_stats[key]['correct'] += 1
                else:
                    mismatches.append(f"{key}: '{gt_val}' != '{pred_val}'")
                field_stats[key]['total'] += 1

            if mismatches:
                log_file.write("MISMATCHES:\n" + "\n".join(mismatches) + "\n")
            else:
                log_file.write("RESULT: Perfect Match\n")

            log_file.write("\n" + "-"*40 + "\n\n")

# Final Report
print("\n" + "="*50)
print(f"EVALUATION RESULTS (Total: {len(subset)} images)")
print("="*50)
print(f"{ 'FIELD':<20} | { 'ACCURACY':<10} | { 'CORRECT':<8} / { 'TOTAL'}")
print("-" * 50)

with open(log_filename, "a") as log_file:
    log_file.write("\n" + "="*50 + "\nEVALUATION RESULTS\n" + "="*50 + "\n")
    log_file.write(f"{ 'FIELD':<20} | { 'ACCURACY':<10} | { 'CORRECT':<8} / { 'TOTAL'}\n")
    log_file.write("-" * 50 + "\n")

    for key in template_keys:
        stats = field_stats[key]
        acc = (stats['correct'] / stats['total']) * 100 if stats['total'] > 0 else 0.0

        # Build the line
        line = f"{key:<20} | {acc:>9.2f}% | {stats['correct']:<8} / {stats['total']}"

        print(line)
        log_file.write(line + "\n")

    log_file.write("-" * 50 + "\n")

print(f"\nDetailed log saved to {log_filename}")

In [None]:
import shutil
from google.colab import files

# 1. Zip the folder
# This creates 'nuextract_id_finetune.zip' from your folder
print("Zipping model folder... (This may take a minute)")
shutil.make_archive("/content/nuextract_id_finetune", 'zip', "/content/nuextract_id_finetune")
print("Zip created!")

# 2. Trigger Download
print("Downloading...")
files.download("/content/nuextract_id_finetune.zip")

In [None]:
import re

log_file_path = '/content/finetuned_eval_log_detailed.txt'

total_correct = 0
total_count = 0

try:
    with open(log_file_path, 'r') as f:
        lines = f.readlines()

    print("Parsing evaluation log for total accuracy...")

    # We look for lines in the summary table that contain the counts
    # Format: Field | Accuracy | Correct / Total
    # Regex to capture the "Correct / Total" part at the end of the line
    # e.g. "| 551      / 597"
    pattern = re.compile(r'\|\s+(\d+)\s+/\s+(\d+)\s*$')

    for line in lines:
        match = pattern.search(line)
        if match:
            correct = int(match.group(1))
            total = int(match.group(2))
            total_correct += correct
            total_count += total

    if total_count > 0:
        overall_accuracy = (total_correct / total_count) * 100
        print("=" * 40)
        print(f"TOTAL AGGREGATE ACCURACY")
        print("=" * 40)
        print(f"Total Fields:  {total_count}")
        print(f"Total Correct: {total_correct}")
        print(f"Accuracy:      {overall_accuracy:.2f}%")
        print("=" * 40)
    else:
        print("No accuracy data found in the log file.")

except FileNotFoundError:
    print(f"Error: File '{log_file_path}' not found. Please ensure the evaluation step completed.")

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
import re

def plot_mismatches(log_file_path, image_dir, max_plots=10):
    """
    Parses the evaluation log and plots images that had extraction errors.
    """
    if not os.path.exists(log_file_path):
        print(f"Error: Log file not found at {log_file_path}")
        return

    with open(log_file_path, 'r') as f:
        content = f.read()

    # Split the log by image entries
    # The regex looks for the separator "--- Image: filename ---"
    entries = re.split(r'--- Image: (.*?) ---\n', content)[1:]

    plot_count = 0

    # Iterate through entries (filename is at index i, details at i+1)
    for i in range(0, len(entries), 2):
        if plot_count >= max_plots:
            break

        filename = entries[i].strip()
        details = entries[i+1]

        # We only care about entries with MISMATCHES
        if "MISMATCHES:" in details:
            # Extract the mismatch section
            mismatch_section = details.split("MISMATCHES:")[1].split("RESULT:")[0].strip()
            # Also extract the prediction for context if needed, or just show the diff
            # Let's clean up the mismatch text for display
            lines = [line.strip() for line in mismatch_section.split('\n') if line.strip()]

            img_path = os.path.join(image_dir, filename)

            if os.path.exists(img_path):
                # Setup Plot
                fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 6), gridspec_kw={'width_ratios': [1, 1]})

                # Show Image
                img = mpimg.imread(img_path)
                ax1.imshow(img)
                ax1.axis('off')
                ax1.set_title(f"Filename: {filename}", fontsize=12, fontweight='bold')

                # Show Mismatches text
                text_str = "MISMATCHED FIELDS:\n" + "-"*30 + "\n"
                for line in lines:
                    # formatting for better readability
                    # e.g., "address: '123 fake st' != '123 FAKE ST'"
                    parts = line.split('!=')
                    if len(parts) == 2:
                        field_part = parts[0].split(':')[0]
                        gt_part = parts[0].split(':')[1].strip().strip("'")
                        pred_part = parts[1].strip().strip("'")

                        text_str += f"FIELD: {field_part}\n"
                        text_str += f"  GT:   {gt_part}\n"
                        text_str += f"  PRED: {pred_part}\n\n"
                    else:
                        text_str += line + "\n"

                # Add text to the second subplot
                ax2.text(0.05, 0.95, text_str, transform=ax2.transAxes, fontsize=11,
                         verticalalignment='top', fontfamily='monospace',
                         bbox=dict(boxstyle='round', facecolor='#ffeeee', alpha=0.5))
                ax2.axis('off')

                plt.tight_layout()
                plt.show()
                plot_count += 1
            else:
                print(f"Warning: Image not found at {img_path}")

# --- RUN IT ---
# Adjust paths if your folders are named differently
LOG_FILE = '/content/finetuned_eval_log_detailed.txt'
IMG_DIR = '/content/ho_img'

print(f"Plotting first 10 failures from {LOG_FILE}...\n")
plot_mismatches(LOG_FILE, IMG_DIR, max_plots=14)