In [1]:
import cv2
import random
import json
import time

import torch
import numpy as np
from tqdm import tqdm

import clip
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from segment_anything.utils.transforms import ResizeLongestSide 

from PIL import Image  
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image

from sam_caf import hyper_params_tuning, get_crops, retrieve_relevant_crop, retrieve_relevant_crop_biomed, get_sam_prompts, sam_predicton, retrieve_relevant_crop_biomed_topk

In [2]:
#config
class DictToObject:
    def __init__(self, dict_obj):
        for key, value in dict_obj.items():
            setattr(self, key, value)

config_dict = {
    "model_name" : "SAM",
    "model_type" : "vit_h",
    "source":    "False", 
    "refine" : "False",
    "pre_trained": "True", 
    "sam_ckpt":  "/data/aofei/LLM/SAM/sam_vit_h_4b8939.pth", 
    "clip_prompts": "./clip_prompts/abd_seg.json"
}

config = DictToObject(config_dict)

prompt_mode, mode = "crops", "sam_clip"

def preprocess_image(image_path):
    image = cv2.imread(image_path)
    image = cv2.resize(image, (256, 256))
    return image

In [3]:
import os
# os.environ["TRANSFORMERS_CACHE"]="/data/aofei/huggingface_cache/transformers"
os.environ["HF_HOME"]="/data/aofei/huggingface_cache/transformers"
from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8

biomed_clip_model, biomed_preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224', device="cuda")
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

clip_model, preprocess = clip.load("ViT-L/14", device="cuda")
sam_checkpoint = config.sam_ckpt

sam = sam_model_registry[config.model_type](checkpoint=sam_checkpoint)
sam.to("cuda")
resize_transform = ResizeLongestSide(sam.image_encoder.img_size)

dice_scores = []
mask_generator, area = hyper_params_tuning(sam)

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [4]:
def sam_generation(image_path):
    image = preprocess_image(image_path=image_path)
    with torch.no_grad():
        # if mode == "sam_clip":
        masks = mask_generator.generate(image)
        masks = [mask for mask in masks if mask["area"] < area] # area filtering based on area value from hyper-params tuning
        img_crops = get_crops(image, masks, prompt_mode)
        
    return masks, img_crops

def filter_sam_results(masks, img_crops):
    new_masks, new_img_crops = [], []
    for i in range(len(masks)):
        mask = masks[i]
        if mask['bbox'][0] == 0 or mask['bbox'][1] == 0:
            continue
        if mask['bbox'][2] <= 12 or mask['bbox'][3] <= 12:
            continue
        y_max, x_max = mask['bbox'][1] + mask['bbox'][3], mask['bbox'][0] + mask['bbox'][2]
        if y_max > 253 or x_max > 253:
            continue
        new_masks.append(mask)
        new_img_crops.append(img_crops[i])
    return new_masks, new_img_crops

def get_topk_similar(k, crop_scores):
    sorted_scores = sorted([(i, m) for (i,m) in enumerate(crop_scores)], key=lambda x: x[1], reverse=True)
    return sorted_scores[:k]

def get_compelete_contour(masks):
    width_list = []
    # need to consider the chest xray
    for i in masks:
        width, height = i['bbox'][2], i['bbox'][3]
        width_list.append(width)
    sorted_width = sorted([(i, m) for (i,m) in  enumerate(width_list)], key=lambda x: x[1], reverse=True)
    return sorted_width[0][1]

def judge_inner_boxes(bboxes):
    for bbox in bboxes:
        bbox

In [5]:
masks_image_dict = dict()
masks_all_image_dict = dict()
crops_image_dict = dict()
def generate_segments(query, image_path):
    if masks_image_dict.__contains__(image_path):
        masks = masks_image_dict[image_path]
        img_crops = crops_image_dict[image_path]
    else:
        masks, img_crops = sam_generation(image_path=image_path)
        masks, img_crops = filter_sam_results(masks, img_crops)
        masks_image_dict[image_path] = masks
        crops_image_dict[image_path] = img_crops
    img_crops_filtered = img_crops
    prompts = {"query": [query]}
    max_indices, scores = retrieve_relevant_crop_biomed_topk(img_crops, prompts, biomed_clip_model, biomed_preprocess, config, tokenizer=tokenizer, topk=4)
    # topk_indices = get_topk_similar(3, scores["query"])
    # define a set of rules, firstly return top3
    # if there is no explicit organs to be used as query, then just use the whole segmentation
    # if the smaller boxes are in the bigger box, then use all of them but assign higher weights on smaller inner boxes
    bboxes = []
    segs = []
    # print(max_indices)
    if max_indices is not None:
        for i in max_indices["query"]:
            bboxes.append(masks[i]["bbox"])
            segs.append(masks[i]["segmentation"])
        return bboxes, segs, max_indices["query"]
    else:
        return bboxes, segs, []


def generate_all_segments(image_path):
    if masks_all_image_dict.__contains__(image_path):
        masks = masks_all_image_dict[image_path]
    else:
        masks, img_crops = sam_generation(image_path=image_path)
        masks, img_crops = filter_sam_results(masks, img_crops)
        masks_all_image_dict[image_path] = masks
    return masks

In [6]:
import json
with open(r"/data/aofei/hallucination/CARES/OmniMedVQA/training.json", "r") as f:
    data = json.load(f)
len(data)
all_train_data = data

In [7]:
all_train_data[0]

{'image': 'Images/Adam Challenge/AMD/A0017.jpg',
 'id': 'Adam Challenge_0000_train0',
 'conversations': [{'from': 'human',
   'value': '<image>\nWhat imaging technique is employed to acquire this fundus image? options: X-ray imaging, Fundus photography, Ultrasound imaging, Magnetic resonance imaging (MRI)'},
  {'from': 'gpt', 'value': 'Fundus photography'}]}

In [8]:
all_train_data_en = all_train_data

In [9]:
import copy
train_data_seg = copy.deepcopy(all_train_data_en)
len(train_data_seg)

6155

In [10]:
len(set([i['image'] for i in train_data_seg]))

5731

In [None]:
# len(set([i['img_name'] for i in train_data]))

450

In [16]:
# for i in range(len(train_rad_data)):
#failure case: 432-3 + 100
# from tqdm import tqdm
for i in tqdm(range(len(train_data_seg))):
# for i in tqdm(range(20)):
    data = train_data_seg[i]
    image_path = os.path.join('/data/aofei/hallucination/OmniMedVQA/VQA/raw/OmniMedVQA', data["image"])
    question = data["conversations"][0]["value"].replace("<image>\n", "")
    # question = question.split("The candidate Options are:")[0]
    # print(question)
    query = question

    bbox, segs, max_indices = [], [], []
    try:
        bbox, segs, max_indices = generate_segments(query, image_path)
    except:
        continue
    # print(bbox)
    data['bbox'] = bbox
    data['mask'] = segs
    data["bbox_indices"] = max_indices
    

  3%|▎         | 177/6155 [06:40<3:22:00,  2.03s/it]

Skipping zero-sized bounding box.


  5%|▍         | 291/6155 [11:21<3:02:56,  1.87s/it]

Skipping zero-sized bounding box.


  5%|▌         | 312/6155 [12:10<4:15:39,  2.63s/it]

Skipping zero-sized bounding box.


  5%|▌         | 313/6155 [12:13<4:19:01,  2.66s/it]

Skipping zero-sized bounding box.


  6%|▌         | 347/6155 [13:15<2:41:33,  1.67s/it]

Skipping zero-sized bounding box.


  7%|▋         | 414/6155 [15:50<4:18:48,  2.70s/it]

Skipping zero-sized bounding box.


  7%|▋         | 442/6155 [17:07<4:22:47,  2.76s/it]

Skipping zero-sized bounding box.


  8%|▊         | 472/6155 [18:29<4:21:11,  2.76s/it]

Skipping zero-sized bounding box.


  8%|▊         | 482/6155 [18:56<4:20:30,  2.76s/it]

Skipping zero-sized bounding box.
Skipping zero-sized bounding box.


  8%|▊         | 487/6155 [19:10<4:17:08,  2.72s/it]

Skipping zero-sized bounding box.


  9%|▊         | 529/6155 [21:03<4:17:30,  2.75s/it]

Skipping zero-sized bounding box.


  9%|▊         | 533/6155 [21:14<4:15:11,  2.72s/it]

Skipping zero-sized bounding box.


 12%|█▏        | 712/6155 [28:12<3:07:57,  2.07s/it]

Skipping zero-sized bounding box.


 12%|█▏        | 718/6155 [28:28<4:00:01,  2.65s/it]

Skipping zero-sized bounding box.


 12%|█▏        | 750/6155 [29:42<2:49:25,  1.88s/it]

Skipping zero-sized bounding box.


 13%|█▎        | 776/6155 [30:34<3:24:38,  2.28s/it]

Skipping zero-sized bounding box.


 13%|█▎        | 791/6155 [31:07<3:24:37,  2.29s/it]

Skipping zero-sized bounding box.


 16%|█▌        | 981/6155 [39:02<3:42:35,  2.58s/it]

Skipping zero-sized bounding box.


 20%|██        | 1244/6155 [49:40<3:47:17,  2.78s/it]

Skipping zero-sized bounding box.


 21%|██        | 1263/6155 [50:33<3:45:49,  2.77s/it]

Skipping zero-sized bounding box.


 21%|██        | 1302/6155 [52:21<3:42:49,  2.75s/it]

Skipping zero-sized bounding box.


 25%|██▍       | 1527/6155 [1:02:43<3:30:54,  2.73s/it]

Skipping zero-sized bounding box.
Skipping zero-sized bounding box.


 26%|██▌       | 1604/6155 [1:06:16<3:28:01,  2.74s/it]

Skipping zero-sized bounding box.


 26%|██▋       | 1625/6155 [1:07:14<3:30:52,  2.79s/it]

Skipping zero-sized bounding box.


 27%|██▋       | 1692/6155 [1:10:21<3:29:17,  2.81s/it]

Skipping zero-sized bounding box.


 30%|██▉       | 1838/6155 [1:17:11<3:21:26,  2.80s/it]

Skipping zero-sized bounding box.


 31%|███       | 1887/6155 [1:19:28<3:20:33,  2.82s/it]

Skipping zero-sized bounding box.


 32%|███▏      | 1965/6155 [1:23:04<3:15:48,  2.80s/it]

Skipping zero-sized bounding box.


 35%|███▍      | 2133/6155 [1:30:49<3:08:11,  2.81s/it]

Skipping zero-sized bounding box.
Skipping zero-sized bounding box.


 37%|███▋      | 2258/6155 [1:36:36<3:00:43,  2.78s/it]

Skipping zero-sized bounding box.


 37%|███▋      | 2294/6155 [1:38:16<2:59:01,  2.78s/it]

Skipping zero-sized bounding box.


 39%|███▊      | 2377/6155 [1:42:07<2:54:29,  2.77s/it]

Skipping zero-sized bounding box.


 40%|████      | 2474/6155 [1:46:37<2:47:16,  2.73s/it]

Skipping zero-sized bounding box.


 42%|████▏     | 2592/6155 [1:52:03<2:43:41,  2.76s/it]

Skipping zero-sized bounding box.


 44%|████▍     | 2694/6155 [1:56:47<2:36:35,  2.71s/it]

Skipping zero-sized bounding box.


 45%|████▌     | 2778/6155 [2:00:40<2:40:32,  2.85s/it]

Skipping zero-sized bounding box.


 48%|████▊     | 2973/6155 [2:09:43<2:27:06,  2.77s/it]

Skipping zero-sized bounding box.


 49%|████▉     | 3037/6155 [2:12:41<2:24:59,  2.79s/it]

Skipping zero-sized bounding box.


 50%|████▉     | 3047/6155 [2:13:09<2:24:01,  2.78s/it]

Skipping zero-sized bounding box.


 50%|█████     | 3108/6155 [2:15:59<2:23:16,  2.82s/it]

Skipping zero-sized bounding box.


 52%|█████▏    | 3174/6155 [2:19:02<2:19:00,  2.80s/it]

Skipping zero-sized bounding box.
Skipping zero-sized bounding box.


 60%|█████▉    | 3679/6155 [2:42:20<1:53:45,  2.76s/it]

Skipping zero-sized bounding box.


 60%|██████    | 3708/6155 [2:43:41<1:54:42,  2.81s/it]

Skipping zero-sized bounding box.


 61%|██████    | 3758/6155 [2:46:00<1:50:05,  2.76s/it]

Skipping zero-sized bounding box.


 61%|██████    | 3763/6155 [2:46:14<1:51:30,  2.80s/it]

Skipping zero-sized bounding box.


 62%|██████▏   | 3840/6155 [2:49:48<1:46:32,  2.76s/it]

Skipping zero-sized bounding box.


 65%|██████▍   | 3970/6155 [2:55:46<1:41:30,  2.79s/it]

Skipping zero-sized bounding box.


 67%|██████▋   | 4127/6155 [3:02:55<1:34:23,  2.79s/it]

Skipping zero-sized bounding box.


 69%|██████▉   | 4236/6155 [3:07:57<1:27:51,  2.75s/it]

Skipping zero-sized bounding box.


 71%|███████   | 4353/6155 [3:13:19<1:23:31,  2.78s/it]

Skipping zero-sized bounding box.


 72%|███████▏  | 4458/6155 [3:18:10<1:18:46,  2.79s/it]

Skipping zero-sized bounding box.


 74%|███████▍  | 4570/6155 [3:23:22<1:13:56,  2.80s/it]

Skipping zero-sized bounding box.


 74%|███████▍  | 4584/6155 [3:24:01<1:13:02,  2.79s/it]

Skipping zero-sized bounding box.


 98%|█████████▊| 6053/6155 [4:20:13<04:33,  2.68s/it]  

Skipping zero-sized bounding box.


 99%|█████████▊| 6066/6155 [4:20:45<03:56,  2.65s/it]

Skipping zero-sized bounding box.
Skipping zero-sized bounding box.
Skipping zero-sized bounding box.


100%|██████████| 6155/6155 [4:21:59<00:00,  2.55s/it]


In [22]:
# np.sum(train_data_seg[0]['mask'][0])
np.sum(train_data_seg[0]['mask'][1].astype(int))

211

In [17]:
s = 0
for i in train_data_seg:
    if not i.__contains__("mask"):
        s+= 1
s

10

In [18]:
train_data_seg[3]

{'image': 'Images/Adam Challenge/Non-AMD/N0051.jpg',
 'id': 'Adam Challenge_0003_train3',
 'conversations': [{'from': 'human',
   'value': '<image>\nWhat method is employed to obtain this fundus image? options: Ultrasound imaging, Magnetic resonance imaging (MRI), Spirometry, Fundus photography'},
  {'from': 'gpt', 'value': 'Fundus photography'}],
 'bbox': [[95, 150, 27, 26],
  [25, 87, 35, 40],
  [41, 94, 19, 31],
  [196, 180, 47, 45]],
 'mask': [array([[False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         ...,
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False]]),
  array([[False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         ...

In [19]:
#training2
new_train_data = []
segments_dict = dict()
for i in train_data_seg:
    template = dict()
    
    # template['answer_type'] = i['answer_type']
    template['image'] = i['image']
    template['id'] = i['id']
    template['conversations'] = []
    template['bboxes'] = []
    template['masks'] = []
    segments_dict[str(i['id'])] = []
    if i.__contains__("bbox"):
        template['bboxes'] = i["bbox"]
    if i.__contains__("mask"):
        segments_dict[str(i['id'])] = i["mask"]

    # if i.__contains__("mask"):
    #     json_ready_segments = [
    #         arr.astype(int).tolist() for arr in i["mask"]
    #     ]
    #     template['masks'] = json_ready_segments
    # template['text'] = i['question']

    # new_qa = {"from": "human", "value": i['question']}
    # new_qa2 = {"from": "gpt", "value": str(i['answer'])}
    template['conversations'] = i['conversations']
    new_train_data.append(template)

new_train_data[-6]

{'image': 'Images/MIAS/mdb188.png',
 'id': 'MIAS_0134_train6149',
 'conversations': [{'from': 'human',
   'value': '<image>\nWhat type of abnormality is depicted in this image? options: Spiculated masses, Metastatic tumors, Scar tissue formations, Cystic lesions'},
  {'from': 'gpt', 'value': 'Spiculated masses'}],
 'bboxes': [[65, 39, 102, 183]],
 'masks': []}

In [20]:
new_train_data_top4 = []
for i in new_train_data:
    j = i.copy()
    j["bboxes"] = j["bboxes"][:4]
    new_train_data_top4.append(j)

In [22]:
new_train_data_top4[100]

{'image': 'Images/DeepDRiD/regular-fundus-test/463/463_r1.jpg',
 'id': 'DeepDRiD_0054_train100',
 'conversations': [{'from': 'human',
   'value': '<image>\nWhat type of imaging technique was used to capture this image? options: fundus photography., Electrocardiogram (ECG)., Endoscopy., MRI.'},
  {'from': 'gpt', 'value': 'fundus photography.'}],
 'bboxes': [[167, 43, 46, 42]],
 'masks': []}

In [23]:
# segments_dict['0']
ed = 0
for i in segments_dict:
    if len(segments_dict[i]) == 0:
        ed += 1
ed

411

In [24]:
# save the masks to npz file

np.savez_compressed("/data/aofei/hallucination/CARES/OmniMedVQA/training_segments_top4.npz", **segments_dict)

In [25]:
len(new_train_data)

6155

In [26]:
with open('/data/aofei/hallucination/CARES/OmniMedVQA/training_masks_top4.json', 'w') as json_file:
    json.dump(new_train_data_top4, json_file, indent=4)