In [2]:
# CLIP + tinyllama + multitoken
import os
import torch
import numpy as np
import torch.nn as nn
from tqdm import tqdm
import re
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, CLIPVisionModel, CLIPImageProcessor

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# ======================================================
# 1. 超参数：视觉 token 个数（必须和训练时一致）
# ======================================================
NUM_VISION_TOKENS = 4

# ======================================================
# 2. 加载 CLIP Vision
# ======================================================
clip_name = "openai/clip-vit-base-patch32"
clip_vision = CLIPVisionModel.from_pretrained(clip_name).to(device)
clip_processor = CLIPImageProcessor.from_pretrained(clip_name)

for p in clip_vision.parameters():
    p.requires_grad = False

vision_width = clip_vision.config.hidden_size
print("CLIP hidden_size =", vision_width)

# ======================================================
# 3. 加载 TinyLLaMA（冻结）
# ======================================================
llama_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(llama_name)

llama = AutoModelForCausalLM.from_pretrained(
    llama_name,
    torch_dtype=torch.float16,
).to(device)

for p in llama.parameters():
    p.requires_grad = False

lm_hidden_size = llama.config.hidden_size
print("LLaMA hidden =", lm_hidden_size)

# ======================================================
# 4. 定义 Projector（必须和训练时一致）
# ======================================================
class Projector(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.linear1 = nn.Linear(in_dim, out_dim)
        self.act = nn.GELU()
        self.linear2 = nn.Linear(out_dim, out_dim)

    def forward(self, x):
        if x.dim() == 2:
            x = x.unsqueeze(1)
        x = self.linear1(x)
        x = self.act(x)
        x = self.linear2(x)
        return x

projector = Projector(vision_width, lm_hidden_size).to(device)

# ======================================================
# 5. 加载你训练好的 projector 权重
# ======================================================
weight_path = "/workspace/CLIP+tinyllama+multitoken/projector_epoch1.pt"
projector.load_state_dict(torch.load(weight_path, map_location=device))
projector.eval()
print(f"Loaded trained projector: {weight_path}")

# ======================================================
# 6. 定义多视觉 token 提取函数（和训练完全一致）
# ======================================================
def extract_vision_embeds(img):
    clip_inputs = clip_processor(images=img, return_tensors="pt").to(device)

    with torch.no_grad():
        vision_outputs = clip_vision(**clip_inputs)
        patch_tokens = vision_outputs.last_hidden_state[:, 1:, :]  # remove CLS

    B, S, C = patch_tokens.shape
    num_tokens = NUM_VISION_TOKENS

    # 均匀分 N 组
    if S < num_tokens:
        repeat_factor = (num_tokens + S - 1)//S
        patch_tokens = patch_tokens.repeat(1, repeat_factor, 1)
        S = patch_tokens.size(1)

    chunk_size = S // num_tokens
    usable_len = chunk_size * num_tokens
    patch_tokens = patch_tokens[:, :usable_len, :]
    patch_tokens = patch_tokens.view(B, num_tokens, chunk_size, C)
    patch_tokens = patch_tokens.mean(dim=2)  # (B, NUM_TOKENS, C)

    embeds_f32 = projector(patch_tokens)
    embeds = embeds_f32.to(torch.float16)
    return embeds

# ======================================================
# 7. PathVQA 测试集
# ======================================================
test_ds = load_dataset("flaviagiammarino/path-vqa", split="test")
N = 200   # 想推理多少条就写多少
test_ds = test_ds.select(range(N))

print("Test samples =", len(test_ds))

# ======================================================
# 8. 指标
# ======================================================
def normalize(s):
    s = s.lower().strip()
    s = re.sub(r"[^a-z0-9]+", " ", s)
    return " ".join(s.split())

def f1(pred, gt):
    pt = normalize(pred).split()
    gt = normalize(gt).split()
    if len(pt)==0 or len(gt)==0:
        return float(pt==gt)
    common = set(pt) & set(gt)
    if not common: return 0.0
    p = len(common)/len(pt)
    r = len(common)/len(gt)
    return 2*p*r/(p+r)

def em(pred, gt):
    return float(normalize(pred)==normalize(gt))

# ======================================================
# 9. 进行推理
# ======================================================
all_f1 = []
all_em = []
predictions = []

for sample in tqdm(test_ds, desc="Infer Test Set"):
    img = sample["image"]
    question = sample["question"]
    answer_gt = sample["answer"]

    # 多视觉 token
    with torch.no_grad():
        vision_embeds = extract_vision_embeds(img)

        prompt = f"Question: {question}\nAnswer:"
        tok = tokenizer(prompt, return_tensors="pt")
        input_ids = tok.input_ids.to(device)

        text_emb = llama.model.embed_tokens(input_ids)

        inputs_embeds = torch.cat([vision_embeds, text_emb], dim=1)

        output = llama.generate(
            inputs_embeds=inputs_embeds,
            max_new_tokens=32,
            do_sample=False,
        )

        pred = tokenizer.decode(output[0], skip_special_tokens=True)
        pred = pred.split("Answer:")[-1].strip().split("\n")[0].strip()

    all_f1.append(f1(pred, answer_gt))
    all_em.append(em(pred, answer_gt))
    predictions.append((question, answer_gt, pred))

# ======================================================
# 10. 输出测试集结果
# ======================================================
print("\n=== Test Set Results ===")
print("Samples:", len(test_ds))
print("F1 =", np.mean(all_f1))
print("EM =", np.mean(all_em))

# 如果你想看到一些预测样例：
print("\n=== Example Predictions ===")
for i in range(20):
    q, gt, pd = predictions[i]
    print(f"Q: {q}\nGT: {gt}\nPD: {pd}\n")


Device: cuda
CLIP hidden_size = 768
LLaMA hidden = 2048
Loaded trained projector: /workspace/CLIP+tinyllama+multitoken/projector_epoch1.pt
Test samples = 200


Infer Test Set: 100%|██████████| 200/200 [01:25<00:00,  2.35it/s]


=== Test Set Results ===
Samples: 200
F1 = 0.30269507270934815
EM = 0.28

=== Example Predictions ===
Q: what are positively charged, thus allowing the compaction of the negatively charged dna?
GT: the histone subunits
PD: the cytoplasm of the cell is the dna, which is the genetic material of the cell, and the cytoplasm of the cell

Q: how are the histone subunits charged?
GT: positively charged
PD: by the phosphatidine phosphatidine phosphatidine phosphatidine phosphatidine phosphatidine

Q: what is showing increased eosinophilia of cytoplasm, and swelling of occasional cells?
GT: early (reversible) ischemic injury
PD: the tumor is a large, irregularly shaped mass of cells, which is the most common type of tumor in the body, and is usually found

Q: does mycobacterium avium infection in a duodenal biopsy from a patient with aids show massive intracellular macrophage infection with acid-fast organisms filamentous and pink in this acid-fast stain preparation?
GT: yes
PD: no

Q: what sh




In [6]:
# CLIP_tinyllama_LORA

# 如果你想看到一些预测样例：
print("\n=== Example Predictions ===")
for i in range(190):
    q, gt, pd = predictions[i]
    print(f"Q: {q}\nGT: {gt}\nPD: {pd}\n")


=== Example Predictions ===
Q: what are positively charged, thus allowing the compaction of the negatively charged dna?
GT: the histone subunits
PD: the cytoplasm of the cell is the dna, which is the genetic material of the cell, and the cytoplasm of the cell

Q: how are the histone subunits charged?
GT: positively charged
PD: by the phosphatidine phosphatidine phosphatidine phosphatidine phosphatidine phosphatidine

Q: what is showing increased eosinophilia of cytoplasm, and swelling of occasional cells?
GT: early (reversible) ischemic injury
PD: the tumor is a large, irregularly shaped mass of cells, which is the most common type of tumor in the body, and is usually found

Q: does mycobacterium avium infection in a duodenal biopsy from a patient with aids show massive intracellular macrophage infection with acid-fast organisms filamentous and pink in this acid-fast stain preparation?
GT: yes
PD: no

Q: what shows branching papillae having flbrovascular stalk covered by a single laye

In [9]:
# ======================================================
# 11. 评估 Yes/No Accuracy
# ======================================================
test_ds = load_dataset("flaviagiammarino/path-vqa", split="test")
# N = 1800   # 想推理多少条就写多少
# test_ds = test_ds.select(range(N))

def is_yes_no(ans):
    ans = ans.lower().strip()
    return ans in ["yes", "no"]

yesno_samples = [s for s in test_ds if is_yes_no(s["answer"])]
print("Yes/No samples =", len(yesno_samples))

yn_correct = 0
yn_total = len(yesno_samples)

yn_predictions = []

for sample in tqdm(yesno_samples, desc="Eval Yes/No"):
    img = sample["image"]
    question = sample["question"]
    answer_gt = sample["answer"].lower().strip()  # "yes" / "no"

    # --------------------------------------------------
    # 推理
    # --------------------------------------------------
    with torch.no_grad():
        vision_embeds = extract_vision_embeds(img)

        prompt = f"Question: {question}\nAnswer:"
        tok = tokenizer(prompt, return_tensors="pt").to(device)

        text_emb = llama.model.embed_tokens(tok.input_ids)
        inputs_embeds = torch.cat([vision_embeds, text_emb], dim=1)

        output = llama.generate(
            inputs_embeds=inputs_embeds,
            max_new_tokens=8,
            do_sample=False,
        )

        pred = tokenizer.decode(output[0], skip_special_tokens=True)
        pred = pred.split("Answer:")[-1].strip().split("\n")[0].strip().lower()

    # 标准化（避免有些模型输出 like "no," or "yes."）
    pred_clean = "yes" if pred.startswith("y") else ("no" if pred.startswith("n") else pred)

    # 判断
    correct = (pred_clean == answer_gt)
    yn_correct += correct

    yn_predictions.append((question, answer_gt, pred_clean, correct))

# ======================================================
# 输出结果
# ======================================================
print("\n=== Yes/No Accuracy ===")
print(f"Total Yes/No samples = {yn_total}")
print(f"Accuracy = {yn_correct / yn_total:.4f}")

print("\n=== Example Yes/No Predictions ===")
for i in range(15):
    q, gt, pd, ok = yn_predictions[i]
    print(f"Q: {q}")
    print(f"GT: {gt} | Pred: {pd} | Correct: {ok}\n")


Yes/No samples = 3362


Eval Yes/No: 100%|██████████| 3362/3362 [05:15<00:00, 10.64it/s]


=== Yes/No Accuracy ===
Total Yes/No samples = 3362
Accuracy = 0.7606

=== Example Yes/No Predictions ===
Q: does mycobacterium avium infection in a duodenal biopsy from a patient with aids show massive intracellular macrophage infection with acid-fast organisms filamentous and pink in this acid-fast stain preparation?
GT: yes | Pred: no | Correct: False

Q: does microscopy show branching papillae having flbrovascular stalk covered by a single layer of cuboidal cells having ground-glass nuclei?
GT: yes | Pred: no | Correct: False

Q: does pbf show branching papillae having flbrovascular stalk covered by a single layer of cuboidal cells having ground-glass nuclei?
GT: no | Pred: no | Correct: True

Q: is chematic mechanisms involved in pathogenesis of two main types of diabetes mellitus?
GT: yes | Pred: no | Correct: False

Q: are numbers in the illustrations involved in pathogenesis of two main types of diabetes mellitus?
GT: no | Pred: no | Correct: True

Q: are there inadequate t ce


