In [1]:
import numpy as np
import nibabel as nib

from utils.class_SAMMed2D import SAMMed2DInferer
import utils.prompt as prUt
import utils.analysis as anUt
from utils.image import read_im_gt
from utils.interactivity import iterate_2d

# Obtain model
device = 'cuda'
checkpoint_path = "/home/t722s/Desktop/UniversalModels/TrainedModels/sam-med2d_b.pth"
sammed2d_inferer = SAMMed2DInferer(checkpoint_path, device)

# Load image, gt
img_path = '/home/t722s/Desktop/Datasets/Dataset350_AbdomenAtlasJHU_2img/imagesTr/BDMAP_00000001_0000.nii.gz'
gt_path = '/home/t722s/Desktop/Datasets/Dataset350_AbdomenAtlasJHU_2img/labelsTr/BDMAP_00000001.nii.gz'
class_label = 3

gt_unprocessed = nib.load(gt_path).get_fdata()
gt_unprocessed = np.where(gt_unprocessed == class_label, 1, 0)

img, gt = read_im_gt(img_path, gt_path, class_label)

# Set image to predict on 
sammed2d_inferer.set_image(img_path)


INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.13 (you have 1.4.8). Upgrade using: pip install --upgrade albumentations


True
*******load /home/t722s/Desktop/UniversalModels/TrainedModels/sam-med2d_b.pth


In [2]:
# Experiment: n randomly sampled points from foreground
seed = 11121
# n_clicks = 1
n_clicks = 5
point_prompt = prUt.get_pos_clicks2D_row_major(gt, n_clicks, seed = seed)
segmentation = sammed2d_inferer.predict(point_prompt).get_fdata()
anUt.compute_dice(segmentation, gt_unprocessed)

Performing inference on slices: 100%|██████████| 45/45 [00:02<00:00, 18.05it/s]


0.9141744000992084

In [4]:
# Experiment: 2d bounding box per slice with foreground
box_prompt = prUt.get_minimal_boxes_row_major(gt, 0, 0)
segmentation = sammed2d_inferer.predict(box_prompt).get_fdata() 
anUt.compute_dice(segmentation, gt_unprocessed)

Using previously generated image embeddings


Performing inference on slices: 100%|██████████| 45/45 [00:00<00:00, 161.71it/s]


0.9479291027642575

In [5]:
# Experiment: get 3d bbox and slice it to feed it in 2d axially
box_prompt = prUt.get_bbox3d_sliced(gt)

segmentation = sammed2d_inferer.predict(box_prompt).get_fdata()
anUt.compute_dice(segmentation, gt_unprocessed)

Using previously generated image embeddings


Performing inference on slices: 100%|██████████| 45/45 [00:00<00:00, 159.22it/s]


0.7618183641651699

In [6]:
# Experiment: line interpolation
n_slices = 5
interpolation = 'linear'
point_prompt = prUt.point_interpolation(gt, n_slices, interpolation)

segmentation = sammed2d_inferer.predict(point_prompt).get_fdata()
anUt.compute_dice(segmentation, gt_unprocessed)

Using previously generated image embeddings


Performing inference on slices: 100%|██████████| 45/45 [00:00<00:00, 151.37it/s]


0.9110163505559189

In [4]:
import torch
#del sammed2d_inferer
torch.cuda.empty_cache()
from importlib import reload
import utils.class_SAMMed2D as c
reload(c)
sammed2d_inferer = c.SAMMed2DInferer(checkpoint_path, device)

True
*******load /home/t722s/Desktop/UniversalModels/TrainedModels/sam-med2d_b.pth


In [7]:
# Iteratively improve previous segmentation (from line interpolation)
n_slices = 5
interpolation = 'linear'
point_prompt = prUt.point_interpolation(gt, n_slices, interpolation)

segmentation, low_res_logits = sammed2d_inferer.predict(point_prompt, return_low_res_logits=True, transform = False)


initial_prompt = point_prompt
condition = 'dof'
dof_bound = 90
perf_bound = 0.85
dice_scores, dof, segmentations, prompts = iterate_2d(sammed2d_inferer, gt, segmentation, low_res_logits, initial_prompt, sammed2d_inferer.pass_prev_prompts,
                                                      scribble_length = 0.6, contour_distance = 3, disk_size_range= (0,3),
                                                      init_dof = 5, perf_bound = 0.85, dof_bound = dof_bound, seed = seed, detailed = True)

Using previously generated image embeddings


Performing inference on slices: 100%|██████████| 45/45 [00:00<00:00, 143.35it/s]


0.9110163505559189
0.9073761479256036
0.9101709423880054
0.91098359026401
0.9119163930907016
0.9207136409493973
0.9136357753041677
0.9142206464077366
0.9141466468086943
0.9141327750227983


In [9]:
# Experiment: box interpolation
n_boxes = 5

seed_boxes = prUt.get_seed_boxes(gt, n_boxes)
box_prompt = prUt.box_interpolation(seed_boxes)
segmentation = sammed2d_inferer.predict(box_prompt).get_fdata()
anUt.compute_dice(segmentation, gt_unprocessed)

Using previously generated image embeddings


Performing inference on slices: 100%|██████████| 45/45 [00:00<00:00, 136.89it/s]


0.9463246814431747

In [12]:
# Experiment: Point propagation
seed = 11121
n_clicks = 5

# Get seed prompt and bounds
seed_point = prUt.get_seed_point(gt, n_clicks, seed)
slices_to_infer = np.where(np.any(gt, axis=(1,2)))[0]

segmentation, prompt = prUt.point_propagation(sammed2d_inferer, seed_point, slices_to_infer, seed, n_clicks)
print(anUt.compute_dice(gt,segmentation))


Propagating down: 100%|██████████| 22/22 [00:00<00:00, 91.94it/s] 
Propagating up: 100%|██████████| 22/22 [00:00<00:00, 107.68it/s]


0.834450084939546


In [20]:
# Iteratively improve previous segmentation (from point propagation)
seed = 11121
n_clicks = 5

# Get seed prompt and bounds
seed_point = prUt.get_seed_point(gt, n_clicks, seed)
slices_to_infer = np.where(np.any(gt, axis=(1,2)))[0]

segmentation, low_res_logits, prompt = prUt.point_propagation(sammed2d_inferer, seed_point, slices_to_infer, seed, n_clicks, return_low_res_logits=True)


initial_prompt = prompt
condition = 'dof'
dof_bound = 90
perf_bound = 0.85
dice_scores, dof, segmentations, prompts = i.iterate_2d(sammed2d_inferer, gt, segmentation, low_res_logits, initial_prompt, False, #sammed2d_inferer.pass_prev_prompts,
                                                      scribble_length = 0.6, contour_distance = 3, disk_size_range= (0,3),
                                                      init_dof = 5, perf_bound = 0.85, dof_bound = dof_bound, seed = seed, detailed = True)

Propagating down: 100%|██████████| 22/22 [00:00<00:00, 96.04it/s]
Propagating up: 100%|██████████| 22/22 [00:00<00:00, 100.16it/s]


0.834450084939546
0.8617375067128478
0.8637177066134087
0.8820878916181261
0.887469382362208
0.8538703613970255
0.8190438980827729
0.8085717529518619
0.8086292543075307
0.8011691820383827


In [22]:
# Experiment: Box propagation

seed_box = prUt.get_seed_box(gt)
slices_to_infer = np.where(np.any(gt, axis=(1,2)))[0]

segmentation, prompt = prUt.box_propagation(sammed2d_inferer, seed_box, slices_to_infer)

print(anUt.compute_dice(gt,segmentation))

Propagating down: 100%|██████████| 22/22 [00:00<00:00, 127.64it/s]
Propagating up: 100%|██████████| 22/22 [00:00<00:00, 75.59it/s]


0.94016673002479
