In [None]:
from PIL import Image
import requests
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration


In [None]:
%load_ext autoreload
%autoreload 2
from models.ad_llava import ADLlavaModel
from models.condition import conditional_prompt

In [None]:

MAX_LENGTH = 1024

model_path = " ".format(MAX_LENGTH)
lora_path = " ".format(MAX_LENGTH)

model = ADLlavaModel.from_pretrained(model_path)
processor = AutoProcessor.from_pretrained("")
from peft import PeftConfig, PeftModel

config = PeftConfig.from_pretrained(lora_path)
lora_model = PeftModel.from_pretrained(model, lora_path, attn_implementation="flash_attention_2")
model = lora_model.to('cuda').to(torch.bfloat16)


In [None]:
import json
with open('', 'r') as f:
    test_datas = [json.loads(line) for line in f]

In [None]:
raw_images = []
for test_data in tqdm(test_datas):
    image = []

    for image_path in image_paths:
        image.append(Image.open(image_path).convert("RGB"))

In [None]:
import torch
import time
from tqdm import tqdm
from models.condition import conditional_prompt

label_dict = {0: 'CN', 1: 'MCI', 2: 'AD'}

start_time = time.time()

for test_data in tqdm(test_datas):
    true_mri_label = test_data['mri_label']
    true_pet_label = test_data['pet_label']
    image_paths = test_data['image']
    for conversation in test_data['conversations']:
        if conversation['from'] == 'USER':
            question = conversation['value']
        if conversation['from'] == 'ASSISTANT':
            answer = conversation['value']
    # print(image_paths)
    image = []

    for image_path in image_paths:
        image.append(Image.open(image_path).convert("RGB"))

    prompts = f"{question}\nASSISTANT:"
    # print(prompts)

    processed_example = processor(text = prompts, padding="max_length", truncation=True, \
                                        max_length=MAX_LENGTH, images=image,return_tensors="pt")

    inputs = {}
    inputs['pixel_values'] = torch.FloatTensor(processed_example['pixel_values'].unsqueeze(0)).to('cuda')
    inputs['input_ids'] = torch.LongTensor(processed_example['input_ids']).to('cuda')
    inputs['labels'] = None
    inputs['mri_label'] = None
    inputs['pet_label'] = None
    inputs['return_cls_only'] = True


    loss, mri_logits, pet_logits, image_feature, ad_image_feature = model.forward(**inputs)
    mri_pred_level = mri_logits[0].argmax().item()
    pet_pred_level = pet_logits[0].argmax().item()

    additional_prompt = conditional_prompt(mri_pred_level,pet_pred_level)

    system_prompts = f"You are a expert in the field of Alzheimer's Disease diagnosis. Your task is to diagnose based on the image information, demographic information, and neuropsychological scales data. Please give the final clinical answer."
    text = system_prompts + "\n<image>\nUser:" + additional_prompt + prompts

    processed_example = processor(text, padding="max_length", truncation=True, \
                                    max_length=MAX_LENGTH, images=image,return_tensors="pt")


    inputs['input_ids'] = torch.LongTensor(processed_example['input_ids']).to('cuda')

    output = model.generate(**inputs, max_new_tokens=1024)
    generated_text = processor.batch_decode(output, skip_special_tokens=True)
    answer = generated_text.split("ASSISTANT:")[-1]

end_time = time.time()
elapsed_time = end_time - start_time
print(f"程序运行时间：{elapsed_time}秒")
    