In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pickle

from IPython.display import display_html
from keras.models import load_model, Model
from matplotlib.colors import rgb2hex
import matplotlib.pyplot as plt
import numpy as np

from masterthesis.features.build_features import file_to_sequence
from masterthesis.utils import ATTENTION_LAYER, DATA_DIR, MODEL_DIR, RESULTS_DIR, load_split

Using TensorFlow backend.


In [3]:
def show_weighted(tokens, weights):
    cmap = plt.cm.coolwarm
    ents = ['<p style="max-width: 80ch">']
    for token, weight in zip(tokens, weights):
        col = rgb2hex(cmap(weight))
        ent = '<span style="color: %s">%s</span>' % (col, token)
        ents.append(ent)
    ents.append('</p>')
    return ' '.join(ents)

In [4]:
def tokens_in_file(filepath):
    with filepath.open() as f:
        for line in f:
            for token in line.strip().split():
                yield token

In [5]:
def file_with_attention(filepath, attention_model, w2i, maxlen=None):
    x = file_to_sequence(700, filepath, w2i)[np.newaxis, :]
    weights = attention_model.predict(x)[0]
    weights -= weights.min()
    weights /= weights.max()
    if maxlen is not None:
        weights = weights[:maxlen]
    unk_idx = w2i['__UNK__']
    tokens = (token if idx != unk_idx else 'UNK' for token, idx in zip(tokens_in_file(filepath), x[0]))
    return show_weighted(tokens, weights)

In [6]:
model_name = 'rnn_nli-25740789'
# model_name = 'rnn_nli-25717032'

In [7]:
w2i = pickle.load((MODEL_DIR / (model_name + '_model_w2i.pkl')).open('rb'))
model = load_model(str(MODEL_DIR / (model_name + '_model.h5')))

In [8]:
attention_model = Model(inputs=model.input, outputs=model.get_layer(ATTENTION_LAYER).output)

In [9]:
dev = load_split('dev')

results = pickle.load((RESULTS_DIR / (model_name + '.pkl')).open('rb'))
correct_predictions = dev.iloc[np.where(results.true == results.predictions)]
langs = ['russisk', 'vietnamesisk', 'engelsk', 'somali']
sample_indices = [correct_predictions[correct_predictions.lang == l].index[0]
                  for l in langs]
samples = correct_predictions.loc[sample_indices]
display(samples)
filenames = [DATA_DIR / 'txt' / (fn + '.txt') for fn in samples.filename]

Unnamed: 0,age,cefr,filename,gender,lang,num_tokens,split,testlevel,title,topic
183,29,B2,h0186,kvinne,russisk,551,dev,Høyere nivå,Helse og livskvalitet,helse
1014,29,A2/B1,s0180,kvinne,vietnamesisk,380,dev,Språkprøven,Røyking,helse røyking
186,51,C1,h0189,kvinne,engelsk,523,dev,Høyere nivå,Helse og livskvalitet,helse
1450,21,A2/B1,s0621,kvinne,somali,284,dev,Språkprøven,En hyggelig opplevelse,opplevelse


In [10]:
for filename in filenames:
    display_html('<h2>%s</h2>' % filename.stem, raw=True)
    html = file_with_attention(filename, attention_model, w2i, maxlen=300)
    with open(filename.stem + '.html', 'w') as f:
        print("""
<!DOCTYPE html>
<html lang="nb">
<head>
  <meta charset="utf-8">
  <title>%s</title>
  <style>
  body {
    font-family: "Helvetica Neue", Helvetica, Arial, sans-serif;
    font-size: 14px;
    line-height: 20px;
    text-align: left;
  }
  </style>
</head>
<body>
""" % filename.stem, file=f)
        print(html, file=f)
        print('</body></html>', file=f)
    display_html(html, raw=True)