In [1]:
from classes.SAMMed2DClass_unstable import SAMMed2DInferer
from utils.base_classes import Points

import utils.analysis as anUt
import utils.prompt as prUt
from utils.image import read_im_gt

import numpy as np

# Obtain model, image, gt
device = 'cuda'
sammed2d_checkpoint_path = "/home/t722s/Desktop/UniversalModels/TrainedModels/sam-med2d_b.pth"
sammed2d_inferer = SAMMed2DInferer(sammed2d_checkpoint_path, device)

img_path = '/home/t722s/Desktop/Datasets/Dataset350_AbdomenAtlasJHU_2img/imagesTr/BDMAP_00000001_0000.nii.gz'
gt_path = '/home/t722s/Desktop/Datasets/Dataset350_AbdomenAtlasJHU_2img/labelsTr/BDMAP_00000001.nii.gz'
img, gt = read_im_gt(img_path, gt_path, 3)


INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.10 (you have 1.4.8). Upgrade using: pip install --upgrade albumentations


True
*******load /home/t722s/Desktop/UniversalModels/TrainedModels/sam-med2d_b.pth


In [2]:
# Experiment: n randomly sampled points from foreground
seed = 11121
n_clicks = 5
point_prompt = prUt.get_pos_clicks2D_row_major(gt, n_clicks, seed = seed)
segmentation = sammed2d_inferer.predict(img, point_prompt)
anUt.compute_dice(segmentation, gt)

Performing inference on slices: 100%|██████████| 45/45 [00:02<00:00, 18.22it/s]


0.9143276197755548

In [5]:
# Iteratively improve
from utils.interactivity import iterate_2d
initial_prompt = point_prompt
condition = 'dof'
dof_bound = 90
seed_sub = np.random.randint(10**5)
segmentation, dof, segmentations, prompts, max_fp_idxs = iterate_2d(sammed2d_inferer, img, gt, segmentation, initial_prompt, 
                                                                         condition = 'dof', init_dof = 5, dof_bound = dof_bound, seed = seed_sub, detailed = True)

0.9163881373017136
Using previously generated image embeddings


Performing inference on slices: 100%|██████████| 6/6 [00:00<00:00, 133.65it/s]


0.9160622758587516
Using previously generated image embeddings


Performing inference on slices: 100%|██████████| 6/6 [00:00<00:00, 152.03it/s]


0.9165641151858233
Using previously generated image embeddings


Performing inference on slices: 100%|██████████| 6/6 [00:00<00:00, 155.12it/s]


0.9194610507487476
Using previously generated image embeddings


Performing inference on slices: 100%|██████████| 3/3 [00:00<00:00, 151.53it/s]


0.9199503386245078
Using previously generated image embeddings


Performing inference on slices: 100%|██████████| 5/5 [00:00<00:00, 131.33it/s]


0.9202966253263359
Using previously generated image embeddings


Performing inference on slices: 100%|██████████| 2/2 [00:00<00:00, 137.61it/s]


0.9230876507914372
Using previously generated image embeddings


Performing inference on slices: 100%|██████████| 7/7 [00:00<00:00, 118.54it/s]


0.9227941374065495
degrees of freedom bound met; terminating with performance 0.9227941374065495


In [3]:
# Experiment: 2d bounding box per slice with foreground
box_prompt = prUt.get_minimal_boxes_row_major(gt, 3, 3)
segmentation = sammed2d_inferer.predict(img, box_prompt)
anUt.compute_dice(segmentation, gt)

Using previously generated image embeddings


Performing inference on slices: 100%|██████████| 45/45 [00:00<00:00, 152.53it/s]


0.9396739793941389

In [4]:
# Experiment: get 3d bbox and slice it to feed it in 2d axially
box_prompt = prUt.get_bbox3d_sliced(gt)

segmentation = sammed2d_inferer.predict(img, box_prompt)
anUt.compute_dice(segmentation, gt)

Using previously generated image embeddings


Performing inference on slices: 100%|██████████| 45/45 [00:00<00:00, 178.06it/s]


0.7636174284813798

In [7]:
# Experiment: line interpolation
n_slices = 5
interpolation = 'linear'
simulated_clicks = prUt.get_fg_points_from_cc_centers(gt, n_slices)
coords = prUt.interpolate_points(simulated_clicks, kind = interpolation).astype(int)

def line_interpolation(gt, n_slices):
    simulated_clicks = prUt.get_fg_points_from_cc_centers(gt, n_slices)
    coords = prUt.interpolate_points(simulated_clicks, kind = interpolation).astype(int)
    point_prompt = Points(coords = coords, labels = [1]*len(coords))
    return(point_prompt)


segmentation = sammed2d_inferer.predict(img, point_prompt)
anUt.compute_dice(segmentation, gt)

Using previously generated image embeddings


Performing inference on slices: 100%|██████████| 45/45 [00:00<00:00, 156.65it/s]


0.9143276197755548

In [2]:
# Experiment: box interpolation
n_boxes = 5

seed_boxes = prUt.get_seed_boxes(gt, n_boxes)
box_prompt = prUt.box_interpolation(seed_boxes)
segmentation = sammed2d_inferer.predict(img, box_prompt)
anUt.compute_dice(segmentation, gt)

Performing inference on slices: 100%|██████████| 45/45 [00:01<00:00, 22.83it/s]


0.9463972842750766

In [14]:
# Experiment: Point propagation

seed = 11121
n_clicks = 5

# Get seed prompt and bounds
seed_point = prUt.get_seed_point(gt, n_clicks, seed)
slices_to_infer = np.where(np.any(gt, axis=(1,2)))[0]

segmentation = prUt.point_propagation(sammed2d_inferer, img, seed_point, slices_to_infer, seed, n_clicks)
print(anUt.compute_dice(gt,segmentation))


Propagating down: 100%|██████████| 22/22 [00:11<00:00,  1.99it/s]
Propagating up: 100%|██████████| 22/22 [00:10<00:00,  2.04it/s]


0.7984751851037314


In [2]:
# Experiment: Box propagation

seed_box = prUt.get_seed_box(gt)
slices_to_infer = np.where(np.any(gt, axis=(1,2)))[0]

segmentation = prUt.box_propagation(sammed2d_inferer, img, seed_box, slices_to_infer)

print(anUt.compute_dice(gt,segmentation))

Propagating down: 100%|██████████| 48/48 [00:02<00:00, 22.50it/s]
Terminate early: no fg generated
Propagating up:  71%|███████   | 34/48 [00:01<00:00, 22.79it/s]


0.4829657225423463


In [10]:
# Experiment: interaction to performance, start with points per slice

seed = 11121
n_seed_clicks = 1 # For initial segmentation
fix_worst_slice = False

target_performance = 0.95

## Generate initial segmentation to improve
np.random.seed(seed)
point_prompt = prUt.get_pos_clicks2D_row_major(gt, n_seed_clicks, seed = seed)
segmentation = sammed2d_inferer.predict(img, point_prompt)

initial_dof = len(point_prompt.coords)*2+2
dice, revised_slices, dof, segmentations = prUt.iter_improve_perf_sammed2d(img, gt, segmentation, sammed2d_inferer, 
                                                                           point_prompt, initial_dof = initial_dof, target_performance = target_performance,
                                                                           fix_worst_slice = fix_worst_slice, seed = None)
dice


Using previously generated image embeddings


Performing inference on slices: 100%|██████████| 45/45 [00:00<00:00, 139.74it/s]


Improving 144
Improving 158
Improving 158
Improving 168
Improving 166
Improving 141
Improving 135
Improving 137
Improving 127
Improving 143
Improving 164
Improving 125
Improving 155
Improving 164
Improving 154
Improving 151
Improving 136
Improving 166
Improving 132
Target performance not achieved within 20 iterations. Final dice 0.9269


[0.9245013037118691,
 0.9245615987535573,
 0.9248437889202754,
 0.9249542985825937,
 0.9252582647365938,
 0.9251518266897012,
 0.9251948685068118,
 0.9253897373146662,
 0.9253781939912709,
 0.9254822205980489,
 0.9256681602949122,
 0.9257787666878576,
 0.9259259259259259,
 0.9261276920468035,
 0.9261940599096903,
 0.9263074233031314,
 0.9265982218458934,
 0.9265966680073668,
 0.9269000275289584,
 0.9269413402645802]