In [1]:
import json
import logging
from pathlib import Path
from typing import Dict, List, Optional
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

class QwenInference:
    def __init__(self, model_name: str = "/share/ssddata/sarimhashmi/Qwen-VL-Chat"):
        """Initialize the QwenInference class."""
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.info(f"Using device: {self.device}")
        
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                trust_remote_code=True
            )
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                device_map=self.device,
                trust_remote_code=True
            ).eval()
            logger.info("Model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load model: {str(e)}")
            raise

    def process_single_query(
        self,
        question: str,
        image_path: str,
        history: Optional[List] = None
    ) -> Dict:
        """Process a single query with image."""
        try:
            # Format query according to template
            query = [
                {'image': image_path},
                {'text': question}
            ]
            
            # Convert to model format
            formatted_query = self.tokenizer.from_list_format(query)
            
            # Get response
            response, _ = self.model.chat(
                self.tokenizer,
                query=formatted_query,
                history=history
            )
            
            return {
                'question': question,
                'image_path': image_path,
                'response': response
            }
            
        except Exception as e:
            logger.error(f"Error processing query: {str(e)}")
            return {
                'question': question,
                'image_path': image_path,
                'response': f"Error: {str(e)}"
            }

def process_jsonl_file(
    input_path: str,
    output_path: str,
    batch_size: int = 1
):
    """Process a JSONL file and save results."""
    # Initialize model
    qwen = QwenInference()
    results = []
    
    try:
        # Create output directory if it doesn't exist
        output_file = Path(output_path)
        output_file.parent.mkdir(parents=True, exist_ok=True)
        
        # Read and process JSONL file
        with open(input_path, 'r') as f:
            for line in f:
                try:
                    # Parse each line as JSON
                    entry = json.loads(line)
                    
                    # Extract question and image path
                    question = entry['question'].replace('<image>', '<image>').strip()
                    image_path = entry['image']
                    
                    # Process query
                    logger.info(f"Processing question: {question[:50]}...")
                    result = qwen.process_single_query(question, image_path)
                    
                    # Add ground truth answer
                    result['ground_truth'] = entry['answer']
                    results.append(result)
                    
                    # Save results periodically
                    if len(results) % batch_size == 0:
                        with open(output_path, 'w') as out_f:
                            json.dump(results, out_f, indent=2)
                            logger.info(f"Saved {len(results)} results to {output_path}")
                            
                except json.JSONDecodeError as e:
                    logger.error(f"Error parsing JSON line: {str(e)}")
                except Exception as e:
                    logger.error(f"Error processing entry: {str(e)}")
                    
        # Save final results
        with open(output_path, 'w') as out_f:
            json.dump(results, out_f, indent=2)
            logger.info(f"Saved final {len(results)} results to {output_path}")
            
    except Exception as e:
        logger.error(f"Error processing file: {str(e)}")
        raise

if __name__ == "__main__":
    # Example usage
    input_file = "/share/ssddata/sarimhashmi/iuxray/factuality/qwenvl_verify/unnecessary_stuff_processed_data.jsonl"
    output_file = "/share/ssddata/sarimhashmi/iuxray/factuality/qwenvl_verify/output.json"
    
    process_jsonl_file(
        input_path=input_file,
        output_path=output_file,
        batch_size=5  # Save results every 5 entries
    )

  from .autonotebook import tqdm as notebook_tqdm
2024-12-04 21:45:45,272 - __main__ - INFO - Using device: cuda
  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(
  return torch.load(checkpoint_file, map_location=map_location)
Loading checkpoint shards: 100%|██████████| 10/10 [00:15<00:00,  1.56s/it]
2024-12-04 21:46:10,034 - __main__ - INFO - Model loaded successfully
2024-12-04 21:46:10,044 - __main__ - INFO - Processing question: Is the heart size reported to be abnormal? Please ...
2024-12-04 21:46:11,190 - __main__ - INFO - Processing question: Is the heart size within normal limits on the X-ra...
2024-12-04 21:46:11,430 - __main__ - INFO - Processing question: Does the chest X-ray show a pneumothorax? Please c...
2024-12-04 21:46:11,669 - __main__ - INFO - Processing question: Is there any evidence of pleural effusion on the c...
2024-12-04 21:46:11,907 - __main__ - INFO - Processing question: Are the pulmonary vasculature and contours appe

In [2]:
import json

def normalize_answer(text):
    if text.find('.') != -1:
        text = text.split('.')[0]
    text = text.replace(',', '')
    text = text.lower()
    words = text.split()
    if 'no' in words or 'not' in words:
        return 0
    return 1

def evaluate_yes_no(predictions_path):
    # Load predictions
    with open(predictions_path, 'r') as f:
        data = json.load(f)
    
    successful_cases = []
    failed_cases = []
    pred_list = []
    labels = []
    
    # Process predictions and collect cases
    for entry in data:
        pred = normalize_answer(entry['response'])
        label = normalize_answer(entry['ground_truth'])
        
        pred_list.append(pred)
        labels.append(label)
        
        # Store case information
        case = {
            'question': entry['question'],
            'predicted': entry['response'],
            'ground_truth': entry['ground_truth']
        }
        
        if pred == label:
            successful_cases.append(case)
        else:
            failed_cases.append(case)

    # Calculate metrics
    TP, TN, FP, FN = 0, 0, 0, 0
    for pred, label in zip(pred_list, labels):
        if pred == 1 and label == 1:
            TP += 1
        elif pred == 1 and label == 0:
            FP += 1
        elif pred == 0 and label == 0:
            TN += 1
        elif pred == 0 and label == 1:
            FN += 1

    # Print confusion matrix
    print('\nConfusion Matrix:')
    print('TP\tFP\tTN\tFN')
    print(f'{TP}\t{FP}\t{TN}\t{FN}')

    # Calculate and print metrics
    precision = float(TP) / float(TP + FP) if (TP + FP) > 0 else 0
    recall = float(TP) / float(TP + FN) if (TP + FN) > 0 else 0
    f1 = 2*precision*recall / (precision + recall) if (precision + recall) > 0 else 0
    acc = (TP + TN) / (TP + TN + FP + FN)
    
    print(f'\nMetrics:')
    print(f'Accuracy:  {acc:.4f}')
    print(f'Precision: {precision:.4f}')
    print(f'Recall:    {recall:.4f}')
    print(f'F1 score:  {f1:.4f}')

    # Print example cases
    print('\n=== 10 Successful Cases ===')
    for i, case in enumerate(successful_cases[:10], 1):
        print(f'\n{i}. Question: {case["question"]}')
        print(f'   Predicted: {case["predicted"]}')
        print(f'   Ground truth: {case["ground_truth"]}')

    print('\n=== 10 Failed Cases ===')
    for i, case in enumerate(failed_cases[:10], 1):
        print(f'\n{i}. Question: {case["question"]}')
        print(f'   Predicted: {case["predicted"]}')
        print(f'   Ground truth: {case["ground_truth"]}')

# Use like this:
json_path = "/share/ssddata/sarimhashmi/iuxray/factuality/qwenvl_verify/output.json"  # replace with your json path
evaluate_yes_no(json_path)


Confusion Matrix:
TP	FP	TN	FN
485	1606	233	153

Metrics:
Accuracy:  0.2899
Precision: 0.2319
Recall:    0.7602
F1 score:  0.3554

=== 10 Successful Cases ===

1. Question: Is the heart size within normal limits on the X-ray? Please choose from the following two options: [yes, no]
<image>
   Predicted: yes
   Ground truth: - Yes.

2. Question: Are the pulmonary vasculature and contours appearing normal? Please choose from the following two options: [yes, no]
<image>
   Predicted: yes
   Ground truth: Yes

3. Question: Are the lung fields clear on the X-ray? Please choose from the following two options: [yes, no]
<image>
   Predicted: yes
   Ground truth: Yes

4. Question: Is the cardiomediastinal silhouette of normal size on the chest X-ray? - Yes Please choose from the following two options: [yes, no]
<image>
   Predicted: yes
   Ground truth: 2. Are the lungs well aerated according to the chest X-ray? - Yes

5. Question: Is there any evidence of congestive heart failure on the chest 