In [18]:
%matplotlib notebook
import torch
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from config import *
import Data.get_data as _Data
import matplotlib.pyplot as plt
from ipywidgets import widgets
from IPython.display import display
import warnings
from utils import get_spx_pools, launch_cuda, load_checkpoint, ungroup_batches
import Models.get_model as _Models
warnings.filterwarnings("ignore")

config = set_config(jup_notebook=True)
# path to the model weights
config.resume_model_path = '/media/marcelo/SSD/Py_all/lixo-rsync/logs_pinha/exp_00/my_net_pinha_2_last.pth'
# path to the folder containing the dataset 
config.data_root = '../databases'
# batch_size=1 to show a single sample at a time
config.train_batch_size = 1
config.test_batch_size = 1

model, loss_fun = _Models.get_model_loss(config)
model = launch_cuda(model)
model = load_checkpoint(config, model)

train_loader, test_loader = _Data.get_data(config)
# comment the dataset slipt not used 
#_iter = iter(train_loader)
_iter = iter(test_loader)
knn = KNeighborsClassifier(n_neighbors = config.knn_neighbors)

def show_sample(b):
    with output:
        
        model.eval()    
        with torch.no_grad():

            img, spx, obj_label, num_obj, info = next(_iter)
            disp_img = img.clone()

            img = img.float().cuda()
            spx = spx.cuda()
            obj_label = obj_label.cuda()

            spx_pools, _ = get_spx_pools(spx, obj_label)
            super_feat = model(img, spx.float())
            spx_pools, super_feat = ungroup_batches(spx_pools, super_feat)
            
            for b in range(len(super_feat)):
            
                x = super_feat[b].clone().detach().cpu().numpy()
                y = spx_pools[b].clone().detach().cpu().numpy()
                idx = np.arange(1, y.shape[0]+1)
                
                try:                
                    x_train, x_test, y_train, y_test, idx_train, idx_test = \
                        train_test_split(x, y, idx, test_size=config.knn_test_size, stratify=y)                    
                    knn.fit(x_train,y_train)
                    pred = knn.predict(x_test)                    
                    out_str = 'Img: {}, knn_score: {:^7.3f}'.format(info['name'][b], knn.score(x_test, y_test))
                except:
                    out_str = 'Sample skipped'
                    continue
                    
                aux_spx = spx[b][0] 
                aux_spx = aux_spx.cpu()
                test_spx = torch.zeros_like(aux_spx) +1

                for n, i in enumerate(idx_test):    
                    test_spx[aux_spx==i] = pred[n]
                    
                total_spx = test_spx.clone()
                
                train_spx = torch.zeros_like(aux_spx) +1
                for n, i in enumerate(idx_train):    
                    train_spx[aux_spx==i] = y_train[n]
                    total_spx[aux_spx==i] = y_train[n]
                
                plt.rcParams['figure.dpi'] = 200
                plt.figure(num=1, figsize=(4, 4))
                
                plt.subplot(3,2,1)
                plt.axis('off')
                plt.title('Image', fontsize=6)
                plt.imshow(disp_img[b].permute(1,2,0).cpu(), cmap='brg')                
                plt.subplot(3,2,3)
                plt.axis('off')
                plt.title('obj_labels #{}'.format(obj_label[b,0].max()), fontsize=6)
                plt.imshow(obj_label[b,0].cpu(), cmap='jet')
                plt.subplot(3,2,4)
                plt.axis('off')
                plt.title('spx train #{}'.format(train_spx.max().int()), fontsize=6)
                plt.imshow(train_spx, cmap='jet')
                plt.subplot(3,2,5)
                plt.axis('off')
                plt.title('spx test #{}'.format(test_spx.max().int()), fontsize=6)
                plt.imshow(test_spx, cmap='jet')
                plt.subplot(3,2,6)
                plt.axis('off')
                plt.title('spx total #{}'.format(total_spx.max().int()), fontsize=6)
                plt.imshow(total_spx, cmap='jet')
                plt.suptitle(out_str, fontsize=6)
                print()

button = widgets.Button(description="Next sample")
output = widgets.Output()
display(button, output)
button.on_click(show_sample)


CUDA version 10.1 [1 device(s) available].
Model sent to cuda.
Loaded checkpoing at: /media/marcelo/SSD/Py_all/lixo-rsync/logs_pinha/exp_00/my_net_pinha_2_last.pth


Button(description='Next sample', style=ButtonStyle())

Output()