In [None]:
from vllm import LLM
from vllm.sampling_params import SamplingParams
import base64
import os
from PIL import Image
import numpy as np
import json
from doctr.models import ocr_predictor


def encode_image(image_path: str):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")


class Pixtral:
    def __init__(self, max_model_len=4096, max_tokens=2048, gpu_memory_utilization=0.65, temperature=0.35):
        self.model_name = "mistralai/Pixtral-12B-2409"

        self.sampling_params = SamplingParams(max_tokens=max_tokens, temperature=temperature)

        self.llm = LLM(
            model=self.model_name,
            tokenizer_mode="mistral",
            gpu_memory_utilization=gpu_memory_utilization,
            load_format="mistral",
            config_format="mistral",
            max_model_len=max_model_len
        )

    def generate_message_from_image(self, prompt, image_path):
        base64_image = encode_image(image_path)

        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt},
                    {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
                ]
            },
        ]

        outputs = self.llm.chat(messages, sampling_params=self.sampling_params)

        return outputs[0].outputs[0].text

    def generate_message(self, prompt):
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt},
                ]
            },
        ]

        outputs = self.llm.chat(messages, sampling_params=self.sampling_params)

        return outputs[0].outputs[0].text

class PIIPipeline:
    def __init__(self, pixtral_model: Pixtral):
        self.model = pixtral_model
        self.ocr = ocr_predictor(pretrained=True)
        
        
    @staticmethod
    def perform_ocr(ocr_model, image_path):
        image = Image.open(image_path).convert("RGB")
        width, height = image.size
        input_page = np.array(image)
        result = ocr_model([input_page])
        words = []
        boxes = []
        for page in result.pages:
            for block in page.blocks:
                for line in block.lines:
                    for word in line.words:
                        (rel_x0, rel_y0), (rel_x1, rel_y1) = word.geometry
                        abs_x0 = int(rel_x0 * width)
                        abs_y0 = int(rel_y0 * height)
                        abs_x1 = int(rel_x1 * width)
                        abs_y1 = int(rel_y1 * height)
                        words.append(word.value)
                        boxes.append([abs_x0, abs_y0, abs_x1, abs_y1])
        return words, boxes

    def process_image(self, image_path: str):

        words, bboxes = self.perform_ocr(self.ocr, image_path)

        word_data = [
            {"id": idx, "word": word}
            for idx, word in enumerate(words)
        ]
        bboxes_data = {
            idx: bbox for idx, bbox in enumerate(bboxes)
        }

        input_json = json.dumps({"words": word_data}, ensure_ascii=False)

        prompt = self.build_prompt(input_json)

        response = self.model.generate_message_from_image(prompt, image_path)

        try:
            response = json.loads(response)
        except Exception as e:
            return None
        
        result = {
            "tokens": [],
            "bboxes": [],
            "ner_tags": []
        }
        for entity in response["entities"]:
            word_id = entity["id"]
            if word_id in bboxes_data:
                bbox = bboxes_data[word_id]
                result["tokens"].append(entity["word"])
                result["bboxes"].append(bbox)
                result["ner_tags"].append("B-" + entity["label"])
                
        return result

    def build_prompt(self, input_json: str) -> str:
        return f"""
You are a PII information extraction model. You are given:

1. An image of a scanned document (e.g., invoice, form, letter).
2. A JSON object that contains the OCR result for this image. It consists of a list of tokens (words). Each token includes:
   - "id": a unique ID for the token (integer)
   - "word": the OCR-recognized text (may be a single character, digit, or symbol, like "4", "@" or "-")

Your task:

- Analyze the OCR result and the document image.
- You should keep in mind that there may be OCR errors. Consider the context of the document and the meaning of the words.
- Label each token that belongs to a PII (Personally Identifiable Information) entity with one of the following categories:

"full_name" – a person’s full name, including first name, last name, and middle name (if applicable)
"phone_number" – phone number (including country codes, separators, etc.)  
"address" – any part of a physical address  
"email_address" – any part of an email address  
"company name" – name of a company or organization  
"payment_information" – IBAN or credit/debit card numbers

📌 Notes:

- You must only label tokens that belong to a PII entity. Do not return tokens labeled "O".
- Each token should be returned along with its original "id" and assigned "label". Check that all the ids match the original input.
- Some entities (like phone numbers or addresses) are composed of multiple tokens — e.g., a phone number "+49 123 456789" may consist of 3 tokens. Each token must be labeled "phone_number".
- If a token is incorrect due to OCR but its intention is clear, you can still label it. For example, if the OCR result is "Mex Mastermann's" but the correct form is "Max Mustermann's", you should label both tokens as "full_name".

Example Input:
```json
{{
    "words": [
        {{"id": 0, "word": "Max"}},
        {{"id": 1, "word": "Mustermann's"}},
        {{"id": 2, "word": "phone"}},
        {{"id": 3, "word": "number"}},
        {{"id": 4, "word": "is"}},
        {{"id": 5, "word": "+49"}},
        {{"id": 6, "word": "123"}},
        {{"id": 7, "word": "456789"}}
    ]
}}
```

✅ Output format:
```json
{{
  "entities": [
    {{"id": 0, "word": "Max", "label": "full_name"}},
    {{"id": 1, "word": "Mustermann", "label": "full_name"}},
    {{"id": 5, "word": "+49", "label": "phone_number"}},
    {{"id": 6, "word": "123", "label": "phone_number"}},
    {{"id": 7, "word": "456789", "label": "phone_number"}}
  ]
}}
Now analyze the image and the OCR result provided below. Return only the JSON object WITHOUT ANY COMMENTS.

OCR data:
{input_json}
""".strip()

In [None]:
images = os.listdir("data/funsd_benchmark/images")
labels_path = "data/pixtral_funsd_benchmark/layoutlm_labels"
model = Pixtral()
pipeline = PIIPipeline(model)    

In [None]:
for image in images:
    if not images.endswith(".png"):
        continue
    image_path = os.path.join("data/funsd_benchmark/images", image)
    result = pipeline.process_image(image_path)
    if result:
        with open(os.path.join(labels_path, image.replace(".png", ".json")), "w") as f:
            json.dump(result, f, ensure_ascii=False, indent=4)
    else:
        print(f"No result returned for image: {image}")

In [None]:
test_samples = [
        {   
            "test_name": "benchmark",
            "gt_labels": "data/funsd_benchmark/layoutlm_labels",
            "predicted_labels": "data/pixtral_funsd_benchmark/layoutlm_labels",
            "image_views": "data/pixtral_funsd_benchmark/labeled_images",
            "class_names": [
                "full_name", "phone_number", "address", "email_address", "company_name"
            ]
        },
]

In [None]:
from src.w_b import count_and_log_all_metrics

count_and_log_all_metrics(
    samples=test_samples,
    lm_model_name=f"Pixtral",
    ocr_model_name="",
    run_specification="benchmark"
)