In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install google-generativeai nltk
!pip install jiwer

In [None]:
import os
import json
import re
import google.generativeai as genai
from PIL import Image
from tqdm import tqdm
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import jiwer

In [None]:
genai.configure(api_key="your_api_key")
model = genai.GenerativeModel("gemini-2.0-flash")

In [None]:
def tokenize_latex(expr):
    return re.findall(r'(\\[a-zA-Z]+|[{}_^=+\-*/(),]|[a-zA-Z]+|\d+)', expr)

def compute_metrics(preds, gts):
    assert len(preds) == len(gts)
    smoothie = SmoothingFunction().method4
    edit_dists, bleus, codebleus = [], [], []

    for pred, gt in zip(preds, gts):
        bleu = sentence_bleu([tokenize_latex(gt)], tokenize_latex(pred), smoothing_function=smoothie)
        codebleu = sentence_bleu([list(gt)], list(pred), smoothing_function=smoothie)
        bleus.append(bleu)
        codebleus.append(codebleu)

    return {
        "bleu": sum(bleus) / len(bleus),
        "codebleu": sum(codebleus) / len(codebleus)
    }


In [None]:
import os
import re
import time
from PIL import Image
from tqdm import tqdm
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import jiwer

# === CONFIGURATION ===
image_folder = "/content/drive/MyDrive/Printed_Split_set/test/Images"
gt_txt_file = "/content/drive/MyDrive/Printed_Split_set/test/Latex.txt"
output_txt_file = "./LATEST_PRINTED_gemini_output.txt"
max_samples = None  # Set to integer if needed

# === LaTeX Tokenizer
def tokenize_latex(expr):
    return re.findall(r'(\\[a-zA-Z]+|[{}_^=+\-*/(),]|[a-zA-Z]+|\d+)', expr)

# === CER Tokenizer Transform
class TokenizeTransform(jiwer.transforms.AbstractTransform):
    def process_string(self, s: str):
        return tokenize_latex(s)
    def process_list(self, tokens: list[str]):
        return [self.process_string(token) for token in tokens]

def compute_cer(truth_and_output: list[tuple[str, str]]) -> float:
    ground_truth = []
    model_output = []

    for i, (gt, pred) in enumerate(truth_and_output):
        try:
            # Coerce to string if not already
            gt = str(gt).strip().replace("\n", " ") if gt else ""
            pred = str(pred).strip().replace("\n", " ") if pred else ""

            if gt == "" and pred == "":
                continue  # skip blank pairs

            ground_truth.append(gt)
            model_output.append(pred)

        except Exception as e:
            print(f"[CER Skipped] Sample {i} due to error: {e}")

    if not ground_truth or not model_output:
        print("[CER Warning] No valid pairs found. Returning CER = 1.0")
        return 1.0

    try:
        return jiwer.cer(
            truth=ground_truth,
            hypothesis=model_output,
            reference_transform=TokenizeTransform(),
            hypothesis_transform=TokenizeTransform()
        )
    except Exception as e:
        print(f"[CER Failure] jiwer.cer failed with: {e}")
        return 1.0


# === Metric Computation
def compute_metrics(preds: list[str], gts: list[str]):
    assert len(preds) == len(gts)
    smoothie = SmoothingFunction().method4
    bleus = []
    truth_and_preds = []

    for pred, gt in zip(preds, gts):
        if not isinstance(pred, str):
            pred = str(pred) if pred is not None else ""
        if not isinstance(gt, str):
            gt = str(gt) if gt is not None else ""

        pred = pred.strip().replace("\n", " ")
        gt = gt.strip().replace("\n", " ")

        bleu = sentence_bleu([tokenize_latex(gt)], tokenize_latex(pred), smoothing_function=smoothie)

        bleus.append(bleu)
        truth_and_preds.append((gt, pred))

    cer = compute_cer(truth_and_preds)

    return {
        "bleu": sum(bleus) / len(bleus),
        "cer": cer
    }


# === Retry wrapper for Gemini inference
def get_latex_from_gemini(prompt, image, max_retries=3, retry_delay=2):
    for attempt in range(1, max_retries + 1):
        try:
            response = model.generate_content([prompt, image])
            return response.text.strip()
        except Exception as e:
            print(f"[Attempt {attempt}] Error: {e}")
            if attempt < max_retries:
                print(f"Retrying in {retry_delay} seconds...")
                time.sleep(retry_delay)
            else:
                print("Max retries reached. Skipping.")
                return ""

# === Load files
image_files = sorted([
    f for f in os.listdir(image_folder)
    if f.lower().endswith((".png", ".jpg", ".jpeg"))
], key=lambda x: int(os.path.splitext(x)[0]))

with open(gt_txt_file, "r", encoding="utf-8") as f:
    gt_lines = [line.strip() for line in f.readlines()]

assert len(image_files) == len(gt_lines), "Mismatch between image and label count"

if max_samples:
    image_files = image_files[:max_samples]
    gt_lines = gt_lines[:max_samples]

# === Inference + evaluation
predictions = []
bleu_scores = []
truth_and_preds = []

with open(output_txt_file, "w", encoding="utf-8") as fout:
    for i, fname in enumerate(tqdm(image_files)):
        try:
            img_path = os.path.join(image_folder, fname)
            image = Image.open(img_path).convert("RGB")
            prompt = "Extract the optimization problem in this image and return only the LaTeX code. I want clean latex in a single line without any begin and end tags, without dollar sign. Example: \\text{max} \\quad & 6x_1 + \\log(x_2 + 1) \\\\ \\text{st} \\quad & e^{x_1} + 2x_2 + w_1 = 17 \\\\ & x_1, x_2, w_1, w_2, w_3 \\geq 0."

            pred = get_latex_from_gemini(prompt, image)
            gt = gt_lines[i]

            predictions.append(pred)
            truth_and_preds.append((gt, pred))

            bleu = sentence_bleu([tokenize_latex(gt)], tokenize_latex(pred), smoothing_function=SmoothingFunction().method4)


            bleu_scores.append(bleu)

            fout.write(f"image: {fname}\n")
            fout.write(f"gt    : {gt}\n")
            fout.write(f"prediction: {pred}\n")
            fout.write(f"bleu  : {bleu:.4f}")
            fout.write("-" * 60 + "\n")

        except Exception as e:
            print(f"Unexpected failure on {fname}: {e}")
            predictions.append("")
            truth_and_preds.append((gt_lines[i], ""))

if len(predictions) != len(gt_lines):
    raise ValueError(f"Length mismatch! Predictions: {len(predictions)} vs GT: {len(gt_lines)}")

# === Final Evaluation Summary
metrics = compute_metrics(predictions, gt_lines)

with open(output_txt_file, "a", encoding="utf-8") as fout:
    fout.write("\n=== Summary ===\n")
    fout.write(f"Total Samples    : {len(predictions)}\n")
    fout.write(f"Average BLEU     : {metrics['bleu']:.4f}\n")
    fout.write(f"Average CER      : {metrics['cer']:.4f}\n")

print("\n=== Evaluation Complete ===")
print(f"Samples evaluated : {len(predictions)}")
print(f"Average BLEU      : {metrics['bleu']:.4f}")
print(f"Average CER       : {metrics['cer']:.4f}")
print(f"Output saved to   : {output_txt_file}")