In [1]:
import numpy as np
%matplotlib notebook
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
from torchvision import transforms, datasets
import torch.utils.data as data

In [2]:
import deeptriplet.models
import deeptriplet.datasets

In [3]:
print(torch.cuda.is_available())
print(torch.cuda.device_count())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

True
1


## Load VOC2012 dataset

In [4]:
valset = deeptriplet.datasets.PascalDataset(pascal_root="/scratch/yardima/datasets/voc12/VOCdevkit/VOC2012",
                                            split_file="/home/yardima/Python/experiments/pascal_split/val_obj.txt",
                                            normalize_imagenet=True,
                                            augment=False,
                                            pad_zeros=True,
                                            downsample_label=1)

valloader = data.DataLoader(valset,
                                batch_size=1,
                                num_workers=2,
                                shuffle=True)

trainset = deeptriplet.datasets.PascalDataset(pascal_root="/scratch/yardima/datasets/voc12/VOCdevkit/VOC2012",
                        split_file="/home/yardima/Python/experiments/pascal_split/train_obj.txt",
                        normalize_imagenet=True,
                        augment=True,
                        pad_zeros=True,
                        downsample_label=8,
                        scale_low=0.8,
                        scale_high=1.2)

trainloader = data.DataLoader(trainset,
                                batch_size=10,
                                num_workers=4,
                                shuffle=True)

## Load trained embedding model

In [5]:
model_path_random = "/srv/glusterfs/yardima/runs/deeplabv2/lfov-triplet/run_8/models/class-vgg-pascal_epoch-49.pth"

In [6]:
d1 = torch.load(model_path_random, map_location=lambda storage, loc: storage)

In [7]:
net = deeptriplet.models.DeepLab_VGG(n_classes=45)
net = net.cuda()
net.load_state_dict(d1)

In [8]:
net = net.eval().cuda();

## Click-precision graph

In [12]:
import scipy.ndimage.morphology as morph

In [321]:
%%time
image, label = valset[1]
results = simulate_clicks_v2(net, image, label, 20)
plt.figure()
plt.plot([r["iou"] for r in results])
plt.ylim(0, 1)



<IPython.core.display.Javascript object>

CPU times: user 6.66 s, sys: 916 ms, total: 7.57 s
Wall time: 8.51 s


In [325]:
%%time
valiter = iter(valloader)
trials = 0

n_clicks = 20
mean_acc = np.zeros(n_clicks - 1)
mean_iou = np.zeros(n_clicks - 1)

for i in range(200):
    image, label = next(valiter)
    
    result = simulate_clicks_v2(net, image, label, n_clicks)

    if len(result) == n_clicks - 1:
        trials += 1

        mean_acc += np.array([r["acc"] for r in result])
        mean_iou += np.array([r["iou"] for r in result])
        
        print(trials)

mean_acc /= trials
mean_iou /= trials



1




2




3




4




5




6




7




8




9




10




11




12




13




14




15




16




17




18




19




20




21




22




23




24




25




26




27




28




29




30




31




32




33




34




35




36




37




38




39




40




41




42




43




44




45




46




47




48




49




50




51




52




53




54




55




56




57




58




59




60




61




62




63




64




65




66




67




68




69




70




71




72




73




74




75




76




77




78




79




80




81




82




83




84




85




86




87




88




89




90




91




92




93




94




95




96




97




98




99




100




101




102




103




104




105




106




107




108




109




110




111




112




113




114




115




116




117




118




119




120




121




122




123




124




125




126




127




128




129




130




131




132




133




134




135




136




137




138




139




140




141




142




143




144




145




146




147




148




149




150




151




152




153




154




155




156




157




158




159




160




161




162




163




164




165




166




167




168




169




170




171




172




173




174




175




176




177




178
CPU times: user 19min 28s, sys: 2min 58s, total: 22min 27s
Wall time: 59min 24s


In [326]:
plt.figure()
plt.plot(mean_iou)
plt.ylim(0,1)

<IPython.core.display.Javascript object>

(0, 1)

In [266]:
image, label = valset[1]
image = image.unsqueeze(0)
label = label.unsqueeze(0)
with torch.no_grad():
    image = image.cuda()
    out = net.forward(image)
    print(image.shape)
    fc8_interp_test = nn.UpsamplingBilinear2d(size=(label.shape[1], label.shape[2]))
    out = fc8_interp_test(out)
    
    embeddings = out.cpu().data.numpy()
    embeddings = np.transpose(embeddings.squeeze(), axes=[1, 2, 0]).reshape(-1, net.n_classes)
    label = label.data.numpy()


plt.figure(figsize=(18, 18))
plt.subplot(2,3,1)
plt.imshow(np.transpose(image.squeeze().data.cpu().numpy(), axes=[1, 2, 0]) / 4 + 0.5)
plt.subplot(2,3,2)
label2 = label.copy()
label2[label2 == 255] = 0
plt.imshow(label2.reshape(image.shape[2], image.shape[3]))

plt.subplot(2,3,3)

label = label.squeeze()
border = label == 255
border = morph.binary_dilation(border, iterations=d_margin)
plt.imshow(border)

o = np.random.choice(np.sort(np.unique(label * (1 - border)))[1:])

label = (label == o).astype(np.int32)

click_map = np.zeros(shape=border.shape, dtype=np.int32)

def pick_click(mask):
    candidates = np.transpose(np.nonzero(mask))
    return candidates[np.random.randint(low=0, high=candidates.shape[0]), :]
    
p1 = pick_click((label) * (1 - border))
n1 = pick_click((1 - label) * (1 - border))

d_margin = 20
d_clicks = 35

click_map[max(p1[0]-d_clicks, 0):p1[0]+d_clicks, max(p1[1]-d_clicks,0):p1[1]+d_clicks] = 1
click_map[max(n1[0]-d_clicks, 0):n1[0]+d_clicks, max(n1[1]-d_clicks,0):n1[1]+d_clicks] = 1

plt.subplot(2,3,4)
plt.imshow(click_map)

result = run_clicks(embeddings, label, [p1, n1], flat=False)

plt.subplot(2,3,5)
plt.imshow(pred)



array([0.73767135, 0.7631721 , 0.777656  , 0.79200661, 0.79661856,
       0.8015463 , 0.80793648, 0.8128886 , 0.82395793, 0.82946955,
       0.832834  , 0.83678062, 0.84023911, 0.84060427, 0.84257754,
       0.84235443, 0.84625579, 0.847451  , 0.84902684])

In [219]:
import skimage.transform