In [None]:
from transformers import AutoProcessor, Idefics3ForConditionalGeneration
from fuzzywuzzy import fuzz
from io import BytesIO
from PIL import Image
from tqdm import tqdm
import base64
import torch
import json
import os
import re

Using TensorFlow backend.


In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Path to the saved model
model_path = "./SmolDocling-256M-preview-30--NoLoc"

# Load model and processor
print("Loading processor...")
processor = AutoProcessor.from_pretrained(model_path)

print("Loading model...")
model = Idefics3ForConditionalGeneration.from_pretrained(model_path).to(DEVICE)
model.eval()
print("Model loaded.")



Loading processor...
Loading model...
Model loaded.


#### Loading Data

In [None]:
def getJsonFiles(DataFolder,Type):
    """   Get all JSON files from the specified folder and its subfolders."""
    Files = []
    subFolder = os.path.join(DataFolder, Type)
    for folder in os.listdir(subFolder):  
        folder_path = os.path.join(subFolder, folder)
        for file in os.listdir(folder_path):
            if file.endswith('.json') and not file.endswith('_processed.json'):
                file_path = os.path.join(folder_path, file)
                Files.append(file_path)
    return Files

def get_image_from_json(json_file_path):
    """ Load an image from a JSON file containing base64 encoded image data."""
    with open(json_file_path, 'r', encoding="utf-8") as f:
        data = json.load(f)
    base64_image = data["pages"][0]["image"]["content"]
    image_data = base64.b64decode(base64_image)
    image = Image.open(BytesIO(image_data))
    return image.convert("RGB")

'./Data/test/24c6b7320051721a/1435963b-528c-4d02-9d7e-2f75da33d9d2.json'

In [None]:
Files = getJsonFiles("./Data",'test')
json_path = Files[8]

#### Formating functions

In [None]:

def doctag2Json(doctag, labels):
    """ Convert a doctag string to a JSON object with specified labels."""
    try : 
        result = {}
        entities = doctag.strip().split("</text>")
        entities = [e.replace("<text>", "").strip() for e in entities if e.strip()]
        
        cleaned_entities = []
        for e in entities:
            cleaned_entities.append(e.split(":", 1)[1].strip())

        
        for index, label in enumerate(labels):
            result[label] = cleaned_entities[index]
    except:
        print("error")
    
    return result

def extract_doctags(model_output):
    # Find all <text>...</text> blocks
    doctags = re.findall(r'<text>.*?</text>', model_output, re.DOTALL)
    return "\n".join(doctags)

labels = [
    "Adresse-prescripteur",
    "Date-de-la-prescription",
    "Nom-du-medecin",
    "Numero-ADELI",
    "Numero-AM-Finess",
    "Numero-RPPS",
    "Signature",
    "Texte-manuscrit",
    "Texte-Signature",
    "Texte-soin-ALD",
    "Texte-soin-sans-ALD",
]

In [None]:
image = get_image_from_json(json_path)

In [None]:

def benchmark_json(true_json, pred_json):
    """ Compare two JSON objects and return hard and fuzzy match scores for each label."""
    results = {}
    
    for label in true_json:
        true_value = (true_json.get(label) or "").strip()
        pred_value = (pred_json.get(label) or "").strip()

        if pred_value == "":
            pred_value = "None"
        if true_value == "":
            true_value = "None"
        
        hard_match = int(true_value == pred_value)

        fuzzy_match = fuzz.ratio(true_value, pred_value) / 100.0  
        
        results[label] = {
            "hard_match": hard_match,
            "fuzzy_match": round(fuzzy_match, 4)  
        }

    return results
    
def merge_benchmarks(all_step_jsons):
    """ Merge multiple JSON benchmark results and calculate average scores for each label."""
    from collections import defaultdict

    merged = defaultdict(lambda: {"hard_match": [], "fuzzy_match": []})

    for step_json in all_step_jsons:
        for label, scores in step_json.items():
            merged[label]["hard_match"].append(scores.get("hard_match", 0))
            merged[label]["fuzzy_match"].append(scores.get("fuzzy_match", 0.0))

    averaged = {}
    for label, scores in merged.items():
        avg_hard = round(sum(scores["hard_match"]) / len(scores["hard_match"]), 4)
        avg_fuzzy = round(sum(scores["fuzzy_match"]) / len(scores["fuzzy_match"]), 4)
        averaged[label] = {
            "hard_match": avg_hard,
            "fuzzy_match": avg_fuzzy
        }

    return averaged



In [None]:
def evaluate_model(model):
    """ Evaluate the model on the test dataset and return averaged benchmark results."""

    benchmark_jsons = []
    for json_path in Files:
        image = get_image_from_json(json_path)

        # Construct prompt
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": "Convert this page to docling."},
                    {"type": "image"}
                ]
            }
        ]

        chat = processor.apply_chat_template(messages, add_generation_prompt=True)
        inputs = processor(text=chat, images=[[image]], return_tensors="pt", padding=True).to(DEVICE)

        # Generate response
        with torch.no_grad():
            output_ids = model.generate(**inputs, max_new_tokens=512)

        # Decode result
        output_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]

        doctags = extract_doctags(output_text)
        predicted_json = doctag2Json(doctags, labels)

        json_processed_path = json_path.replace(".json", "_processed.json")

        try:
            with open(json_processed_path, 'r', encoding="utf-8") as file:
                true_json = json.load(file)
                step_benchmark_json = benchmark_json(true_json,predicted_json)
        except Exception as e:
            print(f"Error loading JSON from {json_path}: {e}")
        
        benchmark_jsons.append(step_benchmark_json)

    
    averaged_json = merge_benchmarks(benchmark_jsons)

    return averaged_json

average_json = evaluate_model(model)

error
error
error
error
error
error
error
error
error
error
