In [1]:
import numpy as np
%matplotlib notebook
import matplotlib.pyplot as plt

In [2]:
import scipy.ndimage.morphology as morph
import skimage
import skimage.segmentation
from scipy import ndimage
from sklearn.neighbors import KNeighborsClassifier

In [3]:
from tqdm import tqdm
import os
import tarfile

import shutil

if not os.path.isdir("/scratch/yardima/"):
    os.mkdir("/scratch/yardima/")
    
if not os.path.isdir("/scratch/yardima/embed_temp"):
    os.mkdir("/scratch/yardima/embed_temp")

src = "/srv/glusterfs/yardima/embed_temp"
src_files = os.listdir(src)
dest_folder = "/scratch/yardima/embed_temp"

for file_name in src_files:
    full_file_name = os.path.join(src, file_name)
    if (os.path.isfile(full_file_name)):
        shutil.copy(full_file_name, dest_folder)

KeyboardInterrupt: 

# Old simulations

In [4]:
def _fast_hist(label_true, label_pred, n_class):
        mask = (label_true >= 0) & (label_true < n_class)
        hist = np.bincount(
            n_class * label_true[mask].astype(int) + label_pred[mask],
            minlength=n_class ** 2,
        ).reshape(n_class, n_class)
        return hist

In [5]:
import skimage.segmentation

def simulate_clicks_v2_e(embeddings, label, n_clicks, *, n_channels=50, d_clicks=40, d_margin=20):
    
    def run_clicks(embeddings, labels, clicks, flat=True):
        dim1, dim2 = labels.squeeze().shape

        if not flat:
            clicks = np.array([c[0] * dim2 + c[1] for c in clicks], dtype=np.int32)
            
            embeddings = embeddings.reshape(dim1 * dim2, -1)
            labels = labels.reshape(-1)

        knn = KNeighborsClassifier(n_neighbors=1)
        knn.fit(embeddings[clicks,:], labels[clicks])

        pred = knn.predict(embeddings)

        cf_matrix = _fast_hist(labels[valid], pred[valid], 2)

        iou = cf_matrix[1,1] / (cf_matrix[0,1] + cf_matrix[1,1] + cf_matrix[1,0])
        acc = np.sum(labels[valid] == pred[valid]) / len(valid)

        return {"pred":pred.reshape(dim1,dim2), 
                "cf":cf_matrix, 
                "iou":iou, 
                "acc":acc}
    
    
    def pick_click(mask):
        candidates = np.transpose(np.nonzero(mask))
        return candidates[np.random.randint(low=0, high=candidates.shape[0]), :]
    
    
#     if image.dim() == 3:
#         image = image.unsqueeze(0)
    
#     if label.dim() == 2:
#         label = label.unsqueeze(0)
        
#     with torch.no_grad():
#         image = image.cuda()
#         out = net.forward(image)
        
#         embeddings = out.cpu().data.numpy()
#         embeddings = np.transpose(embeddings.squeeze(), axes=[1, 2, 0]).reshape(-1, n_channels)
#         label = label.data.numpy()

    
    label = label.squeeze()
    border = label == 255
    valid = np.nonzero((label.reshape(-1) != 255))[0]
    border2 = morph.binary_dilation(border, iterations=d_margin)
    
    if len(np.unique(label * (1 - border2))) > 1:
        o = np.random.choice(np.sort(np.unique(label * (1 - border2)))[1:])
    else:
        return []
    
    label_obj = (label == o).astype(np.int32)
    
    border = border + morph.binary_dilation(skimage.segmentation.find_boundaries(label_obj), iterations=d_margin)
    
    if not np.any((1 - label_obj) * (1 - border)) or not np.any(label_obj * (1 - border)):
        return []

    click_map = np.zeros(shape=border.shape, dtype=np.int32)

    p1 = pick_click((label_obj) * (1 - border))
    n1 = pick_click((1 - label_obj) * (1 - border))

    click_map[max(p1[0]-d_clicks, 0):p1[0]+d_clicks, max(p1[1]-d_clicks,0):p1[1]+d_clicks] = 1
    click_map[max(n1[0]-d_clicks, 0):n1[0]+d_clicks, max(n1[1]-d_clicks,0):n1[1]+d_clicks] = 1

    clicks = [p1, n1]
    results = []
    results.append(run_clicks(embeddings, label_obj, clicks, flat=False))
    lbl_out = results[-1]['pred']
    err_pixels = np.logical_xor(lbl_out, label_obj)

    for i in range(2, n_clicks):
        
        if not np.any((1 - click_map) * (1 - border)):
            return results
        
        if np.any(err_pixels * (1 - click_map) * (1 - border)):
            c = pick_click(err_pixels * (1 - click_map) * (1 - border))
            click_map[max(c[0]-d_clicks, 0):c[0]+d_clicks, max(c[1]-d_clicks,0):c[1]+d_clicks] = 1
        else:
            c = pick_click((1 - click_map) * (1 - border))
            click_map[max(c[0]-d_clicks, 0):c[0]+d_clicks, max(c[1]-d_clicks,0):c[1]+d_clicks] = 1

        clicks.append(c)
        
        results.append(run_clicks(embeddings, label_obj, clicks, flat=False))
        
        lbl_out = results[-1]['pred']
        err_pixels = np.logical_xor(lbl_out, label_obj)
        
    return results

In [79]:
%%time
trials = 0

n_clicks = 20
mean_acc = np.zeros(n_clicks - 1)
mean_iou = np.zeros(n_clicks - 1)
clicks = 0.0

for i in range(50):
    embedding = np.load("/scratch/yardima/embed_temp/{}.embed.aug.npy".format(i))
    label = np.load("/scratch/yardima/embed_temp/{}.label.aug.npy".format(i))
    
    result = simulate_clicks_v2_e(embedding, label, n_clicks)

    if len(result) == n_clicks - 1:
        trials += 1

        mean_acc += np.array([r["acc"] for r in result])
        mean_iou += np.array([r["iou"] for r in result])
        
        iou = np.array([r["iou"] for r in result])
        t = iou > 0.9
        
        if np.any(t):
            clicks += np.argmax(t) + 2
            print(np.argmax(t) + 2)
        else:
            clicks += 20.0
            print(20)

        print(trials)

mean_acc /= trials
mean_iou /= trials
clicks /= trials

2
1
4
2
20
3
20
4
2
5
20
6
3
7
2
8
20
9
2
10
5
11
3
12
20
13
2
14
3
15
2
16
2
17
2
18
2
19
8
20
3
21
20
22
20
23
3
24
3
25
20
26
10
27
20
28
4
29
2
30
2
31
5
32
20
33
2
34
11
35
6
36
2
37
20
38
2
39
3
40
2
41
2
42
2
43
15
44
20
45
2
46
6
47
10
48
20
49
20
50
CPU times: user 4min 32s, sys: 2min 16s, total: 6min 48s
Wall time: 4min 32s


In [5]:
graphcut = [0, 0.41647, 0.52051, 0.61135, 0.59292 , 0.68515 , 0.73828 , 0.73773 , 0.75957 , 0.76526 , 0.80299 , 0.80789 , 0.83517 , 0.82191 , 0.85346  , 0.8519 , 0.86424 , 0.87155 , 0.87143 , 0.87907 , 0.89611]
geodesicmatting = [0 , 0.23718 , 0.38133  , 0.4326 , 0.55617  , 0.6602 , 0.63506 , 0.68842 , 0.73264  , 0.7581 , 0.79506 , 0.81119 , 0.81285 , 0.82494 , 0.85251 , 0.85958 , 0.86015 , 0.88313 , 0.88004  , 0.8849 , 0.87809]
randomwalker= [0 , 0.24353 , 0.37081 , 0.47507 , 0.56905 , 0.62045 , 0.65946 , 0.73991 , 0.78886 , 0.80488 , 0.81368 , 0.82177 , 0.83732 , 0.87131  , 0.8701  , 0.9013 , 0.88533 , 0.88293 , 0.91703 , 0.90852 , 0.91907]
euclideanstarconvexity = [0 , 0.52075  , 0.5707 , 0.65062 , 0.71197 , 0.77218  , 0.7821  , 0.8171 , 0.83639 , 0.88536 , 0.87501 , 0.87949 , 0.91244 , 0.91596 , 0.90658 , 0.93117  , 0.9282 , 0.90621 , 0.91999 , 0.94405 , 0.94914]
geodesicstarconvexity = [0 , 0.48836 , 0.56403 , 0.66126 , 0.68438 , 0.72558  , 0.7596 , 0.82642 , 0.85716 , 0.85727 , 0.87142 , 0.86295 , 0.91054 , 0.92736 , 0.91775 , 0.93614 , 0.93673 , 0.93886 , 0.94633 , 0.93902  , 0.9558]
growcut = [0 , 0.26565 , 0.36187 , 0.47439 , 0.53135 , 0.58615 , 0.64972 , 0.67848 , 0.68794 , 0.71877 , 0.74829 , 0.76001 , 0.77618 , 0.78652 , 0.80577 , 0.82243 , 0.83948 , 0.85783 , 0.83632 , 0.86645 , 0.85449]
deepios = [0 , 0.62889 , 0.74285 , 0.79842 , 0.83959 , 0.83684  , 0.8625 , 0.88998 , 0.88087 , 0.88309 , 0.91481 , 0.91363 , 0.92075 , 0.92631 , 0.91706 , 0.93541 , 0.93936 , 0.95019 , 0.93832 , 0.95318 , 0.94784]

In [81]:
print(clicks)
print(mean_iou[-1])

plt.figure()
plt.plot(range(2, 21),mean_iou, label="Ours", linewidth=3.0)

plt.plot(range(0, 21),graphcut, label="GraphCut")
plt.plot(range(0, 21),geodesicmatting, label="GraphCut")
plt.plot(range(0, 21),randomwalker, label="GeodesicMatting")
plt.plot(range(0, 21),euclideanstarconvexity, label="EuclideanSC")
plt.plot(range(0, 21),geodesicstarconvexity, label="GeodesicSC")
plt.plot(range(0, 21),growcut, label="GrowCut")
plt.plot(range(0, 21),deepios, label="DeepIOS")

plt.ylim(0,1)
plt.xlim(0,20)
plt.title("Click vs. IoU on 50 GrabCut images")
plt.xlabel("Clicks")
plt.ylabel("Mean IoU")
plt.yticks(np.arange(0.0, 1.0, 0.05))
plt.xticks(np.arange(0.0, 20.0, 1))
plt.legend()

8.42
0.9051275554719386


<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x2b7ebf9b63c8>

# New method

In [6]:
from scipy import ndimage

In [7]:
def run_clicks(embeddings, labels, clicks, valid, flat_indices=True):
    dim1, dim2 = labels.squeeze().shape

    if not flat_indices:
        clicks = np.array([c[0] * dim2 + c[1] for c in clicks], dtype=np.int32)
        
    embeddings = embeddings.reshape(dim1 * dim2, -1)
    labels = labels.reshape(-1)

    knn = KNeighborsClassifier(n_neighbors=1, n_jobs=1)
    knn.fit(embeddings[clicks,:], labels[clicks])

    pred = knn.predict(embeddings)

    cf_matrix = _fast_hist(labels[valid], pred[valid], 2)

    iou = cf_matrix[1,1] / (cf_matrix[0,1] + cf_matrix[1,1] + cf_matrix[1,0])
    acc = np.sum(labels[valid] == pred[valid]) / len(valid)

    return {"pred":pred.reshape(dim1,dim2), 
            "cf":cf_matrix, 
            "iou":iou, 
            "acc":acc}

In [9]:
i = 0
computer_borders_manually = False
d_margin = 15
d_click = 3

embedding = np.load("/scratch/yardima/embed_temp/pascal{}.embed.aug.npy".format(i)).squeeze()
label = np.load("/scratch/yardima/embed_temp/pascal{}.label.aug.npy".format(i)).squeeze()

In [98]:
# if len(np.unique(label * (1 - border2))) > 1:
## pick an object at random
o = np.random.choice(np.sort(np.unique(label[label != 255]))[1:])
label_obj = (label == o).astype(np.int32)

In [99]:
%%time
if computer_borders_manually:
    border = label == 255
    border = border + morph.binary_dilation(skimage.segmentation.find_boundaries(label_obj), iterations=d_margin)
else:
    if np.any(label == 255):
        temp = (label == 255) + skimage.segmentation.find_boundaries(label_obj)
        border = morph.binary_dilation(label_obj, iterations=6) & temp
        border = morph.binary_dilation(border, iterations=max(d_margin-3, 1))
    else:
        border = label == 255
        border = border + morph.binary_dilation(skimage.segmentation.find_boundaries(label_obj), iterations=d_margin)
        
valid = np.nonzero((label.reshape(-1) != 255))[0]

CPU times: user 32 ms, sys: 0 ns, total: 32 ms
Wall time: 31 ms


In [100]:
# inner_part = label_obj & np.logical_not(border)
# outer_part = np.logical_not(label_obj) & np.logical_not(border)
label_obj_temp = label_obj.copy()
label_obj_temp[:, 0] = 1
label_obj_temp[0, :] = 1
label_obj_temp[-1, :] = 1
label_obj_temp[:, -1] = 1

inner_part = label_obj
outer_part = np.logical_not(label_obj_temp)

In [101]:
plt.figure()
plt.imshow(border)

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x2b79b53baf60>

In [102]:
%%time
## pick initial points
dist_inner = ndimage.distance_transform_edt(inner_part)
dist_outer = ndimage.distance_transform_edt(outer_part)
c1 = np.random.choice(np.nonzero(dist_inner.flatten() == dist_inner.max())[0])
c2 = np.random.choice(np.nonzero(dist_outer.flatten() == dist_outer.max())[0])

CPU times: user 72 ms, sys: 0 ns, total: 72 ms
Wall time: 67.2 ms


In [103]:
click_map = np.zeros_like(label_obj)
click_map.ravel()[c1] = 1
click_map.ravel()[c2] = 1

In [104]:
valid_clicks = ndimage.distance_transform_edt(np.logical_not(click_map)) > d_click

In [105]:
plt.figure()
plt.imshow(valid_clicks)

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x2b7965bb4e10>

In [106]:
# plot initial clicks
def plot_point(i, img):
    h = img.shape[1]

    plt.scatter([i % h],[i // h])
    
plt.figure()
plt.imshow(label_obj)
plot_point(c1, label)
plot_point(c2, label)

<IPython.core.display.Javascript object>

In [107]:
clicks = [c1, c2]

In [108]:
embedding.shape

(513, 513, 50)

In [109]:
results = [run_clicks(embedding, label_obj, clicks, valid, flat_indices=True)]
last_pred = results[-1]['pred']

(263169,)


In [110]:
print(results[-1]['iou'])
plt.figure()
plt.imshow(results[-1]['pred'])

0.9594306287085805


<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x2b79af0da4e0>

In [113]:
mislabeled = last_pred != label_obj
mislabeled_border = mislabeled & ~border

In [114]:
plt.figure()
plt.imshow(mislabeled_border)

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x2b79c42acfd0>

In [73]:
d = ndimage.distance_transform_edt(~label_obj)
mislabeled_positive = mislabeled_border & label_obj
mislabeled_negative = mislabeled_border & ~label_obj
dp = d * mislabeled_positive
dn = d * mislabeled_negative

if dn.max() >= dp.max():
    new_click = np.random.choice(np.nonzero(dn.flatten() == dn.max())[0])
else:
    new_click = np.random.choice(np.nonzero(dp.flatten() == dp.max())[0])

In [74]:
if mislabeled_border.any():
    mislabeled_positive = mislabeled_border & label_obj
    mislabeled_negative = mislabeled_border & ~label_obj
    
    dist_mlp = ndimage.distance_transform_edt(mislabeled_positive)
    dist_mln = ndimage.distance_transform_edt(mislabeled_negative)
    
    if dist_mlp.max() > dist_mln.max():
        new_click = np.random.choice(np.nonzero(dist_mlp.flatten() == dist_mlp.max())[0])
    else:
        new_click = np.random.choice(np.nonzero(dist_mln.flatten() == dist_mln.max())[0]) 
else:
    dist = ndimage.distance_transform_edt(~mislabeled)
    dist = dist * (~border)
    new_click = np.random.choice(np.nonzero((dist == dist[~border].min()).flatten())[0])

In [75]:
clicks.append(new_click)

In [82]:
plt.figure()
plt.imshow(mislabeled_border)
#plot_point(new_click, label)
mislabeled_border.flatten()[new_click]

<IPython.core.display.Javascript object>

True

In [78]:
results.append(run_clicks(embedding, label_obj, clicks, valid, flat_indices=True))
last_pred = results[-1]['pred']

(263169,)


In [79]:
print(results[-1]['iou'])
plt.figure()
plt.imshow(last_pred)

0.1621484671902711


<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x2b79b535a748>

## Run the algorithm on all images

In [8]:
def run_click_sim(embeddings, labels, n_clicks, computer_borders_manually, d_margin=15, d_clicks=5):
    label = labels.squeeze()
    embedding = embeddings.squeeze()
    
    # if len(np.unique(label * (1 - border2))) > 1:
    ## pick an object at random
    o = np.random.choice(np.sort(np.unique(label[label != 255]))[1:])
    label_obj = (label == o).astype(np.int32)
    
    valid = np.nonzero((label.reshape(-1) != 255))[0]

    if computer_borders_manually:
        border = label == 255
        border = border + morph.binary_dilation(skimage.segmentation.find_boundaries(label_obj), iterations=d_margin)
    else:
        if np.any(label == 255):
            temp = (label == 255) + skimage.segmentation.find_boundaries(label_obj)
            border = morph.binary_dilation(label_obj, iterations=6) & temp
            border = morph.binary_dilation(border, iterations=max(d_margin-3, 1))
        else:
            border = label == 255
            border = border + morph.binary_dilation(skimage.segmentation.find_boundaries(label_obj), iterations=d_margin)
        
        

    # inner_part = label_obj & np.logical_not(border)
    # outer_part = np.logical_not(label_obj) & np.logical_not(border)
    label_obj_temp = label_obj.copy()
    label_obj_temp[:, 0] = 1
    label_obj_temp[0, :] = 1
    label_obj_temp[-1, :] = 1
    label_obj_temp[:, -1] = 1
    
    inner_part = label_obj
    outer_part = np.logical_not(label_obj_temp)

    ## pick initial points
    dist_inner = ndimage.distance_transform_edt(inner_part)
    dist_outer = ndimage.distance_transform_edt(outer_part)
    
#     dist_outer[:, 0:20] = 0.
#     dist_outer[0:20, :] = 0.
#     dist_outer[-20:, :] = 0.
#     dist_outer[:, -20:] = 0.
    
    c1 = np.random.choice(np.nonzero(dist_inner.ravel() == dist_inner.max())[0])
    c2 = np.random.choice(np.nonzero(dist_outer.ravel() == dist_outer.max())[0])

    clicks = [c1, c2]

    results = [run_clicks(embedding, label_obj, clicks, valid, flat_indices=True)]
    last_pred = results[-1]['pred']
    
    click_map = np.zeros_like(label_obj)
    click_map.ravel()[c1] = 1
    click_map.ravel()[c2] = 1

    for k in range(2, n_clicks):
        mislabeled = last_pred != label_obj
        mislabeled_border = mislabeled & ~border

        if mislabeled_border.any():
            mislabeled_positive = mislabeled_border & label_obj
            mislabeled_negative = mislabeled_border & ~label_obj

#             d = ndimage.distance_transform_edt(~label_obj)
#             dp = d * mislabeled_positive
#             dn = d * mislabeled_negative
            
#             if dn.max() >= dp.max():
#                 new_click = np.random.choice(np.nonzero(dn.flatten() == dn.max())[0])
#             else:
#                 new_click = np.random.choice(np.nonzero(dp.flatten() == dp.max())[0])
                      
            dist_mlp = ndimage.distance_transform_edt(mislabeled_positive)
            dist_mln = ndimage.distance_transform_edt(mislabeled_negative)

            if dist_mlp.max() > dist_mln.max():
                new_click = np.random.choice(np.nonzero(dist_mlp.ravel() == dist_mlp.max())[0])
            else:
                new_click = np.random.choice(np.nonzero(dist_mln.ravel() == dist_mln.max())[0]) 
        else:
            valid_clicks = ndimage.distance_transform_edt(np.logical_not(click_map)) > d_clicks
            
            dist = ndimage.distance_transform_edt(~mislabeled)
            dist = dist * (~border) * valid_clicks
            new_click = np.random.choice(np.nonzero((dist == dist[~border * valid_clicks].min()).ravel())[0])

        clicks.append(new_click)
        
        click_map.ravel()[new_click] = 1

        results.append(run_clicks(embedding, label_obj, clicks, valid, flat_indices=True))
        last_pred = results[-1]['pred']
        
    return results

In [9]:
def run_image(i, d_margin=5, d_clicks=10):
    i = i % 50
    
    embedding = np.load("/srv/glusterfs/yardima/embed_temp/{}.embed.aug.npy".format(i))
    label = np.load("/srv/glusterfs/yardima/embed_temp/{}.label.aug.npy".format(i))
    
    return run_click_sim(embedding, label, 20, True, d_margin=d_clicks, d_clicks=d_margin)

In [10]:
import gc
gc.collect()

0

In [11]:
import multiprocessing as mp

In [34]:
%%time
pool = mp.Pool(processes=16)
results = pool.map(run_image, range(0,200))
pool.close()
pool.join()
del pool

CPU times: user 1.16 s, sys: 8.94 s, total: 10.1 s
Wall time: 2min 22s


In [35]:
n_clicks = 20
mean_acc = np.zeros(n_clicks - 1)
mean_iou = np.zeros(n_clicks - 1)
clicks = 0.0
c = []

for result in results:
    mean_acc += np.array([r["acc"] for r in result])
    mean_iou += np.array([r["iou"] for r in result])

    iou = np.array([r["iou"] for r in result])
    t = iou > 0.9

    if np.any(t):
        clicks += np.argmax(t) + 2.
        c.append( np.argmax(t) + 2.)
    else:
        clicks += 20.0
        c.append(20)
            
mean_acc /= len(results)
mean_iou /= len(results)
clicks /= len(results)

In [36]:
graphcut = [0, 0.41647, 0.52051, 0.61135, 0.59292 , 0.68515 , 0.73828 , 0.73773 , 0.75957 , 0.76526 , 0.80299 , 0.80789 , 0.83517 , 0.82191 , 0.85346  , 0.8519 , 0.86424 , 0.87155 , 0.87143 , 0.87907 , 0.89611]
geodesicmatting = [0 , 0.23718 , 0.38133  , 0.4326 , 0.55617  , 0.6602 , 0.63506 , 0.68842 , 0.73264  , 0.7581 , 0.79506 , 0.81119 , 0.81285 , 0.82494 , 0.85251 , 0.85958 , 0.86015 , 0.88313 , 0.88004  , 0.8849 , 0.87809]
randomwalker= [0 , 0.24353 , 0.37081 , 0.47507 , 0.56905 , 0.62045 , 0.65946 , 0.73991 , 0.78886 , 0.80488 , 0.81368 , 0.82177 , 0.83732 , 0.87131  , 0.8701  , 0.9013 , 0.88533 , 0.88293 , 0.91703 , 0.90852 , 0.91907]
euclideanstarconvexity = [0 , 0.52075  , 0.5707 , 0.65062 , 0.71197 , 0.77218  , 0.7821  , 0.8171 , 0.83639 , 0.88536 , 0.87501 , 0.87949 , 0.91244 , 0.91596 , 0.90658 , 0.93117  , 0.9282 , 0.90621 , 0.91999 , 0.94405 , 0.94914]
geodesicstarconvexity = [0 , 0.48836 , 0.56403 , 0.66126 , 0.68438 , 0.72558  , 0.7596 , 0.82642 , 0.85716 , 0.85727 , 0.87142 , 0.86295 , 0.91054 , 0.92736 , 0.91775 , 0.93614 , 0.93673 , 0.93886 , 0.94633 , 0.93902  , 0.9558]
growcut = [0 , 0.26565 , 0.36187 , 0.47439 , 0.53135 , 0.58615 , 0.64972 , 0.67848 , 0.68794 , 0.71877 , 0.74829 , 0.76001 , 0.77618 , 0.78652 , 0.80577 , 0.82243 , 0.83948 , 0.85783 , 0.83632 , 0.86645 , 0.85449]
deepios = [0 , 0.62889 , 0.74285 , 0.79842 , 0.83959 , 0.83684  , 0.8625 , 0.88998 , 0.88087 , 0.88309 , 0.91481 , 0.91363 , 0.92075 , 0.92631 , 0.91706 , 0.93541 , 0.93936 , 0.95019 , 0.93832 , 0.95318 , 0.94784]

In [37]:
print(clicks)
print(mean_iou[-1])

plt.figure()
plt.plot(range(2, 21),mean_iou, label="Ours", linewidth=3.0)

plt.plot(range(0, 21),graphcut, label="GraphCut")
plt.plot(range(0, 21),geodesicmatting, label="GraphCut")
plt.plot(range(0, 21),randomwalker, label="GeodesicMatting")
plt.plot(range(0, 21),euclideanstarconvexity, label="EuclideanSC")
plt.plot(range(0, 21),geodesicstarconvexity, label="GeodesicSC")
plt.plot(range(0, 21),growcut, label="GrowCut")
plt.plot(range(0, 21),deepios, label="DeepIOS")

plt.ylim(0,1)
plt.xlim(0,20)
plt.title("Click vs. IoU on 50 GrabCut images")
plt.xlabel("Clicks")
plt.ylabel("Mean IoU")
plt.yticks(np.arange(0.0, 1.0, 0.05))
plt.xticks(np.arange(0.0, 20.0, 1))
plt.legend()

6.755
0.9382674536214222


<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x2b48628d72b0>

In [None]:
import itertools

r = []
for d_margin, d_clicks in itertools.product(range(1, 23, 3), range(1, 32, 3)):
    def func(img):
        return run_image(img, d_margin=d_margin, d_clicks=d_clicks)
    
    pool = mp.Pool(processes=8)
    results = pool.map(func, range(0,200))
    pool.close()
    pool.join()
    del pool
    
    n_clicks = 20
    mean_acc = np.zeros(n_clicks - 1)
    mean_iou = np.zeros(n_clicks - 1)
    clicks = 0.0
    c = []

    for result in results:
        mean_acc += np.array([r["acc"] for r in result])
        mean_iou += np.array([r["iou"] for r in result])

        iou = np.array([r["iou"] for r in result])
        t = iou > 0.9

        if np.any(t):
            clicks += np.argmax(t) + 2.
            c.append( np.argmax(t) + 2.)
        else:
            clicks += 20.0
            c.append(20)

    mean_acc /= len(results)
    mean_iou /= len(results)
    clicks /= len(results)
    
    del results
    r.append((mean_iou, clicks))

In [None]:
import pickle

with open('/home/yardima/results12.pkl', 'wb') as f:
    pickle.dump(r, f)

## On PASCAL

In [11]:
def run_image_pascal(i):
    i = i % 200
    
    embedding = np.load("/srv/glusterfs/yardima/embed_temp/sl-obj/run_1pascal{}.embed.aug.npy".format(i))
    label = np.load("/srv/glusterfs/yardima/embed_temp/sl-obj/run_1pascal{}.label.aug.npy".format(i))
    
    return run_click_sim(embedding, label, 20, False, d_margin=5, d_clicks=10)

In [12]:
%%time
pool = mp.Pool(processes=16)
results = pool.map(run_image_pascal, range(0,400))
pool.close()
pool.join()
del pool

CPU times: user 2.64 s, sys: 13.5 s, total: 16.1 s
Wall time: 5min 35s


In [13]:
n_clicks = 20
mean_acc = np.zeros(n_clicks - 1)
mean_iou = np.zeros(n_clicks - 1)
clicks = 0.0
c = []

for result in results:
    mean_acc += np.array([r["acc"] for r in result])
    mean_iou += np.array([r["iou"] for r in result])

    iou = np.array([r["iou"] for r in result])
    t = iou > 0.85

    if np.any(t):
        clicks += np.argmax(t) + 2.
        c.append( np.argmax(t) + 2.)
    else:
        clicks += 20.0
        c.append(20)
            
mean_acc /= len(results)
mean_iou /= len(results)
clicks /= len(results)

In [14]:
graphcut = [0   , 0.27681   , 0.32977   , 0.37455   , 0.41053   , 0.43497   , 0.45993   , 0.5112   , 0.54717   , 0.57843   , 0.59827   , 0.61407   , 0.62586   , 0.65327   , 0.66878   , 0.67121   , 0.68808   , 0.70618   , 0.72117   , 0.72616   , 0.73358]
geodesicmatting = [0   , 0.23837   , 0.31603   , 0.39958   , 0.45853   , 0.50365   , 0.54942   , 0.6036   , 0.62801   , 0.65857   , 0.69147   , 0.70598   , 0.7251   , 0.76067   , 0.76306   , 0.77118   , 0.79357   , 0.80196   , 0.81664   , 0.81001   , 0.81565]
randomwalker = [0   , 0.25499   , 0.38225   , 0.47944   , 0.55064   , 0.61289   , 0.66075   , 0.70061   , 0.71498   , 0.75836   , 0.78338   , 0.79171   , 0.81012   , 0.82623   , 0.82689   , 0.84386   , 0.84943    , 0.861   , 0.8613   , 0.87658   , 0.87359]
euclideanstarconvexity = [0   , 0.31757   , 0.37173   , 0.43203   , 0.49685   , 0.56124   , 0.60921   , 0.64424   , 0.6719   , 0.72457   , 0.75115   , 0.76781   , 0.78669   , 0.80184   , 0.8173   , 0.83749   , 0.85186   , 0.85061   , 0.86692   , 0.86684   , 0.87823]
geodesicstarconvexity = [0   , 0.31026   , 0.39227   , 0.44247   , 0.48026   , 0.5363   , 0.59881   , 0.64341   , 0.66498   , 0.69703   , 0.71841   , 0.76724   , 0.77042   , 0.79595   , 0.82272   , 0.83416   , 0.84581   , 0.85118   , 0.85446   , 0.85961   , 0.86994]
growcut = [0   , 0.25874   , 0.37407   , 0.46186   , 0.5239   , 0.58875   , 0.62349   , 0.66604   , 0.6842   , 0.70787   , 0.72344   , 0.75468   , 0.78347   , 0.81573   , 0.82886   , 0.83437   , 0.82817   , 0.84007   , 0.85472   , 0.86841   , 0.86967]
deepios = [0   , 0.53573   , 0.63382   , 0.69688   , 0.75223   , 0.77911   , 0.7974   , 0.82285   , 0.8351   , 0.85559   , 0.85738   , 0.86204   , 0.87519   , 0.88335   , 0.8921   , 0.8905   , 0.89625   , 0.90147   , 0.89738   , 0.90696   , 0.91054]

In [15]:
print(clicks)
print(mean_iou[-1])

plt.figure()
plt.plot(range(2, 21),mean_iou, label="Ours", linewidth=3.0)

plt.plot(range(0, 21),graphcut, label="GraphCut")
plt.plot(range(0, 21),geodesicmatting, label="GraphCut")
plt.plot(range(0, 21),randomwalker, label="GeodesicMatting")
plt.plot(range(0, 21),euclideanstarconvexity, label="EuclideanSC")
plt.plot(range(0, 21),geodesicstarconvexity, label="GeodesicSC")
plt.plot(range(0, 21),growcut, label="GrowCut")
plt.plot(range(0, 21),deepios, label="DeepIOS")

plt.ylim(0,1)
plt.xlim(0,20)
plt.title("Click vs. IoU on 200 Pascal images")
plt.xlabel("Clicks")
plt.ylabel("Mean IoU")
plt.yticks(np.arange(0.0, 1.0, 0.05))
plt.xticks(np.arange(0.0, 20.0, 1))
plt.legend()

11.585
0.8950843470446301


<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x2b4866e88908>

In [21]:
print(clicks)
print(mean_iou[-1])

plt.figure()
plt.plot(range(2, 21),mean_iou, label="Ours", linewidth=3.0)

plt.plot(range(0, 21),graphcut, label="GraphCut")
plt.plot(range(0, 21),geodesicmatting, label="GraphCut")
plt.plot(range(0, 21),randomwalker, label="GeodesicMatting")
plt.plot(range(0, 21),euclideanstarconvexity, label="EuclideanSC")
plt.plot(range(0, 21),geodesicstarconvexity, label="GeodesicSC")
plt.plot(range(0, 21),growcut, label="GrowCut")
plt.plot(range(0, 21),deepios, label="DeepIOS")

plt.ylim(0,1)
plt.xlim(0,20)
plt.title("Click vs. IoU on 200 Pascal images")
plt.xlabel("Clicks")
plt.ylabel("Mean IoU")
plt.yticks(np.arange(0.0, 1.0, 0.05))
plt.xticks(np.arange(0.0, 20.0, 1))
plt.legend()

6.4825
0.9241827835731171


<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x2b9294cf7668>