In [1]:
import cv2
import random
import json
import time

import torch
import numpy as np
from tqdm import tqdm

import clip
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from segment_anything.utils.transforms import ResizeLongestSide 

from PIL import Image  
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image

from sam_caf import hyper_params_tuning, get_crops, retrieve_relevant_crop, retrieve_relevant_crop_biomed, get_sam_prompts, sam_predicton, retrieve_relevant_crop_biomed_topk

In [2]:
#config
class DictToObject:
    def __init__(self, dict_obj):
        for key, value in dict_obj.items():
            setattr(self, key, value)

config_dict = {
    "model_name" : "SAM",
    "model_type" : "vit_h",
    "source":    "False", 
    "refine" : "False",
    "pre_trained": "True", 
    "sam_ckpt":  "/data/aofei/LLM/SAM/sam_vit_h_4b8939.pth", 
    "clip_prompts": "./clip_prompts/abd_seg.json"
}

config = DictToObject(config_dict)

prompt_mode, mode = "crops", "sam_clip"

def preprocess_image(image_path):
    image = cv2.imread(image_path)
    image = cv2.resize(image, (256, 256))
    return image

In [3]:
import os
# os.environ["TRANSFORMERS_CACHE"]="/data/aofei/huggingface_cache/transformers"
os.environ["HF_HOME"]="/data/aofei/huggingface_cache/transformers"
from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8

biomed_clip_model, biomed_preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224', device="cuda")
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

clip_model, preprocess = clip.load("ViT-L/14", device="cuda")
sam_checkpoint = config.sam_ckpt

sam = sam_model_registry[config.model_type](checkpoint=sam_checkpoint)
sam.to("cuda")
resize_transform = ResizeLongestSide(sam.image_encoder.img_size)

dice_scores = []
mask_generator, area = hyper_params_tuning(sam)

  _torch_pytree._register_pytree_node(


  _torch_pytree._register_pytree_node(




  _torch_pytree._register_pytree_node(




In [4]:
def sam_generation(image_path):
    image = preprocess_image(image_path=image_path)
    with torch.no_grad():
        # if mode == "sam_clip":
        masks = mask_generator.generate(image)
        masks = [mask for mask in masks if mask["area"] < area] # area filtering based on area value from hyper-params tuning
        img_crops = get_crops(image, masks, prompt_mode)
        
    return masks, img_crops

def filter_sam_results(masks, img_crops):
    new_masks, new_img_crops = [], []
    for i in range(len(masks)):
        mask = masks[i]
        if mask['bbox'][0] == 0 or mask['bbox'][1] == 0:
            continue
        if mask['bbox'][2] <= 12 or mask['bbox'][3] <= 12:
            continue
        y_max, x_max = mask['bbox'][1] + mask['bbox'][3], mask['bbox'][0] + mask['bbox'][2]
        if y_max > 253 or x_max > 253:
            continue
        new_masks.append(mask)
        new_img_crops.append(img_crops[i])
    return new_masks, new_img_crops

def get_topk_similar(k, crop_scores):
    sorted_scores = sorted([(i, m) for (i,m) in enumerate(crop_scores)], key=lambda x: x[1], reverse=True)
    return sorted_scores[:k]

def get_compelete_contour(masks):
    width_list = []
    # need to consider the chest xray
    for i in masks:
        width, height = i['bbox'][2], i['bbox'][3]
        width_list.append(width)
    sorted_width = sorted([(i, m) for (i,m) in  enumerate(width_list)], key=lambda x: x[1], reverse=True)
    return sorted_width[0][1]

def judge_inner_boxes(bboxes):
    for bbox in bboxes:
        bbox

In [5]:
masks_image_dict = dict()
masks_all_image_dict = dict()
crops_image_dict = dict()
def generate_segments(query, image_path):
    if masks_image_dict.__contains__(image_path):
        masks = masks_image_dict[image_path]
        img_crops = crops_image_dict[image_path]
    else:
        masks, img_crops = sam_generation(image_path=image_path)
        masks, img_crops = filter_sam_results(masks, img_crops)
        masks_image_dict[image_path] = masks
        crops_image_dict[image_path] = img_crops
    img_crops_filtered = img_crops
    prompts = {"query": [query]}
    max_indices, scores = retrieve_relevant_crop_biomed_topk(img_crops, prompts, biomed_clip_model, biomed_preprocess, config, tokenizer=tokenizer, topk=8)
    # topk_indices = get_topk_similar(3, scores["query"])
    # define a set of rules, firstly return top3
    # if there is no explicit organs to be used as query, then just use the whole segmentation
    # if the smaller boxes are in the bigger box, then use all of them but assign higher weights on smaller inner boxes
    bboxes = []
    segs = []
    # print(max_indices)
    for i in max_indices["query"]:
        bboxes.append(masks[i]["bbox"])
        segs.append(masks[i]["segmentation"])
    return bboxes, segs, max_indices["query"]


def generate_all_segments(image_path):
    if masks_all_image_dict.__contains__(image_path):
        masks = masks_all_image_dict[image_path]
    else:
        masks, img_crops = sam_generation(image_path=image_path)
        masks, img_crops = filter_sam_results(masks, img_crops)
        masks_all_image_dict[image_path] = masks
    return masks

In [6]:
# Preprocessing for Slake
import json
# # with open(r"/data/aofei/hallucination/Slake/data/training_contours.json", "r") as f:
# #     train_data = json.load(f)
# # len(train_data)

with open(r"/data/aofei/hallucination/PathVQA/pvqa/train_copy.json", "r") as f:
    all_train_data = json.load(f)
len(all_train_data)

# with open(r"/data/aofei/hallucination/Slake/data/test.json", "r") as f:
#     test_data = json.load(f)
# len(test_data)

# with open(r"/data/aofei/hallucination/Slake/test.json", "r") as f:
#     all_test_data = json.load(f)
# len(all_test_data)

# # Preprocessing for VQA-RAD
# import json
# with open(r"/data/aofei/hallucination/VQA_RAD/MED_RAD_test.json", "r") as f:
#     rad_data = json.load(f)
# len(rad_data)

# train_rad_data = []
# for i in rad_data:
#     if not i['phrase_type'].startswith("test"):
#         train_rad_data.append(i)
# len(train_rad_data)

19755

In [7]:
all_train_data[0]

{'qid': 100422000,
 'image_name': 'train_0422.jpg',
 'answer': 'in the canals of hering',
 'answer_type': 'other',
 'question_type': 'where',
 'question': 'Where are liver stem cells (oval cells) located?'}

In [8]:
all_train_data_en = all_train_data

In [9]:
import copy
train_data_seg = copy.deepcopy(all_train_data_en)
len(train_data_seg)

19755

In [10]:
len(set([i['image_name'] for i in train_data_seg]))

2599

In [11]:
# len(set([i['img_name'] for i in train_data]))

In [None]:
# for i in range(len(train_rad_data)):
#failure case: 432-3 + 100
import os
from tqdm import tqdm
for i in tqdm(range(len(train_data_seg))):
# for i in tqdm(range(20)):
    data = train_data_seg[i]
    image_path = os.path.join("/data/aofei/hallucination/PathVQA/pvqa/images/train", data["image_name"])
    query = f"Medical image of pathology." + data["question"]
    bbox, segs, max_indices = [], [], []
    # try:
    bbox, segs, max_indices = generate_segments(query, image_path)
    # except:
    #     continue
    # print(bbox)
    data['bbox'] = bbox
    data['mask'] = segs
    data["bbox_indices"] = max_indices
    


  0%|                                                                                                                                                                                                                                                   | 0/20 [00:00<?, ?it/s]


  5%|███████████▊                                                                                                                                                                                                                               | 1/20 [00:03<00:57,  3.02s/it]


 20%|███████████████████████████████████████████████                                                                                                                                                                                            | 4/20 [00:03<00:09,  1.67it/s]


 30%|██████████████████████████████████████████████████████████████████████▌                                                                                                                                                                    | 6/20 [00:05<00:12,  1.09it/s]


 45%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                                                                                 | 9/20 [00:08<00:10,  1.05it/s]


 50%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                                     | 10/20 [00:08<00:08,  1.25it/s]


 55%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                                                         | 11/20 [00:09<00:05,  1.52it/s]


 60%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                                                             | 12/20 [00:09<00:04,  1.87it/s]


 65%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                  | 13/20 [00:09<00:03,  2.31it/s]


 70%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                      | 14/20 [00:09<00:02,  2.84it/s]


 75%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                          | 15/20 [00:12<00:04,  1.01it/s]


 85%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                   | 17/20 [00:12<00:01,  1.72it/s]


 95%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎           | 19/20 [00:12<00:00,  2.62it/s]


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:12<00:00,  1.61it/s]




In [13]:
train_data_seg[0]

{'qid': 100422000,
 'image_name': 'train_0422.jpg',
 'answer': 'in the canals of hering',
 'answer_type': 'other',
 'question_type': 'where',
 'question': 'Where are liver stem cells (oval cells) located?',
 'bbox': [[225, 102, 21, 24],
  [188, 88, 14, 28],
  [71, 138, 20, 23],
  [178, 89, 45, 36],
  [104, 80, 21, 40],
  [96, 49, 46, 38]],
 'mask': [array([[False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         ...,
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False]]),
  array([[False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         ...,
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., Fals

In [14]:
s = 0
for i in train_data_seg:
    if not i.__contains__("mask"):
        s+= 1
s

19735

In [15]:
train_data_seg[3]

{'qid': 100422003,
 'image_name': 'train_0422.jpg',
 'answer': 'yes',
 'answer_type': 'yes/no',
 'question_type': 'are',
 'question': 'Are bile duct cells and canals of Hering stained here with an immunohistochemical stain for cytokeratin 7?',
 'bbox': [[225, 102, 21, 24],
  [188, 88, 14, 28],
  [71, 138, 20, 23],
  [178, 89, 45, 36],
  [104, 80, 21, 40],
  [96, 49, 46, 38]],
 'mask': [array([[False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         ...,
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False]]),
  array([[False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         ...,
         [False, False, False, ..., False, False, False],
   

In [16]:
#training2
new_train_data = []
segments_dict = dict()
for i in train_data_seg:
    template = dict()
    
    # template['answer_type'] = i['answer_type']
    template['image'] = i['image_name']
    template['id'] = i['qid']
    template['conversations'] = []
    template['bboxes'] = []
    template['masks'] = []
    segments_dict[str(i['qid'])] = []
    if i.__contains__("bbox"):
        template['bboxes'] = i["bbox"]
    if i.__contains__("mask"):
        segments_dict[str(i['qid'])] = i["mask"]

    # if i.__contains__("mask"):
    #     json_ready_segments = [
    #         arr.astype(int).tolist() for arr in i["mask"]
    #     ]
    #     template['masks'] = json_ready_segments
    # template['text'] = i['question']

    new_qa = {"from": "human", "value": "<image>\n" + i['question']}
    new_qa2 = {"from": "gpt", "value": str(i['answer'])}
    template['conversations'] += [new_qa, new_qa2]
    new_train_data.append(template)

new_train_data[-6]

{'image': 'train_2794.jpg',
 'id': 102794002,
 'conversations': [{'from': 'human',
   'value': '<image>\nIs bone nearly completely filled with tumor primary present?'},
  {'from': 'gpt', 'value': 'no'}],
 'bboxes': [],
 'masks': []}

In [17]:
new_train_data_top4 = []
for i in new_train_data:
    j = i.copy()
    j["bboxes"] = j["bboxes"][:4]
    new_train_data_top4.append(j)

In [18]:
new_train_data_top4[3]

{'image': 'train_0422.jpg',
 'id': 100422003,
 'conversations': [{'from': 'human',
   'value': '<image>\nAre bile duct cells and canals of Hering stained here with an immunohistochemical stain for cytokeratin 7?'},
  {'from': 'gpt', 'value': 'yes'}],
 'bboxes': [[225, 102, 21, 24],
  [188, 88, 14, 28],
  [71, 138, 20, 23],
  [178, 89, 45, 36]],
 'masks': []}

In [19]:
# segments_dict['0']
ed = 0
for i in segments_dict:
    if len(segments_dict[i]) == 0:
        ed += 1
ed

19735

In [20]:
# save the masks to npz file

np.savez_compressed("/data/aofei/hallucination/PathVQA/pvqa/training_segments_top8.npz", **segments_dict)

In [21]:
len(new_train_data)

19755

In [22]:
# loaded_segments = np.load("/data/aofei/hallucination/Slake/data/training_segments.npz", allow_pickle=True)  # Allow pickle to handle lists

# # Access a list of segment arrays by its ID, e.g., "id_1"
# segments_list_id_1 = loaded_segments["0"]

# # Each item in segments_list_id_1 is a 256x256 numpy array
# for segment in segments_list_id_1:
#     print(segment.shape)

In [23]:
with open('/data/aofei/hallucination/VQA_RAD/data/training_masks_top4.json', 'w') as json_file:
    json.dump(new_train_data_top4, json_file, indent=4)

In [24]:
# with open('/data/aofei/hallucination/VQA_RAD/data/training_masks_top8.json', 'w') as json_file:
#     json.dump(new_train_data, json_file, indent=4)