In [1]:
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import os
import ast
import torch
import random
import pandas as pd
from transformers import TrainingArguments, AutoTokenizer, HfArgumentParser
from utils.my_trainer import CustomTrainer
from utils.utils import my_compute_metrics,seed_everything
from typing import Optional
from dataclasses import dataclass, field
from model.qformer import Blip2QformerPathInstruct
from peft import LoraConfig
from datasets import load_dataset, load_from_disk, concatenate_datasets
from utils.data_collator import MyDataCollatorForQFormerPatchInstruct

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = 'cuda'
llm_name = 'meta-llama/Meta-Llama-3-8B-Instruct'

# set up tokenizer
llm_tokenizer = AutoTokenizer.from_pretrained(llm_name)
llm_tokenizer.pad_token = llm_tokenizer.eos_token
llm_tokenizer.padding_side = "right"
llm_tokenizer.truncation_side = 'left'

new_tokens = ['<Question>',  '<Answer>', '<Image>']  
num_added_toks = llm_tokenizer.add_tokens(new_tokens)
new_tokens_ids = llm_tokenizer.convert_tokens_to_ids(new_tokens)
print("new_tokens_ids: ", new_tokens_ids)

questions = pd.read_csv('./utils/question_list.csv', header=None)  
questions = questions[0].tolist()

new_tokens_ids:  [128256, 128257, 128258]


In [3]:
def formatting_func_ytb(examples):
    text = examples['conversations'].replace("<image>\n", "").replace("<image>", "")
    question = ast.literal_eval(text[1:-1].split('\n')[0])['value'].replace("\n", "")
    answer = ast.literal_eval(text[1:-1].split('\n')[1])['value'].replace("\n", "")
    text = f"<Question> {question}{llm_tokenizer.eos_token}" # + f"<Answer> {answer}{llm_tokenizer.eos_token}\n"
    examples["text_input"] = question
    examples["text"] = text
    return examples

In [6]:
# set up dataset for wsi
data_cache_dir = "/bask/projects/p/phwq4930-gbm/Zeyu/PathVLM/.cache"
# select_data_num = 100
# dataset_local_path = "/home/shared/su123/YoutubePathQA/pretrain_data_all"
# split_text = "train[:{}]".format(select_data_num) # [:10000]

# dataset = load_from_disk(dataset_local_path)
# dataset = dataset.map(formatting_func_ytb, num_proc=4, remove_columns=['id','conversations'])

In [5]:
dataset = dataset.select(range(100))

In [7]:
model = Blip2QformerPathInstruct(
                                    clip_name = 'conch',
                                    num_query_token = 16,
                                    cross_attention_freq = 2,
                                    pretrain_name = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext',
                                    llm_requires_grad = False, 
                                    load_in_8bit = False, 
                                    load_in_4bit = False, 
                                    llm_name = llm_name, 
                                    trust_remote_code = False, 
                                    token = None, 
                                    llm_tokenizer = llm_tokenizer,
                                    image_token_id = new_tokens_ids[-1],
                                    data_cache_dir = data_cache_dir,
                                )

vision_encoder loading ...
llm loading ...


Loading checkpoint shards: 100%|██████████| 4/4 [00:47<00:00, 11.78s/it]


In [8]:
ckpt_path = "/bask/homes/a/asiw9691/PathVLM/source/PathLLM/output/Conch_Bert_Llama3_PatchInstruct/ckpt500.bin"
model.load_state_dict(torch.load(ckpt_path, map_location=device), strict=False)
model.to(device)
# model = model.to(torch.bfloat16)

Blip2QformerPathInstruct(
  (vision_encoder): CoCa(
    (text): TextTransformer(
      (token_embedding): Embedding(32007, 768)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ls_2): Identity()
          )
        )
      )
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (visual): VisualModel(
      (tr

In [13]:
from torch.utils.data import DataLoader

batch_size = 4  # 你可以根据需要调整批处理大小

data_collator = MyDataCollatorForQFormerPatchInstruct(image_processor=model.image_processor, tokenizer=llm_tokenizer, test=True)

In [9]:
tokenized_dataset = dataset.map(
            llm_tokenizer,
            batched=False,
            # remove_columns=['text'],
            num_proc=10,
            batch_size=batch_size,
            input_columns=['text'],
       )

dataloader_params = {
            "batch_size": batch_size,
            "collate_fn": data_collator,
        }

eval_dataloader = DataLoader(tokenized_dataset, **dataloader_params)

Map (num_proc=10):   0%|          | 0/100 [00:00<?, ? examples/s]

In [10]:
from transformers import GenerationConfig

generation_config = GenerationConfig(
                max_length=512,
                temperature=1.0,
                top_k=50,
                top_p=0.95,
                num_return_sequences=1,
                repetition_penalty=1.1,
                do_sample=True,
                pad_token_id=llm_tokenizer.eos_token_id,
                bos_token_id=llm_tokenizer.bos_token_id,
            )

In [11]:
# for batch in tqdm(eval_dataloader):
eval_data_iter = iter(eval_dataloader)
batch = next(eval_data_iter)
input_ids = batch['input_ids'].to(device)
attention_masks = batch['attention_mask'].to(device)
image = batch['image'].to(device)
text = batch['text']

In [11]:
model = model.to(torch.bfloat16)

In [None]:
with torch.no_grad():
    llm_input_ids = batch["input_ids"].to(device)
    image = batch["image"].to(device).to(torch.bfloat16)
    llm_input_attention_mask = batch["attention_mask"].to(device)
    p_num = batch["patch_num"]

    image_embeds = model.vision_encoder.encode_image(image, normalize=False, proj_contrast=False)
    image_embeds, image_atts = model._split_and_pad(image_embeds, p_num)

    image_embeds = image_embeds.to(image.device).to(torch.bfloat16) # batch x max_length x 512
    image_atts = image_atts.to(image.device) # batch x max_length

    query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1).to(torch.bfloat16) # learnable query
    
    text_Qformer = model.bert_tokenizer(
                        text,
                        padding='longest',
                        truncation=True,
                        max_length=model.max_txt_len,
                        return_tensors="pt",
                    ).to(image.device)
    
    query_atts = torch.ones(query_tokens.size()[:-1]).to(image.device)
    Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask],dim=1)

    query_output = model.Qformer.bert(
                        text_Qformer.input_ids,
                        attention_mask=Qformer_atts,
                        query_embeds=query_tokens,
                        encoder_hidden_states=image_embeds,
                        encoder_attention_mask=image_atts,
                        return_dict=True,
                    )

    llm_query_input = model.resampler_layer(query_output.last_hidden_state[:,:query_tokens.size(1),:])

    fusion_embs = model.get_fusion_embedding(llm_input_ids, llm_query_input)

    fusion_attention_mask = model.pad_attention_fusion(fusion_embs.size(1), llm_input_attention_mask)
    
    res = model.llm.generate(inputs_embeds=fusion_embs, attention_mask=fusion_attention_mask, generation_config=generation_config)

generate_list = []
for item in res:
    generation = model.llm_tokenizer.decode(item, skip_special_tokens=True)
    generate_list.append(generation)

In [10]:
# 对于每个批次的数据
import numpy as np
from utils.eval_utils import calculate_f1score 
from tqdm import tqdm

close_ques_acc = 0
close_ques_num = 0
open_ques_f1 = []

open_candidate = []
open_reference = []

close_candidate = []
close_reference = []

max_seq = 0

for batch in tqdm(eval_dataloader):
# eval_data_iter = iter(eval_dataloader)
# batch = next(eval_data_iter)
    input_ids = batch['input_ids'].to(device)
    attention_masks = batch['attention_mask'].to(device)
    labels = batch['labels'].to(device)
    fea1 = batch['fea1'].to(device)
    fea2 = batch['fea2'].to(device)
    fea3 = batch['fea3'].to(device)
    mask1 = batch['mask1'].to(device)
    mask2 = batch['mask2'].to(device)
    mask3 = batch['mask3'].to(device)
    answers = batch['answers']
    if fea1.shape[1] > max_seq:
        max_seq = fea1.shape[1]

print(max_seq)
#     break
#     # 执行模型推断
#     res = model.generate(input_ids=input_ids,
#                          attention_mask=attention_masks,
#                          fea1 = fea1,
#                          fea2 = fea2,
#                          fea3 = fea3,
#                          mask1 = mask1,
#                          mask2 = mask2,
#                          mask3 = mask3,
#                         )
    
#     break
    
#     for i in range(len(answers)):
#         if answers[i] in ['yes','no']:
#             close_candidate.append(res[i])
#             close_reference.append(answers[i])
#         else:
#             open_candidate.append(res[i])
#             open_reference.append(answers[i])

100%|██████████| 624/624 [18:31<00:00,  1.78s/it]

12786





In [11]:
res = model.generate(input_ids=input_ids,
                     attention_mask=attention_masks,
                     fea1 = fea1,
                     fea2 = fea2,
                     fea3 = fea3,
                     mask1 = mask1,
                     mask2 = mask2,
                     mask3 = mask3,
                    )

In [12]:
answers

['The primary diagnosis is transitional cell carcinoma originating from the bladder. Histological examination of the H&E stained WSI from the tumor center reveals a diffusely distributed cell pattern with mosaic and streaming necrosis, alongside stellate-form lymphocytic infiltration. Fibrosis and myxoid changes are present. Cellular features include spindle cell proliferation, scattered pleomorphism with lipoblasts, and inflammatory cells. There is variable cytologic atypia with hyperchromatic and multinucleated giant cells, and notable mitotic activity, contributing to a nuclear grade assessment of G2. The histological diagnosis is urothelial carcinoma G2. Recommended related IHC tests include CK7, CK20, and GATA3 to further characterize the tumor.',
 'The original site of the tumor is the lateral wall of the bladder, diagnosed as high-grade urothelial carcinoma, specifically transitional cell carcinoma. The cancer is staged at AJCC/UICC pT3aN0MX. Microscopic examination of the H&E-s

In [13]:
res

[' The primary diagnosis is transitional cell carcinoma originating from the bladder. Histopathological examination of the tumor center reveals a Grade 3 poorly differentiated transitional cell carcinoma with invasion into perivesical adipose tissue, extensively infiltrating the right and left perivesicles. The carcinoma also invades the trigone and periurethral prostatic ducts. Focal dystrophic calcification is noted within the tumor. No angiolymphatic or perineural invasion is identified. The tumor is staged as T3a N0 Mx. It is recommended to perform immunohistochemistry tests for markers such as CK7, CK20, p63, GATA3, and uroplakin to further characterize the tumor. Additionally, related IHC tests that may be beneficial include PSA (Prostate-Specific Antigen) and PSAP (Prostatic Acid Phosphatase) due to the involvement of prostatic structures. Recommend related IHC tests',
 ' The pathology report describes a case of invasive urothelial carcinoma originating from the anterior wall of

In [11]:
from rouge import Rouge 
rouge = Rouge()
open_ques_rouge = []

for i in range(len(close_reference)):
    close_ques_num += 1
    if close_reference[i] in close_candidate[i]:
        close_ques_acc += 1
        
for i in range(len(open_reference)):
    f1_score = calculate_f1score(open_candidate[i], open_reference[i])
    open_ques_f1.append(f1_score)
    
open_ques_rouge = rouge.get_scores(open_candidate, open_reference, avg=True) 
open_ques_f1 = np.mean(open_ques_f1)
close_ques_acc = close_ques_acc/close_ques_num

print(open_ques_rouge)
print(open_ques_f1)
print(close_ques_acc)

{'rouge-1': {'r': 0.27001985016993246, 'p': 0.017507060508135417, 'f': 0.03129752911526289}, 'rouge-2': {'r': 0.05284908234126984, 'p': 0.0023333627239040663, 'f': 0.004323536432612713}, 'rouge-l': {'r': 0.2615928062936698, 'p': 0.016450324600026073, 'f': 0.02950558446856401}}
0.04000689721539458
0.9538043478260869


In [20]:
scores = compute_bleu_scores(open_candidate, open_reference, avg=True)

In [21]:
scores

0.23677331586807712

### For one image QA test

In [37]:
from PIL import Image

question = "What is your final diagnosis?"
text = f"<Question> {question}{tokenizer.eos_token}"
image = Image.open("./test_images/test1.jpeg")

# i = 1
# text = dataset[i]['text']
# image = dataset[i]['image']
# answer = dataset[i]['answer']

input_dic = tokenizer(text, return_tensors="pt")
map_image_data = model.image_processor(image)
input_dic["image"] = map_image_data

res = model.generate(input_ids = input_dic["input_ids"].to(device),
                    attention_mask = input_dic["attention_mask"].to(device),
                    labels = input_dic["input_ids"].to(device),
                    image = input_dic["image"].unsqueeze(0).to(device),
                    temperature = 0.7,
                    top_p = 0.9,
                    num_return_sequences = 3,
                )

In [28]:
from glob import glob
from PIL import ImageFile, Image
test_image_dir = "./test_images/*"

image_paths = glob(test_image_dir)

patch_list = []
num_list = []
text_list = []

for image_path in image_paths:
    image = Image.open(image_path)
    image = data_collator._resize_image(image)
    patches = data_collator._crop_image(image) # [448x448]
    patches = [model.image_processor(patch) for patch in patches]
    num_list.append(len(patches))
    patch_list += patches
    text_list.append(f"<Question> What is the final pathological diagnosis for this image?{llm_tokenizer.eos_token}")

patch_list = torch.stack(patch_list) # [448x448]

input_dic = llm_tokenizer(text_list, return_tensors="pt")
input_dic["image"] = patch_list
input_dic["patch_num"] = num_list


In [29]:
with torch.no_grad():
    llm_input_ids = input_dic["input_ids"].to(device)
    image = input_dic["image"].to(device).to(torch.bfloat16)
    llm_input_attention_mask = input_dic["attention_mask"].to(device)
    p_num = input_dic["patch_num"]

    image_embeds = model.vision_encoder.encode_image(image, normalize=False, proj_contrast=False)
    image_embeds, image_atts = model._split_and_pad(image_embeds, p_num)

    image_embeds = image_embeds.to(image.device).to(torch.bfloat16) # batch x max_length x 512
    image_atts = image_atts.to(image.device) # batch x max_length

    query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1).to(torch.bfloat16) # learnable query
    
    text_Qformer = model.bert_tokenizer(
                        text_list,
                        padding='longest',
                        truncation=True,
                        max_length=model.max_txt_len,
                        return_tensors="pt",
                    ).to(image.device)
    
    query_atts = torch.ones(query_tokens.size()[:-1]).to(image.device)
    Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask],dim=1)

    query_output = model.Qformer.bert(
                        text_Qformer.input_ids,
                        attention_mask=Qformer_atts,
                        query_embeds=query_tokens,
                        encoder_hidden_states=image_embeds,
                        encoder_attention_mask=image_atts,
                        return_dict=True,
                    )

    llm_query_input = model.resampler_layer(query_output.last_hidden_state[:,:query_tokens.size(1),:])

    fusion_embs = model.get_fusion_embedding(llm_input_ids, llm_query_input)

    fusion_attention_mask = model.pad_attention_fusion(fusion_embs.size(1), llm_input_attention_mask)
    
    res = model.llm.generate(inputs_embeds=fusion_embs, attention_mask=fusion_attention_mask, generation_config=generation_config)

generate_list = []
for item in res:
    generation = model.llm_tokenizer.decode(item, skip_special_tokens=True)
    generate_list.append(generation)

In [36]:
image_paths

['./test_images/test4.jpeg',
 './test_images/test1.jpeg',
 './test_images/test2.jpeg',
 './test_images/test3.jpeg']

In [31]:
generate_list

["<Answer> The final diagnosis is metaplastic meningothelial hyperplasia. The nuclear changes observed can be seen in a variety of conditions and must always be interpreted in the context of the overall histological pattern. The presence of ghost nuclei may suggest some form of cellular degeneration or pathology. However, this interpretation must also be carefully considered in the context of the entire histopathological examination. The absence of an atypia and mitotic activity further supports the likelihood that these are normal cells under stress rather than indicative of a malignant transformation. Further immunohistochemical studies can provide additional context and aid in confirming this preliminary interpretation. This image is a good example to highlight the importance of considering the entire histopathological context when interpreting these findings. It underscores the value of additional stains and other forms of ancillary testing in confirming these preliminary observati

In [24]:
answer

'yes'

In [13]:
!nvidia-smi

Mon May 13 11:51:22 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.154.05             Driver Version: 535.154.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          On  | 00000000:CA:00.0 Off |                    0 |
| N/A   28C    P0              55W / 400W |  40338MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    