In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install --upgrade openai 
!pip install jiwer



In [None]:
%cd /content/drive/MyDrive

/content/drive/MyDrive


In [None]:
#same as above but with CER

import os
import json
import base64
import re
from tqdm import tqdm
from openai import OpenAI
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import jiwer

# âœ… Initialize OpenAI client (new API version)
client = OpenAI(api_key="your_api_key")

# === CONFIGURATION ===
image_folder = "/content/drive/MyDrive/Printed_Split_set/test/Images"
gt_txt_file = "/content/drive/MyDrive/Printed_Split_set/test/Latex.txt"
output_txt = "./LATEST_PRINTED_gpt4o_detailed_output.txt"
max_samples = None  # Set to integer to limit, or None for all
# =====================

# === Tokenizer for LaTeX math
def tokenize_latex(expr):
    return re.findall(r'(\\[a-zA-Z]+|[{}_^=+\-*/(),]|[a-zA-Z]+|\d+)', expr)

# === Tokenizer Transform for CER
class TokenizeTransform(jiwer.transforms.AbstractTransform):
    def process_string(self, s: str):
        return tokenize_latex(s)
    def process_list(self, tokens: list[str]):
        return [self.process_string(token) for token in tokens]

# === CER Calculation
def compute_cer(truth_and_output: list[tuple[str, str]]):
    ground_truth, model_output = zip(*truth_and_output)
    return jiwer.cer(
        truth=list(ground_truth),
        hypothesis=list(model_output),
        reference_transform=TokenizeTransform(),
        hypothesis_transform=TokenizeTransform()
    )

# === BLEU & Edit Distance per sample
def compute_individual_metrics(pred, gt):
    smoothie = SmoothingFunction().method4
    bleu = sentence_bleu([tokenize_latex(gt)], tokenize_latex(pred), smoothing_function=smoothie)
    return bleu

# === Encode Image to Base64
def encode_image(image_path):
    with open(image_path, "rb") as f:
        return f.read()

# === GPT-4o Inference
def get_latex_from_gpt(image_path):
    image_data = encode_image(image_path)
    try:
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "Extract the optimization problem in this image and return only the LaTeX code. I want clean latex in a single line without any begin and end tags, without dollar sign. Example: \\text{max} \\quad & 6x_1 + \\log(x_2 + 1) \\\\ \\text{st} \\quad & e^{x_1} + 2x_2 + w_1 = 17 \\\\ & x_1, x_2, w_1, w_2, w_3 \\geq 0."},
                        {"type": "image_url", "image_url": {
                            "url": "data:image/jpeg;base64," + base64.b64encode(image_data).decode("utf-8")
                        }}
                    ]
                }
            ],
            max_tokens=1000
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        print(f"Error for {image_path}: {e}")
        return ""

# === Load Files
image_files = sorted([
    f for f in os.listdir(image_folder)
    if f.lower().endswith((".png", ".jpg", ".jpeg"))
], key=lambda x: int(os.path.splitext(x)[0]))

with open(gt_txt_file, "r", encoding="utf-8") as f:
    gt_lines = [line.strip() for line in f.readlines()]

assert len(image_files) == len(gt_lines), "Mismatch between images and ground truth"

if max_samples:
    image_files = image_files[:max_samples]
    gt_lines = gt_lines[:max_samples]

# === Run GPT-4o + Evaluation
bleu_scores = []
truth_and_preds = []

with open(output_txt, "w", encoding="utf-8") as fout:
    for i, fname in enumerate(tqdm(image_files)):
        img_path = os.path.join(image_folder, fname)
        gt = gt_lines[i]
        pred = get_latex_from_gpt(img_path)

        bleu = compute_individual_metrics(pred, gt)
        bleu_scores.append(bleu)
        truth_and_preds.append((gt, pred))

        fout.write(f"image: {fname}\n")
        fout.write(f"gt    : {gt}\n")
        fout.write(f"prediction: {pred}\n")
        fout.write(f"bleu  : {bleu:.4f}\n")
        fout.write("-" * 60 + "\n")

# === Final Metrics
avg_bleu = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0.0
cer_score = compute_cer(truth_and_preds)

with open(output_txt, "a", encoding="utf-8") as fout:
    fout.write("\n=== Summary ===\n")
    fout.write(f"Total Samples     : {len(bleu_scores)}\n")
    fout.write(f"Average BLEU      : {avg_bleu:.4f}\n")
    fout.write(f"Average CER       : {cer_score:.4f}\n")

print("\n=== Evaluation Complete ===")
print(f"Samples evaluated : {len(bleu_scores)}")
print(f"Average BLEU      : {avg_bleu:.4f}")
print(f"Average CER       : {cer_score:.4f}")
print(f"Output saved to   : {output_txt}")