In [1]:
import os
import json
import data_util
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import ipywidgets as widgets
from IPython.display import display, clear_output, Markdown, Image, HTML

In [2]:
PID_MAP = '/home/psd2120/research/data/page_id_map.json'
EVAL1_TXT = '/home/psd2120/research/data/eval.txt'
EVAL1_IMG_DIR = '../data/eval1/'
LABELS_TXT = '/home/psd2120/research/data/labels.txt'

with open(EVAL1_TXT, 'r') as f:
    eval_fnames = f.read().splitlines()

with open(PID_MAP, 'r') as f:
    pid_map = json.load(f)

pid2img = dict()
for fname in eval_fnames:
    splt = fname.split('/')
    pid2img[int(splt[-2])] = splt[-1]

label_pid = dict()
label_idx = 1
with tf.io.gfile.GFile(LABELS_TXT, 'r') as f:
    labels = f.read().splitlines()

for label in labels:
    label_pid[label_idx] = label
    label_idx += 1

In [4]:
model = hub.load('gs://eol-tfrc-tpu/chkpts/eol2020/finetune/eval/ResNet50_2048/hub/63940/')

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [5]:
labels_pid = widgets.FileUpload(accept='.jpg', multiple=False)
show_preds = widgets.Button(description="Get Preds")
clear = widgets.Button(description="Clear")
output = widgets.Output()

def on_clear_clicked(b):
    with output:
        clear_output()

def on_show_preds_clicked(b):
    on_clear_clicked(b)
    with output:
        file = labels_pid.value
        fname, val = file.popitem()
        img = tf.image.decode_jpeg(val['content'], channels=3)
        
        img = data_util.preprocess_image(img, 224, 224, is_training=False,\
                                         color_distort=True, test_crop=True)
        img = tf.expand_dims(img, axis=0)
        logits = model.signatures['default'](tf.convert_to_tensor(img))['logits_sup']
        preds_conf, preds_idx = tf.nn.top_k(tf.nn.softmax(logits),k=5)
        preds_conf = preds_conf.numpy().tolist()[0]
        preds_idx = preds_idx.numpy().tolist()[0]
        
        # Get the image paths
        pred_1_img = os.path.join(EVAL1_IMG_DIR, pid2img[int(label_pid[preds_idx[0]])])
        pred_2_img = os.path.join(EVAL1_IMG_DIR, pid2img[int(label_pid[preds_idx[1]])])
        pred_3_img = os.path.join(EVAL1_IMG_DIR, pid2img[int(label_pid[preds_idx[2]])])
        pred_4_img = os.path.join(EVAL1_IMG_DIR, pid2img[int(label_pid[preds_idx[3]])])
        pred_5_img = os.path.join(EVAL1_IMG_DIR, pid2img[int(label_pid[preds_idx[4]])])

        # Prep for display
        display(widgets.Image(value=val["content"], width=300, height=300))
        td_pred_1 = "<td><img src=" + pred_1_img + " width='300' height='300'></td>"
        td_pred_2 = "<td><img src=" + pred_2_img + " width='300' height='300'></td>"
        td_pred_3 = "<td><img src=" + pred_3_img + " width='300' height='300'></td>"
        td_pred_4 = "<td><img src=" + pred_4_img + " width='300' height='300'></td>"
        td_pred_5 = "<td><img src=" + pred_5_img + " width='300' height='300'></td>"

        tr_pid = "<tr><td>" + 'Pred PID->' + "</td><td>" + str(label_pid[preds_idx[0]]) + "</td><td>" +\
                  str(label_pid[preds_idx[1]]) + "</td><td>" + str(label_pid[preds_idx[2]]) +\
                  "</td><td>" + str(label_pid[preds_idx[3]]) +\
                  "</td><td>" + str(label_pid[preds_idx[4]]) + "</td></tr>"

        tr_name = "<tr><td>" + 'canonicalName->' + "</td><td>" +\
                  pid_map[label_pid[preds_idx[0]]]['canonicalName'] + "</td><td>" +\
                  pid_map[label_pid[preds_idx[1]]]['canonicalName'] + "</td><td>" +\
                  pid_map[label_pid[preds_idx[2]]]['canonicalName'] + "</td><td>" +\
                  pid_map[label_pid[preds_idx[3]]]['canonicalName'] +\
                  "</td><td>" + pid_map[label_pid[preds_idx[4]]]['canonicalName'] + "</td></tr>"

        tr_conf = "<tr><td>" + 'Softmax Prob. ->' + "</td><td>" + str(round(preds_conf[0],3)) + "</td><td>" +\
                  str(round(preds_conf[1],3)) + "</td><td>" + str(round(preds_conf[2],3)) + "</td><td>" +\
                  str(round(preds_conf[3],3)) + "</td><td>" + str(round(preds_conf[4],3)) + "</td></tr>"

        tr = "<table><tr>" +\
             '<td>Preds-></td>' + td_pred_1 + td_pred_2 +\
             td_pred_3 + td_pred_4 + td_pred_5 +\
             "</tr>" + tr_pid + tr_name + tr_conf + "</table>"
        display(HTML(tr))


show_preds.on_click(on_show_preds_clicked)
clear.on_click(on_clear_clicked)

In [6]:
display(widgets.HBox((labels_pid, show_preds, clear)))
display(output)

HBox(children=(FileUpload(value={}, accept='.jpg', description='Upload'), Button(description='Get Preds', styl…

Output()