# Attention-based Neural Machine Translation Models as Feature Functions in Moses

</br></br>
##### Marcin Junczys-Dowmunt, Tomasz Dwojak

**Neural Machine Translation by Jointly Learning to Align and Translate**

Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio

* http://arxiv.org/abs/1409.0473
* http://arxiv.org/pdf/1409.0473v6.pdf

And an open source implementation in GroundHog:

https://github.com/lisa-groundhog/GroundHog

## Implementation problems

* Standard Moses FF-interface unusable 
* Results in millions of queries to GPU
* **No. of GPU queries is proportional to no. of hypotheses**

## Possible Solution:

* Add step **before** hypothesis expansion that collects all hypotheses and possible extensions
* Pre-calculate probabilities for all possible probabilities:
   * Assemble all state vectors for hypothesis in stack into one matrix (duplicated by number of extensions per hypothesis)
   * Perform **one step** per target word position
* Treat pre-calculated values like a static language model
* Query within normal FF-interface
* **No. of GPU queries is proportional to no. of stacks** 

## First attempt (MTM2015): Embedding python interpreter in C++

* https://github.com/emjotde/mosesdecoder/blob/oldstuff/moses/FF/NMT/NMT_Wrapper.cpp#L122-L161
* Slow
* Instable
* Would not scale to multiple GPUs or Multi-Threading

## But seems to work!

System trained on IWSLT 2015 data (en-de), evaluated on Test-2013.

</br>


### Stand-alone
* 21.5 (Vanilla Moses)
* 25.6 (our Moses setup)
* 25.8 (3-ensemble Groundhog NMT): 

</br>
### Combinations:
* 27.2 (rescoring a 1000-best list)
* 28.3 (Moses with NMT-FF with Stack-Size=30)

# Re-implementation
## Python Numpy (to understand what's going on)

In [1]:
import numpy as np
data = np.load("/home/marcinj/Badania/best_nmt/search_model.npz")
for key in data:
    print(key, data[key].shape)

W_0_dec_approx_embdr (30001, 620)
b_0_dec_initializer_0 (1000,)
b_0_dec_hid_readout_0 (1000,)
W_0_dec_repr_readout (2000, 1000)
b_dec_deep_softmax (30001,)
D_dec_transition_0 (1000, 1)
B_dec_transition_0 (1000, 1000)
R_dec_transition_0 (1000, 1000)
W_back_enc_transition_0 (1000, 1000)
R_back_enc_transition_0 (1000, 1000)
R_enc_transition_0 (1000, 1000)
W_0_dec_update_embdr_0 (620, 1000)
W_0_dec_initializer_0 (1000, 1000)
W_0_dec_input_embdr_0 (620, 1000)
b_0_dec_input_embdr_0 (1000,)
G_dec_transition_0 (1000, 1000)
W_0_back_enc_reset_embdr_0 (620, 1000)
W_0_dec_dec_updater_0 (2000, 1000)
G_enc_transition_0 (1000, 1000)
W_0_enc_input_embdr_0 (620, 1000)
W_dec_transition_0 (1000, 1000)
W2_dec_deep_softmax (620, 30001)
W_enc_transition_0 (1000, 1000)
b_0_enc_approx_embdr (620,)
W_0_dec_dec_inputter_0 (2000, 1000)
W_0_enc_reset_embdr_0 (620, 1000)
b_0_back_enc_input_embdr_0 (1000,)
W_0_dec_reset_embdr_0 (620, 1000)
W_0_dec_prev_readout_0 (620, 1000)
W_0_dec_hid_readout_0 (1000, 1000)
W_0_d

# Encoder

## Common

* $\overline{E}$ - `W_0_enc_approx_embdr` - the same for both directions, shape ${m \times K_x}$, where $m = 620$ i $K_x = 30001$
* $\overline{b}_{\bar{E}}$ - `b_0_enc_approx_embdr`, shape $m \times 1$, bias for `W_enc_approx_embdr`

## Forward

* $\overrightarrow{W}$ - `W_0_enc_input_embdr_0` - weights for 0th hidden layer, left-to-right, shape $n\times m$, where $n=1000$
* $\overrightarrow{W}_z$ - `W_0_enc_update_embdr_0` - GRU update, shape $n\times m$
* $\overrightarrow{W}_r$ - `W_0_enc_reset_embdr_0` - GRU reset, shape $n\times m$
* $\overrightarrow{U}$ - `W_enc_transition_0`, shape $n \times n$
* $\overrightarrow{U}_z$ - `G_enc_transition_0`, shape $n \times n$
* $\overrightarrow{U}_r$ - `R_enc_transition_0`, shape $n \times n$
* $\overrightarrow{b}_{\overrightarrow{W}}$ - `b_0_enc_input_embdr_0`, shape $n \times 1$, bias for `W_0_enc_input_embdr_0`

## Backward
Analoguous, with `back` interfix

## Calculations

Formulae taken from Bahdanau *et. al* (2014) with re-added biases:

$$
\renewcommand{\ora}[1]{\overrightarrow{#1}}
\renewcommand{\ola}[1]{\overleftarrow{#1}}
\ora{h}_i = \left\{
\begin{array}{ll}
(1 - \ora{z}_i) \circ \ora{h}_{i-1} + \ora{z}_i \circ \ora{\underline{h}}_i & \mathrm{, if } i > 0 \\
0 & \mathrm{, if } i = 0 
\end{array}
\right.
$$

where

$$
\begin{eqnarray}
\ora{\underline{h}}_i &=& \tanh\left(\ora{W}(\overline{E}x_i+\overline{b}) + \ora{b}_{\ora{W}} +\ora{U}\left[\ora{r}_i \circ \ora{h}_{i-1}\right]\right)\\
\ora{z_i} &=& \sigma\left(\ora{W}_z(\overline{E}x_i+\overline{b})+\ora{U}_z\ora{h}_{i-1}\right)\\
\ora{r_i} &=& \sigma\left(\ora{W}_r(\overline{E}x_i+\overline{b})+\ora{U}_r\ora{h}_{i-1}\right)\\
\end{eqnarray}
$$

The other directions works similarly, we only change the direction of the arrow. Implementation-wise, we reverse the input sequence, use the matrices for the other direction and reverse the result row-wise. Then we get:

$$
h_i = \left[
\begin{array}{c}
\ora{h}_i \\
\ola{h}_i
\end{array}
\right]
$$

# Decoder

### Embeddings

* $E$ - `W_0_dec_approx_embdr`, embeddings dla wyjścia, rozmiar $m \times K_y$
* $b$ - `b_0_dec_approx_embdr`, bias dla embeddings, rozmiar $m \times 1$

### RNN and GRU

* $W_s$ - `W_0_dec_initializer_0`, weight matrix used to calculate the initial state of the decoder
* $b_{W_s}$ - `b_0_dec_initializer_0`, bias for the initial state weight matrix
* $W$ - `W_0_dec_input_embdr_0` - shape $n\times m$
* $b_W$ - `b_0_dec_input_embdr_0`, bias
* $W_z$ - `W_0_dec_update_embdr_0` -GRU update, shape $n\times m$
* $W_r$ - `W_0_dec_reset_embdr_0` - GRU reset, shape $n\times m$
* $U$ - `W_dec_transition_0`, shape $n \times n$
* $U_z$ - `G_dec_transition_0`, shape $n \times n$
* $U_r$ - `R_dec_transition_0`, shape $n \times n$
* $C$ - `W_0_dec_dec_inputter_0`, shape $n \times 2n$
* $C_z$ - `W_0_dec_dec_updater_0`, shape $n \times 2n$
* $C_r$ - `W_0_dec_dec_reseter_0`, shape $n \times 2n$

### Alignment model

$n^\prime=1000$ number of neurons in the alignment model, 

* $v_\alpha$ - `D_dec_transition_0`, shape $n^\prime \times 1$
* $W_\alpha$ - `B_dec_transition_0`, shape $n^\prime \times n$
* $U_\alpha$ - `A_dec_transition_0`, shape $n^\prime \times 2n$

### Softmax

$l = 500$, size of hidden softmax layer, and $W_o = W_o^{(2)}W_o^{(1)}$

* $W_{o}^{(1)}$ - `W1_dec_deep_softmax`, shape $m\times l$
* $W_{o}^{(2)}$ - `W2_dec_deep_softmax`, shape $K_y\times m$
* $b_{W_o}$ - `b_dec_deep_softmax`, bias, shape $K_y\times 1$.
* $U_o$ - `W_0_dec_hid_readout_0`, shape $2l \times 2l$
* $b_{U_o}$ `b_0_dec_hid_readout_0`, shape $2l \times 1$
* $V_o$ - `W_0_dec_prev_readout_0`, $2l \times m$
* $C_o$ - `W_0_dec_repr_readout_0`, shape $2l \times 2n$

## Calculations

### GRU

$$
s_i = (1-z_i) \circ s_{i-1} + z_i \circ \tilde{s}_i \qquad s_0 = \tanh\left(W_s\ola{h}_1 + b_{W_s}\right)
$$

where

$$
\begin{eqnarray}
\tilde{s_i} &=& \tanh\left(W(Ey_i+b) + b_W + U \left[r_i \circ s_{i-1}\right] +Cc_i\right) \\
z_i &=& \sigma\left(W_z(Ey_i+b)+U_zs_{i-1} + C_zc_i \right)\\
r_i &=& \sigma\left(W_r(Ey_i+b)+U_rs_{i-1} + C_rc_i \right)\\
\end{eqnarray}
$$

### Attention Model

Attention score is calculated as:

$$
c_i = \sum_{j=1}^{T_x} \alpha_{ij}h_j
$$

where

$$ 
\begin{eqnarray}
\alpha_{ij} &=& \frac{\exp(e_{ij})}{\sum_{k=1}^{T_x}\exp(e_{ik})} \\ 
e_{ij} &=& v_\alpha^T \tanh\left(W_{\alpha}s_{i-1} + U_{\alpha}h_j\right) 
\end{eqnarray}
$$

### Deep Softmax
And deep softmax (here $W_o = W_o^{(2)}W_o^{(1)}$):

$$
p(y_i|s_{i-1},y_{i-1},c_i) = \textrm{softmax}\left(y_i^T\left(W_ot_i + b_o\right)\right) 
$$

where ($l = 500$)

$$
t_i = \left[\max \left\{ \tilde{t}_{i,2j-1},\tilde{t}_{i,2j} \right\}\right]_{j=1,\ldots,l}^T
$$

$$
\tilde{t}_i = U_os_{i-1}+b_{U_o}+V_o(Ey_{i-1} + b)+C_oc_i 
$$

where $y_0$ = 0

## Pure Numpy Implementation

In [7]:
#%%writefile bahdanau.py

import numpy as np
from __future__ import print_function

def logit(X):
    return 1.0 / (1.0 + np.exp(-X))

def softmax(X, ax=0):
    expX = np.exp(X)
    expXsum = np.sum(expX, axis=ax)
    return (expX / expXsum)

def batchAndMask(sents):
    maxLength = max(len(s) for s in sents)
    sentsPadded = [np.pad(np.copy(s), (0, maxLength-len(s)), mode="constant") 
                   for s in sents]
    batch = np.vstack(sentsPadded)
    mask = batch != 0
    return batch, mask

In [8]:
#%%writefile -a bahdanau.py

class Encoder:
    class Embeddings:
        def __init__(self, data):
            self.E  = data["W_0_enc_approx_embdr"]
            self.EB = data["b_0_enc_approx_embdr"].T

        def Lookup(self, i):
            return self.E[i] + self.EB
    
    class RNN:
        def __init__(self, data):
            self.W   = data["W"]
            self.B   = data["B"]
            self.U   = data["U"]
            self.Wz  = data["Wz"]
            self.Uz  = data["Uz"]
            self.Wr  = data["Wr"]
            self.Ur  = data["Ur"]

        def InitializeState(self, batchSize=1):
            H0 = np.zeros(1000 * batchSize).reshape(batchSize, 1000)
            return H0
        
        def GetNextState(self, embd, prevState):
            Zi  = logit(embd.dot(self.Wz) + prevState.dot(self.Uz))
            Ri  = logit(embd.dot(self.Wr) + prevState.dot(self.Ur))
            Hi_ = np.tanh(embd.dot(self.W) + self.B + (Ri * prevState).dot(self.U)) 
            Hi  = (1.0 - Zi) * prevState + Zi * Hi_
            return Hi
        
        def GetContext(self, embeddings):
            states = []
            prevState = self.InitializeState()
            for embd in embeddings:
                state = self.GetNextState(embd, prevState)
                states.append(state)
                prevState = state
            return states
    
    def __init__(self, data):
        self.embeddings = self.Embeddings(data)
        
        fW = dict()
        fW["W"] = data["W_0_enc_input_embdr_0"]
        fW["B"] = data["b_0_enc_input_embdr_0"].T
        fW["U"] = data["W_enc_transition_0"]
        fW["Wz"] = data["W_0_enc_update_embdr_0"]
        fW["Uz"] = data["G_enc_transition_0"]
        fW["Wr"] = data["W_0_enc_reset_embdr_0"]
        fW["Ur"] = data["R_enc_transition_0"]

        bW = dict()
        bW["W"] = data["W_0_back_enc_input_embdr_0"]
        bW["B"] = data["b_0_back_enc_input_embdr_0"].T
        bW["U"] = data["W_back_enc_transition_0"]
        bW["Wz"] = data["W_0_back_enc_update_embdr_0"]
        bW["Uz"] = data["G_back_enc_transition_0"]
        bW["Wr"] = data["W_0_back_enc_reset_embdr_0"]
        bW["Ur"] = data["R_back_enc_transition_0"]
        
        self.rnnForward  = self.RNN(fW)
        self.rnnBackward = self.RNN(bW)
        
    def GetContext(self, batch):
        batchSize, numSteps  = batch.shape
        sourceEmbeddings = [self.embeddings.Lookup(batch[:,i]) for i in range(numSteps)]
        statesForward  = self.rnnForward.GetContext(sourceEmbeddings)
        statesBackward = self.rnnBackward.GetContext(sourceEmbeddings[::-1])[::-1]
        states = np.hstack((np.vstack(statesForward),
                            np.vstack(statesBackward)))
        return states 

In [9]:
#%%writefile -a bahdanau.py

class Decoder:
    class Embeddings:
        def __init__(self, data):
            self.E  = data["W_0_dec_approx_embdr"]
            self.EB = data["b_0_dec_approx_embdr"]
            
        def Initialize(self, batchSize=1):
            return np.zeros((batchSize, self.E.shape[1]))
        
        def Lookup(self, i):
            return self.E[i] + self.EB
        
    class RNN:
        def __init__(self, data):
            self.Ws  = data["W_0_dec_initializer_0"]
            self.WsB = data["b_0_dec_initializer_0"].T

            self.W   = data["W_0_dec_input_embdr_0"]
            self.B   = data["b_0_dec_input_embdr_0"].T
            self.U   = data["W_dec_transition_0"]
            self.C   = data["W_0_dec_dec_inputter_0"]

            self.Wz  = data["W_0_dec_update_embdr_0"]
            self.Uz  = data["G_dec_transition_0"]
            self.Cz  = data["W_0_dec_dec_updater_0"]

            self.Wr  = data["W_0_dec_reset_embdr_0"]
            self.Ur  = data["R_dec_transition_0"]
            self.Cr  = data["W_0_dec_dec_reseter_0"]

        def InitializeState(self, sourceContext, batchSize=1):
            H1Backward = sourceContext[0,1000:].T
            S0 = np.tanh(H1Backward.dot(self.Ws) + self.WsB)
            return np.tile(S0, batchSize).reshape(batchSize, 1000)
        
        def GetNextState(self, embd, prevState, context):        
            Zi = logit(embd.dot(self.Wz) + prevState.dot(self.Uz) + context.dot(self.Cz))
            Ri = logit(embd.dot(self.Wr) + prevState.dot(self.Ur) + context.dot(self.Cr))
            Si_= np.tanh(embd.dot(self.W) + self.B
                          + (Ri * prevState).dot(self.U)
                          + context.dot(self.C))
            Si  = (1.0 - Zi) * prevState + Zi * Si_
            return Si
    
    class AlignmentModel:
        def __init__(self, data):
            self.Va  = data["D_dec_transition_0"].T
            self.Wa  = data["B_dec_transition_0"]
            self.Ua  = data["A_dec_transition_0"]
            
        def GetContext(self, sourceContext, prevState):
            a = sourceContext.dot(self.Ua)
            b = prevState.dot(self.Wa)
            c = a.reshape(1, a.shape[0], a.shape[1]) + b.reshape(b.shape[0], 1, b.shape[1])
            Ei = np.tensordot(self.Va, np.tanh(c).T, axes=[[1],[0]])
            Ai = softmax(Ei, ax=1)
            Ai = Ai.reshape(Ai.shape[1],Ai.shape[2])
            Ci = Ai.T.dot(sourceContext)
            return Ci
    
    class DeepSoftMax:
        def __init__(self, data):
            Wo1      = data["W1_dec_deep_softmax"]
            Wo2      = data["W2_dec_deep_softmax"] 
            self.Wo  = Wo1.dot(Wo2)
            self.WoB = data["b_dec_deep_softmax"].T
            self.Uo  = data["W_0_dec_hid_readout_0"]
            self.UoB = data["b_0_dec_hid_readout_0"].T
            self.Vo  = data["W_0_dec_prev_readout_0"]
            self.Co  = data["W_0_dec_repr_readout"]
            
        def GetProbs(self, prevState, prevEmbd, context):
            Ti = prevState.dot(self.Uo) + self.UoB + prevEmbd.dot(self.Vo) + context.dot(self.Co)
            maximum = np.maximum(Ti[:,::2], Ti[:,1::2])
            P = softmax((maximum.dot(self.Wo) + self.WoB).T)
            logP = np.log(P)
            return logP

    def __init__(self, data):
        self.embeddings     = self.Embeddings(data)
        self.rnn            = self.RNN(data)
        self.alignmentModel = self.AlignmentModel(data)
        self.deepSoftMax    = self.DeepSoftMax(data)
    
    def GetScores(self, batch, mask, sourceContext):
        states, probs = [], []
        batchSize, numSteps  = batch.shape
        
        previousState = self.rnn.InitializeState(sourceContext, batchSize)
        previousEmbedding = self.embeddings.Initialize(batchSize)
        
        for i in range(numSteps):
            wordBatch = batch[:,i]

            alignedSourceContext = self.alignmentModel.GetContext(sourceContext, previousState)

            allProbs = self.deepSoftMax.GetProbs(previousState, previousEmbedding, alignedSourceContext)
            
            currentEmbedding = self.embeddings.Lookup(wordBatch)
            currentState = self.rnn.GetNextState(currentEmbedding, previousState, alignedSourceContext)
            
            for column, wordId in enumerate(wordBatch):
                #print(wordId, allProbs[wordId, column])
                probs.append(allProbs[wordId, column]) 
            previousState, previousEmbedding = currentState, currentEmbedding
            
        probs = np.array(probs).reshape(numSteps, batchSize).T * mask
        return np.sum(probs, axis=1), probs #, states

In [10]:
#%%writefile -a bahdanau.py

data = np.load("/home/marcinj/Badania/best_nmt/search_model.npz")
encoder = Encoder(data)
decoder = Decoder(data)

In [17]:
#%%writefile -a bahdanau.py

# "thank you . <eol>"
sourceSentence, mask = batchAndMask([
        np.array([323, 22, 4, 30000])])

# "vielen dank . <eol>"
t1 = np.array([248, 333, 3, 30000])
batch, mask = batchAndMask([t1])


import time
start = time.time()

np.set_printoptions(precision=6, suppress=True)

sourceContext = encoder.GetContext(sourceSentence)
prob, probs = decoder.GetScores(batch, mask, sourceContext)

end = time.time()

print(probs, "\n")
print("Final: ", prob, "\n") 

print("Time: ", np.round(end - start, 4))

[[-0.460252 -0.002574 -0.033437 -0.000117]] 

Final:  [-0.49638] 

Time:  1.5669


# Re-implementation: 
## C++ and CUDA

<img style="margin:auto" src="moses-gpu.png"/>

## A little code

* https://github.com/emjotde/mosesdecoder/blob/nmt4/moses/FF/NMT/common/encoder.h
* https://github.com/emjotde/mosesdecoder/blob/nmt4/moses/FF/NMT/common/decoder.h
* https://github.com/emjotde/mosesdecoder/blob/nmt4/moses/FF/NMT/mblas/matrix.h

# moses.ini entry

    NeuralScoreFeature mode=rescore name=N0
      batch-size=1000 filtered-softmax=0
      devices=3 num-features=2 state-length=5
      model=/work/best_nmt/search_model.npz 
      source-vocab=/work/best_nmt/vocab/en_de.en.txt 
      target-vocab=/work/best_nmt/vocab.en_de.de.txt

## What kb-Mira thinks:

    LexicalReordering0= 0.0384 -0.0081 -0.0010 0.0672 0.0295 0.0369
    OpSequenceModel0= 0.0245 -0.0100 -0.0089 0.0136 -0.0626
    Distortion0= 0.0037
    LM0= 0.0309
    LM1= -0.0153
    LM2= -0.0106
    LM3= -0.0187
    LM4= 0.0263
    LM5= 0.0500
    LM6= 0.0038
    LM7= -0.0057
    WCLM0= 0.0160
    NeuralScoreFeature0= 0.1754 -0.0352
    WordPenalty0= -0.0753
    TranslationModel0= 0.0660 0.0238 0.0043 0.0028 -0.0043 0.0543 -0.0150 0.0221 0.0370 0.0010
    UnknownWordPenalty0= 1

## A different strategy: Stack Rescoring

* Process stack ignoring the NMT Feature Function. 
* Perform recombination and to-size-pruning of stack.
* Use NMT Feature to rescore surviving hypothesis on a per-stack basis.
* Approximate, but allows to play around with many more hypothesis. 
* Hypothesis recombination needs to be switched off during tuning (fine for testing).

## More experiments

### Stand-alone
* 21.5 (Vanilla Moses)
* 25.6 (our Moses setup)
* 25.8 (3-ensemble Groundhog NMT): 

</br>
### Stack-based pre-calculation:
* 27.2 (rescoring a 1000-best list)
* 28.3 (Moses with NMT-FF with stack=30)

</br>
### Stack rescoring:
* 28.9 (stack=2000 cube-pruning-pop-limit=5000 d=12)
* 29.4 (2-ensemble = log-linear combination of two FFs)

## Other languages

* Planning to use the feature for WMT 16
    * EN<->DE, EN<->RU, EN<->CS
    * No results yet
    * Using BPEs now for PBMT system (looks good)
    * Can maybe report some results next week