## 1. Load the model

In [15]:
import numpy as np
import torch, warnings, os
import nibabel as nib
from transformers import AutoTokenizer, AutoModelForCausalLM
warnings.filterwarnings("ignore")
os.environ['TRANSFORMERS_CACHE'] = './HFCache'
os.environ['HF_HOME'] = './HFCache'
os.environ['CUDA_VISIBLE_DEVICES'] = '4'

device = torch.device('cuda') # 'cpu', 'cuda'
dtype = torch.bfloat16 # or bfloat16, float16, float32

# model_name_or_path = '/mnt/ccvl15/qwu59/checkpoints/m3d/M3D-LaMed-Phi-3-4B'
model_name_or_path = '/mnt/sdh/qwu59/ckpts/m3d/M3D-LaMed-Phi-3-4B'
proj_out_num = 256

# Prepare your 3D medical image:
# 1. The image shape needs to be processed as 1*32*256*256, consider resize and other methods.
# 2. The image needs to be normalized to 0-1, consider Min-Max Normalization.
# 3. The image format needs to be converted to .npy 
# 4. Although we did not train on 2D images, in theory, the 2D image can be interpolated to the shape of 1*32*256*256 for input.

model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    torch_dtype=dtype,
    device_map='auto',
    trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path,
    model_max_length=512,
    padding_side="right",
    use_fast=False,
    trust_remote_code=True
)

model = model.to(device=device)

build_sam_vit_3d...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [16]:
def inference(question, ct_pro, seg_enable=False, branch='ct'):
    image_tokens = "<im_patch>" * proj_out_num
    input_txt = image_tokens + question
    input_id = tokenizer(
        input_txt,
        return_tensors="pt"
    )['input_ids'].to(device=device)
    attention_mask = tokenizer(
        input_txt,
        return_tensors="pt", 
        padding=True
    )['attention_mask'].to(device=device)

    if branch == 'ct':
        image_pt = ct_pro.transformed["image"].unsqueeze(0).to(device=device).type(dtype)
    elif branch == 'ctmin':
        image_pt = ct_pro.ctmin_transformed["image"].unsqueeze(0).to(device=device).type(dtype)
    # generation, seg_logit = model.generate(
    outputs = model.generate(
        image_pt,
        input_id,
        attention_mask=attention_mask,
        seg_enable=seg_enable,
        max_new_tokens=256,
        do_sample=True,
        top_p=0.9,
        temperature=1.0
    )
    if seg_enable:
        generation, seg_logit = outputs
        seg_mask = ((torch.sigmoid(seg_logit) > 0.5) * 1.0).squeeze(0)
        return tokenizer.batch_decode(generation, skip_special_tokens=True)[0], seg_mask
    else:
        generation = outputs
        return tokenizer.batch_decode(generation, skip_special_tokens=True)[0]

import random
def batch_inference(question, ct_pro, seg_enable=False, branch='ct'):
    image_tokens = "<im_patch>" * proj_out_num
    # input_txt = image_tokens + questions
    questions = [pre + question for pre in ['Please answer the quesion', 'Question', '']]
    input_txts = [image_tokens + random.choice(questions) for _ in range(3)]
    input_ids = tokenizer(
        input_txts,
        return_tensors="pt",
        padding=True,
        padding_side="left",
    )['input_ids'].to(device=device)
    print(input_ids.shape)
    attention_masks = tokenizer(
        input_txts,
        return_tensors="pt", 
        padding=True,
        padding_side="left"
    )['attention_mask'].to(device=device)

    if branch == 'ct':
        image_pt = ct_pro.transformed["image"].unsqueeze(0).to(device=device).type(dtype)
    elif branch == 'ctmin':
        image_pt = ct_pro.ctmin_transformed["image"].unsqueeze(0).to(device=device).type(dtype)
    image_pts = torch.cat([image_pt for _ in range(3)], dim=0)
    # print(image_pt.shape, image_pts.shape)
    # generation, seg_logit = model.generate(
    with torch.no_grad():
        outputs = model.generate(
            image_pts,
            input_ids,
            attention_mask=attention_masks,
            seg_enable=seg_enable,
            max_new_tokens=64,
            do_sample=True,
            top_p=0.9,
            temperature=1.0
        )
    if seg_enable:
        generation, seg_logit = outputs
        seg_mask = ((torch.sigmoid(seg_logit) > 0.5) * 1.0).squeeze(0)
        return tokenizer.batch_decode(generation, skip_special_tokens=True), seg_mask
    else:
        generation = outputs
        return tokenizer.batch_decode(generation, skip_special_tokens=True)

## 2. Test for single prompt

In [19]:
def step_1_q(organ):
    return "Does the bone ct image contain the {}? Answer yes or no.".format(organ)

def step_2_q(organ):
    return (
        "The lowest intensity white area within the body in this bone ct image"
        "is the {} mask annotation. ".format(organ) +
        "Is it correct? Only answer yes or no."
    )

In [22]:
import opti_tf as ctf
# import custom_tf as ctf
from importlib import reload
reload(ctf)

import time
case_path = "/mnt/sdh/qwu59/data/s0015"
# test this step run time
start = time.time()
ct_pro = ctf.CTImageProcessor(case_path, ct_name="ct", mask_name="kidneys")
print("Time:", time.time() - start)
organ = "kidneys"
question1 = step_1_q(organ)
question2 = step_2_q(organ)
text1 = inference(question1, ct_pro, branch='ct') # branch='ct' or 'ctmin'
text2 = inference(question2, ct_pro, branch='ctmin')
print("Question 1:", question1)
print("Answer 1:", text1)
print("*" * 80)
print("Question 2:", question2)
print("Answer 2:", text2)

# torch.cuda.empty_cache()
# generated_texts = batch_inference(question2, ct_pro, branch='ctmin')
# print(generated_texts)

Attempting to load CT image from: /mnt/sdh/qwu59/data/s0015/ct.nii.gz


Time: 4.606508255004883
Question 1: Does the bone ct image contain the kidneys? Answer yes or no.
Answer 1: Yes
********************************************************************************
Question 2: The lowest intensity white area within the body in this bone ct imageis the kidneys mask annotation. Is it correct? Only answer yes or no.
Answer 2: Yes


: 

## 3. Organize the data json

In [4]:
import json, csv
from tqdm import tqdm

def append_dict_to_csv(dict_data, csv_path):
    os.makedirs(os.path.dirname(csv_path), exist_ok=True)
    with open(csv_path, mode='a', newline='') as file:
        writer = csv.DictWriter(file, fieldnames=dict_data.keys())
        if file.tell() == 0:
            writer.writeheader()
        writer.writerow(dict_data)
        
def step_1_q(organ):
    return "Does the bone ct image contain the {}? Answer yes or no.".format(organ)

def step_2_q(organ):
    return (
        "The lowest intensity white area within the body in this bone ct image"
        "is the {} mask annotation. ".format(organ) +
        "What do you think of this? Is it correct? Only answer yes or no."
    )
    
result_path = "results/m3d/"

task1_path = "/mnt/sdh/pedro/AbdomenAtlasBeta/"
task1 = "bad_labels_AbdomenAtlasBeta.json"

task2_path = "/mnt/sdc/pedro/ErrorDetection/cropped_nnunet_results_250Epch"
task2_path_ = "/mnt/sdc/pedro/ErrorDetection/cropped_nnunet_results_250Epch_liver"
task2 = "bad_labels_nnUnet.json"

task3_path = "/mnt/sdh/pedro/AbdomenAtlasBeta/"
task3 = "good_labels_AbdomenAtlasBeta.json"


In [5]:
# load the json file for task 1
with open(task1) as f:
    task1_data = json.load(f)

for i, organ in enumerate(tqdm(task1_data)):
    print("Organ:", organ)
    question1 = step_1_q(organ)
    question2 = step_2_q(organ)
    for j, case in enumerate(tqdm(task1_data[organ])):
        # check whether the case exists in the final csv
        check_table = os.path.join(result_path, "final", "errors_beta_full.csv")
        skip_sign = False
        if os.path.exists(check_table):
            with open(check_table, mode='r') as file:
                reader = csv.DictReader(file)
                for row in reader:
                    if row["sample"] == case and row["organ"] == organ:
                        skip_sign = True
                        break
        if skip_sign:
            continue
        case_path = os.path.join(task1_path, case)
        # print("Case:", case, case_path)
        ct_pro = ctf.CTImageProcessor(case_path, ct_name="ct", mask_name=organ)
        text1 = inference(question1, ct_pro, branch='ct')
        text2 = inference(question2, ct_pro, branch='ctmin')
        task1_raw = {
            "sample": case,
            "organ": organ,
            "part": "errors_beta_full",
            "question1": question1,
            "answer1": text1,
            "question2": question2,
            "answer2": text2
        }
        task1_single = {
            "sample": case,
            "organ": organ,
            "part": "errors_beta_full",
            "result step 1": "present" if "yes" in text1.lower() else "no",
            "label step 1": "present" if ct_pro.mask_present else "no",
            "result step 2": "Correct" if "yes" in text2.lower() else "Incorrect",
            "label step 2": "Incorrect",
        }
        print(task1_single)
        append_dict_to_csv(task1_raw, os.path.join(result_path, "raw", "errors_beta_full.csv"))
        append_dict_to_csv(task1_single, os.path.join(result_path, "final", "errors_beta_full.csv"))


  0%|          | 0/8 [00:00<?, ?it/s]

Organ: kidneys


100%|██████████| 14/14 [00:00<00:00, 856.88it/s]


Organ: pancreas


100%|██████████| 3/3 [00:00<00:00, 1232.77it/s]


Organ: gall_bladder


100%|██████████| 2/2 [00:00<00:00, 890.23it/s]


Organ: postcava


100%|██████████| 6/6 [00:00<00:00, 1464.92it/s]


Organ: stomach


100%|██████████| 2/2 [00:00<00:00, 852.50it/s]


Organ: spleen


100%|██████████| 3/3 [00:00<00:00, 1907.37it/s]


Organ: aorta




The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


{'sample': 'BDMAP_00000031', 'organ': 'aorta', 'part': 'errors_beta_full', 'result step 1': 'present', 'label step 1': 'present', 'result step 2': 'Correct', 'label step 2': 'Incorrect'}




{'sample': 'BDMAP_00004290', 'organ': 'aorta', 'part': 'errors_beta_full', 'result step 1': 'present', 'label step 1': 'present', 'result step 2': 'Correct', 'label step 2': 'Incorrect'}




{'sample': 'BDMAP_00004236', 'organ': 'aorta', 'part': 'errors_beta_full', 'result step 1': 'present', 'label step 1': 'present', 'result step 2': 'Correct', 'label step 2': 'Incorrect'}




{'sample': 'BDMAP_00001045', 'organ': 'aorta', 'part': 'errors_beta_full', 'result step 1': 'present', 'label step 1': 'no', 'result step 2': 'Correct', 'label step 2': 'Incorrect'}




{'sample': 'BDMAP_00001483', 'organ': 'aorta', 'part': 'errors_beta_full', 'result step 1': 'present', 'label step 1': 'present', 'result step 2': 'Incorrect', 'label step 2': 'Incorrect'}




{'sample': 'BDMAP_00004834', 'organ': 'aorta', 'part': 'errors_beta_full', 'result step 1': 'present', 'label step 1': 'present', 'result step 2': 'Incorrect', 'label step 2': 'Incorrect'}




{'sample': 'BDMAP_00000003', 'organ': 'aorta', 'part': 'errors_beta_full', 'result step 1': 'present', 'label step 1': 'present', 'result step 2': 'Correct', 'label step 2': 'Incorrect'}
{'sample': 'BDMAP_00001230', 'organ': 'aorta', 'part': 'errors_beta_full', 'result step 1': 'present', 'label step 1': 'present', 'result step 2': 'Correct', 'label step 2': 'Incorrect'}




{'sample': 'BDMAP_00000233', 'organ': 'aorta', 'part': 'errors_beta_full', 'result step 1': 'present', 'label step 1': 'present', 'result step 2': 'Incorrect', 'label step 2': 'Incorrect'}




{'sample': 'BDMAP_00004457', 'organ': 'aorta', 'part': 'errors_beta_full', 'result step 1': 'present', 'label step 1': 'present', 'result step 2': 'Incorrect', 'label step 2': 'Incorrect'}




{'sample': 'BDMAP_00003935', 'organ': 'aorta', 'part': 'errors_beta_full', 'result step 1': 'present', 'label step 1': 'present', 'result step 2': 'Incorrect', 'label step 2': 'Incorrect'}


100%|██████████| 12/12 [05:40<00:00, 28.37s/it]
 88%|████████▊ | 7/8 [05:40<00:48, 48.64s/it]

{'sample': 'BDMAP_00003976', 'organ': 'aorta', 'part': 'errors_beta_full', 'result step 1': 'present', 'label step 1': 'present', 'result step 2': 'Correct', 'label step 2': 'Incorrect'}
Organ: liver




{'sample': 'BDMAP_00001808', 'organ': 'liver', 'part': 'errors_beta_full', 'result step 1': 'no', 'label step 1': 'present', 'result step 2': 'Correct', 'label step 2': 'Incorrect'}




{'sample': 'BDMAP_00001992', 'organ': 'liver', 'part': 'errors_beta_full', 'result step 1': 'no', 'label step 1': 'present', 'result step 2': 'Correct', 'label step 2': 'Incorrect'}




{'sample': 'BDMAP_00001537', 'organ': 'liver', 'part': 'errors_beta_full', 'result step 1': 'no', 'label step 1': 'present', 'result step 2': 'Incorrect', 'label step 2': 'Incorrect'}




{'sample': 'BDMAP_00001032', 'organ': 'liver', 'part': 'errors_beta_full', 'result step 1': 'no', 'label step 1': 'present', 'result step 2': 'Correct', 'label step 2': 'Incorrect'}




{'sample': 'BDMAP_00002316', 'organ': 'liver', 'part': 'errors_beta_full', 'result step 1': 'no', 'label step 1': 'present', 'result step 2': 'Correct', 'label step 2': 'Incorrect'}




{'sample': 'BDMAP_00001399', 'organ': 'liver', 'part': 'errors_beta_full', 'result step 1': 'no', 'label step 1': 'present', 'result step 2': 'Correct', 'label step 2': 'Incorrect'}




{'sample': 'BDMAP_00001095', 'organ': 'liver', 'part': 'errors_beta_full', 'result step 1': 'no', 'label step 1': 'present', 'result step 2': 'Correct', 'label step 2': 'Incorrect'}




{'sample': 'BDMAP_00001044', 'organ': 'liver', 'part': 'errors_beta_full', 'result step 1': 'no', 'label step 1': 'present', 'result step 2': 'Correct', 'label step 2': 'Incorrect'}




{'sample': 'BDMAP_00002727', 'organ': 'liver', 'part': 'errors_beta_full', 'result step 1': 'no', 'label step 1': 'present', 'result step 2': 'Incorrect', 'label step 2': 'Incorrect'}




{'sample': 'BDMAP_00003467', 'organ': 'liver', 'part': 'errors_beta_full', 'result step 1': 'no', 'label step 1': 'present', 'result step 2': 'Incorrect', 'label step 2': 'Incorrect'}


100%|██████████| 11/11 [01:24<00:00,  7.70s/it]
100%|██████████| 8/8 [07:05<00:00, 53.15s/it]

{'sample': 'BDMAP_00003023', 'organ': 'liver', 'part': 'errors_beta_full', 'result step 1': 'no', 'label step 1': 'present', 'result step 2': 'Correct', 'label step 2': 'Incorrect'}





In [None]:
# load the json file for task 2
with open(task2) as f:
    task2_data = json.load(f)

for i, organ in enumerate(tqdm(task2_data)):
    print("Organ:", organ)
    question1 = step_1_q(organ)
    question2 = step_2_q(organ)
    for j, case in enumerate(tqdm(task2_data[organ])):
        case_path = os.path.join(task2_path, case)
        # print("Case:", case, case_path)
        ct_pro = ctf.CTImageProcessor(case_path, ct_name="ct", mask_name=organ)
        text1 = inference(question1, ct_pro, branch='ct')
        text2 = inference(question2, ct_pro, branch='ctmin')
        task2_raw = {
            "sample": case,
            "organ": organ,
            "part": "errors_nnUnet_full",
            "question1": question1,
            "answer1": text1,
            "question2": question2,
            "answer2": text2
        }
        task2_single = {
            "sample": case,
            "organ": organ,
            "part": "errors_nnUnet_full",
            "result step 1": "present" if "yes" in text1.lower() else "no",
            "label step 1": "present" if ct_pro.mask_present else "no",
            "result step 2": "Correct" if "yes" in text2.lower() else "Incorrect",
            "label step 2": "Incorrect",
        }
        append_dict_to_csv(task2_raw, os.path.join(result_path, "raw", "errors_nnUnet_full.csv"))
        append_dict_to_csv(task2_single, os.path.join(result_path, "final", "errors_nnUnet_full.csv"))


In [None]:
# load the json file for task 1
with open(task3) as f:
    task3_data = json.load(f)

for i, organ in enumerate(tqdm(task1_data)):
    print("Organ:", organ)
    question1 = step_1_q(organ)
    question2 = step_2_q(organ)
    for j, case in enumerate(tqdm(task3_data[organ])):
        case_path = os.path.join(task3_path, case)
        # print("Case:", case, case_path)
        ct_pro = ctf.CTImageProcessor(case_path, ct_name="ct", mask_name=organ)
        text1 = inference(question1, ct_pro, branch='ct')
        text2 = inference(question2, ct_pro, branch='ctmin')
        task3_raw = {
            "sample": case,
            "organ": organ,
            "part": "good_labels_beta_full",
            "question1": question1,
            "answer1": text1,
            "question2": question2,
            "answer2": text2
        }
        task3_single = {
            "sample": case,
            "organ": organ,
            "part": "errors_beta_full",
            "result step 1": "present" if "yes" in text1.lower() else "no",
            "label step 1": "present" if ct_pro.mask_present else "no",
            "result step 2": "Correct" if "yes" in text2.lower() else "Incorrect",
            "label step 2": "Correct",
        }
        append_dict_to_csv(task3_raw, os.path.join(result_path, "raw", "good_labels_beta_full.csv"))
        append_dict_to_csv(task3_single, os.path.join(result_path, "final", "good_labels_beta_full.csv"))
