In [1]:
% cd ../

/home/otani_mayu/Experiments/ads2018_release


In [3]:
import os
import json
import chainer
import numpy as np
import chainer.functions as F
import matplotlib.pyplot as plt
from IPython.display import display_html

In [4]:
import sys
sys.path.append('script/')
from train import DatasetOCR, TextCNN, TextLSTM, NonVisualNet, Net, my_converter, AttentionNetWTL, word_attention

Using Theano backend.


[nltk_data] Downloading package stopwords to
[nltk_data]     /home/otani_mayu/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [5]:
from matplotlib.colors import Normalize
norm = Normalize(vmin=0.0, vmax=2.0)

import matplotlib.colors
cmap = plt.get_cmap('coolwarm')


    
def get_word_vis_html(att_val, words):

    word_tmp = '''
    <p style="display:inline;background-color:{0};font-size:12pt;color:#F2F2F2;font-weight:lighter;font-family:verdata;">{1}</p>
    '''

    col_items = ''
    for a, w in zip(att_val, words):
        col_items += word_tmp.format(matplotlib.colors.to_hex(cmap(a)), w)
        
    return col_items

In [6]:
def get_ocr_att(ocr, act, rsn):
    ya_emb = model.lng_net.word_embedding(act)
    yr_emb = model.lng_net.word_embedding(rsn)        
    yocr_emb = model.lng_net.word_embedding(ocr)
    yocr_emb = F.repeat(yocr_emb, 15, axis=0)

    # action feat
    h_act = model.lng_net(act)

    # reason feat
    h_rsn = model.lng_net(rsn)

    # attention over ocr words (action)
    _, a_att = word_attention(yocr_emb, ya_emb)

    # attention over ocr words (reason)
    _, r_att = word_attention(yocr_emb, yr_emb)
    
    return a_att, r_att
    

In [7]:
def get_result_html(img_id, a_att_array, r_att_array, raw_ocr, desc, sort_i):
    table_tmp = '''
    <table align="left",width="50%%">
      <tr>
        <th>%s</th>
      </tr>
      <tr>
        <th>%s</th>
      </tr>
      <tr>
        <th>%s</th>
      </tr>
    </table>
    '''

    img_tmpl = '''<img src="http://localhost:8888/files/data/test_images/%s" width="200" height="200">'''
    
    html = ''

    for im_i, act_a, rsn_a, o, des, s_i in zip(img_id, a_att_array, r_att_array, raw_ocr, desc, sort_i):
        desc_top1 = des[s_i[0]]
        act_row_html = get_word_vis_html(act_a[s_i[0]], o)
        rsn_row_html = get_word_vis_html(rsn_a[s_i[0]], o)

        img_html = img_tmpl % im_i
        word_html = act_row_html + '<br>' + rsn_row_html
        desc_html = '<p style="display:inline;font-size:12pt;">%s</p>' % desc_top1

        html += table_tmp % (img_html, desc_html, word_html)
    return html

In [8]:
model_dir = 'output/checkpoint/ocr+vis20180628-153515/'
device = 0

In [9]:
args = json.load(open(os.path.join(model_dir, 'args')))

chainer.config.remove_stopwords = args['remove_stopwords']

test = DatasetOCR('test')

text_net = args['text_net']
if text_net == 'cnn':
    lng_net = TextCNN(len(test.tokenizer.word_index) + 1, None)
elif text_net == 'lstm':
    lng_net = TextLSTM(len(test.tokenizer.word_index) + 1, None)
else:
    raise RuntimeError('invalid text_net')

h_size=args['h_size']
margin = args['margin']
model_name = args['model_name']

if model_name == 'ocr':
    model = NonVisualNet(lng_net, h_size=h_size, margin=margin)
elif model_name == 'ocr+vis':
    att_net = AttentionNetWTL(h_size=100)
    model = Net(lng_net, att_net)
else:
    raise RuntimeError

chainer.serializers.load_npz(os.path.join(model_dir, 'model'), model)


load data/tokenizer.pickle


In [11]:
html_tmp = '''
 <!DOCTYPE html>
<html>
<head>
<title>Page Title</title>
</head>
<body>

%s

</body>
</html> 
'''

reverse_word_map = dict(map(reversed, test.tokenizer.word_index.items()))
def get_raw_ocr(item):
    raw_ocr = []
    for x in item:
        raw_ocr.append([reverse_word_map[idx] for idx in x[4]])
    return raw_ocr

if device is not None:
    chainer.cuda.get_device_from_id(device).use()
    model.to_gpu()

results = {}
b_size = 100

with chainer.using_config('train', False), chainer.using_config('test', True), chainer.using_config('enable_backprop', False):
    for i in range(0, len(test), b_size):
        item = test[i: i + b_size]
        img_id = test.images[i:i+b_size]

        with chainer.using_config('clean_description', False):
            desc = [test.get_row_answer(j) for j in range(i, min(len(test)-1, i+b_size))]
        
        raw_ocr = get_raw_ocr(item)

        batch = my_converter(item, device=device)
        outputs = model.predict(*batch, layers=['dist'])
        dist = outputs['dist']

        dist.to_cpu()
        dist = dist.data.ravel()
        dist = np.reshape(dist, (-1, 15))


        sort_i = dist.argsort(axis=1)
        
        # get ocr attention value
        action, reason, ocr = batch[1], batch[2], batch[4]
        a_att, r_att = get_ocr_att(ocr, action, reason)
        a_att.to_cpu()
        r_att.to_cpu()
        
        wc = a_att.shape[-2]
        a_att_array = a_att.data.squeeze().reshape(-1, 15, wc)
        r_att_array = r_att.data.squeeze().reshape(-1, 15, wc)
            
        
        for im_i, des, s_i in zip(img_id, desc, sort_i):
            results[im_i] = des[s_i[0]]
        
        body = get_result_html(img_id, a_att_array, r_att_array, raw_ocr, desc, sort_i)
        
        display_html(body, raw=True)
        break
        