In [8]:
!pip install matplotlib
!pip install -q transformers torch accelerate bitsandbytes sentencepiece pillow pandas matplotlib tqdm
!pip install ipywidgets
!pip install jupyterlab-widgets
from PIL import Image
!pip install bitsandbytes


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1

In [9]:
import json
import re
import os # Added os import here for use in save_jsonl

# A minimal dictionary to replace the one from utils.py
sys_prompt = {
    'blip2 style': "Analyze the chart and answer the question: {}",
    'chartqa': "Answer the following question based on the chart: {}"
}

# A minimal base class to replace the one from utils.py
class ChartBenchTester:
    def __init__(self, test_index, sys_prompt_acc, sys_prompt_nqa, **kwargs):
        self.test_index = test_index
        self.system_prompt_acc = sys_prompt_acc
        self.system_prompt_nqa = sys_prompt_nqa
        self.image_root = ''
        for key, value in kwargs.items():
            setattr(self, key, value)

    def reset_image_root(self, image_root):
        self.image_root = image_root

    def load_test_file(self, file_path, mode='r'):
        """Loads a JSON or JSONL file."""
        try:
            with open(file_path, mode, encoding='utf-8') as f:
                if file_path.endswith('.jsonl'):
                    print(f"Loading JSONL file: {file_path}")
                    return [json.loads(line) for line in f]
                elif file_path.endswith('.json'):
                    print(f"Loading JSON file: {file_path}")
                    return json.load(f) # Assumes a standard JSON list
                else:
                    print(f"Error: Unknown file type {file_path}. Must be .json or .jsonl")
                    return []
        except FileNotFoundError:
            print(f"Error: The file {file_path} was not found.")
            return []
        except Exception as e:
            print(f"Error loading {file_path}: {e}")
            return []

    def save_jsonl(self, file_path, data, mode='a+'):
        """Saves data to a JSONL file."""
        # Ensure the directory exists before writing
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
        with open(file_path, mode, encoding='utf-8') as f:
            for item in data:
                f.write(json.dumps(item) + '\n')

In [10]:
import sys, copy
import numpy as np
from tqdm import tqdm
import time
import warnings
from io import BytesIO
from PIL import Image

import torch
from transformers import (
    AutoProcessor,
    Pix2StructForConditionalGeneration,
    AutoModelForVision2Seq,
    BitsAndBytesConfig
)

warnings.filterwarnings("ignore")

In [11]:
# --- CONSTANTS ---

# DePlot model (using the fixed Hugging Face Hub ID to avoid local path errors)
CKPT_PATH = 'google/deplot'

# Qwen model (from your snippet)
QWEN_MODEL_PATH = "Qwen/Qwen2-VL-7B-Instruct"

# Image directory
IMG_ROOT = "/home/g2/Downloads/png"

# --- MODIFIED: Paths for both test files ---
TEST_AUGMENTED_FILE = "/home/g2/Downloads/test_augmented.json"
TEST_HUMAN_FILE = "/home/g2/Downloads/test_human.json"

# --- MODIFIED: Output file paths for each test ---
# Base path for saving results
SAVE_DIR = "/home/g2/Chart Classifier"
SAVE_PATH_AUGMENTED = os.path.join(SAVE_DIR, "qwen_base_only_augmented.jsonl")
SAVE_PATH_HUMAN = os.path.join(SAVE_DIR, "qwen_base_only_human.jsonl")
# --- ADDED: Intermediate DePlot Output Paths ---
# DePlot generates the data table from the image, which is an intermediate step.
DEPLOT_OUTPUT_AUGMENTED = "/home/g2/Downloads/deplot_augmented_output.jsonl"
DEPLOT_OUTPUT_HUMAN = "/home/g2/Downloads/deplot_human_output.jsonl"

# --- ADDED: Final Qwen Output Paths (Aliasing existing SAVE_PATH) ---
# QWEN_OUTPUT_HUMAN is the final save path, which should match the existing constant.
QWEN_OUTPUT_AUGMENTED = SAVE_PATH_AUGMENTED
QWEN_OUTPUT_HUMAN = SAVE_PATH_HUMAN

In [12]:
class QwenChartBenchTester(ChartBenchTester):
    """
    This class now runs in two stages:
    1. run_deplot_stage: Generates data tables from images.
    2. run_qwen_stage: Runs ONLY the Base model (CoT) evaluation.
    """

    # --- Part 1: Model Loading & Inference ---

    def load_model(self):
        """Loads both DePlot and Qwen models."""
        print("Loading DePlot model...")
        self.deplot_processor = AutoProcessor.from_pretrained(CKPT_PATH)
        self.deplot_model = Pix2StructForConditionalGeneration.from_pretrained(CKPT_PATH).to("cuda")
        print("DePlot model loaded.")

        print("Loading Qwen-VL model...")
        self.qwen_processor = AutoProcessor.from_pretrained(QWEN_MODEL_PATH, trust_remote_code=True)
        self.qwen_model = AutoModelForVision2Seq.from_pretrained(
            QWEN_MODEL_PATH,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            device_map="auto",
            trust_remote_code=True
        )
        print("Qwen-VL model loaded.")

    def get_model_response(self, image: Image.Image, prompt_text: str) -> str:
        """Helper function to run inference with Qwen-VL."""
        messages = [
            {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
        ]
        text = self.qwen_processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        inputs = self.qwen_processor(
            text=text, images=[image], return_tensors="pt"
        )
        inputs = inputs.to(self.qwen_model.device, torch.float16)

        with torch.inference_mode():
            output_ids = self.qwen_model.generate(**inputs, max_new_tokens=512, do_sample=False)

        response_ids = output_ids[:, inputs.input_ids.shape[1]:]
        response = self.qwen_processor.decode(response_ids[0], skip_special_tokens=True).strip()
        return response

    # --- "BASE MODEL" (CoT) ---
    def run_cot(self, image: Image.Image, question: str) -> str:
        """Runs the base Chain-of-Thought (CoT) model."""
        prompt = f"""You are an expert chart analyst.
Your task is to answer a question about the provided chart.
Analyze the chart step-by-step to arrive at the correct answer.
Provide your reasoning first, then end with: "The final answer is: [Your Answer]".
Question: {question}
Let's think step by step:
"""
        return self.get_model_response(image, prompt)

    # --- GoT Methods (REMOVED) ---
    # _generate_thoughts, _evaluate_thoughts, _aggregate_thoughts, and run_got are omitted.

    def extract_final_answer(self, reasoning_text: str) -> str:
        """Extracts the final answer from the reasoning text."""
        match = re.search(r"The final answer is:\s*(.*)", reasoning_text, re.IGNORECASE | re.DOTALL)
        if match:
            return match.group(1).strip()
        else:
            # Fallback: return the last non-empty line
            lines = [line.strip() for line in reasoning_text.split('\n') if line.strip()]
            return lines[-1] if lines else ""

    def calculate_relaxed_accuracy(self, predicted: str, ground_truth: str) -> bool:
        """Calculates relaxed accuracy (substring match)."""
        pred_norm = re.sub(r"[^\w\s\.]", "", str(predicted)).lower().strip()
        gt_norm = re.sub(r"[^\w\s\.]", "", str(ground_truth)).lower().strip()
        if not pred_norm or not gt_norm:
            return False
        return (pred_norm in gt_norm) or (gt_norm in pred_norm)

    # --- Part 2: Evaluation Stages ---

    # ### --- STAGE 1: DePLOT EXECUTION --- ###
    def run_deplot_stage(self, test_file_path, output_deplot_path):
        """
        Runs ONLY the DePlot stage to generate data tables from images.
        Saves results to an intermediate JSONL file.
        """
        print(f"--- Starting DePlot Stage for: {test_file_path} ---")
        test_data = self.load_test_file(test_file_path)
        if not test_data:
            return

        ckpt_index = 0
        if os.path.exists(output_deplot_path):
            try:
                ckpt_index = len(self.load_test_file(output_deplot_path, mode='r'))
                if ckpt_index > 0:
                    print(f"DePlot progress file found. Resuming from sample {ckpt_index}...")
            except Exception as e:
                print(f"Could not read DePlot checkpoint, starting from 0. Error: {e}")
        else:
            print("No DePlot progress file found. Starting from sample 0.")

        deplot_prompt = "Generate the underlying data table for this chart:"
        deplot_task_prefix = "<GRAPH_TO_TEXT>"

        for i in tqdm(range(ckpt_index, len(test_data)), desc="Stage 1: DePlot"):
            sample = test_data[i]
            im_path = os.path.join(self.image_root, sample['imgname'])
            result_entry = copy.deepcopy(sample)

            try:
                image = Image.open(im_path).convert('RGB')
                inputs = self.deplot_processor(images=image, text=deplot_prompt, return_tensors="pt").to("cuda")

                with torch.inference_mode():
                    generated_ids = self.deplot_model.generate(**inputs, max_new_tokens=1024)

                generated_text = self.deplot_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
                result_entry['deplot_table'] = generated_text.replace(deplot_task_prefix, "").strip()

            except (KeyError, TypeError) as e:
                print(f"Skipping item {i}. Reason: Key error '{e}'. Check JSON format.")
            except FileNotFoundError:
                print(f"Skipping item {i}. Reason: Image file not found at {im_path}")
            except Exception as e:
                print(f"ERROR processing item {im_path} in DePlot stage: {e}")

            # Save progress immediately
            self.save_jsonl(output_deplot_path, [result_entry], mode='a+')

        print(f"--- DePlot Stage Complete. Results saved to: {output_deplot_path} ---")

    # ### --- STAGE 2: QWEN EXECUTION (MODIFIED) --- ###
    def run_qwen_stage(self, deplot_file_path, output_qwen_path, run_name):
        """
        Runs the Qwen (Base model) comparison using pre-computed DePlot results.
        """
        print(f"--- Starting Qwen Stage for: {run_name} ---")
        print(f"Loading DePlot results from: {deplot_file_path}")

        deplot_results = self.load_test_file(deplot_file_path, mode='r')
        if not deplot_results:
            print(f"No DePlot results found at {deplot_file_path}. Cannot run Qwen stage.")
            return

        ckpt_index = 0
        if os.path.exists(output_qwen_path):
            try:
                ckpt_index = len(self.load_test_file(output_qwen_path, mode='r'))
                if ckpt_index > 0:
                    print(f"Qwen progress file found. Resuming from sample {ckpt_index}...")
            except Exception as e:
                print(f"Could not read Qwen checkpoint, starting from 0. Error: {e}")
        else:
            print("No Qwen progress file found. Starting from sample 0.")

        hint = 'The key information in the chart has been extracted as below:\n{}\n'
        all_results = []
        if ckpt_index > 0:
            print("Loading previous Qwen results for final summary...")
            all_results = self.load_test_file(output_qwen_path, mode='r')[:ckpt_index]

        for i in tqdm(range(ckpt_index, len(deplot_results)), desc="Stage 2: Qwen (Base Model)"):
            sample = deplot_results[i]
            im_path = os.path.join(self.image_root, sample['imgname'])

            # Use a deepcopy to ensure original sample data is preserved
            result_entry = copy.deepcopy(sample)

            try:
                question = sample['question']
                ground_truth = str(sample['label'])
                deplot_table = sample.get('deplot_table', '[Data table not available]') # Use .get for safety

                question_with_hint = f"{hint.format(deplot_table)}\nQuestion: {question}"

                image = Image.open(im_path).convert('RGB')

                # --- BASE MODEL (CoT) ---
                start_time = time.time()
                base_reasoning = self.run_cot(image, question_with_hint)
                base_time = time.time() - start_time
                base_answer = self.extract_final_answer(base_reasoning)
                base_correct = self.calculate_relaxed_accuracy(base_answer, ground_truth)

                # --- GoT (REMOVED) ---

                result_entry.update({
                    "base_reasoning": base_reasoning, "base_answer": base_answer,
                    "base_time_s": base_time, "base_is_correct": base_correct,
                })

            except FileNotFoundError:
                print(f"Skipping item {i}. Reason: Image file not found at {im_path}")
                continue
            except Exception as e:
                print(f"ERROR processing item {im_path} in Qwen stage: {e}")
                result_entry.update({"base_answer": f"Error: {e}", "base_reasoning": f"Error: {e}"})

            # Save Qwen progress immediately
            self.save_jsonl(output_qwen_path, [result_entry], mode='a+')
            all_results.append(result_entry)

        # --- Print Final Summary ---
        self._print_summary_for_type(run_name, all_results)

    # ### --- Helper: Final Summary Printer (MODIFIED) --- ###
    def _print_summary_for_type(self, title, results_list):
        """Helper function to print a formatted summary."""
        print("\n" + "="*50)
        print(f"--- FINAL SUMMARY FOR: {title} ---")

        if not results_list:
            print("No samples processed for this type.")
            print("="*50)
            return

        total_samples = len(results_list)

        # Calculate Base metrics
        base_correct_list = [r['base_is_correct'] for r in results_list if 'base_is_correct' in r]
        base_time_list = [r['base_time_s'] for r in results_list if 'base_time_s' in r]

        base_accuracy = np.mean(base_correct_list) * 100 if base_correct_list else 0.0
        avg_base_time = np.mean(base_time_list) if base_time_list else 0.0

        print(f"Total samples processed: {total_samples}")
        print("-" * 30)
        print(f"Base Model (CoT) Accuracy: {base_accuracy:.2f}%")
        print(f"Average Base Model Time: {avg_base_time:.2f} s")
        print("="*50)

In [13]:
from transformers import AutoTokenizer, AutoModelForCausalLM
!pip install transformers==4.30.2 accelerate einops pillow bitsandbytes tiktoken transformers_stream_generator
class QwenChartBenchTester(ChartBenchTester):
    """
    This class now runs in two stages:
    1. run_deplot_stage: Generates data tables from images.
    2. run_qwen_stage: Runs Base model (CoT) with a math-solver step.
    """

    # --- Part 1: Model Loading & Inference ---

    def load_model(self):
        """Loads both DePlot and Qwen models."""
        print("Loading DePlot model...")
        self.deplot_processor = AutoProcessor.from_pretrained(CKPT_PATH)
        self.deplot_model = Pix2StructForConditionalGeneration.from_pretrained(CKPT_PATH).to("cuda")
        print("DePlot model loaded.")

        print("Loading Qwen-VL model...")
        self.qwen_processor = AutoProcessor.from_pretrained(QWEN_MODEL_PATH, trust_remote_code=True)
        self.qwen_model = AutoModelForVision2Seq.from_pretrained(
            QWEN_MODEL_PATH,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            device_map="auto",
            trust_remote_code=True
        )
        print("Qwen-VL model loaded.")

    def get_model_response(self, image: Image.Image, prompt_text: str) -> str:
        """Helper function to run inference with Qwen-VL."""
        messages = [
            {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
        ]
        text = self.qwen_processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        inputs = self.qwen_processor(
            text=text, images=[image], return_tensors="pt"
        )
        inputs = inputs.to(self.qwen_model.device, torch.float16)

        with torch.inference_mode():
            output_ids = self.qwen_model.generate(**inputs, max_new_tokens=512, do_sample=False)

        response_ids = output_ids[:, inputs.input_ids.shape[1]:]
        response = self.qwen_processor.decode(response_ids[0], skip_special_tokens=True).strip()
        return response

    # --- ADDED: "Math Solver" Simulation ---
    def _solve_math_expression(self, expression_str: str) -> str:
        """
        Simulates a 'math solver small llm' by safely evaluating
        a mathematical expression.
        """
        try:
            # Sanitize: Allow only digits, dots, +, -, *, /, (, )
            allowed_chars = "0123456789.+-*/() "
            # Check if any character in the expression is NOT in the allowed list
            if any(char not in allowed_chars for char in expression_str):
                # Remove whitespace from the offending string for a cleaner error
                offending_chars = "".join(set(c for c in expression_str if c not in allowed_chars))
                return f"[Error: Invalid characters in expression: {offending_chars}]"

            # WARNING: eval() is used here as a placeholder for a "math solver".
            # In a real-world scenario, a safer parsing library should be used.
            result = eval(expression_str)
            return str(result)
        except Exception as e:
            return f"[Error: Could not solve '{expression_str}'. Details: {e}]"


    # --- "BASE MODEL" (CoT) - MODIFIED for Math Solver ---
    def run_cot(self, image: Image.Image, question: str) -> str:
        """
        Runs the base Chain-of-Thought (CoT) model, with a simulated
        call to a 'math solver llm' if a calculation is detected.
        """

        # --- Step 1: Initial Reasoning Prompt ---
        # Ask the model to reason and output a specific [CALCULATE: ...] tag
        # if math is needed.
        prompt_step_1 = f"""You are an expert chart analyst.
Your task is to answer a question about the provided chart.
Analyze the chart step-by-step.
If you need to perform a mathematical calculation to answer the question,
output ONLY the tag `[CALCULATE: expression]`.
For example: `[CALCULATE: 100 - (25 + 30)]`.
Do not solve it yourself.
If no calculation is needed, or if you can answer directly, provide your
full reasoning and then end with: "The final answer is: [Your Answer]".

Question: {question}
Let's think step by step:
"""
        initial_reasoning = self.get_model_response(image, prompt_step_1)

        # --- Step 2: Check for Math Tag ---
        # Use re.DOTALL to make '.' match newlines, in case the tag is split
        calc_match = re.search(r"\[CALCULATE:\s*(.*?)\s*\]", initial_reasoning, re.IGNORECASE | re.DOTALL)

        if calc_match:
            # --- Step 3: Call "Math Solver" (Simulation) ---
            expression_to_solve = calc_match.group(1).strip()
            print(f"  [run_cot] Detected math expression: {expression_to_solve}")
            math_result = self._solve_math_expression(expression_to_solve)
            print(f"  [run_cot] Math result: {math_result}")

            if "[Error:" in math_result:
                # If math solver fails, return the error
                return f"Step 1 Reasoning: {initial_reasoning}\nMath Solver Error: {math_result}\nThe final answer is: [Math Error]"

            # --- Step 4: Final Answer Prompt ---
            # Feed the result back to the model to get the final answer
            prompt_step_2 = f"""You are an expert chart analyst.
You were asked the following question about a chart:
{question}

Your initial analysis determined that a calculation was needed.
Calculation: `{expression_to_solve}`
Result: `{math_result}`

Now, use this result to provide the final answer.
Provide your reasoning and end with: "The final answer is: [Your Answer]".

Reasoning using the calculation:
"""
            final_reasoning = self.get_model_response(image, prompt_step_2)

            # Combine the reasoning for logging
            combined_reasoning = f"--- Step 1: Math Detection ---\n{initial_reasoning}\n\n--- Step 2: Math Solver ---\nInput: {expression_to_solve}\nOutput: {math_result}\n\n--- Step 3: Final Answer ---\n{final_reasoning}"
            return combined_reasoning
        else:
            # No calculation tag was found.
            # Return the initial reasoning as-is.
            return initial_reasoning

    # --- GoT Methods (REMOVED) ---

    def extract_final_answer(self, reasoning_text: str) -> str:
        """Extracts the final answer from the reasoning text."""
        match = re.search(r"The final answer is:\s*(.*)", reasoning_text, re.IGNORECASE | re.DOTALL)
        if match:
            return match.group(1).strip()
        else:
            # Fallback: return the last non-empty line
            lines = [line.strip() for line in reasoning_text.split('\n') if line.strip()]
            return lines[-1] if lines else ""

    def calculate_relaxed_accuracy(self, predicted: str, ground_truth: str) -> bool:
        """Calculates relaxed accuracy (substring match)."""
        pred_norm = re.sub(r"[^\w\s\.]", "", str(predicted)).lower().strip()
        gt_norm = re.sub(r"[^\w\s\.]", "", str(ground_truth)).lower().strip()
        if not pred_norm or not gt_norm:
            return False
        return (pred_norm in gt_norm) or (gt_norm in pred_norm)

    # --- Part 2: Evaluation Stages ---

    # ### --- STAGE 1: DePLOT EXECUTION --- ###
    def run_deplot_stage(self, test_file_path, output_deplot_path):
        """
        Runs ONLY the DePlot stage to generate data tables from images.
        Saves results to an intermediate JSONL file.
        """
        print(f"--- Starting DePlot Stage for: {test_file_path} ---")
        test_data = self.load_test_file(test_file_path)
        if not test_data:
            return

        ckpt_index = 0
        if os.path.exists(output_deplot_path):
            try:
                ckpt_index = len(self.load_test_file(output_deplot_path, mode='r'))
                if ckpt_index > 0:
                    print(f"DePlot progress file found. Resuming from sample {ckpt_index}...")
            except Exception as e:
                print(f"Could not read DePlot checkpoint, starting from 0. Error: {e}")
        else:
            print("No DePlot progress file found. Starting from sample 0.")

        deplot_prompt = "Generate the underlying data table for this chart:"
        deplot_task_prefix = "<GRAPH_TO_TEXT>"

        for i in tqdm(range(ckpt_index, len(test_data)), desc="Stage 1: DePlot"):
            sample = test_data[i]
            im_path = os.path.join(self.image_root, sample['imgname'])
            result_entry = copy.deepcopy(sample)

            try:
                image = Image.open(im_path).convert('RGB')
                inputs = self.deplot_processor(images=image, text=deplot_prompt, return_tensors="pt").to("cuda")

                with torch.inference_mode():
                    generated_ids = self.deplot_model.generate(**inputs, max_new_tokens=1024)

                generated_text = self.deplot_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
                result_entry['deplot_table'] = generated_text.replace(deplot_task_prefix, "").strip()

            except (KeyError, TypeError) as e:
                print(f"Skipping item {i}. Reason: Key error '{e}'. Check JSON format.")
            except FileNotFoundError:
                print(f"Skipping item {i}. Reason: Image file not found at {im_path}")
            except Exception as e:
                print(f"ERROR processing item {im_path} in DePlot stage: {e}")

            # Save progress immediately
            self.save_jsonl(output_deplot_path, [result_entry], mode='a+')

        print(f"--- DePlot Stage Complete. Results saved to: {output_deplot_path} ---")

    # ### --- STAGE 2: QWEN EXECUTION (Unchanged) --- ###
    # This method is unchanged, but its call to `self.run_cot`
    # will now use the new math-solver logic.
    def run_qwen_stage(self, deplot_file_path, output_qwen_path, run_name):
        """
        Runs the Qwen (Base model) comparison using pre-computed DePlot results.
        """
        print(f"--- Starting Qwen Stage for: {run_name} ---")
        print(f"Loading DePlot results from: {deplot_file_path}")

        deplot_results = self.load_test_file(deplot_file_path, mode='r')
        if not deplot_results:
            print(f"No DePlot results found at {deplot_file_path}. Cannot run Qwen stage.")
            return

        ckpt_index = 0
        if os.path.exists(output_qwen_path):
            try:
                ckpt_index = len(self.load_test_file(output_qwen_path, mode='r'))
                if ckpt_index > 0:
                    print(f"Qwen progress file found. Resuming from sample {ckpt_index}...")
            except Exception as e:
                print(f"Could not read Qwen checkpoint, starting from 0. Error: {e}")
        else:
            print("No Qwen progress file found. Starting from sample 0.")

        hint = 'The key information in the chart has been extracted as below:\n{}\n'
        all_results = []
        if ckpt_index > 0:
            print("Loading previous Qwen results for final summary...")
            all_results = self.load_test_file(output_qwen_path, mode='r')[:ckpt_index]

        for i in tqdm(range(ckpt_index, len(deplot_results)), desc="Stage 2: Qwen (Base Model)"):
            sample = deplot_results[i]
            im_path = os.path.join(self.image_root, sample['imgname'])

            # Use a deepcopy to ensure original sample data is preserved
            result_entry = copy.deepcopy(sample)

            try:
                question = sample['query']
                ground_truth = str(sample['label'])
                deplot_table = sample.get('deplot_table', '[Data table not available]') # Use .get for safety

                question_with_hint = f"{hint.format(deplot_table)}\nQuestion: {question}"

                image = Image.open(im_path).convert('RGB')

                # --- BASE MODEL (CoT) ---
                # This call now invokes the new 2-step math logic
                start_time = time.time()
                base_reasoning = self.run_cot(image, question_with_hint)
                base_time = time.time() - start_time
                base_answer = self.extract_final_answer(base_reasoning)
                base_correct = self.calculate_relaxed_accuracy(base_answer, ground_truth)

                # --- GoT (REMOVED) ---

                result_entry.update({
                    "base_reasoning": base_reasoning, "base_answer": base_answer,
                    "base_time_s": base_time, "base_is_correct": base_correct,
                })

            except FileNotFoundError:
                print(f"Skipping item {i}. Reason: Image file not found at {im_path}")
                continue
            except Exception as e:
                print(f"ERROR processing item {im_path} in Qwen stage: {e}")
                result_entry.update({"base_answer": f"Error: {e}", "base_reasoning": f"Error: {e}"})

            # Save Qwen progress immediately
            self.save_jsonl(output_qwen_path, [result_entry], mode='a+')
            all_results.append(result_entry)

        # --- Print Final Summary ---
        self._print_summary_for_type(run_name, all_results)

    # ### --- Helper: Final Summary Printer (Unchanged) --- ###
    def _print_summary_for_type(self, title, results_list):
        """Helper function to print a formatted summary."""
        print("\n" + "="*50)
        print(f"--- FINAL SUMMARY FOR: {title} ---")

        if not results_list:
            print("No samples processed for this type.")
            print("="*50)
            return

        total_samples = len(results_list)

        # Calculate Base metrics
        base_correct_list = [r['base_is_correct'] for r in results_list if 'base_is_correct' in r]
        base_time_list = [r['base_time_s'] for r in results_list if 'base_time_s' in r]

        base_accuracy = np.mean(base_correct_list) * 100 if base_correct_list else 0.0
        avg_base_time = np.mean(base_time_list) if base_time_list else 0.0

        print(f"Total samples processed: {total_samples}")
        print("-" * 30)
        print(f"Base Model (CoT) Accuracy: {base_accuracy:.2f}%")
        print(f"Average Base Model Time: {avg_base_time:.2f} s")
        print("="*50)


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [14]:
def main():
    # --- 1. Initialize Tester and Load Models ---
    tester = QwenChartBenchTester(
        test_index=None, # Will be set per-run
        sys_prompt_acc=sys_prompt['blip2 style'],
        sys_prompt_nqa=sys_prompt['chartqa']
    )
    tester.load_model()
    tester.reset_image_root(IMG_ROOT)

    # --- 2. Define the two test runs ---
    test_runs = [
        {
            "name": "Augmented Test",
            "input_file": TEST_AUGMENTED_FILE,
            "deplot_output_file": DEPLOT_OUTPUT_AUGMENTED,
            "qwen_output_file": QWEN_OUTPUT_AUGMENTED
        },
        {
            "name": "Human Test",
            "input_file": TEST_HUMAN_FILE,
            "deplot_output_file": DEPLOT_OUTPUT_HUMAN,
            "qwen_output_file": QWEN_OUTPUT_HUMAN
        }
    ]

    # # --- 3. Execute STAGE 1 (DePlot) for both ---
    # print("\n" + "#"*70)
    # print("### STARTING STAGE 1: DePLOT (Data Table Generation) ###")
    # print("#"*70 + "\n")
    # for run in test_runs:
    #     print(f"--- Running DePlot for: {run['name']} ---")
    #     tester.test_index = run['input_file'] # Set the correct input
    #     tester.run_deplot_stage(run['input_file'], run['deplot_output_file'])
    #     print(f"--- Finished DePlot for: {run['name']} ---")

    # --- 4. Execute STAGE 2 (Qwen) for both ---
    print("\n" + "#"*70)
    print("### STARTING STAGE 2: Qwen (Base Model) ###")
    print("#"*70 + "\n")
    for run in test_runs:
        print(f"--- Running Qwen for: {run['name']} ---")
        # Pass the intermediate DePlot file, the final save path, and the run name
        tester.run_qwen_stage(run['deplot_output_file'], run['qwen_output_file'], run['name'])
        print(f"\n### FINISHED RUN: {run['name']} ###")

    print("\nAll evaluations finished.")

if __name__ == "__main__":
    main()

Loading DePlot model...


ImportError: cannot import name 'is_flax_available' from 'transformers.utils' (/home/g2/.pyenv/versions/3.10.13/lib/python3.10/site-packages/transformers/utils/__init__.py)

In [None]:
!pip uninstall transformers tokenizers -y
!pip install -U git+https://github.com/huggingface/transformers
!pip install accelerate tokenizers