In [1]:
%matplotlib notebook
import torch
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from config import *
import Data.get_data as _Data
import matplotlib.pyplot as plt
from ipywidgets import widgets
from IPython.display import display
import warnings
from utils import get_spx_pools, launch_cuda, load_checkpoint, ungroup_batches, iou_metrics
from utils import AverageMeter, merge_spx_label
import Models.get_model as _Models
from Data.augmentation import remove_small_objcts
warnings.filterwarnings("ignore")

config = set_config(jup_notebook=True, dataset='davis')
# path to the model weights
#config.resume_model_path = '/media/marcelo/SSD/Py_all/lixo-rsync/logs_pinha/exp_pinha_msra10k_03/my_net_best.pth'
config.resume_model_path = '/media/marcelo/SSD/Py_all/lixo-rsync/logs_sdumont/exp_sdumont_msra10k_ddp_02/my_net_best.pth'
# path to the folder containing the dataset 
config.data_root = '../databases'
# batch_size=1 to show a single sample at a time
config.train_batch_size = 1
config.test_batch_size = 1
# set True only if pre-computed superpixels are available
config.pre_computed_spx = False

model, loss_fun = _Models.get_model_loss(config)
model = launch_cuda(model)
model, _ = load_checkpoint(config, model)

train_loader, test_loader = _Data.get_data(config)
# comment the dataset slipt not used 
#_iter = iter(train_loader)
_iter = iter(test_loader)
knn = KNeighborsClassifier(n_neighbors = config.knn_neighbors)
iou = AverageMeter()
iiou = AverageMeter()
ag_iou = AverageMeter()
train_spx = torch.tensor([])
test_spx = torch.tensor([])

def show_sample(b):
    with output:
        
        model.eval()    
        with torch.no_grad():

            img, spx, obj_label, num_obj, info = next(_iter)
            disp_img = img.clone()

            img = img.float().cuda()
            spx = spx.cuda()
            obj_label = obj_label.cuda()
            
            if info['frame_idx'].item() == 0:
                iou.reset()
                iiou.reset()
                ag_iou.reset()
                _, spx2label = get_spx_pools(spx, obj_label)
                spx[0][0] = merge_spx_label(spx[0][0], obj_label[0][0], spx2label[0,0])
                 
            spx_pools, _ = get_spx_pools(spx, obj_label)            
            super_feat = model(img, spx.float())
            spx_pools, super_feat = ungroup_batches(spx_pools, super_feat)
            
            for b in range(len(super_feat)):
            
                x = super_feat[b].clone().detach().cpu().numpy()
                y = spx_pools[b].clone().detach().cpu().numpy()
                idx = np.arange(1, y.shape[0]+1)
                
                if info['frame_idx'].item() == 0:
                    x_train, y_train, idx_train = x, y, idx
                    knn.fit(x_train,y_train)
                else:
                    x_test, y_test, idx_test = x, y, idx
                    pred = knn.predict(x_test)                    
                    out_str = 'Img: {}, knn_score: {:^7.3f}'.format(info['name'][b], knn.score(x_test, y_test))
                    
                aux_spx = spx[b][0] 
                aux_spx = aux_spx.cpu()
                global train_spx
                global test_spx
                
                if info['frame_idx'].item() == 0:
                    
                    train_spx = torch.zeros_like(aux_spx) +1
                    for n, i in enumerate(idx_train):    
                        train_spx[aux_spx==i] = y_train[n]
                    out_str = '\n Video: {}, Frame: {}/{} (knn fitting)'.format(info['name'][0],
                                                                   info['frame_idx'].item()+1,
                                                                   info['num_frames'].item())
                    plt.rcParams['figure.dpi'] = 200
                    plt.figure(num=0, figsize=(5, 3))
                    plt.subplot(2,2,1)
                    plt.axis('off')
                    plt.title('Frame', fontsize=6)
                    plt.imshow(disp_img[b].permute(1,2,0).cpu(), cmap='brg')                
                    plt.subplot(2,2,2)
                    plt.axis('off')
                    plt.title('Ground truth #{}'.format(obj_label[b,0].max()), fontsize=6)
                    plt.imshow(obj_label[b,0].cpu(), cmap='jet')
                    plt.subplot(2,2,3)
                    plt.axis('off')
                    plt.title('Spx #{}'.format(spx[b,0].max()), fontsize=6)
                    plt.imshow(spx[b,0].cpu(), cmap='jet')
                    plt.subplot(2,2,4)
                    plt.axis('off')
                    plt.title('spx train #{}'.format(train_spx.max().int()), fontsize=6)
                    plt.imshow(train_spx, cmap='jet')
                    plt.suptitle(out_str, fontsize=6)
                    print()
                else:    
                    test_spx = torch.zeros_like(aux_spx) +1
                    for n, i in enumerate(idx_test):    
                        test_spx[aux_spx==i] = pred[n]
                
                    _iou, _iiou, _ag_iou = iou_metrics(obj_label[b,0].clone(), spx[b,0].clone(),
                                                            pred, [], [], idx_test)
                    iou.update(_iou.item()), iiou.update(_iiou.item()), ag_iou.update(_ag_iou.item())
                    
                    out_str += '\n IoU: {:^7.3f}, iIoU: {:^7.3f} (Mean: {:^7.3f})'.format(iou.avg,
                                                                                          iiou.avg, ag_iou.avg)
                    
                    out_str += '\n Video: {}, Frame: {}/{}'.format(info['name'][0],
                                                                   info['frame_idx'].item()+1,
                                                                   info['num_frames'].item())
                    plt.rcParams['figure.dpi'] = 200
                    plt.figure(num=0, figsize=(5, 3))
                    plt.subplot(2,2,1)
                    plt.axis('off')
                    plt.title('Frame', fontsize=6)
                    plt.imshow(disp_img[b].permute(1,2,0).cpu(), cmap='brg')                
                    plt.subplot(2,2,2)
                    plt.axis('off')
                    plt.title('Ground truth #{}'.format(obj_label[b,0].max()), fontsize=6)
                    plt.imshow(obj_label[b,0].cpu(), cmap='jet')
                    plt.subplot(2,2,3)
                    plt.axis('off')
                    plt.title('Spx #{}'.format(spx[b,0].max()), fontsize=6)
                    plt.imshow(spx[b,0].cpu(), cmap='jet')
                    plt.subplot(2,2,4)
                    plt.axis('off')
                    plt.title('spx test #{}'.format(test_spx.max().int()), fontsize=6)
                    plt.imshow(test_spx, cmap='jet')
                    plt.suptitle(out_str, fontsize=6)
                    print()

button = widgets.Button(description="Next sample")
output = widgets.Output()
display(button, output)
button.on_click(show_sample)

CUDA version 10.1 [1 device(s) available].
Model sent to cuda.
Loaded checkpoing at: /media/marcelo/SSD/Py_all/lixo-rsync/logs_sdumont/exp_sdumont_msra10k_ddp_02/my_net_best.pth


Button(description='Next sample', style=ButtonStyle())

Output()