In [None]:
# Step 1: Import libraries
import json
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


# Step 2: Load JSON dataset
with open("cad_questions_combined.json", "r") as f:
    dataset = json.load(f)

print(f"Loaded {len(dataset)} questions.")

# Step 3: Load LLaMA model and tokenizer
model_name = "meta-llama/Llama-3-8b-instruct"  # 你本地部署的路径如果不同请改这里

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
model.eval()

# Step 4: Define prompt template
def build_prompt(entry):
    question = entry["question"]
    choices = entry["choices"]
    formatted_choices = "\n".join([f"{k}. {v}" for k, v in choices.items()])
    
    prompt = f"""You are a helpful medical assistant. Please answer the following multiple-choice question:

Question:
{question}

Choices:
{formatted_choices}

Answer:"""
    return prompt

# Step 5: Run inference
results = []
for item in tqdm(dataset[:50]):  # 改为[:N]控制测试样本数量
    prompt = build_prompt(item)
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=10,
            temperature=0.7,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    
    output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    answer = output_text.split("Answer:")[-1].strip().split("\n")[0][:1]  # 提取首个选项字母

    results.append({
        "id": item["id"],
        "question": item["question"],
        "true_answer": item["answer"],
        "predicted_answer": answer,
        "correct": answer.upper() == item["answer"].upper()
    })

# Step 6: Evaluate accuracy
accuracy = sum(r["correct"] for r in results) / len(results)
print(f"Accuracy on {len(results)} questions: {accuracy:.2%}")

# Optional: Save predictions
with open("model_predictions.json", "w") as f:
    json.dump(results, f, indent=2)

print("✅ Inference complete. Results saved to model_predictions.json")

: 