In [2]:
!pip install transformers torch pillow tqdm sentencepiece
import zipfile
import os
import torch
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
from PIL import Image
import os
import json
import re
from tqdm import tqdm
import warnings
import matplotlib.pyplot as plt



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [20]:
# Cell 3: Configuration and Setup

# --- Configuration ---

# *** IMPORTANT: PLEASE UPDATE THESE PATHS ***
IMAGE_DIR = "/home/g2/Downloads/png"
# --- DATASET FILE PATHS ---
# Cell 3: Configuration and Setup (This cell is correct, no changes)

# --- DATASET FILE PATHS ---

# 1. Base files (as .json containing a LIST)
BASE_DATASET_FILES = {
    "human": "/home/g2/Downloads/test_human.json",
    "augmented": "/home/g2/Downloads/test_augmented.json"
}

# 2. DePlot files (as .jsonl)
DEPLOT_DATA_FILES = {
    "human": "/home/g2/Downloads/deplot_human_output.jsonl",
    "augmented": "/home/g2/Downloads/deplot_augmented_output.jsonl"
}

# --- MODEL 1: MAIN QA MODEL (Pix2Struct) ---
MAIN_MODEL_ID = "google/pix2struct-base"

# --- MODEL 2: CHART CLASSIFIER (DinoV2) ---
CLASSIFIER_MODEL_ID = "/home/g2/Chart Classifier/chartqa-dinov2-finetuned"
CLASSIFIER_LABELS_PATH = "/home/g2/Chart Classifier/labels.json"


# --- NEW COMBINED PROMPT TEMPLATE ---
FINAL_COT_PROMPT = (
    "Let's think step by step to solve this problem.\n"
    "Here is the data from the chart, linearized into a table format:\n"
    "--- START OF CHART DATA ---\n"
    "{deplot_data}\n"
    "--- END OF CHART DATA ---\n\n"
    "The image is a **{chart_type}**.\n\n"
    "Using this chart data, chart type, AND the provided image, answer the following question.\n"
    "First, identify the relevant numbers from the provided chart data or the image. "
    "Second, perform the necessary calculations. "
    "Finally, state the final answer clearly.\n\n"
    "Question: {question}"
)

# --- Device Setup ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [4]:
# Cell 4: Load BOTH Models, Processors, and Labels

from transformers import (
    Pix2StructProcessor, 
    Pix2StructForConditionalGeneration,
    AutoImageProcessor, 
    AutoModelForImageClassification
)

# --- 1. Load Main QA Model (Pix2Struct) ---
print(f"Loading main QA processor: {MAIN_MODEL_ID}...")
try:
    main_processor = Pix2StructProcessor.from_pretrained(MAIN_MODEL_ID)
    print("Main processor loaded successfully.")
except Exception as e:
    print(f"Error loading main processor: {e}")
    main_processor = None

print(f"Loading main QA model: {MAIN_MODEL_ID}...")
try:
    main_model = Pix2StructForConditionalGeneration.from_pretrained(MAIN_MODEL_ID).to(device)
    main_model.eval()
    print("Main QA model loaded successfully and set to eval mode.")
except Exception as e:
    print(f"Error loading main model: {e}")
    main_model = None

# --- 2. Load Classifier Model (DinoV2) ---
print(f"Loading classifier processor: {CLASSIFIER_MODEL_ID}...")
try:
    classifier_processor = AutoImageProcessor.from_pretrained(CLASSIFIER_MODEL_ID)
    print("Classifier processor loaded successfully.")
except Exception as e:
    print(f"Error loading classifier processor: {e}")
    classifier_processor = None

print(f"Loading classifier model: {CLASSIFIER_MODEL_ID}...")
try:
    classifier_model = AutoModelForImageClassification.from_pretrained(CLASSIFIER_MODEL_ID).to(device)
    classifier_model.eval()
    print("Classifier model loaded successfully and set to eval mode.")
except Exception as e:
    print(f"Error loading classifier model: {e}")
    classifier_model = None

# --- 3. Load Classifier Labels ---
print(f"Loading classifier labels: {CLASSIFIER_LABELS_PATH}...")
try:
    with open(CLASSIFIER_LABELS_PATH, "r") as f:
        classifier_labels = json.load(f) # This is expected to be a list
    print(f"Labels loaded successfully. Found {len(classifier_labels)} classes.")
except FileNotFoundError:
    print(f"Error: Labels file not found at {CLASSIFIER_LABELS_PATH}")
    classifier_labels = None
except Exception as e:
    print(f"Error loading labels file: {e}")
    classifier_labels = None

# --- 4. Final Check ---
if not all([main_processor, main_model, classifier_processor, classifier_model, classifier_labels]):
    print("CRITICAL ERROR: Failed to load one or more components.")
    print("Please check all paths and file permissions. Evaluation cannot proceed.")

Loading main QA processor: google/pix2struct-base...
Main processor loaded successfully.
Loading main QA model: google/pix2struct-base...


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Main QA model loaded successfully and set to eval mode.
Loading classifier processor: /home/g2/Chart Classifier/chartqa-dinov2-finetuned...
Classifier processor loaded successfully.
Loading classifier model: /home/g2/Chart Classifier/chartqa-dinov2-finetuned...
Classifier model loaded successfully and set to eval mode.
Loading classifier labels: /home/g2/Chart Classifier/labels.json...
Labels loaded successfully. Found 4 classes.


In [21]:
# Cell 5: Helper Functions (CORRECTED)

import json
import os
from decimal import Decimal

# --- DATA LOADING FUNCTIONS ---

def load_jsonl_file(filepath):
    """Loads a .jsonl file, returning a list of dictionaries."""
    data = []
    if not os.path.exists(filepath):
        print(f"Error: File not found at {filepath}")
        return None
    try:
        with open(filepath, 'r') as f:
            for line in f:
                if line.strip():
                    data.append(json.loads(line))
        print(f"Successfully loaded {len(data)} items from {filepath} (JSONL)")
        return data
    except Exception as e:
        print(f"Error reading {filepath}: {e}")
        return None

def load_json_list_file(filepath):
    """Loads a .json file containing a single list."""
    if not os.path.exists(filepath):
        print(f"Error: File not found at {filepath}")
        return None
    try:
        with open(filepath, 'r') as f:
            data = json.load(f)
        if not isinstance(data, list):
            print(f"Error: File {filepath} does not contain a JSON list.")
            return None
        print(f"Successfully loaded {len(data)} items from {filepath} (JSON List)")
        return data
    except Exception as e:
        print(f"Error reading {filepath}: {e}")
        return None

def merge_datasets(base_data, deplot_data_list):
    """
    Performs a 1-to-1 parallel merge of two datasets.
    """
    if len(base_data) != len(deplot_data_list):
        print(f"Error: Mismatch in dataset lengths. Base data has {len(base_data)} items, DePlot has {len(deplot_data_list)}.")
        return []
    
    print(f"Performing 1-to-1 parallel merge on {len(base_data)} items...")
    merged_data = []
    
    for i in range(len(base_data)):
        base_item = base_data[i]
        deplot_item = deplot_data_list[i]
        
        # Start with all data from the base file (imgname, query, label)
        merged_item = base_item.copy()
        
        # --- THIS IS THE FIX ---
        # Look for "deplot_table" from your .jsonl file
        if "deplot_table" in deplot_item:
            merged_item["deplot_data"] = deplot_item["deplot_table"] # Create a new key "deplot_data"
        else:
            merged_item["deplot_data"] = None
            print(f"Warning: 'deplot_table' key missing for item {i} (imgname: {base_item.get('imgname')}).")
        # -----------------------

        merged_data.append(merged_item)
        
    print(f"Successfully merged {len(merged_data)} items.")
    return merged_data

# --- ACCURACY FUNCTION ---

def calculate_relaxed_accuracy(prediction_str, ground_truth_str):
    """
    Calculates "relaxed" accuracy.
    """
    prediction_str = str(prediction_str).strip().lower()
    ground_truth_str = str(ground_truth_str).strip().lower()

    if prediction_str == ground_truth_str:
        return True

    try:
        pred_num = Decimal(prediction_str)
        gt_num = Decimal(ground_truth_str)
        if pred_num == gt_num:
            return True
    except Exception:
        pass

    return False

In [22]:
# Cell 6: Evaluation Function (CORRECTED)

def evaluate_model(
    main_model, main_processor, 
    classifier_model, classifier_processor, classifier_labels,
    dataset, image_dir, dataset_name="evaluation"
):
    """
    Runs the full evaluation pipeline with all correct keys.
    """
    
    if not all([main_processor, main_model, classifier_processor, classifier_model, classifier_labels]):
        print(f"Skipping evaluation for {dataset_name}: Not all components are loaded.")
        return 0.0

    progress_file = f"evaluation_progress_{dataset_name}.jsonl"
    processed_imgnames = set()
    total_correct = 0
    total_predictions = 0

    # 1. Load existing progress
    if os.path.exists(progress_file):
        try:
            with open(progress_file, 'r') as f:
                for line in f:
                    if line.strip():
                        data = json.loads(line)
                        processed_imgnames.add(data['imgname'])
                        if data['is_correct']:
                            total_correct += 1
                        total_predictions += 1
            print(f"Resuming from {total_predictions} processed items for {dataset_name}...")
        except Exception as e:
            print(f"Error reading progress file {progress_file}: {e}. Starting fresh.")
            processed_imgnames = set()
            total_correct = 0
            total_predictions = 0

    print(f"Starting evaluation... {total_predictions}/{len(dataset)} items already processed.")

    # 2. Loop through remaining items
    try:
        with open(progress_file, 'a') as f_progress:
            for item in tqdm(dataset, desc=f"Evaluating {dataset_name}"):
                try:
                    imgname = item['imgname']
                    if imgname in processed_imgnames:
                        continue 

                    # --- THIS IS THE FIX ---
                    # Use the correct keys from your data files
                    ground_truth_str = str(item['label'])
                    question = item['query'] 
                    deplot_data = item['deplot_data'] # This key is created by our merge function
                    # -----------------------
                    
                    image_path = os.path.join(image_dir, imgname)
                    
                    if deplot_data is None:
                        print(f"Skipping {imgname}: 'deplot_table' was missing from .jsonl file.")
                        continue

                    # 2. Load image
                    try:
                        image = Image.open(image_path).convert("RGB")
                    except FileNotFoundError:
                        print(f"Warning: Image file not found {image_path}. Skipping item.")
                        continue
                    
                    # 3. PIPELINE STEP 1: CLASSIFY CHART TYPE
                    clf_inputs = classifier_processor(images=image, return_tensors="pt").to(device)
                    with torch.no_grad():
                        logits = classifier_model(**clf_inputs).logits
                        pred_id = logits.argmax(-1).item()
                    predicted_chart_type = classifier_labels[pred_id]
                    
                    # 4. PIPELINE STEP 2: FORMAT PROMPT
                    formatted_prompt = FINAL_COT_PROMPT.format(
                        deplot_data=deplot_data,
                        chart_type=predicted_chart_type,
                        question=question
                    )
                    
                    # 5. PIPELINE STEP 3: RUN MAIN QA MODEL
                    inputs = main_processor(images=image, text=formatted_prompt, return_tensors="pt").to(device)
                    with torch.no_grad():
                        generated_ids = main_model.generate(**inputs, max_new_tokens=512)
                    
                    generated_text = main_processor.decode(generated_ids[0], skip_special_tokens=True)

                    # 6. CALCULATE ACCURACY
                    is_correct = calculate_relaxed_accuracy(generated_text, ground_truth_str)
                    
                    if is_correct:
                        total_correct += 1
                    total_predictions += 1
                    
                    # 7. SAVE PROGRESS
                    result_record = {
                        "imgname": imgname,
                        "used_prompt": formatted_prompt,
                        "predicted_chart_type": predicted_chart_type,
                        "prediction_qa": generated_text,
                        "ground_truth_qa": ground_truth_str,
                        "is_correct": is_correct
                    }
                    f_progress.write(json.dumps(result_record) + "\n")
                    f_progress.flush()
                    
                    processed_imgnames.add(imgname)

                except KeyError as e:
                    print(f"Skipping {item.get('imgname')}: Missing expected data key: {e}")
                    continue
                except Exception as e:
                    print(f"Error during evaluation for item {item.get('imgname')}: {e}")
                    continue
                    
    except KeyboardInterrupt:
        print("Evaluation interrupted by user. Progress saved.")

    # 8. Final Accuracy Calculation
    if total_predictions == 0:
        print("No items were processed.")
        return 0.0

    accuracy = (total_correct / total_predictions) * 100
    print(f"Evaluation complete for {dataset_name}.")
    print(f"Processed: {total_predictions} items.")
    print(f"Correct: {total_correct}")
    print(f"Accuracy: {accuracy:.2f}%")

    return accuracy

In [23]:
# Cell 7: Run Evaluation (CORRECTED)

print("--- Starting Evaluation ---")
accuracies = {}

# 1. Load datasets
datasets = {}
for name in BASE_DATASET_FILES.keys():
    print(f"\nLoading dataset: {name}")
    base_data = load_json_list_file(BASE_DATASET_FILES[name])
    deplot_data = load_jsonl_file(DEPLOT_DATA_FILES[name])
    
    if base_data is None or deplot_data is None:
        print(f"Failed to load data for {name}. Skipping.")
        continue
        
    print(f"Merging {name} dataset...")
    merged_data = merge_datasets(base_data, deplot_data)
    
    if merged_data:
        # --- THIS IS THE FIX ---
        # Filter for items that have a "query" key
        filtered_data = [item for item in merged_data if item.get("query")]
        # -----------------------
        
        datasets[name] = filtered_data
        print(f"Dataset {name} loaded with {len(filtered_data)} items.") # <-- This will now be > 0
    else:
        print(f"Failed to merge dataset {name}.")

print("\n--- All Models and Data Loaded, Starting Pipeline Inference ---")

# 2. Run evaluation on each loaded dataset
for dataset_name, dataset in datasets.items():
    print(f"\n--- Evaluating: {dataset_name.upper()} ---")
    
    if not dataset:
        print("Dataset is empty. Skipping.")
        accuracies[dataset_name] = 0.0
        continue
        
    acc = evaluate_model(
        main_model=main_model,
        main_processor=main_processor,
        classifier_model=classifier_model,
        classifier_processor=classifier_processor,
        classifier_labels=classifier_labels,
        dataset=dataset,
        image_dir=IMAGE_DIR,
        dataset_name=dataset_name
    )
    accuracies[dataset_name] = acc

print("\n--- Evaluation Finished ---")
print("Final Accuracies (with Chart Classifier + Detailed CoT Prompt):")
print(json.dumps(accuracies, indent=2))

--- Starting Evaluation ---

Loading dataset: human
Successfully loaded 1250 items from /home/g2/Downloads/test_human.json (JSON List)
Successfully loaded 1250 items from /home/g2/Downloads/deplot_human_output.jsonl (JSONL)
Merging human dataset...
Performing 1-to-1 parallel merge on 1250 items...
Successfully merged 1250 items.
Dataset human loaded with 1250 items.

Loading dataset: augmented
Successfully loaded 1250 items from /home/g2/Downloads/test_augmented.json (JSON List)
Successfully loaded 1250 items from /home/g2/Downloads/deplot_augmented_output.jsonl (JSONL)
Merging augmented dataset...
Performing 1-to-1 parallel merge on 1250 items...
Successfully merged 1250 items.
Dataset augmented loaded with 1250 items.

--- All Models and Data Loaded, Starting Pipeline Inference ---

--- Evaluating: HUMAN ---
Resuming from 0 processed items for human...
Starting evaluation... 0/1250 items already processed.


Evaluating human: 100%|██████████| 1250/1250 [2:30:19<00:00,  7.22s/it] 


Evaluation complete for human.
Processed: 625 items.
Correct: 0
Accuracy: 0.00%

--- Evaluating: AUGMENTED ---
Resuming from 0 processed items for augmented...
Starting evaluation... 0/1250 items already processed.


Evaluating augmented:   2%|▏         | 20/1250 [03:26<3:31:32, 10.32s/it]

Evaluation interrupted by user. Progress saved.
Evaluation complete for augmented.
Processed: 11 items.
Correct: 0
Accuracy: 0.00%

--- Evaluation Finished ---
Final Accuracies (with Chart Classifier + Detailed CoT Prompt):
{
  "human": 0.0,
  "augmented": 0.0
}





In [9]:
# Cell 8: Plot Results

# Use the 'results' dictionary generated in the previous cell
# No need to load from progress file again

if results:
    names = list(results.keys())
    values = list(results.values())

    plt.figure(figsize=(8, 5))
    # Define colors for different datasets if needed, e.g., {'human': '#4A90E2', 'augmented': '#FF6B6B'}
    colors = ['#4A90E2' for _ in names] # Default to blue for all if no specific colors defined

    bars = plt.bar(names, values, color=colors)

    plt.xlabel("Dataset Type", fontsize=12)
    plt.ylabel("Accuracy (%)", fontsize=12)
    # Update title to reflect that it might be complete or partial based on the previous cell's run
    plt.title(f"Model Accuracy: (with DePlot + CoT) - {'Complete' if all(name in results for name in BASE_DATASET_FILES.keys()) else 'Partial'} Results", fontsize=14)
    plt.ylim(0, 100)

    # Add accuracy numbers on top of bars
    for bar in bars:
        yval = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2.0, yval + 1, f'{yval:.2f}%', ha='center', va='bottom', fontsize=11)

    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    plt.show()
else:
    print("No results available to plot.")

NameError: name 'results' is not defined