In [1]:
import os
import sys
import pickle
import csv
import copy

import numpy as np
import cv2
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

In [3]:
sys.path.append("../../../tfobjdetect/lib")
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util

  import matplotlib; matplotlib.use('Agg')  # pylint: disable=multiple-statements


## Loading and preparing label maps

In [4]:
with open('../../../wsod/metadata/ont_m18/class_names_all.pkl', 'rb') as fin:
    mid2name_all = pickle.load(fin)

In [5]:
with open('../../results/det_results_concat_26.pkl', 'rb') as fin:
    det_results_concat = pickle.load(fin)

In [6]:
with open('../../results/det_results_merged_26.pkl', 'rb') as fin:
    det_results_merged = pickle.load(fin)

In [7]:
with open('../../results/det_results_postproc_26.pkl', 'rb') as fin:
    det_results_postproc = pickle.load(fin)

## Choosing examples for each class to visualize

In [8]:
label_to_img = {key: [] for key in mid2name_all}
for key, val in det_results_postproc.items():
    for det in val:
        if det['score'] >= 0.1:
            label_to_img[det['label']].append(key)

In [9]:
all_labels = list(label_to_img.keys())

In [10]:
sort_idx = np.argsort([-len(label_to_img[label]) for label in all_labels])

In [11]:
for i in sort_idx:
    print(all_labels[i], mid2name_all[all_labels[i]], len(label_to_img[all_labels[i]]))

/m/01g317 PER (Person) 4
/m/07yv9 VEH (Vehicle) 2
/m/07cmd VEH.MilitaryVehicle.Tank (Tank) 1
/m/03jm5 FAC.Building.House (House) 1
/m/09ct_ VEH.Aircraft.Helicopter (Helicopter) 1
/m/0cgh4 FAC.Building (Building) 1
/m/01c648 COM (Electronic_device.Laptop) 0
/m/0bh6t6c PER.ProfessionalPosition (Rescuer) 0
/m/015kr FAC.Structure.Bridge (Bridge) 0
/m/02p0zyj Conflict.Attack (Riot) 0
/m/01jpn4 FAC.Building.StoreShop (Grocery_store) 0
/m/01nd_n Conflict.Demonstrate.MarchProtestPoliticalGathering (Protest) 0
/m/06wwc COM.Equipment.Satellite (Satellite) 0
/m/02_g0 Contact.FuneralVigil.Meet (Funeral) 0
/m/04zjc WEA.Gun.Firearm (Machine_gun) 0
/m/03w1r4 ORG.MilitaryOrganization.GovernmentArmedForces (Military_uniform) 0
/m/02bb1s WEA.Gun.Firearm (Sniper_rifle) 0
/m/014zdl Disaster.FireExplosion.FireExplosion (Explosion) 0
/m/026yq0z Conflict.Demonstrate (Demonstration) 0
/m/0g_k0 Manufacture.Artifact.CreateManufacture (Mass_production) 0
/m/012n4x PER.ProfessionalPosition.Firefighter (Firefighte

In [None]:
#select_labels = ['/m/01fnck', '/m/019jd', '/m/03120', '/m/02_41', '/m/07cmd', '/m/06q40', '/m/03qtwd', '/m/09x0r', '/m/01xgg_', '/m/08qrwn', '/m/012n4x', '/m/02p16m6', '/m/02lbcq', '/m/09ct_', '/m/0g54v5d', '/m/0bg2p', '/m/0ct4f', '/m/01nd_n', '/m/09rvcxw', '/m/01nl4x', '/m/04ctx', '/m/04ylt', '/m/0gvss07', '/m/01rzcn', '/m/06nrc', '/m/0cyfs', '/m/0f5lx', '/m/04zjc', '/m/0dhz0', '/m/01lcw4', '/m/03htg', '/m/0gxl3', '/m/01bq8v', '/m/0lt4_', '/m/02gzp', '/m/02yjc', '/m/012n7d', '/m/0jb3', '/m/0dwx7']
select_labels = [all_labels[i] for i in sort_idx]

In [21]:
keyframe_vid_id_to_seedling_vid_id = {}
seedling_vid_id_to_keyframe_vid_id = {}
keyframe_img_filename_to_seedling_img_id = {}
seedling_img_id_to_keyframe_img_filename = {}

with open('../../../../data/eval_m9/msb.txt', 'r') as fin:
    for line in fin:
        row = line.split()

        keyframe_vid_id = row[0]
        seedling_vid_id = row[1].split('_')[0]
        
        keyframe_vid_id_to_seedling_vid_id[keyframe_vid_id] = seedling_vid_id
        seedling_vid_id_to_keyframe_vid_id[seedling_vid_id] = keyframe_vid_id
        
        seedling_img_id = row[1]
        keyframe_img_filename = row[0] + '_' + seedling_img_id.split('_')[-1] + '.png'
        
        keyframe_img_filename_to_seedling_img_id[keyframe_img_filename] = seedling_img_id
        seedling_img_id_to_keyframe_img_filename[seedling_img_id] = keyframe_img_filename


In [22]:
with open('../../../../data/eval_m9/keyframes.txt', 'r') as fin:
    test_img_subpaths = [line.strip() for line in fin]

In [23]:
seedling_img_id_to_keyframe_img_path = {}
for i, item in enumerate(test_img_subpaths):
    keyframe_img_filename = item.split('/')[-1]
    if keyframe_img_filename in keyframe_img_filename_to_seedling_img_id:
        seedling_img_id_to_keyframe_img_path[keyframe_img_filename_to_seedling_img_id[keyframe_img_filename]] = item

## Visualizing Results

In [12]:
mid2name_extended = {}
extended_classes = set()
for key in all_labels:
    for model, abv in [('coco', 'CO'), ('oi', 'OI'), ('ws', 'WS'), ('pascal','PA'), ('coco/J', 'CO/J'), ('oi/J', 'OI/J'), ('ws/J', 'WS/J'), ('pascal/J','PA/J')]:
        extended_classes.add(key + '/' + model)
        mid2name_extended[key + '/' + model] = mid2name_all[key] + f' ({abv})'
extended_classes = list(extended_classes)
mid2idx_extended = {}
for i, key in enumerate(extended_classes):
    mid2idx_extended[key] = i
category_index_extended = {mid2idx_extended[key]:{'id': mid2idx_extended[key], 'name': mid2name_extended[key]} for key in mid2name_extended}

In [13]:
def preproc(im):
    target_size = 400
    max_size = 1024
    im_size_min = np.min(im.shape[0:2])
    im_size_max = np.max(im.shape[0:2])
    im_scale = float(target_size) / float(im_size_min)
    # Prevent the biggest axis from being more than MAX_SIZE
    if np.round(im_scale * im_size_max) > max_size:
        im_scale = float(max_size) / float(im_size_max)
    im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale,
                    interpolation=cv2.INTER_LINEAR)    
    return im

In [14]:
result_dict = {
    'Not Merged': det_results_concat,
    'Merged': det_results_merged,
    'Post-Processed': det_results_postproc,
}

In [15]:
path_prefix = '../../../../data/eval_m9/jpg/' 

In [16]:
model_abv_dict = {
    'coco': 'CO',
    'pascal': 'PA',
    'oi': 'OI',
    'ws': 'WS',
    'coco/J': 'CO/J',
    'pascal/J': 'PA/J',
    'oi/J': 'OI/J',
    'ws/J': 'WS/J',
}

In [34]:
def show(imgid, source, thresh):
    filename = '../../../../data/eval_m9/keyframes/' + seedling_img_id_to_keyframe_img_path[imgid]
    print(imgid)
    with open(filename, 'rb') as fin:
        #_ = fin.read(1024)
        imgbin = fin.read()
    imgbgr = cv2.imdecode(np.fromstring(imgbin, dtype='uint8'), cv2.IMREAD_COLOR)
    image_np = imgbgr[:,:,[2,1,0]]
    image_np = preproc(image_np)
    try:
        detections = result_dict[source][imgid]
        boxes = np.asarray([det['bbox_normalized'] for det in detections])[:,[1,0,3,2]]
        scores = [det['score'] for det in detections]
        label_idx = [mid2idx_extended[det['label'] + '/' + det['model']] for det in detections]

        # print([(mid2name_extended[extended_classes[item]], boxes[i]) for i, item in enumerate(label_idx) if scores[i] > thresh])

        vis_util.visualize_boxes_and_labels_on_image_array(
          image_np,
          boxes,
          label_idx,
          scores,
          category_index_extended,
          use_normalized_coordinates=True,
          min_score_thresh=thresh,
          max_boxes_to_draw=1000,
          line_thickness=2)
    except:
        print('No Detections')
    plt.figure(figsize=(12, 8))
    plt.imshow(image_np)
    plt.axis('off')
    plt.show()    

In [None]:
for label in select_labels:
    if len(label_to_img[label]) == 0:
        continue
    print(mid2name_all[label])
    imgid = np.random.choice(label_to_img[label])
    print(imgid)
    w = interactive(show, 
        imgid=fixed(imgid),
        source=widgets.RadioButtons(options=['Post-Processed', 'Merged', 'Not Merged'], value='Post-Processed'), 
        thresh=widgets.FloatSlider(min=0.0, max=1.0, step=0.01, value=0.1, continuous_update=False), 
    )
    w.children[-1].layout.height = '600px'
    display(w)
    

In [None]:
query_docs = [  'HC00002ZO',
                'HC00002ZR',
                'HC00002ZQ',
                'HC000T67L',
                'HC00002ZN',
                'HC000T6B0',
                'HC000T69B',
                'HC000T6J5',
                'IC001JWOR',
                'HC000T6J3',
                'HC00002ZQ',
                'HC000T6IF',
                'HC00002ZS',
                'HC00002ZQ',
                'HC00002ZD',
                'HC000031G',
                'HC000T65T',
                'HC00002ZG',
                'HC000T6IT',
                'HC000T69A',
                'HC000Q7NJ',
                'HC000ZXRT',
                'HC000T65W',
                'HC0007DBW',
                'HC000Q6RT',
                'IC001JNOC',
                'IC001JNOL',
                'IC001L4P9',
                'HC000T6J5',
                'HC00038S5']

In [None]:
p2c = {}
c2p = {}
with open('../../../../data/eval_m9/parent_children.tab', 'r') as fin:
    flag = False
    for row in csv.reader(fin, delimiter='\t'):
        if not flag:
            flag = True
            continue
        if row[2] not in p2c:
            p2c[row[2]] = []
        p2c[row[2]].append(row[3])
        c2p[row[3]] = row[2]

In [None]:
parent_query_docs = set([c2p[item] for item in query_docs])

In [None]:
query_docs_expanded = set()
for item in parent_query_docs:
    query_docs_expanded |= set(p2c[item])

In [None]:
query_docs_expanded

In [None]:
query_docs_expanded_jpg = set()
for item in query_docs_expanded:
    if item in det_results_postproc:
        query_docs_expanded_jpg.add(item)

In [None]:
query_docs_expanded_jpg

In [None]:
for imgid in query_docs_expanded_jpg:
    print(imgid)
    w = interactive(show, 
        imgid=fixed(imgid),
        source=widgets.RadioButtons(options=['Post-Processed', 'Merged', 'Not Merged'], value='Post-Processed'), 
        thresh=widgets.FloatSlider(min=0.0, max=1.0, step=0.01, value=0.1, continuous_update=False), 
    )
    w.children[-1].layout.height = '600px'
    display(w)


In [20]:
det_results_postproc.keys()

dict_keys(['HC000SYGU_71', 'HC000Q8MF_40', 'IC0015YFV_54', 'IC0019MX2_1'])

In [35]:
for imgid in det_results_postproc:
    print(imgid)
    w = interactive(show, 
        imgid=fixed(imgid),
        source=widgets.RadioButtons(options=['Post-Processed', 'Merged', 'Not Merged'], value='Post-Processed'), 
        thresh=widgets.FloatSlider(min=0.0, max=1.0, step=0.01, value=0.1, continuous_update=False), 
    )
    w.children[-1].layout.height = '600px'
    display(w)


HC000SYGU_71


interactive(children=(RadioButtons(description='source', options=('Post-Processed', 'Merged', 'Not Merged'), v…

HC000Q8MF_40


interactive(children=(RadioButtons(description='source', options=('Post-Processed', 'Merged', 'Not Merged'), v…

IC0015YFV_54


interactive(children=(RadioButtons(description='source', options=('Post-Processed', 'Merged', 'Not Merged'), v…

IC0019MX2_1


interactive(children=(RadioButtons(description='source', options=('Post-Processed', 'Merged', 'Not Merged'), v…

In [33]:
det_results_postproc['HC000Q8MF_40']

[]