In [4]:
%matplotlib inline
import matplotlib.pyplot as plt

import sys
import numpy as np
import torch
import cv2

sys.path.insert(0, '..')
from isegm.utils import vis, exp

from isegm.inference import utils
from isegm.inference.predictors import get_predictor
from isegm.inference.clicker import Clicker, Click

device = torch.device('cuda:0')
cfg = exp.load_config_file('../config.yml', return_edict=True)

In [5]:
# Possible choices: 'GrabCut', 'Berkeley', 'DAVIS', 'COCO_MVal', 'SBD', 'BRATS', 'LIDC', 'LIDC_val'
DATASET = 'LIDC_2D_VAL'
dataset = utils.get_dataset(DATASET, cfg)

EVAL_MAX_CLICKS = 20
MODEL_THRESH = 0.49

# checkpoint_path = utils.find_checkpoint(cfg.INTERACTIVE_MODELS_PATH, 'coco_lvis_h18s_itermask')

checkpoint_path = utils.find_checkpoint(cfg.INTERACTIVE_MODELS_bPATH, 'fullres_20')
model = utils.load_is_model(checkpoint_path, device)

# Possible choices: 'NoBRS', 'f-BRS-A', 'f-BRS-B', 'f-BRS-C', 'RGB-BRS', 'DistMap-BRS', 'FocalClick'
brs_mode = 'FocalClick'
predictor = get_predictor(model, brs_mode, device, prob_thresh=MODEL_THRESH)

AssertionError: 

In [6]:
cfg

{'INTERACTIVE_MODELS_PATH': './weights',
 'EXPS_PATH': './experiments',
 'GRABCUT_PATH': '/home/admin/workspace/project/data/datasets/GrabCut',
 'BERKELEY_PATH': '/home/admin/workspace/project/data/datasets/Berkeley',
 'DAVIS_PATH': '/home/admin/workspace/project/data/datasets/DAVIS',
 'COCO_MVAL_PATH': '/home/admin/workspace/project/data/datasets/COCO_MVal',
 'PASCALVOC_PATH': '/home/admin/workspace/project/data/datasets/VOC2012/VOCdevkit/VOC2012',
 'DAVIS585_PATH': '/home/admin/workspace/project/data/datasets/InterDavis/Selected_480P',
 'BRATS_VAL_PATH': '/gpfs/space/projects/PerkinElmer/donatasv_experiments/datasets/processed_datasets/BraTS_2d/val',
 'LIDC_VAL_PATH': '/gpfs/space/projects/PerkinElmer/donatasv_experiments/datasets/processed_datasets/LIDC-2D/val',
 'LIDC_256_VAL_PATH': '/gpfs/space/projects/PerkinElmer/donatasv_experiments/datasets/processed_datasets/LIDC-2D-256/val',
 'SBD_PATH': '/home/admin/workspace/project/data/datasets/SBD/dataset',
 'COCO_PATH': '/home/xizhi.cx

In [3]:
sample_id = 3
TARGET_IOU = 0.95

sample = dataset.get_sample(sample_id)

gt_mask = sample.gt_mask
image = sample.image

# clicks_list, ious_arr, pred, pred_list = evaluate_sample(image, gt_mask, np.zeros_like(gt_mask), predictor, 
#                                               pred_thr=MODEL_THRESH, 
#                                               max_iou_thr=TARGET_IOU, max_clicks=EVAL_MAX_CLICKS,
#                                               vis=False)
# pred_mask = pred > MODEL_THRESH

# draw = vis.draw_with_blend_and_clicks(image[:, :, :3], mask=pred_mask, clicks_list=clicks_list)
# draw = np.concatenate((draw,
#     255 * pred_mask[:, :, np.newaxis].repeat(3, axis=2),
#     255 * (gt_mask > 0)[:, :, np.newaxis].repeat(3, axis=2)
# ), axis=1)

# print(ious_arr)

# plt.figure(figsize=(20, 30))
# plt.imshow(draw)
# plt.show()

In [7]:
def Progressive_Merge(pred_mask, previous_mask, y, x):
    diff_regions = np.logical_xor(previous_mask, pred_mask)
    num, labels = cv2.connectedComponents(diff_regions.astype(np.uint8))
    label = labels[y,x]
    corr_mask = labels == label
    if previous_mask[y,x] == 1:
        progressive_mask = np.logical_and( previous_mask, np.logical_not(corr_mask))
    else:
        progressive_mask = np.logical_or( previous_mask, corr_mask)
    return progressive_mask

In [27]:
clicker = Clicker(gt_mask)

In [31]:
pred_mask = np.zeros_like(gt_mask)
prev_mask = pred_mask
init_mask = pred_mask
ious_list = []
pred_mask_list = []
clicks_list = []
progressive_mode = True
callback = None

with torch.no_grad():
    predictor.set_input_image(image)
    if init_mask is not None:
        predictor.set_prev_mask(init_mask)
        pred_mask = init_mask
        prev_mask = init_mask
        num_pm = 0
    else:
        num_pm = 999
    vis_pred = prev_mask
    
    # clicker.make_next_click(pred_mask)
    pred_probs = predictor.get_prediction(clicker)
    pred_mask = pred_probs > MODEL_THRESH
    if progressive_mode:
        clicks = clicker.get_clicks()
        if len(clicks) >= num_pm: 
            last_click = clicks[-1]
            last_y, last_x = last_click.coords[0], last_click.coords[1]
            pred_mask = Progressive_Merge(pred_mask, prev_mask,last_y, last_x)
            predictor.transforms[0]._prev_probs = np.expand_dims(np.expand_dims(pred_mask,0),0)

    pred_mask_list.append(pred_mask)
    iou = utils.get_iou(gt_mask, pred_mask)
    ious_list.append(iou)
    prev_mask = pred_mask
    
fig, ax = plt.subplots(1, 2, figsize=(10, 5))

ax[0].imshow(image[:, :, 0], cmap='gray')
y = [click.coords[0] for click in clicker.clicks_list]
x = [click.coords[1] for click in clicker.clicks_list]
c = ['lime' if click.is_positive else 'r' for click in clicker.clicks_list]
ax[0].scatter(x, y, c=c, s=5)
ax[0].set_title('Clicks')
ax[0].contour(gt_mask, colors='blue', alpha=0.5, linewidths=0.5)
ax[1].set_title(f'IOU: {iou:.3f}')
ax[1].imshow(pred_mask)

IndexError: list index out of range