# Chartqwen Error Bar Detection - Inference

This notebook performs inference using the fine-tuned Chartqwen model for error bar detection in scientific plots.

## Task:
- **Input**: Scientific plot image + data point coordinates (x, y)
- **Output**: Error bar distances (topBarPixelDistance, bottomBarPixelDistance)

## Model:
- **Base**: Qwen2.5-VL-7B-Instruct
- **Fine-tuned**: Sayeem26s/Chartqwen
- **Method**: LoRA adapter loaded on top of base model

## 0. Install Required Packages

In [None]:
# Install libraries for VLM inference
!pip install transformers accelerate bitsandbytes -q
!pip install peft -q
!pip install pandas pillow tqdm -q
!pip install qwen-vl-utils -q
print("All libraries installed successfully!")

## 1. Setup and Imports

In [None]:
# Import Libraries
import torch
import pandas as pd
import os
import gc
import json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from typing import List, Dict, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

# For image processing
from PIL import Image
from tqdm import tqdm

# Transformers
from transformers import (
    Qwen2_5_VLForConditionalGeneration,
    AutoProcessor,
    BitsAndBytesConfig,
)

# PEFT for LoRA
from peft import PeftModel

# Check GPU
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU Memory: {gpu_mem:.1f} GB")

print("\nLibraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")

## 2. Configuration and Data Paths

In [None]:
# Data paths (Kaggle format)
BASE_PATH = "/kaggle/input/graph-plots"
TEST_IMAGES = os.path.join(BASE_PATH, "Test", "images")
TEST_INPUT_LABELS = os.path.join(BASE_PATH, "Test", "test_labels")  # Input: x,y only
TEST_GROUND_TRUTH = os.path.join(BASE_PATH, "Test", "labels")       # Ground truth: with error bars

# Model configuration
BASE_MODEL = "Qwen/Qwen2.5-VL-7B-Instruct"
FINETUNED_MODEL = "Sayeem26s/Chartqwen"  # Your fine-tuned model

# Inference settings
IMAGE_MAX_SIZE = 768
MAX_NEW_TOKENS = 2048
TEMPERATURE = 0.0  # Deterministic for consistent results

print(f"Base Model: {BASE_MODEL}")
print(f"Fine-tuned Model: {FINETUNED_MODEL}")
print(f"Test images: {TEST_IMAGES}")
print(f"Test input labels: {TEST_INPUT_LABELS}")
print(f"Ground truth: {TEST_GROUND_TRUTH}")

## 3. Load Fine-tuned Model with LoRA Adapter

In [None]:
def load_chartqwen_model():
    """
    Load Chartqwen fine-tuned model with LoRA adapter.
    Uses FP16 precision for stable inference.
    """
    print("\n" + "="*60)
    print("LOADING CHARTQWEN MODEL")
    print("="*60)
    
    print(f"\nLoading base model: {BASE_MODEL}")
    print("This may take 2-3 minutes...")
    
    # Load processor
    processor = AutoProcessor.from_pretrained(FINETUNED_MODEL, trust_remote_code=True)
    print("Processor loaded!")
    
    # Load base model with FP16
    base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        BASE_MODEL,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    
    print(f"Loading LoRA adapter from: {FINETUNED_MODEL}")
    
    # Load fine-tuned LoRA adapter
    model = PeftModel.from_pretrained(
        base_model,
        FINETUNED_MODEL,
        torch_dtype=torch.float16,
    )
    
    # Merge LoRA weights for faster inference
    model = model.merge_and_unload()
    
    model.eval()
    print("Model loaded with LoRA adapter!")
    
    # Print memory usage
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1e9
        print(f"GPU Memory Used: {allocated:.2f} GB")
    
    return model, processor


# Load the model
model, processor = load_chartqwen_model()

print("\n" + "="*60)
print("MODEL READY FOR INFERENCE")
print("="*60)

## 4. Define System Prompt

In [None]:
SYSTEM_PROMPT = """You are a precise error bar detection system for scientific plots.
Given an image of a scientific plot and data point coordinates, detect the error bars.
For each point, output the pixel distance from the data point to the top and bottom of the error bar.
If no error bar exists for a point, output 0 for both distances."""

print("System prompt defined!")

## 5. Helper Functions

In [None]:
def load_test_input(json_path):
    """Load test input JSON (contains only x,y coordinates)"""
    with open(json_path, 'r') as f:
        return json.load(f)

def load_ground_truth(json_path):
    """Load ground truth JSON (contains error bar distances)"""
    with open(json_path, 'r') as f:
        return json.load(f)

def load_image_as_pil(image_path):
    """Load image as PIL Image"""
    return Image.open(image_path).convert('RGB')

def create_input_prompt(input_points: List[Dict]) -> str:
    """
    Create the input prompt with data point coordinates.
    """
    points_str = json.dumps(input_points, indent=2)
    
    prompt = f"""Analyze this scientific plot image and detect error bars for the following data points:

{points_str}

For each point, measure:
- topBarPixelDistance: pixel distance from data point to top of error bar (0 if none)
- bottomBarPixelDistance: pixel distance from data point to bottom of error bar (0 if none)

Output as JSON array:
[
  {{"x": <x>, "y": <y>, "topBarPixelDistance": <top>, "bottomBarPixelDistance": <bottom>}}
]"""
    
    return prompt

def parse_response(response_text: str, original_points: List[Dict]) -> Optional[Dict]:
    """
    Parse model response to extract error bar measurements.
    """
    try:
        # Clean response
        cleaned = response_text.strip()
        
        # Remove markdown code blocks
        if '```json' in cleaned:
            start = cleaned.find('```json') + 7
            end = cleaned.find('```', start)
            if end > start:
                cleaned = cleaned[start:end].strip()
        elif '```' in cleaned:
            start = cleaned.find('```') + 3
            end = cleaned.find('```', start)
            if end > start:
                cleaned = cleaned[start:end].strip()
        
        # Find JSON
        if cleaned.startswith('['):
            json_str = cleaned
        else:
            start_idx = cleaned.find('[')
            end_idx = cleaned.rfind(']') + 1
            if start_idx >= 0 and end_idx > start_idx:
                json_str = cleaned[start_idx:end_idx]
            else:
                return None
        
        # Parse JSON
        parsed = json.loads(json_str)
        
        # Convert to standard format
        measurements = []
        for item in parsed:
            x = float(item.get('x', 0))
            y = float(item.get('y', 0))
            top_dist = float(item.get('topBarPixelDistance', 0))
            bottom_dist = float(item.get('bottomBarPixelDistance', 0))
            
            measurements.append({
                "data_point": {"x": x, "y": y},
                "upper_error_bar": {"x": x, "y": y - top_dist},
                "lower_error_bar": {"x": x, "y": y + bottom_dist},
                "topBarPixelDistance": top_dist,
                "bottomBarPixelDistance": bottom_dist
            })
        
        return {"measurements": measurements}
        
    except json.JSONDecodeError as e:
        print(f"JSON parse error: {e}")
        return None
    except Exception as e:
        print(f"Parse error: {e}")
        return None

print("Helper functions defined!")

## 6. Model Inference Function

In [None]:
def infer_error_bars(image_path: str, data_points: List[Dict]) -> Optional[Dict]:
    """
    Infer error bars for given data points in an image.
    
    Args:
        image_path: Path to the plot image
        data_points: List of {"x": float, "y": float}
    
    Returns:
        Dict with measurements or None if failed
    """
    try:
        # Load and resize image
        image = Image.open(image_path).convert('RGB')
        if max(image.size) > IMAGE_MAX_SIZE:
            ratio = IMAGE_MAX_SIZE / max(image.size)
            new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
            image = image.resize(new_size, Image.BILINEAR)
        
        # Create prompt
        input_prompt = create_input_prompt(data_points)
        
        # Create messages
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": f"{SYSTEM_PROMPT}\n\n{input_prompt}"}
                ]
            }
        ]
        
        # Apply chat template
        text = processor.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        # Process inputs
        inputs = processor(
            text=[text],
            images=[image],
            padding=True,
            return_tensors="pt"
        ).to(model.device)
        
        # Generate
        num_points = len(data_points)
        max_tokens = min(MAX_NEW_TOKENS, max(512, num_points * 80))
        
        with torch.no_grad():
            generated_ids = model.generate(
                **inputs,
                max_new_tokens=max_tokens,
                do_sample=False,
                temperature=TEMPERATURE,
                num_beams=1,
                pad_token_id=processor.tokenizer.pad_token_id,
            )
        
        # Decode
        generated_ids_trimmed = [
            out_ids[len(in_ids):]
            for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        response = processor.batch_decode(
            generated_ids_trimmed,
            skip_special_tokens=True
        )[0]
        
        # Parse response
        result = parse_response(response, data_points)
        
        # Cleanup
        del inputs, generated_ids, image
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        return result
        
    except Exception as e:
        print(f"Inference error: {e}")
        return None

print("Inference function defined!")

## 7. Convert to Output Format

In [None]:
def convert_vlm_to_standard_format(result: Dict, line_name: str) -> Dict:
    """
    Convert VLM measurements to standard prediction format with pixel distances.
    """
    points = []
    
    measurements = result.get('measurements', [])
    
    for measure in measurements:
        data_pt = measure['data_point']
        upper_bar = measure['upper_error_bar']
        lower_bar = measure['lower_error_bar']
        
        x = data_pt['x']
        y = data_pt['y']
        
        # Calculate pixel distances
        top_dist = abs(y - upper_bar['y'])  # Distance to upper error bar
        bottom_dist = abs(lower_bar['y'] - y)  # Distance to lower error bar
        dev_dist = max(top_dist, bottom_dist)
        
        points.append({
            "x": x,
            "y": y,
            "label": "",
            "topBarPixelDistance": float(top_dist),
            "bottomBarPixelDistance": float(bottom_dist),
            "deviationPixelDistance": float(dev_dist)
        })
    
    return {
        "label": {"lineName": line_name},
        "points": points
    }

def convert_to_output_format(image_file: str, predictions: List[Dict]) -> Dict:
    """
    Convert to final output format with error bar endpoints.
    """
    error_bars = []
    
    for pred_line in predictions:
        line_name = pred_line.get('label', {}).get('lineName', '')
        pred_points = [p for p in pred_line.get('points', []) 
                      if p.get('label', '') not in ['xmin', 'xmax', 'ymin', 'ymax']]
        
        points_data = []
        for point in pred_points:
            x = point['x']
            y = point['y']
            top_dist = point['topBarPixelDistance']
            bottom_dist = point['bottomBarPixelDistance']
            
            point_data = {
                "data_point": {"x": x, "y": y},
                "upper_error_bar": {"x": x, "y": y - top_dist},
                "lower_error_bar": {"x": x, "y": y + bottom_dist}
            }
            
            points_data.append(point_data)
        
        line_data = {
            "lineName": line_name,
            "points": points_data
        }
        
        error_bars.append(line_data)
    
    return {
        "image_file": image_file,
        "model": "Chartqwen",
        "error_bars": error_bars
    }

print("Format conversion functions defined!")

## 8. Test on Sample Image

In [None]:
# Load test data
test_label_files = sorted([f for f in os.listdir(TEST_INPUT_LABELS) if f.endswith('.json')])[:1]

if test_label_files:
    test_file = test_label_files[0]
    print(f"Testing on: {test_file}\n")
    
    # Load input labels (x, y only)
    test_input = load_test_input(os.path.join(TEST_INPUT_LABELS, test_file))
    
    image_file = test_input['image_file']
    image_path = os.path.join(TEST_IMAGES, image_file)
    
    # Get data points
    data_points = []
    for line_data in test_input.get('data_points', []):
        for pt in line_data.get('points', []):
            data_points.append({"x": round(pt['x'], 1), "y": round(pt['y'], 1)})
    
    print(f"Image: {image_path}")
    print(f"Data points: {len(data_points)}")
    if data_points:
        print(f"First point: {data_points[0]}")
    
    # Run inference
    print("\nRunning inference...")
    result = infer_error_bars(image_path, data_points)
    
    if result and 'measurements' in result:
        print(f"\nInference successful!")
        print(f"Got {len(result['measurements'])} measurements")
        
        print("\nFirst 3 measurements:")
        for i, m in enumerate(result['measurements'][:3]):
            print(f"  [{i+1}] Point: ({m['data_point']['x']:.1f}, {m['data_point']['y']:.1f})")
            print(f"       Top: {m['topBarPixelDistance']:.1f}px, Bottom: {m['bottomBarPixelDistance']:.1f}px")
    else:
        print("Inference failed")
else:
    print("No test files found")

## 9. Process All Test Data

In [None]:
# Process all test files
all_test_files = sorted([f for f in os.listdir(TEST_INPUT_LABELS) if f.endswith('.json')])
print(f"Processing {len(all_test_files)} test files...\n")

all_predictions = {}
all_results = []
failed_count = 0
processed_count = 0

for i, test_file in enumerate(tqdm(all_test_files, desc="Processing test files")):
    try:
        # Load input
        test_input = load_test_input(os.path.join(TEST_INPUT_LABELS, test_file))
        
        image_file = test_input['image_file']
        image_path = os.path.join(TEST_IMAGES, image_file)
        
        # Get all data points
        all_points = []
        for line_data in test_input.get('data_points', []):
            for pt in line_data.get('points', []):
                all_points.append({"x": round(pt['x'], 1), "y": round(pt['y'], 1)})
        
        # Run inference
        result = infer_error_bars(image_path, all_points)
        
        if result and 'measurements' in result:
            all_predictions[test_file] = {
                'image_file': image_file,
                'measurements': result['measurements']
            }
            processed_count += 1
        else:
            failed_count += 1
        
        # Progress
        if (i + 1) % 10 == 0:
            print(f"âœ“ Processed {i+1}/{len(all_test_files)} | Success: {processed_count} | Failed: {failed_count}")
        
        # Clear cache periodically
        if (i + 1) % 5 == 0 and torch.cuda.is_available():
            torch.cuda.empty_cache()
            
    except Exception as e:
        failed_count += 1
        if failed_count <= 5:
            print(f"Error on {test_file}: {e}")

print(f"\n{'='*60}")
print(f"PROCESSING COMPLETE")
print(f"Processed: {processed_count}")
print(f"Failed: {failed_count}")
print(f"{'='*60}")

## 10. Evaluation Metrics

In [None]:
def calculate_metrics(predictions: Dict, ground_truth_dir: str) -> Dict:
    """
    Calculate evaluation metrics.
    """
    all_top_errors = []
    all_bottom_errors = []
    
    for json_file, pred_data in predictions.items():
        try:
            # Load ground truth
            gt_path = os.path.join(ground_truth_dir, json_file)
            if not os.path.exists(gt_path):
                continue
            
            with open(gt_path, 'r') as f:
                gt_data = json.load(f)
            
            # Collect all GT points
            gt_points = []
            for line_data in gt_data:
                for pt in line_data.get('points', []):
                    if pt.get('label', '') not in ['xmin', 'xmax', 'ymin', 'ymax']:
                        gt_points.append(pt)
            
            # Compare with predictions
            pred_measurements = pred_data['measurements']
            
            for pred_m, gt_pt in zip(pred_measurements, gt_points):
                pred_top = pred_m.get('topBarPixelDistance', 0)
                pred_bottom = pred_m.get('bottomBarPixelDistance', 0)
                gt_top = gt_pt.get('topBarPixelDistance', 0)
                gt_bottom = gt_pt.get('bottomBarPixelDistance', 0)
                
                all_top_errors.append(abs(pred_top - gt_top))
                all_bottom_errors.append(abs(pred_bottom - gt_bottom))
                
        except Exception as e:
            continue
    
    if not all_top_errors:
        return None
    
    all_mean_errors = [(t + b) / 2 for t, b in zip(all_top_errors, all_bottom_errors)]
    
    metrics = {
        'num_points': len(all_top_errors),
        'mean_top_error': np.mean(all_top_errors),
        'mean_bottom_error': np.mean(all_bottom_errors),
        'mean_overall_error': np.mean(all_mean_errors),
        'median_top_error': np.median(all_top_errors),
        'median_bottom_error': np.median(all_bottom_errors),
        'std_top_error': np.std(all_top_errors),
        'std_bottom_error': np.std(all_bottom_errors),
        'accuracy_5px': sum(1 for e in all_mean_errors if e <= 5) / len(all_mean_errors) * 100,
        'accuracy_10px': sum(1 for e in all_mean_errors if e <= 10) / len(all_mean_errors) * 100,
        'accuracy_20px': sum(1 for e in all_mean_errors if e <= 20) / len(all_mean_errors) * 100,
    }
    
    return metrics

# Calculate metrics
if all_predictions:
    print("Calculating evaluation metrics...")
    metrics = calculate_metrics(all_predictions, TEST_GROUND_TRUTH)
    
    if metrics:
        print("\n" + "="*60)
        print("EVALUATION RESULTS")
        print("="*60)
        print(f"\nTotal Points Evaluated: {metrics['num_points']}")
        print(f"\nPixel Error:")
        print(f"  Mean Top Error: {metrics['mean_top_error']:.2f} px")
        print(f"  Mean Bottom Error: {metrics['mean_bottom_error']:.2f} px")
        print(f"  Mean Overall Error: {metrics['mean_overall_error']:.2f} px")
        print(f"  Median Top Error: {metrics['median_top_error']:.2f} px")
        print(f"  Median Bottom Error: {metrics['median_bottom_error']:.2f} px")
        print(f"  Std Top Error: {metrics['std_top_error']:.2f} px")
        print(f"  Std Bottom Error: {metrics['std_bottom_error']:.2f} px")
        print(f"\nAccuracy:")
        print(f"  Within 5px: {metrics['accuracy_5px']:.1f}%")
        print(f"  Within 10px: {metrics['accuracy_10px']:.1f}%")
        print(f"  Within 20px: {metrics['accuracy_20px']:.1f}%")
        print("="*60)
    else:
        print("No metrics calculated - check predictions and ground truth")
else:
    print("No predictions to evaluate")

## 11. Save Predictions

In [None]:
# Save predictions
OUTPUT_PREDICTIONS_DIR = "/kaggle/working/chartqwen_predictions"
os.makedirs(OUTPUT_PREDICTIONS_DIR, exist_ok=True)

print(f"Saving {len(all_predictions)} prediction files...\n")

for json_file, pred_data in all_predictions.items():
    try:
        # Convert to output format
        output = {
            "image_file": pred_data['image_file'],
            "model": "Chartqwen",
            "error_bars": [{
                "lineName": "",
                "points": [
                    {
                        "data_point": m['data_point'],
                        "upper_error_bar": m['upper_error_bar'],
                        "lower_error_bar": m['lower_error_bar']
                    }
                    for m in pred_data['measurements']
                ]
            }]
        }
        
        output_path = os.path.join(OUTPUT_PREDICTIONS_DIR, json_file)
        with open(output_path, 'w') as f:
            json.dump(output, f, indent=2)
            
    except Exception as e:
        print(f"Error saving {json_file}: {e}")

print(f"Predictions saved to: {OUTPUT_PREDICTIONS_DIR}")

# Create ZIP
import zipfile
from datetime import datetime

zip_path = f"/kaggle/working/chartqwen_predictions_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip"

with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
    for f in os.listdir(OUTPUT_PREDICTIONS_DIR):
        if f.endswith('.json'):
            zipf.write(os.path.join(OUTPUT_PREDICTIONS_DIR, f), f"predictions/{f}")

print(f"ZIP created: {zip_path}")
print(f"\nDownload the ZIP file to get all predictions!")

## Summary

This notebook demonstrates inference using the fine-tuned **Chartqwen** model for error bar detection.

### Model Details:
- **Base Model**: Qwen2.5-VL-7B-Instruct
- **Fine-tuned**: Sayeem26s/Chartqwen
- **Method**: LoRA adapter with FP16 precision
- **Task**: Error bar detection in scientific plots

### Input Format:
- Image of scientific plot
- Data point coordinates (x, y)

### Output Format:
- Error bar pixel distances (topBarPixelDistance, bottomBarPixelDistance)
- Upper and lower error bar endpoints

### Evaluation:
- Mean pixel error
- Accuracy within 5px, 10px, 20px thresholds

### Model Hub:
- HuggingFace: `Sayeem26s/Chartqwen`