In [None]:
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
import json
from typing import Any, List, Dict
import logging


# Configure logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')


def load_jsonl(path: Path) -> List[Dict[str, Any]]:
    """
    Load data from a JSON Lines (JSONL) file.

    Each line in the file should be a valid JSON object.

    Args:
        path (Path): The file path to the JSONL file.

    Returns:
        List[Dict[str, Any]]: A list of dictionaries parsed from the JSONL file.

    Raises:
        FileNotFoundError: If the specified file does not exist.
        json.JSONDecodeError: If a line in the file is not valid JSON.
    """
    try:
        with path.open('r', encoding='utf-8') as f:
            return [json.loads(line) for line in f]
    except FileNotFoundError:
        raise FileNotFoundError(f"The file {path} was not found.")
    except json.JSONDecodeError as e:
        raise json.JSONDecodeError(f"Invalid JSON in {path}: {e.msg}", e.doc, e.pos)


def load_json(path: Path) -> Any:
    """
    Load data from a JSON file.

    Args:
        path (Path): The file path to the JSON file.

    Returns:
        Any: The data parsed from the JSON file.

    Raises:
        FileNotFoundError: If the specified file does not exist.
        json.JSONDecodeError: If the file contains invalid JSON.
    """
    try:
        with path.open('r', encoding='utf-8') as f:
            return json.load(f)
    except FileNotFoundError:
        raise FileNotFoundError(f"The file {path} was not found.")
    except json.JSONDecodeError as e:
        raise json.JSONDecodeError(f"Invalid JSON in {path}: {e.msg}", e.doc, e.pos)


def save_json(path: Path, data: Any) -> None:
    """
    Save data to a JSON file with indentation for readability.

    Args:
        path (Path): The file path to save the JSON data.
        data (Any): The data to be serialized and saved.

    Raises:
        IOError: If the file cannot be written.
    """
    try:
        with path.open('w', encoding='utf-8') as f:
            json.dump(data, f, indent=4)
    except IOError as e:
        raise IOError(f"Failed to write to {path}: {e.strerror}")


def save_jsonl(path: Path, data: List[Dict[str, Any]]) -> None:
    """
    Save a list of dictionaries to a JSON Lines (JSONL) file.

    Each dictionary is serialized as a separate JSON object on its own line.

    Args:
        path (Path): The file path to save the JSONL data.
        data (List[Dict[str, Any]]): A list of dictionaries to be serialized.

    Raises:
        IOError: If the file cannot be written.
    """
    try:
        with path.open('w', encoding='utf-8') as f:
            for line in data:
                json.dump(line, f)
                f.write('\n')
    except IOError as e:
        raise IOError(f"Failed to write to {path}: {e.strerror}")


# Process created json files for inference and evaluation

After running the vqa_json_creator.py script, we have json files containing multiple QA pairs for each image and consistent with the LLaVA-Med format. In this notebook, we will process these json files to create finalized json file for inference and evaluation.

## Add relative file path to the image ID

In [None]:
def add_relative_path(input_file: Path, output_file: Path, path_table_file: Path) -> None:
    """
    Add relative MIMIC image paths (including subdirectories) to JSON data based on a path table.

    This function reads the input JSON file, replaces the 'image' field in each
    entry with its corresponding relative path from the path table, and saves
    the updated data to the output JSON file.

    Args:
        input_file (Path): Path to the input JSON file containing the data.
        output_file (Path): Path to the output JSON file where updated data will be saved.
        path_table_file (Path): Path to the JSON file mapping image IDs to their relative paths.

    Raises:
        KeyError: If an image ID in the input data is not found in the path table.
        FileNotFoundError: If any of the specified files do not exist.
        json.JSONDecodeError: If any of the JSON files contain invalid JSON.
        IOError: If the output file cannot be written.
    """
    data = load_json(input_file)
    path_table = load_json(path_table_file)
    
    new_data = []
    for entry in data:
        image_id = entry.get('image')
        if image_id is None:
            raise KeyError(f"Missing 'image' key in entry: {entry}")
        if image_id not in path_table:
            raise KeyError(f"Image ID '{image_id}' not found in path table.")
        new_entry = {
            'image': path_table[image_id],
            'id': entry.get('id'),
            'conversations': list(entry.get('conversations', []))
        }
        new_data.append(new_entry)
    
    save_json(output_file, new_data)

# Usage
# input and output json files
input_json = Path("input_json") / "llava_med_instruct_mimicvqa_test_expertmodel.json"
output_json = Path("output_json") / "llava_med_instruct_mimicvqa_test_expertmodel_path.json"
# json file mapping each image ID to its relative path
path_table_json = Path("mimic_cxr_relpath.json")
add_relative_path(input_json, output_json, path_table_json)

## Split QA pairs into separate json files based on question type

In [None]:
def select_conversation(
    old_list: List[Dict[str, Any]],
    qtype: str,
    qtable: Dict[str, str]
) -> List[Dict[str, Any]]:
    """
    Extract QA pairs from an old QA list that match a specific question type at the image level.
    
    Args:
        old_list (List[Dict[str, Any]]): The original list of QA pairs.
        qtype (str): The specific question type to filter by.
        qtable (Dict[str, str]): A mapping of questions to their types.
    
    Returns:
        List[Dict[str, Any]]: A new list containing only the QA pairs of the specified type.
    
    Raises:
        ValueError: If the prefix format in the first QA pair is invalid.
        KeyError: If a question is not found in the qtable.
    """
    new_list = []
    prefix = ''
    suffix = '\n<image>'

    total_pairs = len(old_list) // 2
    for i in range(total_pairs):
        qa_index = i * 2
        answer_index = qa_index + 1

        # Process the first QA pair to determine the prefix
        if i == 0:
            text = old_list[qa_index].get('value', '')
            test_split = text.split('\n')
            if len(test_split) == 2:
                prefix = ''
            elif len(test_split) == 3:
                prefix = f"{test_split[0]}\n"
            else:
                raise ValueError(f"Invalid prefix format: {text}")

        # Extract and process the question
        question_full = old_list[qa_index].get('value', '')
        question_part = question_full.split('?')[0].split('\n')[-1]
        question_key = f"{question_part}?"

        if question_key not in qtable:
            raise KeyError(f"Question '{question_key}' not found in question table.")

        if qtable[question_key] == qtype:
            # Add the QA pair to the new list
            if not new_list:
                new_item = old_list[qa_index].copy()
                new_item['value'] = f"{prefix}{question_key}{suffix}"
                new_list.append(new_item)
                new_list.append(old_list[answer_index].copy())
            else:
                new_list.append(old_list[qa_index].copy())
                new_list.append(old_list[answer_index].copy())

    return new_list


def select_qa(
    input_file: Path,
    output_file: Path,
    question_table_file: Path,
    question_type: str
) -> None:
    """
    Extract QA pairs from a JSON file that match a specific question type and save the filtered data.
    
    Args:
        input_file (Path): Path to the input JSON file containing QA pairs.
        output_file (Path): Path to the output JSON file to save filtered QA pairs.
        question_table_file (Path): Path to the JSON file mapping each question to its type.
        question_type (str): The specific question type to filter by.
    
    Raises:
        FileNotFoundError: If any of the specified files do not exist.
        json.JSONDecodeError: If any of the JSON files contain invalid JSON.
        KeyError: If a question is not found in the question table.
        IOError: If the output file cannot be written.
    """
    data = load_json(input_file)
    qtable = load_json(question_table_file)
    new_data = []

    for entry in data:
        conversations = entry.get('conversations', [])
        filtered_conversations = select_conversation(conversations, question_type, qtable)
        if filtered_conversations:
            new_entry = {
                'image': entry.get('image'),
                'id': entry.get('id'),
                'conversations': filtered_conversations
            }
            new_data.append(new_entry)

    save_json(output_file, new_data)

# Usage
question_types = ['abnormality', 'level', 'location', 'presence', 'type', 'view']
question_table_file = Path('mimic_vqa_qtype.json')
root = Path("output_json")

for qtype in question_types:
    output_dir = root / qtype
    if not output_dir.exists():
        output_dir.mkdir(parents=True)
    
    # with expert model
    input_file = root / 'llava_med_instruct_mimicvqa_test_expertmodel_path.json'
    output_file = output_dir / 'llava_med_instruct_mimicvqa_test_expert.json'
    select_qa(input_file, output_file, question_table_file, qtype)

These JSON files can now be used for first-round inference. However, during training, expert information is only included in the first question of each QA conversation round. Therefore, to generate the inference correctly, the following changes need to be made:

* If a question q is not the first question in a QA conversation, a second-round inference is required.
* In this second-round inference, the first question in that QA conversation, along with its predicted answer from the first-round inference and the expert information, should be inserted before the question q.

In [None]:
def create_dict_id_predictions(input_file: Path) -> Dict[str, Any]:
    """
    Convert first-round inference answers to a dictionary with combined keys.

    Each key is a combination of 'question_id' and 'prompt', and the value is the corresponding 'text'.

    Args:
        input_file (Path): Path to the input JSONL file containing inference answers.

    Returns:
        Dict[str, Any]: A dictionary mapping combined keys to their prediction texts.

    Raises:
        KeyError: If expected keys are missing in any data entry.
    """
    data = load_jsonl(input_file)
    res = {}
    for entry in data:
        try:
            question_id = entry['question_id']
            prompt = entry['prompt']
            text = entry['text']
            key = f"{question_id}_{prompt}"
            res[key] = text
        except KeyError as e:
            logging.warning(f"Missing key {e} in entry: {entry}")
            continue
    return res


# Usage
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
question_types = ['abnormality', 'presence', 'view', 'location', 'level', 'type']
predictions_root = Path("../predicted_vqa_json/ckpt_llava_med_all_mimic_expert_1run")
predictions = {}
for qtype in question_types:
    input_file = predictions_root / qtype / 'llava_med_expert_all_mimic_expert_run1.jsonl'
    predictions.update(create_dict_id_predictions(input_file))

save_json(Path('predictions.json'), predictions)

In [None]:
def add_expert_prompt(
    image_id: str,
    old_list: List[Dict[str, Any]],
    qtype: str,
    qtable: Dict[str, str],
    predictions: Dict[str, Any]
) -> List[Dict[str, Any]]:
    """
    Insert first-round inference (predictions) and expert information into each QA pair.
    
    This function processes a list of QA pairs, updates the first QA pair with prediction
    and expert information based on the question type, and filters QA pairs that match
    the specified question type.
    
    Args:
        image_id (str): The unique identifier for the image.
        old_list (List[Dict[str, Any]]): The original list of QA pairs.
        qtype (str): The specific question type to filter by.
        qtable (Dict[str, str]): A mapping of questions to their types.
        predictions (Dict[str, Any]): A dictionary containing prediction texts keyed by
                                      a combination of image_id and question prompt.
    
    Returns:
        List[Dict[str, Any]]: A new list of QA pairs with inserted prediction and expert information.
    
    Raises:
        ValueError: If the first question prompt has an invalid format.
        KeyError: If a question is not found in the qtable or predictions.
    """
    new_list = []
    prefix = ''
    suffix = ''
    total_pairs = len(old_list) // 2
    
    for i in range(total_pairs):
        qa_index = i * 2
        answer_index = qa_index + 1
        
        if i == 0:
            qa_pair = old_list[qa_index]
            text = qa_pair.get('value', '')
            text_split = text.split('\n')
            
            try:
                # Construct the key to retrieve the prediction
                pred_answer = predictions[f"{image_id}_{text}"]
            except KeyError:
                logging.warning(f"Prediction not found for key: {image_id}_{text}")
                continue
            
            if len(text_split) == 2:
                expert = ''
                question = text_split[0]
            elif len(text_split) == 3:
                expert = f"{text_split[0]}\n"
                question = text_split[1]
            else:
                raise ValueError(f"Invalid first question prompt format: {text}")
            
            prefix = f"{expert}{question}\n<patch_token_holder>\n###{pred_answer}\n###Human: "
            
            # Check if the question type matches
            question_key = f"{question}"
            if qtable.get(question_key) == qtype:
                updated_qa = qa_pair.copy()
                updated_qa['value'] = f"{expert}{question}\n<patch_token_holder>"
                new_list.append(updated_qa)
                new_list.append(old_list[answer_index].copy())
        else:
            qa_pair = old_list[qa_index]
            question = qa_pair.get('value', '')
            
            if qtable.get(question) == qtype:
                if not new_list:
                    updated_qa = qa_pair.copy()
                    updated_qa['value'] = f"{prefix}{question}{suffix}"
                    new_list.append(updated_qa)
                    new_list.append(old_list[answer_index].copy())
                else:
                    new_list.append(qa_pair.copy())
                    new_list.append(old_list[answer_index].copy())
                    
    return new_list


def add_expert_prompt_json(
    input_file: Path,
    output_file: Path,
    question_table_file: Path,
    question_type: str,
    predictions_file: Path
) -> None:
    """
    Apply expert prompts to all QA pairs in a JSON file based on question type and predictions.
    
    This function processes each entry in the input JSON file, updates the conversations
    with expert prompts where applicable, and saves the filtered and updated data to the
    output JSON file.
    
    Args:
        input_file (Path): Path to the input JSON file containing QA pairs.
        output_file (Path): Path to the output JSON file to save updated QA pairs.
        question_table_file (Path): Path to the JSON file mapping questions to their types.
        question_type (str): The specific question type to filter by.
        predictions_file (Path): Path to the JSON file containing prediction texts.
    
    Raises:
        FileNotFoundError: If any of the specified files do not exist.
        json.JSONDecodeError: If any of the JSON files contain invalid JSON.
        KeyError: If a question is not found in the question table or predictions.
        IOError: If the output file cannot be written.
    """
    try:
        data = load_json(input_file)
        qtable = load_json(question_table_file)
        predictions = load_json(predictions_file)
    except Exception as e:
        logging.error(f"Failed to load JSON files: {e}")
        raise
    
    new_data = []
    for entry in data:
        image_id = entry.get('id')
        conversations = entry.get('conversations', [])
        
        if not image_id:
            logging.warning(f"Missing 'id' in entry: {entry}")
            continue
        
        try:
            updated_conversations = add_expert_prompt(
                image_id=image_id,
                old_list=conversations,
                qtype=question_type,
                qtable=qtable,
                predictions=predictions
            )
            if updated_conversations:
                new_entry = {
                    'image': entry.get('image'),
                    'id': image_id,
                    'conversations': updated_conversations
                }
                new_data.append(new_entry)
        except Exception as e:
            logging.error(f"Failed to process entry with id {image_id}: {e}")
            continue
    
    try:
        save_json(output_file, new_data)
        logging.info(f"Successfully saved updated data to {output_file}.")
    except Exception as e:
        logging.error(f"Failed to save output file {output_file}: {e}")
        raise

# Usage
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
question_types = ['abnormality', 'presence', 'view', 'location', 'level', 'type']
question_table_file = 'mimic_vqa_qtype.json'
predictions_file = 'predictions.json'
input_file = Path('output_json') / 'llava_med_instruct_mimicvqa_test_expertmodel_path.json'
for qtype in question_types:
    output_dir = Path('output_json') / qtype
    if not output_dir.exists():
        output_dir.mkdir(parents=True)
    
    # with expert model
    output_file = output_dir / 'llava_med_instruct_mimicvqa_test_expert_2run.json'
    add_expert_prompt_json(input_file, output_file, question_table_file, qtype, predictions_file)

## Add question and answer type to testing json

In [None]:
def add_qa_type(
    input_file: Path,
    output_file: Path,
    question_type: str
) -> None:
    """
    Insert question type and determine answer type for each QA pair in a JSON file.
    
    This function processes each entry in the input JSON file, analyzes the answers to
    categorize them as 'closed' or 'open', and appends the question type and answer type
    to each entry. The updated data is then saved to the output JSON file.
    
    Args:
        input_file (Path): Path to the input JSON file containing QA pairs.
        output_file (Path): Path to the output JSON file to save updated QA pairs.
        question_type (str): The specific question type to assign to each entry.
    
    Raises:
        FileNotFoundError: If the input file does not exist.
        json.JSONDecodeError: If the input file contains invalid JSON.
        IOError: If the output file cannot be written.
    """
    try:
        data = load_json(input_file)
    except Exception as e:
        logging.error(f"Failed to load input file {input_file}: {e}")
        raise

    new_data = []
    
    for entry in data:
        try:
            conversations = entry.get('conversations', [])
            if len(conversations) < 2:
                logging.warning(f"Entry with id {entry.get('id')} has insufficient conversations.")
                continue

            # Extract the question text
            question_full = conversations[0].get('value', '')
            question_split = question_full.split('?')[0].split('\n')
            question = question_split[-1] if question_split else ''

            # Extract and process the answer
            answer = conversations[1].get('value', '').lower().strip()
            
            if answer in {'yes', 'no'}:
                answer_type = 'closed'
            elif 'yes' in answer or 'no' in answer:
                answer_type = 'closed'
                logging.warning(f"Ambiguous answer for question '{question}': '{answer}'")
            else:
                answer_type = 'open'
            
            # Append the updated entry
            new_entry = {
                'image': entry.get('image'),
                'id': entry.get('id'),
                'conversations': conversations,
                'question_type': question_type,
                'answer_type': answer_type
            }
            new_data.append(new_entry)
        
        except Exception as e:
            logging.error(f"Error processing entry with id {entry.get('id')}: {e}")
            continue
    
    try:
        save_json(output_file, new_data)
        logging.info(f"Successfully saved updated data to {output_file}.")
    except Exception as e:
        logging.error(f"Failed to save output file {output_file}: {e}")
        raise


# Usage
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
    
question_types = ['abnormality', 'level', 'location', 'presence', 'type', 'view']
root = Path('output_json')

for qtype in question_types:
    output_dir = root / qtype
    if not output_dir.exists():
        output_dir.mkdir(parents=True)
    
    # with expert model
    input_file = output_dir / 'llava_med_instruct_mimicvqa_test_expert.json'
    output_file = output_dir / 'llava_med_instruct_mimicvqa_test_expert_type.json'
    add_qa_type(input_file, output_file, qtype)