<a href="https://colab.research.google.com/github/JAshinflame/AI-Agents/blob/main/wir_extraction_donut_llama.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# 📦 Install dependencies
!pip install transformers torch torchvision torchaudio pillow requests --quiet

In [None]:
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import json, requests
from pathlib import Path

# === Step 1: Donut Extraction ===
processor = DonutProcessor.from_pretrained('naver-clova-ix/donut-base-finetuned-docvqa')
model = VisionEncoderDecoderModel.from_pretrained('naver-clova-ix/donut-base-finetuned-docvqa')

wir_image_path = Path('/mnt/data/Test 10 WIR.png')
assert wir_image_path.exists(), f"WIR image not found: {wir_image_path}"
print('✅ Image found:', wir_image_path)

image = Image.open(wir_image_path).convert('RGB')
pixel_values = processor(image, return_tensors='pt').pixel_values

task_prompt = '<s_docvqa><s_question>Extract all key details from this WIR form.</s_question><s_answer>'
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors='pt').input_ids

outputs = model.generate(
    pixel_values,
    decoder_input_ids=decoder_input_ids,
    max_length=512,
    num_beams=3,
    repetition_penalty=2.5,
    no_repeat_ngram_size=3,
    temperature=0.7
)

raw_text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
print('--- Raw Donut Output (first 1000 chars) ---\n', raw_text[:1000])

In [None]:
# === Step 2: Clean & Structure with Llama/Mistral via Ollama ===
import subprocess

# Choose model: 'llama3' or 'mistral'
OLLAMA_MODEL = 'llama3'

prompt = f'''
You are an AI document parser. Clean and extract structured information from the raw text below.
Identify and normalize exactly these fields:
- Project
- WIR No
- Date
- Activity
- Contractor
- Remarks

Output **only valid JSON** with these keys and their values.

--- Raw Text ---
{raw_text}
'''

# Run Ollama model locally
result = subprocess.run(
    ['ollama', 'run', OLLAMA_MODEL],
    input=prompt.encode('utf-8'),
    stdout=subprocess.PIPE,
    stderr=subprocess.PIPE
)

llm_output = result.stdout.decode('utf-8').strip()
print('--- Ollama Output ---\n', llm_output)

# Attempt to parse as JSON
try:
    structured = json.loads(llm_output)
except Exception:
    structured = {'raw_response': llm_output}

output_path = Path('/mnt/data/wir_extracted_llama_v9.json')
with open(output_path, 'w', encoding='utf-8') as f:
    json.dump(structured, f, indent=2)

print('\n✅ Saved structured WIR data to:', output_path)