SETUP

Document loading

In [ ]:
# LLaVA-Llama3 Vision Language Model Setup and Evaluation
import os
import json
import base64
import csv
import re
from io import BytesIO
from PIL import Image
from langchain_ollama import OllamaLLM
import numpy as np
import string

# Configuration
os.environ["OLLAMA_NUM_GPU_LAYERS"] = "40"
data_dir = "docvqa_samples_300"
image_dir = os.path.join(data_dir, "images")
metadata_file = os.path.join(data_dir, "metadata.json")
output_csv = "vlm_results.csv"

# Load metadata
with open(metadata_file, "r", encoding="utf-8") as f:
    metadata = json.load(f)

# Evaluation functions
def preprocess_answer(text):
    if not text or not isinstance(text, str):
        return ""
    text = text.lower().strip()
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'[^\w\s]', '', text)
    text = ' '.join(text.split())
    return text

def exact_match(pred, ground_truths):
    pred_processed = preprocess_answer(pred)
    return any(pred_processed == preprocess_answer(gt) for gt in ground_truths)

def f1(pred, ground_truths):
    def score(pred, gt):
        pred_tokens = preprocess_answer(pred).split()
        gt_tokens = preprocess_answer(gt).split()
        common = set(pred_tokens) & set(gt_tokens)
        if not common:
            return 0.0
        precision = len(common) / len(pred_tokens) if pred_tokens else 0.0
        recall = len(common) / len(gt_tokens) if gt_tokens else 0.0
        return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    return max(score(pred, gt) for gt in ground_truths)

def pil_to_base64(pil_img):
    buffered = BytesIO()
    pil_img.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode("utf-8")

# LLM Setup
try:
    llm = OllamaLLM(model="llava-llama3")
    test_response = llm.invoke("Hello")
    print("✅ Ollama connection successful with llava-llama3")
except Exception as e:
    print(f"❌ Error connecting to Ollama: {e}")

# Main processing pipeline
with open(output_csv, "w", newline="", encoding="utf-8") as csvfile:
    fieldnames = ["id", "image_filename", "question", "ground_truth", "predicted_answer", "exact_match", "f1_score"]
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    writer.writeheader()

    em_scores = []
    f1_scores = []

    for i, sample in enumerate(metadata):
        print(f"Processing sample {i+1}/{len(metadata)}")
        
        idx = sample["id"]
        img_path = os.path.join(image_dir, sample["image_filename"])
        
        if not os.path.exists(img_path):
            print(f"⚠️ Image not found: {img_path}")
            continue
            
        image = Image.open(img_path)
        question = sample["question"]
        ground_truths = sample["answers"]

        # Convert image to base64
        image_b64 = pil_to_base64(image)

        # Create prompt for vision model
        prompt = f"Question: {question}\n\nAnswer this question based on what you see in the image. Provide only the specific answer requested, be concise."

        try:
            response = llm.invoke(prompt, images=[image_b64])
            pred_answer = str(response).strip()
        except Exception as e:
            pred_answer = ""
            print(f"⚠️ Error processing {sample['image_filename']}: {e}")

        em = exact_match(pred_answer, ground_truths)
        f1_val = f1(pred_answer, ground_truths)

        em_scores.append(int(em))
        f1_scores.append(f1_val)

        writer.writerow({
            "id": idx,
            "image_filename": sample["image_filename"],
            "question": question,
            "ground_truth": " | ".join(ground_truths),
            "predicted_answer": pred_answer,
            "exact_match": em,
            "f1_score": round(f1_val, 2)
        })

# Summary
if em_scores and f1_scores:
    print(f"Evaluation Summary on {len(metadata)} samples:")
    print(f"Avg Exact Match: {np.mean(em_scores)*100:.2f}%")
    print(f"Avg F1 Score: {np.mean(f1_scores)*100:.2f}%")
    print(f"Results saved to: {output_csv}")
else:
    print("❌ No samples were processed successfully")