In [None]:
from google.colab import drive
from google.colab import files
import shutil
import os
# https://stackoverflow.com/a/79469725
# 1. Mount Google Drive for output storage
drive.mount('/content/drive')

In [None]:
# !pip install -U -q trl bitsandbytes peft hf_xet tensorboard
!pip install  -U -q transformers==4.51.3 git+https://github.com/huggingface/trl.git datasets bitsandbytes peft qwen-vl-utils wandb accelerate
# Tested with transformers==4.47.0.dev0, trl==0.12.0.dev0, datasets==3.0.2, bitsandbytes==0.44.1, peft==0.13.2, qwen-vl-utils==0.0.8, wandb==0.18.5, accelerate==1.0.1
!pip install -q torch==2.4.1+cu121 torchvision==0.19.1+cu121 torchaudio==2.4.1+cu121 --extra-index-url https://download.pytorch.org/whl/cu121

In [None]:
# Clone just LFS pointers for Med-R1 models
!GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/yuxianglai117/Med-R1

In [None]:
# Pull only the Med-R1/VQA_CT model.
# This should take about 6 minutes with the basic Colab runtime.
!cd Med-R1/ && git lfs pull --include="VQA_CT"

In [None]:
# Pull the OmniMedVQA dataset.
!wget https://huggingface.co/datasets/foreverbeliever/OmniMedVQA/resolve/main/OmniMedVQA.zip
!unzip OmniMedVQA
!rm OmniMedVQA.zip

In [None]:
!mkdir gh
!cd gh && git clone https://github.com/Yuxiang-Lai117/Med-R1.git

In [None]:
import argparse
import glob
import json
import os
import re
import torch
from tqdm import tqdm

from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoTokenizer

MODEL_PATH = "Med-R1/VQA_CT"
BSZ = 1
OUTPUT_PATH = "output.json"
modalities = glob.glob("Med-R1/Splits/modality/test/*.json")
question_types = glob.glob("Med-R1/Splits/question_type/test/*.json")

all_outputs = []

def extract_option_answer(output_str):
    # Try to find the number within <answer> tags, if can not find, return None
    answer_pattern = r'<answer>\s*(\w+)\s*</answer>'
    match = re.search(answer_pattern, output_str)

    if match:
        return match.group(1)
    return None

def extract_think_text(output_str):
    # Try to find the number within <think> tags, if can not find, return None
    answer_pattern = r'(.*)\s*</think>'
    match = re.search(answer_pattern, output_str)

    if match:
        return match.group(1)
    return None

def remove_special_tokens(output_str, patterns):
    return re.sub(patterns, '', output_str)

final_output = []
correct_number = 0

model = Qwen2VLForConditionalGeneration.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.bfloat16,
    #attn_implementation="flash_attention_2",
    device_map="auto",
)

processor = AutoProcessor.from_pretrained(MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side="left")

BASE_PATH="OmniMedVQA/"
MAX_THINK_TOKENS = 2048
MIN_TOKENS_PER_ITER = 0
IGNORE_STR = "Wait"
SPECIAL_TOKENS=r"<think>|</think>|<answer>|</answer>"
MAX_NUM_IGNORES=0

QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> and final choice (A, B, C, D ...) in <answer> </answer> tags."

messages = []

from qwen_vl_utils import process_vision_info

data = []
for test_json in [*modalities, *question_types]:
    with open(test_json, "r", encoding="utf-8") as f:
        test_data = json.load(f)
        # Augment the data with a new ID key so we can keep track.
        for index, req in enumerate(test_data):
            custom_id = f"{os.path.basename(test_json)}_{index}"
            req['id'] = custom_id.replace(" ", "_").replace(".", "-").replace("(", "").replace(")", "")

    data.extend(test_data)

for i in data:
    message = [{
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": f"file://{BASE_PATH}{i['image']}"
            },
            {
                "type": "text",
                "text": QUESTION_TEMPLATE.format(Question=i['problem'])
            }
        ]
    }]
    messages.append(message)

num_tokens = [0] * len(messages)

for i in tqdm(range(0, len(messages), BSZ)):
    continue_generating = True
    num_ignores = 0
    while continue_generating:
        batch_messages = messages[i:i + BSZ]

        # Preparation for inference
        text = []
        max_new_tokens = 256 #MIN_TOKENS_PER_ITER
        for batch_idx, msg in enumerate(batch_messages):
            if num_tokens[i+batch_idx] < MAX_THINK_TOKENS:
                if msg[0].get("new_msg"):
                    print(f"MSG new: {msg[0]['new_msg']}")
                    new_msg = "<think>" + msg[0]["new_msg"]
                else:
                    new_msg = ""
                text.append(processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) + new_msg)
                print(text)
                if MAX_THINK_TOKENS - num_tokens[i+batch_idx] > max_new_tokens:
                    max_new_tokens = max(MAX_THINK_TOKENS - num_tokens[i+batch_idx], MIN_TOKENS_PER_ITER)

        print(f"Max Tokens to generate: {max_new_tokens}")
        #text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
        image_inputs, video_inputs = process_vision_info(batch_messages)
        inputs = processor(
            text=text,
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to("cuda")

        # Inference: Generation of the output
        generated_ids = model.generate(
            **inputs,
            use_cache=True,
            max_new_tokens=max_new_tokens,
            do_sample=False)

        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        batch_output_text = processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )

        # Apply the s1 methodology.
        for batch_idx, output in enumerate(batch_output_text):
            print(output)
            model_think = extract_think_text(output)
            # If the <think> section has not completed, we can just get the number of tokens of
            # the whole output and keep going.
            if not model_think:
                print("Using output!")
                model_think = output

            # Remove any <think>, </think>, <answer>, </answer>, ... tokens
            model_think = remove_special_tokens(model_think, SPECIAL_TOKENS)
            num_tokens[i + batch_idx] += len(tokenizer.encode(model_think, add_special_tokens=False))
            print(f"Tokens generated: {num_tokens[i+batch_idx]}")

        continue_generating = False
        # Check if all prompts in the batch have reached the generation limit.
        for num_token in num_tokens[i:i + BSZ]:
            if num_token < MAX_THINK_TOKENS:
                continue_generating = True

        # If we set an upper-limit of ignores, stop multi-generation.
        if num_ignores >= MAX_NUM_IGNORES:
            continue_generating = False

        for batch_idx, msg in enumerate(batch_messages[0]):
            if num_tokens[i+batch_idx] < MAX_THINK_TOKENS:
                # Append the new token to the end of the content.
                print(f"MSG: {msg}")
                if msg.get("new_msg"):
                    old_msg = msg["new_msg"]
                else:
                    old_msg = ""

                model_think = extract_think_text(batch_output_text[batch_idx])
                # Don't keep going if there is no thinking section.
                if not model_think:
                    print("NO THINK!")
                    print(model_think)
                    print("-----------")
                    continue
                new_msg = f"{old_msg}{model_think.split('</think>')[0]} {IGNORE_STR} "
                batch_messages[0][batch_idx]["new_msg"] = new_msg
                print(f"{batch_idx}: {new_msg}")
        num_ignores += 1

    all_outputs.extend(batch_output_text)
    print(f"Processed batch {i//BSZ + 1}/{(len(messages) + BSZ - 1)//BSZ}")

for input_example, model_output in zip(data,all_outputs):
    original_output = model_output
    ground_truth = input_example['solution']
    model_answer = extract_option_answer(original_output)

    # Create a result dictionary for this example
    result = {
        'id': input_example['id'],
        'question': input_example,
        'ground_truth': ground_truth,
        'model_output': original_output,
        'extracted_answer': model_answer
    }
    final_output.append(result)

    ground_truth_pattern = r'<answer>\s*(\w+)\s*</answer>'
    ground_truth_match = re.search(ground_truth_pattern, ground_truth)
    ground_truth = ground_truth_match.group(1)

    # Count correct answers
    print(f"model_answer: {model_answer}, ground_truth: {ground_truth}")
    if model_answer is not None and model_answer == ground_truth:
        correct_number += 1

# Calculate and print accuracy
accuracy = (correct_number / len(all_outputs)) * 100
print(f"\nAccuracy: {accuracy:.2f}%")

# Save results to a JSON file
output_path = OUTPUT_PATH
with open(output_path, "w") as f:
    json.dump({
        'accuracy': accuracy,
        'results': final_output
    }, f, indent=2)

print(f"Results saved to {output_path}")


In [None]:
from google.colab import drive
from google.colab import files
import shutil
import os

# 3. Get the filename
filename = output_path

# 4. Specify the destination path
destination_path = os.path.join('/content/drive/My Drive/Colab_Output', filename)

# 5. Move the uploaded file to the destination
shutil.move(filename, destination_path)

print(f'File "{filename}" uploaded to "{destination_path}" successfully.')
