In [30]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])

    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

In [162]:
# Get initial segmentation
from classes.SAMMed2DClass_unstable import SAMMed2DInferer
from utils.base_classes import Points
import torch

import utils.analysisUtils as anUt
import utils.promptUtils as prUt
from utils.imageUtils import read_im_gt

import numpy as np
import pickle
import matplotlib.pyplot as plt
from utils.interactivity import gen_contour_fp_scribble

# Obtain model, image, gt
device = 'cuda'
sammed2d_checkpoint_path = "/home/t722s/Desktop/UniversalModels/TrainedModels/sam-med2d_b.pth"
sammed2d_inferer = SAMMed2DInferer(sammed2d_checkpoint_path, device)

img_path = '/home/t722s/Desktop/Datasets/Dataset350_AbdomenAtlasJHU_2img/imagesTr/BDMAP_00000001_0000.nii.gz'
gt_path = '/home/t722s/Desktop/Datasets/Dataset350_AbdomenAtlasJHU_2img/labelsTr/BDMAP_00000001.nii.gz'
img, gt = read_im_gt(img_path, gt_path, 3)

# Obtain initial segmentation
# Experiment: Point propagation

seed = 11121
n_clicks = 5

# Get seed prompt and bounds
seed_point = prUt.get_seed_point(gt, n_clicks, seed)
slices_to_infer = np.where(np.any(gt, axis=(1,2)))[0]

segmentation, all_prompts = prUt.point_propagation(sammed2d_inferer, img, seed_point, slices_to_infer, seed, n_clicks)
print(anUt.compute_dice(gt,segmentation))

anUt.compute_dice(segmentation, gt)


with open('/home/t722s/Desktop/test/test_segjhu10.pkl', 'wb') as f:
    pickle.dump(segmentation, f)

with open('/home/t722s/Desktop/test/test_segjhu10.pkl', 'rb') as f:
    segmentation = pickle.load(f)
segmentation, prompt = prUt.point_propagation(sammed2d_inferer, img, seed_point, slices_to_infer, seed, n_clicks)
condition = 'dof'
dof_bound = 60
perf_bound = 1 # Place holder, only needed when condition = 'perf'
init_dof = 5
contour_distance = 2
disk_size_range = (0,0)
scribble_length = 0.3


# Initialise low res masks to provide for interactivity
verbosity = sammed2d_inferer.verbose 
sammed2d_inferer.verbose = False
dof = init_dof
slices_inferred = np.unique(prompt.value['coords'][:,0])
low_res_masks = sammed2d_inferer.slice_lowres_dict.copy()
low_res_masks = {k:torch.sigmoid(v).squeeze().cpu().numpy() for k,v in low_res_masks.items()}
max_fp_slices = []
has_generated_positive_prompt = False

prompts = [prompt]
segmentations = [segmentation]
dice_scores = [prUt.compute_dice(segmentation, gt)]


while True:
    # Determine whether to give positive prompts or attempt negative prompt
    fn_mask = (segmentation == 0) & (gt == 1)
    fn_count = np.sum(fn_mask)

    fg_count = np.sum(segmentation)

    generate_positive_prompts_prob = fn_count/fg_count # Generate positive prompts when much of the foreground isn't segmented
    generate_positive_prompts = np.random.binomial(1,generate_positive_prompts_prob)

    if not generate_positive_prompts:
        # Obtain contour scribble on worst sagittal slice
        fp_mask = (segmentation == 1) & (gt == 0)
        axis = 1
        fp_sums = np.sum(fp_mask, axis=tuple({0,1,2} - {axis}))
        max_fp_idx = np.argmax(fp_sums)
        max_fp_slice = gt[:, max_fp_idx]
        max_fp_slices.append(max_fp_slice)
        slice_seg = segmentation[:, max_fp_idx]

        scribble = gen_contour_fp_scribble(gt[:, max_fp_idx], slice_seg, contour_distance, disk_size_range, scribble_length)
        if scribble is None:
            generate_positive_prompts = 1 # Generate positive prompts instead
        else: 
            scribble_coords = np.where(scribble)
            scribble_coords = np.stack(scribble_coords, axis = 1)

            # Obtain false positive points and make new prompt
            is_fp_mask = slice_seg[*scribble_coords.T].astype(bool)
            fp_coords = scribble_coords[is_fp_mask]

            ## Position fp_coords back into original 3d coordinate system
            missing_axis = np.repeat(max_fp_idx, len(fp_coords))
            fp_coords_3d = np.vstack([fp_coords[:,0], missing_axis, fp_coords[:,1]]).T
            improve_slices = fp_coords_3d[:,0]
            dof += 3*4 # To dicuss: assume drawing a scribble is as difficult as drawing four points

            ## Add to old prompt
            coords = np.concatenate([prompt.value['coords'], fp_coords_3d], axis = 0)
            labels = np.concatenate([prompt.value['labels'], [0]*len(fp_coords_3d)])
            prompt = Points(value = {'coords': coords, 'labels': labels})

            ## Subset to prompts only on the slices with new prompts
            coords, labels = prompt.value.values()
            fix_slice_mask = np.isin(prompt.value['coords'][:,0], improve_slices)
            new_prompt = Points({'coords': coords[fix_slice_mask], 'labels': labels[fix_slice_mask]})

    if generate_positive_prompts:
        if not has_generated_positive_prompt: 
            dof+=6 # If first time generating positive prompts, generate a bottom and top point, taking 4 degrees of freedom: (4 dof even though there are 6 coordinates since the coordinate of the lowest and highest slice is fixed) 
            bottom_seed_prompt, _, top_seed_prompt = prUt.get_fg_points_from_cc_centers(gt, 3)
            has_generated_positive_prompt = True

        # Find fp coord from the middle axial range of the image
        lower, upper = np.percentile(slices_inferred, [30, 70 ])
        fp_coords = np.vstack(np.where(fn_mask)).T
        middle_mask = (lower < fp_coords[:, 0]) & (fp_coords[:,0] < upper) # Mask to determine which false negatives lie between the 30th to 70th percentile
        if np.sum(middle_mask) == 0:
            middle_mask = np.ones(len(fp_coords), bool) # If there are no false negatives in the middle, draw from all coordinates (unlikely given that there must be many)
        fp_coords = fp_coords[middle_mask, :]
        new_middle_seed_prompt = fp_coords[np.random.choice(len(fp_coords), 1)]
        dof += 3

        # Obtain top and bottom prompts and then interpolate a line of coordinates in between
        new_seed_prompt = np.vstack([bottom_seed_prompt, new_middle_seed_prompt, top_seed_prompt])
        new_coords =  prUt.interpolate_points(new_seed_prompt, kind = 'linear').astype(int)

        # Add to old prompt
        coords = np.concatenate([prompt.value['coords'], new_coords], axis = 0)
        labels = np.concatenate([prompt.value['labels'], [1]*len(new_coords)])
        new_prompt = Points(value = {'coords': coords, 'labels': labels})
        improve_slices = slices_inferred # improve all slices

    # Generate new segmentation and integrate into old one
    new_seg = sammed2d_inferer.predict(img, new_prompt)
    prompts.append(new_prompt)
    segmentation[improve_slices] = new_seg[improve_slices]
    segmentations.append(segmentation.copy())
    # Update the dictionary
    low_res_masks.update({fix_slice_idx: torch.sigmoid(sammed2d_inferer.slice_lowres_dict[fix_slice_idx]).squeeze().cpu().numpy() for fix_slice_idx in improve_slices})
    dice_scores = [prUt.compute_dice(segmentation, gt)]
    print(dice_scores[-1])

    # Check break conditions
    if condition == 'dof' and dof >= dof_bound:
        print(f'degrees of freedom bound met; terminating with performance {dice_scores[-1]}')
        break
    elif condition == 'perf' and dice_scores[-1] >= perf_bound:
        print(f'performance bound met; terminating with performance {dice_scores[-1]}')
        break
    elif condition == 'perf' and len(dice_scores) == 10:
        print(f'Could not achieve desired performance within 10 steps; terminating with performance {dice_scores[-1]}')

sammed2d_inferer.verbose = verbosity # return verbosity to initial state

True
*******load /home/t722s/Desktop/UniversalModels/TrainedModels/sam-med2d_b.pth


Propagating down: 100%|██████████| 22/22 [00:01<00:00, 19.16it/s]
Propagating up: 100%|██████████| 22/22 [00:01<00:00, 20.31it/s]


0.8177772392806715


Propagating down: 100%|██████████| 22/22 [00:00<00:00, 73.20it/s]
Propagating up: 100%|██████████| 22/22 [00:00<00:00, 73.27it/s]


0.8263593581674888
0.8269173423455405
0.8302757994359734
0.8294789514803222
0.8293258507806317
degrees of freedom bound met; terminating with performance 0.8293258507806317


In [168]:
coords = [p.value['coords'] for p in prompts]

def find_new_rows(coords):
    """
    Generates a list of arrays, where each array contains rows from t[i] 
    that are not present in t[i-1].

    Args:
    t (list of np.array): A list of m nx3 numpy arrays.

    Returns:
    list of np.array: A list of length m-1, where each entry contains rows
                      from t[i] that are not in t[i-1].
    """
    diffs = []
    for i in range(1, len(coords)):
        # Convert arrays to sets of tuples for comparison
        set_prev = set(map(tuple, coords[i-1]))
        set_curr = set(map(tuple, coords[i]))
        
        # Find difference and convert it back to an array
        new_rows = np.array(list(set_curr - set_prev))
        diffs.append(new_rows)
    
    return diffs

t = find_new_rows(coords)
t[0]


array([[144, 145, 178],
       [148, 145, 173],
       [147, 145, 174],
       [146, 145, 167],
       [148, 145, 170],
       [143, 145, 162],
       [146, 145, 176],
       [146, 145, 182],
       [145, 145, 164],
       [145, 145, 176],
       [148, 145, 172],
       [144, 145, 180],
       [144, 145, 177],
       [147, 145, 170],
       [146, 145, 166],
       [146, 145, 169],
       [146, 145, 175],
       [146, 145, 181],
       [148, 145, 171],
       [144, 145, 179],
       [146, 145, 165],
       [146, 145, 168],
       [146, 145, 174],
       [146, 145, 183],
       [144, 145, 163],
       [145, 145, 180]])

In [None]:
# Visualise
plt.figure(figsize = (13,13))
plt.imshow(max_fp_slice, cmap = 'gray')
show_mask(slice_seg, plt.gca())
# contour_rounded = np.round(contour).astype(int)
# plt.scatter(contour_rounded[:,1], contour_rounded[:,0], c= 'green', s = 10)
plt.scatter(scribble_coords[:,1], scribble_coords[:,0], c = 'red', s = 2)

In [None]:
for seg in segmentations:
    # Visualise
    plt.figure(figsize = (13,13))
    plt.imshow(max_fp_slice, cmap = 'gray')
    show_mask(slice_seg, plt.gca())
    # contour_rounded = np.round(contour).astype(int)
    # plt.scatter(contour_rounded[:,1], contour_rounded[:,0], c= 'green', s = 10)
    plt.scatter(scribble_coords[:,1], scribble_coords[:,0], c = 'red', s = 2)