In [1]:
import bokeh
import numpy as np
from bokeh.plotting import figure, show
from bokeh.io import output_notebook, export_svgs
from bokeh.models import LabelSet, ColumnDataSource
output_notebook()
import os
os.environ['QT_QPA_PLATFORM']='offscreen'

In [30]:
from bokeh.models import LinearColorMapper

def plot_matrix(matrix, x_tick_text=None, y_tick_text=None):
    plot = figure(width=matrix.shape[1] * 75 + 150,
                 height=matrix.shape[0] * 30 + 50,
                 x_range=(0, matrix.shape[1]),
                 y_range=(0, matrix.shape[0]),
                 toolbar_location='below')
    
    # set axes tick labels
    if x_tick_text is not None:
        plot.xaxis.ticker = [i + 0.5 for i in range(matrix.shape[1])]
        plot.xaxis.major_label_overrides = {i + 0.5: x_tick_text[i]
                                            for i in range(matrix.shape[1])}
    if y_tick_text is not None:
        plot.yaxis.ticker = [i + 0.5 for i in range(matrix.shape[0])]
        plot.yaxis.major_label_overrides = {i + 0.5: y_tick_text[i]
                                            for i in range(matrix.shape[0])} 
    plot.xaxis.major_label_text_font_size = "12pt"
    plot.yaxis.major_label_text_font_size = "12pt"

    # show color grids
    color_mapper = LinearColorMapper(high=1.0, low=0.0, palette='Blues8')
    plot.image(x=[0], y=[0], dw=[matrix.shape[1]], dh=[matrix.shape[0]], 
               image=[matrix],
               color_mapper=color_mapper)

    x = np.linspace(0.5, matrix.shape[1] - 0.5, matrix.shape[1])
    y = np.linspace(0.5, matrix.shape[0] - 0.5, matrix.shape[0])
    xv, yv = np.meshgrid(x, y)
    text = ['{:.2f}'.format(i) for i in np.reshape(matrix, [-1])]
    colors = ['black' if matrix[i][j] > 0.5 else 'white'
             for i in range(matrix.shape[0])
             for j in range(matrix.shape[1])]
    source = ColumnDataSource({
        'x': np.reshape(xv, [-1]),
        'y': np.reshape(yv, [-1]),
        'text': text,
        'color': colors
    })

    numbers = LabelSet(x='x', y='y', text='text',
                      text_align='center',
                      text_baseline='middle',
                      text_font_size='12pt',
                      text_color='color',
                      source=source)
    plot.add_layout(numbers)
    return plot

In [5]:
% cd ../src/

/home/tinray/DSTC7/src


In [3]:
import pickle, torch, json

In [71]:
# dump
with open('../models/task1/advising/recurrent_transformer_pool/dump.pkl.5', 'rb') as f:
    dump = pickle.load(f)

# data
with open('../data/task1/advising_dev_p.pkl', 'rb') as f:
    data = pickle.load(f)
    data.min_context_len = 10000
    
# embedding
with open('../data/task1/advising_embeddings_e1.pkl', 'rb') as f:
    embedding = pickle.load(f)

rev_dict = {v: k for k, v in embedding.word_dict.items()}
    
def get_text(data, index):
    context = [rev_dict[w] for w in data[index]['context']]
    utterances = [context[start + 1:end + 1]
                 for start, end in zip([-1] + data[index]['utterance_ends'],
                                      data[index]['utterance_ends'])]
    options = [[rev_dict[w] for w in option]
                 for option in data[index]['options']]
    return utterances, options


In [20]:
dump.keys()

dict_keys(['connection'])

In [75]:
# 9
index = 0
offset = 7
attn = dump['connection'][index][offset].numpy()
utterances, options = get_text(data, index)

plot = plot_matrix(attn[0, 0], x_tick_text=utterances[offset] + ['self'], y_tick_text=utterances[offset + 1])

plot.output_backend = "svg"

show(plot)