In [1]:
%%capture

!pip install -q transformers datasets peft bitsandbytes wandb trl evaluate

In [2]:
!git clone https://github.com/microsoft/LLaVA-Med.git LLaVA_Med

Cloning into 'LLaVA_Med'...
remote: Enumerating objects: 429, done.[K
remote: Counting objects: 100% (41/41), done.[K
remote: Compressing objects: 100% (36/36), done.[K
remote: Total 429 (delta 1), reused 31 (delta 1), pack-reused 388 (from 1)[K
Receiving objects: 100% (429/429), 77.09 MiB | 12.23 MiB/s, done.
Resolving deltas: 100% (122/122), done.


In [3]:
import os
os.chdir("/content/LLaVA_Med")

os.getcwd()

'/content/LLaVA_Med'

In [4]:
import warnings
warnings.filterwarnings("ignore")

import torch
from torch.utils.data.dataset import Dataset
from torch.nn.utils.rnn import pad_sequence
from datasets import load_dataset
from PIL import Image
from io import BytesIO
import requests
import json
import uuid

from transformers import Trainer, TrainingArguments
from peft import LoraConfig, LoraModel, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer

from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from llava.conversation import Conversation
from llava.mm_utils import tokenizer_image_token, process_images
from llava.model.builder import load_pretrained_model
from llava.conversation import conv_templates

from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [5]:
def rename_medvqa_columns(dataset):
    # Rename columns to match the expected names in the preprocessing function
    dataset = dataset.rename_columns({
        "image_names": "image_names",
        "images": "image",
        "questions": "question",
        "answers": "answer"
    })
    return dataset

def process_and_save(dataset, output_folder, subset_name):
  subset_folder = os.path.join(output_folder, subset_name)
  image_subfolder = os.path.join(subset_folder, "images")

  if not os.path.exists(image_subfolder):
    os.makedirs(image_subfolder, exist_ok=True)

  if not os.path.exists(subset_folder):
    os.makedirs(subset_folder)

  json_data_list = []

  for item in dataset:
    if isinstance(item["image"], str):
      response = requests.get(item["image"])
      image = Image.open(BytesIO(response.content))
    else:
      image = item["image"]

    unique_id = str(uuid.uuid4())

    image_path = os.path.join(image_subfolder, f"{unique_id}.jpg")
    image.save(image_path)

    answers = item["answer"]
    formatted_answers = "".join(answers)

    json_data = {
        "id": unique_id,
        "image": f"{unique_id}.jpg",
        "conversations": [
            {
                "from": "human",
                "value": item["question"]
            },
            {
                "from": "gpt",
                "value": formatted_answers
            }
        ]
    }

    json_data_list.append(json_data)

  json_output_path = os.path.join(output_folder, subset_name, "dataset.json")
  with open(json_output_path, "w") as json_file:
    json.dump(json_data_list, json_file, indent=4)


def save_dataset(dataset_name, output_folder, subset_name):
  dataset = load_dataset(dataset_name)

  if dataset_name == "agupte/MedVQA":
        dataset = rename_medvqa_columns(dataset)

  process_and_save(dataset[subset_name], output_folder, subset_name)

In [6]:
output_folder = "dataset"
save_dataset("flaviagiammarino/vqa-rad", output_folder, 'train')
save_dataset("flaviagiammarino/vqa-rad", output_folder, 'test')
# save_dataset("mdwiratathya/SLAKE-vqa-english", output_folder, 'train')
# save_dataset("mdwiratathya/SLAKE-vqa-english", output_folder, 'test')
# save_dataset("agupte/MedVQA", output_folder, 'train')
# save_dataset("agupte/MedVQA", output_folder, 'test')

README.md:   0%|          | 0.00/3.91k [00:00<?, ?B/s]

(…)-00000-of-00001-eb8844602202be60.parquet:   0%|          | 0.00/24.2M [00:00<?, ?B/s]

(…)-00000-of-00001-e5bc3d208bb4deeb.parquet:   0%|          | 0.00/10.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1793 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/451 [00:00<?, ? examples/s]

In [20]:
import json

# Input and output file paths
input_file = "dataset/test/dataset.json"  # Replace with your input JSON file path
output_file = "dataset/test/question.jsonl"

# Load the input dataset
with open(input_file, "r") as f:
    input_data = json.load(f)

# Initialize output list
output_list = []
question_id_counter = 0

# Process the dataset
for entry in input_data:
    # Extract the image name and conversation
    image = entry['image']
    conversations = entry['conversations']

    # Extract the question and answer
    question = None
    answer = None
    for conv in conversations:
        if conv["from"] == "human":
            question = conv["value"]
        elif conv["from"] == "gpt":
            answer = conv["value"]

    # Combine into the required format
    output_list.append({
        "question_id": question_id_counter,
        "image": image,
        "text": f"{question}\n<image>",
        "gpt4_answer": answer
    })

    # Increment question ID counter
    question_id_counter += 1

# Save combined output to a JSONL file
with open(output_file, "w") as f:
    for entry in output_list:
        f.write(json.dumps(entry) + "\n")

print(f"Conversion complete! File saved as '{output_file}'.")


Conversion complete! File saved as 'dataset/test/question.jsonl'.


In [8]:
os.getcwd()

'/content/LLaVA_Med'

In [9]:
%%capture
!pip install -q -e .

In [22]:
!python llava/eval/model_vqa.py \
    --conv-mode mistral_instruct \
    --model-path Veda0718/llava-med-v1.5-mistral-7b-finetuned \
    --question-file dataset/test/question.jsonl \
    --image-folder dataset/test/images \
    --answers-file dataset/test/answers.jsonl \
    --temperature 0.0

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
2024-12-09 06:01:32.669904: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-09 06:01:32.691302: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-09 06:01:32.697828: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
  _torch_pytree._register_pytree_node(
Loading checkpoint shards: 100% 6/6 [00:04<00:00,  1.40it/s]
  return torch.load(checkpoint_file, map_location=map_location)
100% 451/451 [01:41<00:00,  4.46it/s]


In [39]:
!pip install jsonlines

Collecting jsonlines
  Downloading jsonlines-4.0.0-py3-none-any.whl.metadata (1.6 kB)
Downloading jsonlines-4.0.0-py3-none-any.whl (8.7 kB)
Installing collected packages: jsonlines
Successfully installed jsonlines-4.0.0


In [40]:
import jsonlines
from typing import List, Dict

def load_data(questions_file: str, answers_file: str) -> List[Dict]:
    """
    Load and match questions and answers from JSONL files based on question_id
    """
    # Load question file
    questions = []
    with jsonlines.open(questions_file) as reader:
        for obj in reader:
            questions.append(obj)

    # Load answer file
    answers = []
    with jsonlines.open(answers_file) as reader:
        for obj in reader:
            answers.append(obj)

    # Create a dictionary of answers indexed by question_id
    answer_dict = {ans['question_id']: ans for ans in answers}

    # Match questions with their corresponding answers
    matched_pairs = []
    for question in questions:
        qid = question['question_id']
        if qid in answer_dict:
            matched_pairs.append({
                'gpt4_answer': question['gpt4_answer'],
                'model_answer': answer_dict[qid]['text']
            })

    return matched_pairs

def calculate_accuracy(matched_pairs: List[Dict]) -> float:
    """
    Calculate exact match accuracy
    """
    # Normalize answers by converting to lowercase and stripping whitespace
    accurate_matches = sum(
        1 for pair in matched_pairs
        if pair['gpt4_answer'].lower().strip() == pair['model_answer'].lower().strip()
    )
    return accurate_matches / len(matched_pairs) if matched_pairs else 0

def main(questions_file: str, answers_file: str):
    """
    Main function to evaluate answer accuracy
    """
    # Load and match data
    matched_pairs = load_data(questions_file, answers_file)

    # Calculate accuracy
    accuracy = calculate_accuracy(matched_pairs)
    print(f"Exact Match Accuracy: {accuracy:.4f}")
    print(f"Total QA pairs: {len(matched_pairs)}")


if __name__ == '__main__':
    main('dataset/test/question.jsonl', 'dataset/test/answers.jsonl')

Exact Match Accuracy: 0.4945
Total QA pairs: 451
