In [None]:
# Enable saving to GDrive for long running jobs.
from google.colab import drive
from google.colab import files
import shutil
import os
# https://stackoverflow.com/a/79469725
# 1. Mount Google Drive for output storage
drive.mount('/content/drive')

In [None]:
# Install necessary packages for inference
!pip install  -U -q transformers==4.51.3 git+https://github.com/huggingface/trl.git datasets bitsandbytes peft qwen-vl-utils wandb accelerate
# Tested with transformers==4.47.0.dev0, trl==0.12.0.dev0, datasets==3.0.2, bitsandbytes==0.44.1, peft==0.13.2, qwen-vl-utils==0.0.8, wandb==0.18.5, accelerate==1.0.1
!pip install -q torch==2.4.1+cu121 torchvision==0.19.1+cu121 torchaudio==2.4.1+cu121 --extra-index-url https://download.pytorch.org/whl/cu121

In [None]:
from datasets import load_dataset

dataset_name = "avuong/vm1"

# Load Dataset
dataset = load_dataset(dataset_name)

In [None]:
!wget https://huggingface.co/datasets/foreverbeliever/OmniMedVQA/resolve/main/OmniMedVQA.zip
!unzip OmniMedVQA
!rm OmniMedVQA.zip

In [None]:
# import gc
# import time


# def clear_memory():
#     # Delete variables if they exist in the current global scope
#     if "inputs" in globals():
#         del globals()["inputs"]
#     if "model" in globals():
#         del globals()["model"]
#     if "processor" in globals():
#         del globals()["processor"]
#     if "trainer" in globals():
#         del globals()["trainer"]
#     if "peft_model" in globals():
#         del globals()["peft_model"]
#     if "bnb_config" in globals():
#         del globals()["bnb_config"]
#     time.sleep(2)

#     # Garbage collection and clearing CUDA memory
#     gc.collect()
#     time.sleep(2)
#     torch.cuda.empty_cache()
#     torch.cuda.synchronize()
#     time.sleep(2)
#     gc.collect()
#     time.sleep(2)

#     print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
#     print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")


# clear_memory()

In [None]:
import argparse
import glob
import json
import os
import re
import torch
from tqdm import tqdm

from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoTokenizer

MODEL_PATH = "avuong/vm1-sft"
BSZ = 1
MAX_THINK_TOKENS = 2048
OUTPUT_PATH = f"output_inference_{MAX_THINK_TOKENS}.json"
input_json = "test_split.json"

all_outputs = []

def extract_option_answer(output_str):
    if len(output_str) == 1:
        return output_str
    # Try to find the number within <answer> tags, if can not find, return None
    answer_pattern = r'<answer>\s*(\w+)\s*</answer>'
    match = re.search(answer_pattern, output_str)

    if match:
        return match.group(1)
    return None

def extract_think_text(output_str):
    # Try to find the number within <think> tags, if can not find, return None
    answer_pattern = r'(.*)\s*</think>'
    match = re.search(answer_pattern, output_str)

    if match:
        return match.group(1)
    return None

def remove_special_tokens(output_str):
    # Take the thinking process and remove the end-of-think token/answers.
    cleaned_str = output_str.split("</think>")[0]
    return cleaned_str

final_output = []
correct_number = 0

model = Qwen2VLForConditionalGeneration.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.bfloat16,
    #attn_implementation="flash_attention_2",
    device_map="auto",
)

processor = AutoProcessor.from_pretrained(MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side="right")
processor.tokenizer.padding_side = "right"

BASE_PATH="OmniMedVQA/"
NUM_TO_EVALUATE = 40
MIN_TOKENS_PER_ITER = 0
IGNORE_STR = "Wait"
SPECIAL_TOKENS=r"</think>|<answer>|</answer>"
MAX_NUM_IGNORES=0
repick_questions = False
QUESTION_TEMPLATE = "{Question} Output the thinking process in <think> </think>, think for only {ThinkTokens} tokens, and output a single, final single-letter choice (A, B, C, D ...) in <answer> </answer> tags."

messages = []

from qwen_vl_utils import process_vision_info

def generate_random_test_set(test_data):
    import random
    data = []
    evaluate_count = 0
    available_indices = set(list(range(0, len(test_data))))
    chosen_elements = []
    while evaluate_count < NUM_TO_EVALUATE:
        chosen_index = random.choice(list(available_indices))
        chosen_elements.append(chosen_index)
        available_indices.remove(chosen_index)
        evaluate_count += 1
    print(chosen_elements)
    return chosen_elements

data = []

for test_json in [input_json]:
    t_data = []
    with open(test_json, "r", encoding="utf-8") as f:
        test_data = json.load(f)

        # Generate a random test set of questions.
        if os.path.exists("test_set.json"):
            print("Loading a random test set.")
            with open("test_set.json", "r") as f:
                random_test_set = json.load(f)
            if len(random_test_set) != NUM_TO_EVALUATE:
                print("Evaluation number has changed, resampling indices to questions.")
                repick_questions = True
        if not os.path.exists("test_set.json") or repick_questions is True:
            print("Generating a random test set.")
            random_test_set = generate_random_test_set(test_data)
            with open("test_set.json", "w") as f:
                json.dump(random_test_set, f, indent=2)

        # Augment the data with a new ID key so we can keep track.
        for index in random_test_set:
            req = test_data[index]
            custom_id = f"{os.path.basename(test_json)}_{index}"
            req['id'] = custom_id.replace(" ", "_").replace(".", "-").replace("(", "").replace(")", "")
            t_data.append(req)
    data.extend(t_data)

for i in data:
    message = [{
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": f"file://{BASE_PATH}{i['image']}"
            },
            {
                "type": "text",
                "text": QUESTION_TEMPLATE.format(Question=i['problem'],
                                                 ThinkTokens=MAX_THINK_TOKENS)
            }
        ]
    }]
    messages.append(message)

for i in tqdm(range(0, len(messages))):
    message = messages[i]
    continue_generating = True
    token_count = 0
    max_new_tokens = MAX_THINK_TOKENS
    while continue_generating:
        # Prompt the model to provide a final answer.
        if token_count >= MAX_THINK_TOKENS and message[-1]["role"] == "assistant":
            message[-1]["content"][0]["text"] = remove_special_tokens(message[-1]["content"][0]["text"])
            message[-1]["content"][0]["text"] += "</think><answer>"
            print("Prompting for final answer!")
            message.append({
               "role": "user",
               "content": "Do not reason further. Provide a final, single-letter (A, B, C, D ...) answer in <answer></answer> tags. FINAL ANSWER:"
            })
            max_new_tokens = 25
            continue_generating = False
        # Append the Wait token if we haven't reached the thinking limit.
        if token_count < MAX_THINK_TOKENS and message[-1]["role"] == "assistant":
            message[-1]["content"][0]["text"] = remove_special_tokens(message[-1]["content"][0]["text"])
            message[-1]["content"][0]["text"] += IGNORE_STR
            max_new_tokens = MAX_THINK_TOKENS - token_count

        # Preparation for inference
        image_inputs, video_inputs = process_vision_info(message)
        text = [processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)]
        inputs = processor(
            text=text,
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to("cuda")

        # Inference: Generation of the output
        generated_ids = model.generate(
            **inputs,
            use_cache=True,
            max_new_tokens=max_new_tokens,
            do_sample=False)

        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        batch_output_text = processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        print(json.dumps(batch_output_text, indent=2))
        # Get the token count of the output_text.
        token_count += len(tokenizer.encode(batch_output_text[0], add_special_tokens=False))
        print(f"Tokens generated: {token_count}")
        if message[-1]["role"] != "assistant":
            message.append({
                "role": "assistant",
                "content": [{
                    "type": "text",
                    # We only have one batch_output, so we can just index by that.
                    "text": batch_output_text[0]
                }]
            })
    all_outputs.extend(batch_output_text)
    print(f"Answer: {batch_output_text}")
    print(f"Processed batch {i+1}/{len(messages)}")

for input_example, model_output in zip(data,all_outputs):
    original_output = model_output
    ground_truth = input_example['solution']
    model_answer = extract_option_answer(original_output)

    # Create a result dictionary for this example
    result = {
        'id': input_example['id'],
        'question': input_example,
        'ground_truth': ground_truth,
        'model_output': original_output,
        'extracted_answer': model_answer
    }
    final_output.append(result)

    ground_truth_pattern = r'<answer>\s*(\w+)\s*</answer>'
    ground_truth_match = re.search(ground_truth_pattern, ground_truth)
    ground_truth = ground_truth_match.group(1)

    # Count correct answers
    print(f"model_answer: {model_answer}, ground_truth: {ground_truth}")
    if model_answer is not None and model_answer == ground_truth:
        correct_number += 1

# Calculate and print accuracy
accuracy = (correct_number / len(all_outputs)) * 100
print(f"\nAccuracy: {accuracy:.2f}%")

# Save results to a JSON file
output_path = OUTPUT_PATH
with open(output_path, "w") as f:
    json.dump({
        'accuracy': accuracy,
        'results': final_output
    }, f, indent=2)

print(f"Results saved to {output_path}")

In [None]:
from google.colab import drive
from google.colab import files
import shutil
import os

# 3. Get the filename
filename = output_path

# 4. Specify the destination path
destination_path = os.path.join('/content/drive/My Drive/Colab_Output', filename)

# 5. Move the uploaded file to the destination
shutil.move(filename, destination_path)

print(f'File "{filename}" uploaded to "{destination_path}" successfully.')