In [20]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [56]:
import tqdm
import os
import glob
import torch
import matplotlib.pyplot as plt
import pandas as pd
import cv2

from mirror.visualisations.core import GradCam
from models import *
from utils import get_learner, get_probs_and_labels_from_preds, get_patches_form_df, load_model_from_name, device
from callbacks import StoreBestWorstAndSample, ROC_AUC
from os import path
from datasets.TraversabilityDataset import TraversabilityDataset, get_transform
from Config import Config
from patches import *

class StorePredictions():
    def __init__(self, model_name, model_dir, store_dir):
        self.model_name = model_name
        self.model_dir = model_dir
        self.dfs = []
        self.df_path2df = {}
        self.store_dir = store_dir

    def handle_dataset(self, dataset):
        df = dataset.df
        learner, _ = get_learner(self.model_name, self.model_dir, callbacks=[], dataset=dataset)
        preds = learner.get_preds(learner.data.test_dl)
        probs, labels = get_probs_and_labels_from_preds(preds)

        df['out_0'] = probs[:, 0]
        df['out_1'] = probs[:, 1]
        df['prediction'] = labels.cpu().tolist()

        return df

    def store(self):
        for df_path, df in tqdm.tqdm(self.df_path2df.items()):
            file_name = path.basename(df_path)
            map_name = path.basename(path.split(df_path)[0])
            out_path = path.normpath('{}/{}/'.format(self.store_dir, map_name))
            os.makedirs(out_path, exist_ok=True)
            out_path = path.normpath('{}/{}'.format(out_path, file_name))
            df.to_csv(out_path)

    def restore(self):
        dfs_path = glob.glob(self.store_dir + '/**/*.csv')

        for df_path in tqdm.tqdm(dfs_path):
            df = pd.read_csv(df_path)
            self.df_path2df[df_path] = df
            self.dfs.append(df)

    def __call__(self, datasets):
        bar = tqdm.tqdm(datasets)
        for dataset in bar:
            if type(dataset) is not TraversabilityDataset: raise ValueError('inputs must be of type TraversabilityDataset')
            bar.set_description('[INFO] Reading {}'.format(dataset.df_path))
            if len(dataset) > 0:
                df = self.handle_dataset(dataset)
                self.dfs.append(df)
                self.df_path2df[dataset.df_path] = self.dfs[-1]

        return self.dfs






In [10]:
concat = TraversabilityDataset.from_paths(Config.DATA_ROOT, [Config.DATA_DIR], tr=0.45, transform=get_transform(scale=10))
store = StorePredictions(Config.BEST_MODEL_NAME, Config.BEST_MODEL_DIR, '/home/francesco/Desktop/store-test/')
dfs = store(concat.datasets)
# store.restore()




[INFO] Reading /media/francesco/saetta/no-shift-88-750/train//df/bars1/1550614988.2771952-patch.csv: 100%|██████████| 1/1 [00:00<00:00,  2.06it/s][A[A[A

In [13]:
class Best():
    name = 'best'
    def __call__(self, df):
        df = df.loc[df['label'] == 1]
        return df.sort_values(['out_1'], ascending=False)

class Worst():
    name = 'worst'
    def __call__(self, df):
        df = df.loc[df['label'] == 0]
        return df.sort_values(['out_0'], ascending=False)

class FalseNegative():
    name = 'false_negative'
    def __call__(self, df):
        return false_something(df, 0)

class FalsePositive():
    name = 'false_positive'
    def __call__(self, df):
        return false_something(df, 1)


def false_something(df, something):
    neg = df.loc[df['label'] == something]
    return neg.loc[neg['prediction'] != something]



In [108]:
class FilterPatches():
    def __init__(self, transform=None):
        self.df = None
        self.transform = transform
    
    def transform_patches(self, patches):
        
        return [self.transform(patch) for patch in patches]
    
    def filter_patches(self, df, image_dir):
        return self.transform_patches(get_patches_form_df(df, image_dir))

    def __call__(self, df, image_dir, filter_fn, n=2):
        filtered_df = filter_fn(df)
        filtered_df = filtered_df.head(n)
        return filtered_df, self.filter_patches(filtered_df, image_dir),

class Convert2Patches():
    
    def __call__(self, patches):
        df, (patches, grad_cams) = data
        
        return (df, Patch.from_tensors(patches), Patch.from_tensors(grad_cams))
    
class GradCamVisualization():

    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.grad_cam = GradCam(model.to(self.device), self.device)
    
    def get_grad_cam(self, patch):
        img = patch.unsqueeze(0).to(self.device)

        _, info = self.grad_cam(img, None, target_class=None)

        cam = info['cam'].cpu().numpy()
        cam = cv2.resize(cam, (patch.shape[1], patch.shape[2]))
        cam = (cam - cam.min()) / (cam.max() - cam.min())
        
        return cam
    
    def __call__(self, data):
        df, patches = data
        
        grad_cams = [self.get_grad_cam(patch) for patch in patches]
        
        return df, patches, grad_cams

In [109]:
def get_all_interesting_patches(transform, df, image_dir):
    filters = [Best(), Worst(), FalseNegative(), FalsePositive()]
    result = {}
    
    f_patch = FilterPatches(transform=transform)
#     g_patch = GradCamVisualization
#     c_patch = Convert2Patches()
    
    for f in filters:
        result[f.name] = f_patch(df, image_dir, f)
        
    return result
    

In [110]:
data = get_all_interesting_patches(get_transform(scale=1), store.dfs[0], Config.DATA_ROOT)

In [111]:
df, patches = data['worst']


In [112]:
model = load_model_from_name(Config.BEST_MODEL_DIR + '/roc_auc.pth', Config.BEST_MODEL_NAME)

In [113]:
grad_cam_vis = GradCamVisualization(model, device)

df, patches, grad_cam = grad_cam_vis(data['worst'])

In [114]:
patches = Patch.from_tensors(patches)
grad_cam = Patch.from_hms(grad_cam)

In [119]:
for (idx, row), patch, grad in zip(df.iterrows(), patches, grad_cam):
    patch.plot3d(grad.hm)
#     grad.plot3d()
    plt.title('advancement={:.2f} prediction = {} ground truth = {}'.format(row['advancement'], row['prediction'], row['label']))


<IPython.core.display.Javascript object>

ValueError: could not broadcast input array from shape (88,88) into shape (88)

In [116]:
patches[0].plot2d()

<IPython.core.display.Javascript object>

In [None]:
p = Patch.from_tensor(patches[0])

In [None]:
hm = cv2.imread('/home/francesco/Documents/Master-Thesis/core/maps/test/querry-big-10.png')
hm = cv2.cvtColor(hm, cv2.COLOR_BGR2GRAY)

In [None]:
import matplotlib.patches as mpatches


In [None]:
 def plot_box_on_hm(row, hm, patch_size):
        fig = plt.figure()
        ax = plt.gca()
        x, y, ang, ad = row["hm_x"], \
                        row["hm_y"], \
                        row["pose__pose_e_orientation_z"], \
                        row["advancement"]

        sns.heatmap(hm / 255, vmin=0, vmax=1, ax=ax)

        rect = mpatches.Rectangle((x - patch_size // 2, y - patch_size // 2), patch_size,
                                  patch_size, linewidth=1, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        plt.show()

In [None]:
plot_box_on_hm(df.iloc[10], hm, 88)

In [None]:
p.plot2d()
plt.title(df.iloc[10]['advancement'])

In [None]:
p.plot3d()