In [None]:
# Run selective search over images and generate plots for precomputed data.

In [None]:
import skimage.io
from selective_search import selective_search
import random
from PIL import ImageDraw, Image

import arid.arid as arid
from IPython.display import display
from pathlib import Path

import json
from tqdm import tqdm

wps = arid.get_wps("/home/justin/Desktop/arid-dataset")

ss_root = Path('/home/justin/Desktop/thesis/ss')

In [None]:
# Compute IOUs and write results to files
results = {}
for mode in ['single', 'fast', 'quality']:
    results[mode] = {}
    for wp in tqdm(wps):
        title = wp.get_title()
        ss_path = Path(ss_root / f'ss-{mode}-{title}.json')
        with open(ss_path) as f:
            ss_data = json.load(f) # img number -> 'boxes'[]

        img_paths = wp.rgb_image_paths()

        for img_path in img_paths:
            img = Image.open(img_path)
            img_id = Path(img_path).stem
            img_key = f'{title}-{img_id}'
            results[mode][img_key] = {}

            ss_boxes = ss_data[img_id]['boxes']

            for gt_annotation_raw in wp.get_annotations(img_path.stem)['annotations']:

                if gt_annotation_raw['id'] is not None:
                    x = gt_annotation_raw['x']
                    y = gt_annotation_raw['y']
                    w = gt_annotation_raw['width']
                    h = gt_annotation_raw['height']
                    obj_instance = gt_annotation_raw['id']
                    
                    results[mode][img_key][obj_instance] = {
                        '0.5': False,
                        '0.7': False,
                        '0.9': False,
                    }

                    obj_coords = [(x,y), (x+w, y), (x+w, y+h), (x,y+h)]
                    for x1, y1, x2, y2 in ss_boxes:
                        iou = arid.compute_bbox_iou(obj_coords, [(x1, y1), (x2, y1), (x2, y2), (x1,y2)])
                        if iou >= 0.5:
                            results[mode][img_key][obj_instance]['0.5'] = True

                        if iou >= 0.7:
                            results[mode][img_key][obj_instance]['0.7'] = True

                        if iou >= 0.9:
                            results[mode][img_key][obj_instance]['0.9'] = True
                            break
print(results)
                            
# with open('ss-iou-results.json', 'w') as outfile:
#     json.dump(results, outfile)


In [None]:
# Number of total predictions
for mode in ['single', 'fast', 'quality']:
    img_count = 0
    preds = 0
    for wp in tqdm(wps):
        title = wp.get_title()
        ss_path = Path(ss_root / f'ss-{mode}-{title}.json')
        with open(ss_path) as f:
            ss_data = json.load(f) # img number -> 'boxes'[]

        img_paths = wp.rgb_image_paths()

        for img_path in img_paths:
            img = Image.open(img_path)
            img_id = Path(img_path).stem
            img_key = f'{title}-{img_id}'
            img_count += 1

            ss_boxes = ss_data[img_id]['boxes']
            preds += len(ss_boxes)
    
    print(preds / img_count)

In [None]:
# Execution time
for mode in ['single', 'fast', 'quality']:
    total_runtime = 0
    img_count = 0
    for wp in tqdm(wps):
        title = wp.get_title()
        ss_path = Path(ss_root / f'ss-{mode}-{title}.json')
        with open(ss_path) as f:
            ss_data = json.load(f) # img number -> 'boxes'[]

        img_paths = wp.rgb_image_paths()

        for img_path in img_paths:
            img_id = Path(img_path).stem
            img_key = f'{title}-{img_id}'
            img_count += 1
            runtime = ss_data[img_id]['time']
            total_runtime += runtime
            
    print(total_runtime / img_count)

In [None]:
# Visualize random wp img
import random

mode = 'quality'
_wps = {_wp.get_title(): _wp for _wp in wps}

# title = random.choice(list(_wps.keys()))
title = 'wp_4_4_7'

print(title)
wp = _wps[title]
title = wp.get_title()
ss_path = Path(ss_root / f'ss-{mode}' / f'ss-{mode}-{title}.json')
with open(ss_path) as f:
    ss_data = json.load(f)



img_paths = wp.rgb_image_paths()
for p in img_paths:
    if p.stem == '005':
        img_path = p
# img_path = random.choice(img_paths)

img_id = Path(img_path).stem
img_key = f'{title}-{img_id}'
new_img_path = arid.annotation_path(img_path, 'selective-search')

print(img_key)
with open('ss-iou-results.json') as outfile:
    results = json.load(outfile)
    print(results[mode][img_key])

img = Image.open(img_path)


acc1 = []
acc2 = []
acc3 = []
annotations = []
for x1, y1, x2, y2 in ss_data[img_id]['boxes']:
    annotation = {
        'id': '.',
        'coords': [(x1,y1), (x2,y1), (x2,y2), (x1,y2)],
        'score': 1.0,
        'colormap': 'spring_r'
    }
    annotations.append(annotation)

gt_annotations_raw = wp.get_annotations(img_path.stem)['annotations']
gt_annotations = []
for gt_annotation_raw in gt_annotations_raw:
    if gt_annotation_raw['id'] is not None:
        x = gt_annotation_raw['x']
        y = gt_annotation_raw['y']
        w = gt_annotation_raw['width']
        h = gt_annotation_raw['height']
        gt_annotations.append({
            'id': gt_annotation_raw['id'],
            'coords': [(x,y), (x+w, y), (x+w, y+h), (x,y+h)],
            'score': 1.0,
            'colormap': 'YlGn'
        })

acc1 = []
acc2 = []
acc3 = []
acc = []

for gt_annotation in gt_annotations:
    acc1.append(gt_annotation)
    acc2.append(gt_annotation)
    acc3.append(gt_annotation)

    for a in annotations:
        iou = arid.compute_bbox_iou(gt_annotation['coords'], a['coords'])
#         if iou > 0.9:
#             acc3.append(a)

#         if iou > 0.7:
#             acc2.append(a)

        if iou > 0.5:
            acc1.append(a)
        acc.append(a)

img1 = img.copy()
img2 = img.copy()
img3 = img.copy()

# arid.annotate_img(img, new_img_path, acc, save=False)
arid.annotate_img(img1, new_img_path, acc1, save=False)
# arid.annotate_img(img2, new_img_path, acc2, save=False)
# arid.annotate_img(img3, new_img_path, acc3, save=False)
# display(img)
display(img1)
# display(img2)
# display(img3)
                

In [None]:
# Transform scores for plotting
scores = {} # object => 

with open('ss-iou-results.json') as outfile:
    results = json.load(outfile)
    for mode in ['single', 'fast', 'quality']:
        scores[mode] = {}
        for wp in tqdm(wps):
            title = wp.get_title()
            img_paths = wp.rgb_image_paths()

            for img_path in img_paths:
                img_id = Path(img_path).stem
                img_key = f'{title}-{img_id}'

                for obj_instance in results[mode][img_key]:
                    obj_name = '_'.join(obj_instance.split('_')[:-1])

                    if scores[mode].get(obj_name) is None:
                        scores[mode][obj_name] = {
                            'total': 0,
                            '0.5': 0,
                            '0.7': 0,
                            '0.9': 0,

                        }

                    scores[mode][obj_name]['total'] += 1
                    if results[mode][img_key][obj_instance]['0.5']:
                        scores[mode][obj_name]['0.5'] += 1

                    if results[mode][img_key][obj_instance]['0.7']:
                        scores[mode][obj_name]['0.7'] += 1

                    if results[mode][img_key][obj_instance]['0.9']:
                        scores[mode][obj_name]['0.9'] += 1

    _scores = {}
    for mode in ['single', 'fast', 'quality']:
        _scores[mode] = {}
        for k, v in scores[mode].items():
            t = v['total']
            _scores[mode][k] = {
                '0.5': v['0.5'] / t,
                '0.7': v['0.7'] / t,
                '0.9': v['0.9'] / t 
            }

    s_score = {k: v for k, v in sorted(_scores['single'].items(), key=lambda item: (item[1]['0.5'], item[1]['0.7'], item[1]['0.9']))}
    f_score = {k: v for k, v in sorted(_scores['fast'].items(), key=lambda item: (item[1]['0.5'], item[1]['0.7'], item[1]['0.9']))}
    q_score = {k: v for k, v in sorted(_scores['quality'].items(), key=lambda item: (item[1]['0.5'], item[1]['0.7'], item[1]['0.9']))}


In [None]:
# Create plot for the selective search scores 
def plot_ss_scores(scores, f_name, w=10, t=90):
    import numpy as np
    import matplotlib.pyplot as plt
    
    plt.figure(figsize=(20,30))
    font = {'family' : 'DejaVu Sans',
        'weight' : 'regular',
        'size'   : 24}

    plt.rc('font', **font)
    
    s5 = [v['0.5']*100 for k, v in scores.items()]
    s7 = [v['0.7']*100 for k, v in scores.items()]
    s9 = [v['0.9']*100 for k, v in scores.items()]

    ind = np.arange(len(scores))    # the x locations for the groups
    height = 0.25         # the width of the bars
    a = plt.barh(ind + 2*height, s5, height=height, label='0.5')
    b = plt.barh(ind + height, s7, height=height, label='0.7')
    c = plt.barh(ind, s9, height=height, label='0.9')

    plt.xlabel('Percent of Object Instances')
    plt.ylabel('Object Name')
    plt.yticks(ind, scores.keys(), ha='right')
    plt.legend((a, b, c), ('0.5', '0.7', '0.9'))
    plt.tight_layout()
    plt.savefig(f_name)


In [None]:
min_gap = 100
min_t = ''
for k, v in s_score.items():
    f = f_score[k]
    q = q_score[k]
    
    t_min_gap = f['0.5'] - v['0.5']
    if  t_min_gap < min_gap:
        min_gap = t_min_gap
        min_t = k

print(k)
print(min_gap)


In [None]:
plot_ss_scores(s_score, 'ss-single.png')
plot_ss_scores(f_score, 'ss-fast.png')
plot_ss_scores(q_score, 'ss-quality.png')

In [None]:
# Compute total accuracy
with open('ss-iou-results.json') as outfile:
    results = json.load(outfile)
    for mode in ['single', 'fast', 'quality']:
        t5 = 0
        t7 = 0
        t9 = 0
        total = 0
        for wp in tqdm(wps):
            title = wp.get_title()
            img_paths = wp.rgb_image_paths()

            for img_path in img_paths:
                img_id = Path(img_path).stem
                img_key = f'{title}-{img_id}'

                for obj_instance in results[mode][img_key]:
                    total += 1
                    if results[mode][img_key][obj_instance]['0.5']:
                        t5 += 1
                    if results[mode][img_key][obj_instance]['0.7']:
                        t7 += 1
                    if results[mode][img_key][obj_instance]['0.9']:
                        t9 += 1
                        
        print((t5 / total) * 100)
        print((t7 / total) * 100)
        print((t9 / total) * 100)
        print('---')
