# Generation test of LK-C-Model
### imports

In [1]:
from blocks.model import Model
from conllutil import CoNLLData
from itertools import chain
%matplotlib inline
import matplotlib.pyplot as plt
from network import *
from numpy import array, hstack, load, save, vstack, zeros
from os import path
from pandas import factorize
from random import randint
from scipy.stats import pearsonr
from stoogeplot import hinton_diagram
from theano import function
from theano.tensor.sharedvar import SharedVariable
from theano.tensor import matrix, TensorType
from util import StateComputer

### constants

In [2]:
ALPHA = .05
MODEL_FILE = './models/hdt/hdt-ncs-eos-np-35-7-1.pkl'
IX_2_TOK_FILE = './data/hdt-ncs-eos-np-35-7-1_ix2tok.npy'
HDT_DIR = '../datasets/hdt/hamburg-dependency-treebank-conll/'
NP_FOLDER = './data/np'

### Build model

In [3]:
ix2tok = load(IX_2_TOK_FILE).item()
nt = Network(NetworkType.LSTM, input_dim=len(ix2tok))
nt.set_parameters(MODEL_FILE)

### Building generator

In [4]:
model = Model(nt.generator.generate(n_steps=nt.x.shape[0], batch_size=nt.x.shape[1]))
param_dict = model.get_parameter_dict()
init_state_0 = param_dict['/sequencegenerator/with_fake_attention/transition/layer#0.initial_state#0']
init_state_1 = param_dict['/sequencegenerator/with_fake_attention/transition/layer#1.initial_state#1']
init_state_2 = param_dict['/sequencegenerator/with_fake_attention/transition/layer#2.initial_state#2']
init_cells_0 = param_dict['/sequencegenerator/with_fake_attention/transition/layer#0.initial_cells']
init_cells_1 = param_dict['/sequencegenerator/with_fake_attention/transition/layer#1.initial_cells']
init_cells_2 = param_dict['/sequencegenerator/with_fake_attention/transition/layer#2.initial_cells']
reset_values = {
    0: (init_state_0.get_value(), init_cells_0.get_value()),
    1: (init_state_1.get_value(), init_cells_1.get_value()),
    2: (init_state_2.get_value(), init_cells_2.get_value())
}
gen_func = model.get_theano_function(allow_input_downcast=True)

In [5]:
tok2ix = {v: k for k, v in ix2tok.items()}
sc = StateComputer(nt.cost_model, tok2ix)

# Before we continue ... save the model in a lightweight format

In [6]:
#save(path.join(NP_FOLDER, 'param_dict.npy'), array(param_dict))
#for k in param_dict:
#    save(path.join(NP_FOLDER, k.replace('/', '-')), param_dict[k].get_value())

### Generation procedure

In [7]:
def reset_generator():
    init_state_0 = reset_values[0][0]
    init_cells_0 = reset_values[0][1]
    init_state_1 = reset_values[1][0]
    init_cells_1 = reset_values[1][1]
    init_state_2 = reset_values[2][0]
    init_cells_2 = reset_values[2][1]
    
def init_zero():
    # note sure this is always a good idea
    d = init_state_0.get_value().shape[0]
    dt = 'float32'
    init_state_0.set_value(zeros(d, dtype=dt))
    init_cells_0.set_value(zeros(d, dtype=dt))
    init_state_1.set_value(zeros(d, dtype=dt))
    init_cells_1.set_value(zeros(d, dtype=dt))
    init_state_2.set_value(zeros(d, dtype=dt))
    init_cells_2.set_value(zeros(d, dtype=dt))
    
def generate_sequence(start, reset_func):
    
    seq = [start]
    ix = array([[tok2ix[start]]])
    while not seq[-1] == '<EOS>':
        state_0, cells_0, state_1, cells_1, state_2, cells_2, ix, costs = gen_func(ix)
        init_state_0.set_value(state_0[0][0])
        init_cells_0.set_value(cells_0[0][0])
        init_state_1.set_value(state_1[0][0])
        init_cells_1.set_value(cells_1[0][0])
        init_state_2.set_value(state_2[0][0])
        init_cells_2.set_value(cells_2[0][0])
        seq.append(ix2tok[ix[0][0]])
    
    reset_func()
    
    return ' '.join(seq[:-1])


In [8]:
#print(generate_sequence('ein', reset_generator))  # good results 500 - 1000

In [11]:
cd.wordsequences()[:2]

[['begleitet',
  'von',
  'marktgerüchten',
  'über',
  'den',
  'bevorstehenden',
  'konkurs',
  'von',
  'amazon',
  'setzt',
  'die',
  'aktie',
  'des',
  'online-händlers',
  'ihre',
  'talfahrt',
  'fort'],
 ['an',
  'der',
  'nasdaq',
  'rutschte',
  'das',
  'papier',
  'am',
  'gestrigen',
  'mittwoch',
  'kurz',
  'sogar',
  'unter',
  'die',
  'marke',
  'von',
  '10',
  'us-dollar',
  'in',
  'frankfurt',
  'pendelt',
  'sie',
  'am',
  'heutigen',
  'donnerstag',
  'zwischen',
  'und',
  'euro']]

# Check correlations with POS
## Step 1: Read all sentences from PART_A and store activations

__comment:__  Only part A is used, since this one's annotations are handmade and checked.

In [10]:
eos_ix = tok2ix['<EOS>']
cd = CoNLLData(HDT_DIR, ['part_A.conll'], tok2ix, word_transform=str.lower, lazy_loading=True, min_len=7, max_len=35)
sentences_ix = [[tok2ix[seq[i]] for i in range(len(seq))] + [eos_ix] for seq in cd.wordsequences()]

In [14]:
cell_name = sc.state_var_names[2]
# for testing purposes only read two sequences, try again later with more
activations = sc.read_single_sequence(sentences_ix[0])[cell_name][1:,:]
activations = vstack((activations, sc.read_single_sequence(sentences_ix[1])[cell_name][1:,:]))

### Correlate for each pos tag

In [None]:
# to keep computation low, we just do in for a small set, there is no time left to check it anyway
# tagset = set(chain(*cd.possequences()))
tagset = {'ART', 'ADJA', 'NN'}
pos_corrs = {}
for tag in tagset:
    crlist = []
    for activation in activations.transpose():
        """
        NOTE: activation is the concatenated activation of a particular cell of all sentences
        Otherwise there wouldn't be enough data points for sentences with 7 <= length <= 35
        ---
        This is not the cleanest imaginable approach, since the activations between sequences
        are not related, but it still might provide some insights
        """
        pc = pearsonr(activation, [1 if ptag == tag else 0 for ptag in chain(*cd.possequences()[0:2])])
        crlist.append(pc[0] if pc[1] < ALPHA else 0.0)
    pos_corrs[tag] = array(crlist)

### Extract 10 highest correlations for POS-Tags ART, NN, ADJA, since they might be relevant for the DET relation

In [None]:
top_corr_ix = {}
for tag in tagset:
    top_corr_ix[tag] = abs(pos_corrs[tag]).argsort()

### Let's plot that for 50 tokens, 10 strongest cells only

__NOTE:__ The plots are aligned this time, i. e. a displayed activation occured after reading the word given in the label below!

#### ART

In [None]:
f = plt.figure(figsize=(30, 50))
xlabels = [tpl[1]+' / '+tpl[4] for tpl in chain(*cd.sequences()[:2])]  # TODO remove index, just debugging!
x = activations[:,top_corr_ix['ART'][-10:]]
ax = plt.gca()
h = hinton_diagram(x, ax=ax)
ax.set_xticks(range(len(xlabels)))
ax.set_xticklabels(xlabels, rotation=90, fontdict={'fontsize': 22})
f.subplots_adjust(bottom=.2)
