In [1]:
import cv2
import random
import json
import time

import torch
import numpy as np
from tqdm import tqdm

import clip
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from segment_anything.utils.transforms import ResizeLongestSide 

from PIL import Image  
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image

from sam_caf import hyper_params_tuning, get_crops, retrieve_relevant_crop, retrieve_relevant_crop_biomed, get_sam_prompts, sam_predicton, retrieve_relevant_crop_biomed_topk

In [2]:
#config
class DictToObject:
    def __init__(self, dict_obj):
        for key, value in dict_obj.items():
            setattr(self, key, value)

config_dict = {
    "model_name" : "SAM",
    "model_type" : "vit_h",
    "source":    "False", 
    "refine" : "False",
    "pre_trained": "True", 
    "sam_ckpt":  "/data/aofei/LLM/SAM/sam_vit_h_4b8939.pth", 
    "clip_prompts": "./clip_prompts/abd_seg.json"
}

config = DictToObject(config_dict)

prompt_mode, mode = "crops", "sam_clip"

def preprocess_image(image_path):
    image = cv2.imread(image_path)
    image = cv2.resize(image, (256, 256))
    return image

In [3]:
import os
# os.environ["TRANSFORMERS_CACHE"]="/data/aofei/huggingface_cache/transformers"
os.environ["HF_HOME"]="/data/aofei/huggingface_cache/transformers"
from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8

biomed_clip_model, biomed_preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224', device="cuda")
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

clip_model, preprocess = clip.load("ViT-L/14", device="cuda")
sam_checkpoint = config.sam_ckpt

sam = sam_model_registry[config.model_type](checkpoint=sam_checkpoint)
sam.to("cuda")
resize_transform = ResizeLongestSide(sam.image_encoder.img_size)

dice_scores = []
mask_generator, area = hyper_params_tuning(sam)

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [4]:
def sam_generation(image_path):
    image = preprocess_image(image_path=image_path)
    with torch.no_grad():
        # if mode == "sam_clip":
        masks = mask_generator.generate(image)
        masks = [mask for mask in masks if mask["area"] < area] # area filtering based on area value from hyper-params tuning
        img_crops = get_crops(image, masks, prompt_mode)
        
    return masks, img_crops

def filter_sam_results(masks, img_crops):
    new_masks, new_img_crops = [], []
    for i in range(len(masks)):
        mask = masks[i]
        if mask['bbox'][0] == 0 or mask['bbox'][1] == 0:
            continue
        if mask['bbox'][2] <= 12 or mask['bbox'][3] <= 12:
            continue
        y_max, x_max = mask['bbox'][1] + mask['bbox'][3], mask['bbox'][0] + mask['bbox'][2]
        if y_max > 253 or x_max > 253:
            continue
        new_masks.append(mask)
        new_img_crops.append(img_crops[i])
    return new_masks, new_img_crops

def get_topk_similar(k, crop_scores):
    sorted_scores = sorted([(i, m) for (i,m) in enumerate(crop_scores)], key=lambda x: x[1], reverse=True)
    return sorted_scores[:k]

def get_compelete_contour(masks):
    width_list = []
    # need to consider the chest xray
    for i in masks:
        width, height = i['bbox'][2], i['bbox'][3]
        width_list.append(width)
    sorted_width = sorted([(i, m) for (i,m) in  enumerate(width_list)], key=lambda x: x[1], reverse=True)
    return sorted_width[0][1]

def judge_inner_boxes(bboxes):
    for bbox in bboxes:
        bbox

In [5]:
masks_image_dict = dict()
masks_all_image_dict = dict()
crops_image_dict = dict()
def generate_segments(query, image_path):
    if masks_image_dict.__contains__(image_path):
        masks = masks_image_dict[image_path]
        img_crops = crops_image_dict[image_path]
    else:
        masks, img_crops = sam_generation(image_path=image_path)
        masks, img_crops = filter_sam_results(masks, img_crops)
        masks_image_dict[image_path] = masks
        crops_image_dict[image_path] = img_crops
    img_crops_filtered = img_crops
    prompts = {"query": [query]}
    max_indices, scores = retrieve_relevant_crop_biomed_topk(img_crops, prompts, biomed_clip_model, biomed_preprocess, config, tokenizer=tokenizer, topk=8)
    # topk_indices = get_topk_similar(3, scores["query"])
    # define a set of rules, firstly return top3
    # if there is no explicit organs to be used as query, then just use the whole segmentation
    # if the smaller boxes are in the bigger box, then use all of them but assign higher weights on smaller inner boxes
    bboxes = []
    segs = []
    # print(max_indices)
    for i in max_indices["query"]:
        bboxes.append(masks[i]["bbox"])
        segs.append(masks[i]["segmentation"])
    return bboxes, segs, max_indices["query"]


def generate_all_segments(image_path):
    if masks_all_image_dict.__contains__(image_path):
        masks = masks_all_image_dict[image_path]
    else:
        masks, img_crops = sam_generation(image_path=image_path)
        masks, img_crops = filter_sam_results(masks, img_crops)
        masks_all_image_dict[image_path] = masks
    return masks

In [None]:
# Preprocessing for Slake
import json
# with open(r"/data/aofei/hallucination/Slake/data/training_contours.json", "r") as f:
#     train_data = json.load(f)
# len(train_data)

with open(r"/data/aofei/hallucination/Slake/train.json", "r") as f:
    all_train_data = json.load(f)
len(all_train_data)

with open(r"/data/aofei/hallucination/Slake/data/test.json", "r") as f:
    test_data = json.load(f)
len(test_data)

with open(r"/data/aofei/hallucination/Slake/test.json", "r") as f:
    all_test_data = json.load(f)
len(all_test_data)

2094

In [8]:
all_train_data[0]

{'img_id': 1,
 'img_name': 'xmlab1/source.jpg',
 'question': 'What modality is used to take this image?',
 'answer': 'MRI',
 'q_lang': 'en',
 'location': 'Abdomen',
 'modality': 'MRI',
 'answer_type': 'OPEN',
 'base_type': 'vqa',
 'content_type': 'Modality',
 'triple': ['vhead', '_', '_'],
 'qid': 0}

In [9]:
all_train_data_en = [i for i in all_train_data if i['q_lang']=="en"]
len(all_train_data_en)

4919

In [10]:
import copy
train_data_seg = copy.deepcopy(all_train_data_en)
len(train_data_seg)

4919

In [16]:
len(set([i['img_name'] for i in train_data_seg]))

450

In [None]:
# len(set([i['img_name'] for i in train_data]))

450

In [14]:
# for i in range(len(train_rad_data)):
#failure case: 432-3 + 100
from tqdm import tqdm
for i in tqdm(range(len(train_data_seg))):
# for i in tqdm(range(20)):
    data = train_data_seg[i]
    image_path = os.path.join("/data/aofei/hallucination/Slake/imgs", data["img_name"])
    if "chest" in data['location'].lower():
        # query = f"Medical image of {data['image_organ']} and lungs. " + data["question"]
        # query = f"{data['location']}"
        query = f"Medical image of {data['location']} and lungs. " + data["question"]
    elif "abd" in data['location'].lower():
        query = f"Medical imgaing of abdomen." + data["question"]
        # query = f"Medical imgaing of abdomen."
    else:
        # query = f"Medical image of {data['image_organ']}. " + data["question"]
        query = f"Medical image of {data['location']}." + data["question"]
    bbox, segs, max_indices = [], [], []
    try:
        bbox, segs, max_indices = generate_segments(query, image_path)
    except:
        continue
    # print(bbox)
    data['bbox'] = bbox
    data['mask'] = segs
    data["bbox_indices"] = max_indices
    

  1%|▏         | 73/4919 [00:20<41:19,  1.95it/s]

Skipping zero-sized bounding box.
Skipping zero-sized bounding box.
Skipping zero-sized bounding box.
Skipping zero-sized bounding box.


  2%|▏         | 75/4919 [00:26<1:27:25,  1.08s/it]

Skipping zero-sized bounding box.
Skipping zero-sized bounding box.


  2%|▏         | 76/4919 [00:28<1:47:44,  1.33s/it]

Skipping zero-sized bounding box.
Skipping zero-sized bounding box.


  2%|▏         | 77/4919 [00:31<2:07:45,  1.58s/it]

Skipping zero-sized bounding box.
Skipping zero-sized bounding box.


  2%|▏         | 78/4919 [00:34<2:26:09,  1.81s/it]

Skipping zero-sized bounding box.
Skipping zero-sized bounding box.


  2%|▏         | 79/4919 [00:36<2:42:14,  2.01s/it]

Skipping zero-sized bounding box.
Skipping zero-sized bounding box.


  2%|▏         | 80/4919 [00:39<2:55:36,  2.18s/it]

Skipping zero-sized bounding box.
Skipping zero-sized bounding box.


  2%|▏         | 81/4919 [00:42<3:06:12,  2.31s/it]

Skipping zero-sized bounding box.
Skipping zero-sized bounding box.


  2%|▏         | 82/4919 [00:44<3:14:26,  2.41s/it]

Skipping zero-sized bounding box.
Skipping zero-sized bounding box.


  2%|▏         | 83/4919 [00:47<3:20:40,  2.49s/it]

Skipping zero-sized bounding box.
Skipping zero-sized bounding box.


 22%|██▏       | 1084/4919 [06:34<32:14,  1.98it/s]  

Skipping zero-sized bounding box.


 22%|██▏       | 1085/4919 [06:37<57:02,  1.12it/s]

Skipping zero-sized bounding box.


 22%|██▏       | 1086/4919 [06:39<1:20:36,  1.26s/it]

Skipping zero-sized bounding box.


 22%|██▏       | 1087/4919 [06:42<1:41:10,  1.58s/it]

Skipping zero-sized bounding box.


 22%|██▏       | 1088/4919 [06:45<1:58:33,  1.86s/it]

Skipping zero-sized bounding box.


 22%|██▏       | 1089/4919 [06:48<2:12:26,  2.07s/it]

Skipping zero-sized bounding box.


 22%|██▏       | 1090/4919 [06:50<2:23:13,  2.24s/it]

Skipping zero-sized bounding box.


 22%|██▏       | 1091/4919 [06:53<2:31:27,  2.37s/it]

Skipping zero-sized bounding box.


 22%|██▏       | 1092/4919 [06:56<2:37:27,  2.47s/it]

Skipping zero-sized bounding box.


 22%|██▏       | 1093/4919 [06:58<2:41:46,  2.54s/it]

Skipping zero-sized bounding box.


 22%|██▏       | 1094/4919 [07:01<2:45:09,  2.59s/it]

Skipping zero-sized bounding box.


 22%|██▏       | 1095/4919 [07:04<2:48:03,  2.64s/it]

Skipping zero-sized bounding box.


 22%|██▏       | 1096/4919 [07:07<2:49:46,  2.66s/it]

Skipping zero-sized bounding box.


 22%|██▏       | 1097/4919 [07:09<2:50:39,  2.68s/it]

Skipping zero-sized bounding box.


 22%|██▏       | 1098/4919 [07:12<2:52:17,  2.71s/it]

Skipping zero-sized bounding box.


 22%|██▏       | 1099/4919 [07:15<2:52:15,  2.71s/it]

Skipping zero-sized bounding box.


 39%|███▉      | 1911/4919 [10:54<05:55,  8.47it/s]  

Skipping zero-sized bounding box.


 39%|███▉      | 1914/4919 [10:59<33:17,  1.50it/s]

Skipping zero-sized bounding box.
Skipping zero-sized bounding box.


 39%|███▉      | 1916/4919 [11:05<57:21,  1.15s/it]

Skipping zero-sized bounding box.
Skipping zero-sized bounding box.


 39%|███▉      | 1918/4919 [11:10<1:17:17,  1.55s/it]

Skipping zero-sized bounding box.


 39%|███▉      | 1919/4919 [11:13<1:26:20,  1.73s/it]

Skipping zero-sized bounding box.


 39%|███▉      | 1920/4919 [11:16<1:35:24,  1.91s/it]

Skipping zero-sized bounding box.


 39%|███▉      | 1921/4919 [11:19<1:43:48,  2.08s/it]

Skipping zero-sized bounding box.


 39%|███▉      | 1922/4919 [11:21<1:51:15,  2.23s/it]

Skipping zero-sized bounding box.


 39%|███▉      | 1923/4919 [11:24<1:58:08,  2.37s/it]

Skipping zero-sized bounding box.


 39%|███▉      | 1924/4919 [11:27<2:02:48,  2.46s/it]

Skipping zero-sized bounding box.


 39%|███▉      | 1925/4919 [11:30<2:06:30,  2.54s/it]

Skipping zero-sized bounding box.


 39%|███▉      | 1926/4919 [11:32<2:09:20,  2.59s/it]

Skipping zero-sized bounding box.


 39%|███▉      | 1927/4919 [11:35<2:11:56,  2.65s/it]

Skipping zero-sized bounding box.


 57%|█████▋    | 2810/4919 [15:14<11:11,  3.14it/s]  

Skipping zero-sized bounding box.


 58%|█████▊    | 2840/4919 [15:21<09:20,  3.71it/s]

Skipping zero-sized bounding box.


 58%|█████▊    | 2855/4919 [15:24<11:49,  2.91it/s]

Skipping zero-sized bounding box.


 58%|█████▊    | 2870/4919 [15:28<12:17,  2.78it/s]

Skipping zero-sized bounding box.


 59%|█████▊    | 2885/4919 [15:31<12:12,  2.78it/s]

Skipping zero-sized bounding box.
Skipping zero-sized bounding box.


 59%|█████▊    | 2887/4919 [15:36<29:41,  1.14it/s]

Skipping zero-sized bounding box.
Skipping zero-sized bounding box.


 59%|█████▊    | 2889/4919 [15:42<44:23,  1.31s/it]

Skipping zero-sized bounding box.


 59%|█████▉    | 2890/4919 [15:44<51:11,  1.51s/it]

Skipping zero-sized bounding box.


 59%|█████▉    | 2891/4919 [15:47<58:03,  1.72s/it]

Skipping zero-sized bounding box.


 59%|█████▉    | 2892/4919 [15:50<1:04:30,  1.91s/it]

Skipping zero-sized bounding box.


 59%|█████▉    | 2893/4919 [15:52<1:10:15,  2.08s/it]

Skipping zero-sized bounding box.


 59%|█████▉    | 2894/4919 [15:55<1:15:06,  2.23s/it]

Skipping zero-sized bounding box.


 59%|█████▉    | 2895/4919 [15:58<1:19:03,  2.34s/it]

Skipping zero-sized bounding box.


 59%|█████▉    | 2896/4919 [16:00<1:22:29,  2.45s/it]

Skipping zero-sized bounding box.


 59%|█████▉    | 2897/4919 [16:03<1:24:37,  2.51s/it]

Skipping zero-sized bounding box.


 59%|█████▉    | 2898/4919 [16:06<1:26:10,  2.56s/it]

Skipping zero-sized bounding box.


 59%|█████▉    | 2899/4919 [16:09<1:27:23,  2.60s/it]

Skipping zero-sized bounding box.


 59%|█████▉    | 2903/4919 [16:11<40:23,  1.20s/it]  

Skipping zero-sized bounding box.


 59%|█████▉    | 2925/4919 [16:18<16:37,  2.00it/s]

Skipping zero-sized bounding box.
Skipping zero-sized bounding box.


 60%|█████▉    | 2927/4919 [16:23<36:03,  1.09s/it]

Skipping zero-sized bounding box.


 60%|█████▉    | 2928/4919 [16:26<44:34,  1.34s/it]

Skipping zero-sized bounding box.


 60%|█████▉    | 2929/4919 [16:29<52:53,  1.59s/it]

Skipping zero-sized bounding box.


 60%|█████▉    | 2930/4919 [16:31<1:00:31,  1.83s/it]

Skipping zero-sized bounding box.


 60%|█████▉    | 2931/4919 [16:34<1:07:06,  2.03s/it]

Skipping zero-sized bounding box.


 60%|█████▉    | 2932/4919 [16:37<1:12:31,  2.19s/it]

Skipping zero-sized bounding box.


 60%|█████▉    | 2933/4919 [16:39<1:16:49,  2.32s/it]

Skipping zero-sized bounding box.


 60%|█████▉    | 2934/4919 [16:42<1:20:06,  2.42s/it]

Skipping zero-sized bounding box.


 60%|█████▉    | 2935/4919 [16:45<1:22:31,  2.50s/it]

Skipping zero-sized bounding box.


 60%|█████▉    | 2936/4919 [16:47<1:24:18,  2.55s/it]

Skipping zero-sized bounding box.


 60%|█████▉    | 2937/4919 [16:50<1:25:30,  2.59s/it]

Skipping zero-sized bounding box.


 60%|█████▉    | 2938/4919 [16:53<1:26:48,  2.63s/it]

Skipping zero-sized bounding box.


 60%|█████▉    | 2939/4919 [16:55<1:27:20,  2.65s/it]

Skipping zero-sized bounding box.


 60%|██████    | 2970/4919 [17:05<14:01,  2.32it/s]  

Skipping zero-sized bounding box.


 91%|█████████▏| 4492/4919 [25:14<01:06,  6.43it/s]

Skipping zero-sized bounding box.


 91%|█████████▏| 4494/4919 [25:19<06:32,  1.08it/s]

Skipping zero-sized bounding box.


 91%|█████████▏| 4495/4919 [25:22<08:45,  1.24s/it]

Skipping zero-sized bounding box.


 91%|█████████▏| 4496/4919 [25:25<10:49,  1.54s/it]

Skipping zero-sized bounding box.


 91%|█████████▏| 4497/4919 [25:27<12:39,  1.80s/it]

Skipping zero-sized bounding box.


 91%|█████████▏| 4498/4919 [25:30<14:10,  2.02s/it]

Skipping zero-sized bounding box.


 91%|█████████▏| 4499/4919 [25:33<15:23,  2.20s/it]

Skipping zero-sized bounding box.


 91%|█████████▏| 4500/4919 [25:36<16:19,  2.34s/it]

Skipping zero-sized bounding box.


 92%|█████████▏| 4501/4919 [25:38<17:00,  2.44s/it]

Skipping zero-sized bounding box.


 98%|█████████▊| 4826/4919 [27:50<00:13,  6.88it/s]

Skipping zero-sized bounding box.


 98%|█████████▊| 4828/4919 [27:56<01:23,  1.09it/s]

Skipping zero-sized bounding box.
Skipping zero-sized bounding box.


 98%|█████████▊| 4830/4919 [28:01<02:09,  1.45s/it]

Skipping zero-sized bounding box.


 98%|█████████▊| 4831/4919 [28:04<02:27,  1.67s/it]

Skipping zero-sized bounding box.


 98%|█████████▊| 4832/4919 [28:07<02:43,  1.88s/it]

Skipping zero-sized bounding box.


 98%|█████████▊| 4833/4919 [28:09<02:57,  2.07s/it]

Skipping zero-sized bounding box.


 98%|█████████▊| 4834/4919 [28:12<03:08,  2.22s/it]

Skipping zero-sized bounding box.


 98%|█████████▊| 4835/4919 [28:15<03:17,  2.35s/it]

Skipping zero-sized bounding box.


 99%|█████████▊| 4856/4919 [28:24<00:26,  2.35it/s]

Skipping zero-sized bounding box.


100%|██████████| 4919/4919 [28:45<00:00,  2.85it/s]


In [15]:
train_data_seg[-1]

{'img_id': 99,
 'img_name': 'xmlab99/source.jpg',
 'question': 'Is the lung healthy?',
 'answer': 'Yes',
 'q_lang': 'en',
 'location': 'Lung',
 'modality': 'CT',
 'answer_type': 'CLOSED',
 'base_type': 'vqa',
 'triple': ['vhead', '_', '_'],
 'qid': 4918,
 'content_type': 'Abnormality',
 'bbox': [[5, 58, 244, 183],
  [28, 80, 89, 132],
  [169, 78, 53, 59],
  [147, 140, 77, 72],
  [97, 173, 51, 56],
  [40, 118, 47, 77],
  [28, 78, 198, 135],
  [40, 78, 149, 151]],
 'mask': [array([[False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         ...,
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False]]),
  array([[False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, Fa

In [None]:
import os
import cv2
import matplotlib.pyplot as plt

output_folder = "/data/aofei/hallucination/Slake/imgs_masks"
os.makedirs(output_folder, exist_ok=True)

for data in tqdm(train_data):
    image_path = os.path.join("/data/aofei/hallucination/Slake/imgs", data["img_name"])
    output_path = os.path.join(output_folder, data["img_name"])
    
    if os.path.exists(output_path):
        continue
    
    masks = generate_all_segments(image_path=image_path)
    
    image = preprocess_image(image_path=image_path)
    image_with_contours = image.copy()
    
    for mask in masks:
        segmentation = mask['segmentation'].astype(np.uint8)
        contours, _ = cv2.findContours(segmentation, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(image_with_contours, contours, -1, (0, 0, 255), 1)
    
    # Ensure the output directory exists
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # Save the image with contours
    success = cv2.imwrite(output_path, image_with_contours)
    if not success:
        print(f"Failed to save image: {output_path}")
    
    # Visualize the image with contours before saving
    # plt.figure(figsize=(6, 6))
    # plt.imshow(cv2.cvtColor(image_with_contours, cv2.COLOR_BGR2RGB))
    # plt.axis('off')
    # plt.show()
    
    print(f"Saved image with contours: {output_path}")

    

In [None]:
masks_all_image_dict

In [None]:
np.sum(train_data[0]['mask'][2]) / (256*256)

In [16]:
s = 0
for i in train_data_seg:
    if not i.__contains__("mask"):
        s+= 1
s

90

In [17]:
#training2
new_train_data = []
segments_dict = dict()
for i in train_data_seg:
    if i['q_lang'] != "en":
        continue
    template = dict()
    
    # template['answer_type'] = i['answer_type']
    template['image'] = i['img_name']
    template['id'] = i['qid']
    template['conversations'] = []
    template['bboxes'] = []
    template['masks'] = []
    segments_dict[str(i['qid'])] = []
    if i.__contains__("bbox"):
        template['bboxes'] = i["bbox"]
    if i.__contains__("mask"):
        segments_dict[str(i['qid'])] = i["mask"]

    # if i.__contains__("mask"):
    #     json_ready_segments = [
    #         arr.astype(int).tolist() for arr in i["mask"]
    #     ]
    #     template['masks'] = json_ready_segments
    # template['text'] = i['question']

    new_qa = {"from": "human", "value": "<image>\n" + i['question']}
    new_qa2 = {"from": "gpt", "value": str(i['answer'])}
    template['conversations'] += [new_qa, new_qa2]
    new_train_data.append(template)

new_train_data[-1]

{'image': 'xmlab99/source.jpg',
 'id': 4918,
 'conversations': [{'from': 'human', 'value': '<image>\nIs the lung healthy?'},
  {'from': 'gpt', 'value': 'Yes'}],
 'bboxes': [[5, 58, 244, 183],
  [28, 80, 89, 132],
  [169, 78, 53, 59],
  [147, 140, 77, 72],
  [97, 173, 51, 56],
  [40, 118, 47, 77],
  [28, 78, 198, 135],
  [40, 78, 149, 151]],
 'masks': []}

In [18]:
# segments_dict['0']
ed = 0
for i in segments_dict:
    if len(segments_dict[i]) == 0:
        ed += 1
ed

90

In [19]:
# save the masks to npz file

np.savez_compressed("/data/aofei/hallucination/Slake/data/training_segments_top8.npz", **segments_dict)

In [20]:
len(new_train_data)

4919

In [None]:
loaded_segments = np.load("/data/aofei/hallucination/Slake/data/training_segments.npz", allow_pickle=True)  # Allow pickle to handle lists

# Access a list of segment arrays by its ID, e.g., "id_1"
segments_list_id_1 = loaded_segments["0"]

# Each item in segments_list_id_1 is a 256x256 numpy array
for segment in segments_list_id_1:
    print(segment.shape)

In [21]:
with open('/data/aofei/hallucination/Slake/data/training_masks_top8.json', 'w') as json_file:
    json.dump(new_train_data, json_file, indent=4)

In [None]:
with open(r"/data/aofei/hallucination/Slake/test.json", "r") as f:
    test_data = json.load(f)
len(test_data)

#training2
new_test_data = []
for i in test_data:
    if i['q_lang'] != "en":
        continue
    template = dict()
    
    # template['answer_type'] = i['answer_type']
    template['image'] = i['img_name']
    template['id'] = i['qid']
    template['answer_type'] = i['answer_type']
    template['conversations'] = []

    new_qa = {"from": "human", "value": "<image>\n" + i['question']}
    new_qa2 = {"from": "gpt", "value": str(i['answer'])}
    template['conversations'] += [new_qa, new_qa2]
    new_test_data.append(template)

In [None]:
print(len(test_data))
with open('/data/aofei/hallucination/Slake/data/test.json', 'w') as json_file:
    json.dump(new_test_data, json_file, indent=4)

In [22]:
import json
with open(r"/data/aofei/hallucination/Slake/data/training_masks_top8.json", "r") as f:
    seg_train_data = json.load(f)
len(seg_train_data)
seg_train_dict = dict()
for i in seg_train_data:
    seg_train_dict[i['id']] = i

In [23]:
with open(r"/data/aofei/hallucination/Slake/train.json", "r") as f:
    train_data = json.load(f)
len(train_data)

9835

In [24]:
train_data[1]

{'img_id': 1,
 'img_name': 'xmlab1/source.jpg',
 'question': 'Which part of the body does this image belong to?',
 'answer': 'Abdomen',
 'q_lang': 'en',
 'location': 'Abdomen',
 'modality': 'MRI',
 'answer_type': 'OPEN',
 'base_type': 'vqa',
 'content_type': 'Position',
 'triple': ['vhead', '_', '_'],
 'qid': 1}

In [25]:
num_dict = dict()
ids_dict = dict()
for i in train_data:
    organ = i['location']
    id = i['qid']
    if num_dict.__contains__(organ):
        num_dict[organ] += 1
    else:
        num_dict[organ] = 1
    
    if ids_dict.__contains__(organ):
        ids_dict[organ].append(id)
    else:
        ids_dict[organ] = [id]

In [26]:
num_dict

{'Abdomen': 3041,
 'Lung': 3406,
 'Chest_heart': 187,
 'Chest_lung': 283,
 'Brain_Tissue': 1394,
 'Brain_Face': 250,
 'Brain': 543,
 'Neck': 264,
 'Chest_mediastinal': 33,
 'Pelvic Cavity': 434}

In [27]:
train_data_lungs = []
for _id in ids_dict['Lung']:
    if seg_train_dict.__contains__(_id):
        train_data_lungs.append(seg_train_dict[_id])

In [28]:
len(train_data_lungs)

1710

In [29]:
with open('/data/aofei/hallucination/Slake/data/training_masks_top8_lung.json', 'w') as json_file:
    json.dump(train_data_lungs, json_file, indent=4)

In [30]:
train_data_lungs = []
for _id in ids_dict['Abdomen']:
    if seg_train_dict.__contains__(_id):
        train_data_lungs.append(seg_train_dict[_id])
with open('/data/aofei/hallucination/Slake/data/training_masks_top8_abd.json', 'w') as json_file:
    json.dump(train_data_lungs, json_file, indent=4)

In [None]:
train_data[0]

In [None]:
visualize_with_indices(image_name='xmlab1/source.jpg', indices_list=[3,0,2,1], fig_width=5)

In [None]:
for i in train_rad_data[-20:-10]:
    image_name = i['image_name']
    indices_list = i['bbox_indices']
    visualize_with_indices(image_name=image_name, indices_list=indices_list)

In [None]:
# visualize_with_indices(image_name='synpic28602.jpg', indices_list=[3, 8, 9], fig_width=5)

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

def visualize_with_indices(image_name, indices_list:list = None, indices:tuple = None, width_threshold=260, fig_width=6):
    image_path = os.path.join("/data/aofei/hallucination/Slake/imgs", image_name)
    masks = masks_image_dict[image_path]
    image = cv2.imread(image_path)
    image = cv2.resize(image, (256, 256))
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image_with_boxes = image_rgb.copy()

    # Create an empty mask with the same size as the image
    combined_mask = np.zeros_like(image_rgb, dtype=np.uint8)

    # Loop over all segmentation results
    selected_masks = []
    if indices is not None:
        selected_masks = masks[indices[0]:indices[1]]
    elif indices_list is not None:
        for i in indices_list:
            selected_masks.append(masks[i])

    for seg in selected_masks:
        # Unpack bbox (bounding box)
        x, y, w, h = seg['bbox']
        if w >= width_threshold:
            continue

        # Draw the bounding box (in red)
        cv2.rectangle(image_with_boxes, (x, y), (x + w, y + h), (255, 0, 0), 1)

        # Extract and resize the segmentation mask
        mask = seg['segmentation'].astype(np.uint8)
        mask = cv2.resize(mask, (image_rgb.shape[1], image_rgb.shape[0]))  # Resize to fit the image

        # Add mask to combined mask (use a different color for each mask if desired)
        color_mask = np.zeros_like(image_rgb)
        color_mask[mask == 1] = [0, 255, 0]  # Green mask for the segment
        combined_mask = np.maximum(combined_mask, color_mask)

    # Blend the original image with the combined mask once
    alpha = 0.5  # Transparency factor
    image_with_masks = cv2.addWeighted(image_with_boxes, 1 - alpha, combined_mask, alpha, 0)

    # Display the image with bounding boxes and masks
    plt.figure(figsize=(fig_width, fig_width))
    plt.imshow(image_with_masks)
    plt.axis('off')  # Turn off axis for clean visualization
    plt.show()