In [36]:
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
from rouge import Rouge 

In [37]:
tf.enable_eager_execution()

def mmap(fn, elem):
    return list(map(fn, elem))

def getel(n, lst): 
    return mmap(lambda x: x[n], lst)

In [38]:
def load_pretrained_embeddings(embed_matrix, trainable=False):
    return keras.layers.Embedding(embed_matrix.shape[0],
                                  embed_matrix.shape[1],
                                  weights=[embed_matrix],
                                  trainable=trainable)

def get_new_embeddings(voc_len, embedding_dim):
    return keras.layers.Embedding(voc_len, embedding_dim)


In [39]:
class Encoder(keras.Model):
    def __init__(self, units):
        super(Encoder, self).__init__()
        self.units = units
        self.lstm = keras.layers.LSTM(self.units,
                                      return_sequences=True, 
                                      return_state=True)
        
    def call(self, x):
        sequences, state1, state2 = self.lstm(x)
        return sequences, state1, state2
    
class Attention(keras.Model):
    def __init__(self, w_units):
        super().__init__()
        self.W = keras.layers.Dense(w_units, use_bias=False)
        self.W1 = keras.layers.Dense(w_units, use_bias=False)
        self.v = keras.layers.Dense(1)
        self.tanh = keras.activations.tanh
        self.softmax = keras.activations.softmax
        
    def call(self, enc_hidden, dec_hidden):
        dec_hidden = tf.expand_dims(dec_hidden, 1)
        unnorm = self.v(
            self.tanh(self.W(enc_hidden) + self.W1(dec_hidden))
        )

        attention_weights = self.softmax(unnorm, axis=1)
        
        # Compute the context vector used to generate the decoder state
        c_vec = tf.reduce_sum(attention_weights * enc_hidden, axis=1)
        
        # Return the context vector and the pointer logits and the pointer probs
        return c_vec, tf.squeeze(attention_weights, axis=2)  
    
class Decoder(keras.Model):
    def __init__(self, units, output_size):
        super().__init__()
        self.units = units
        self.lstm = keras.layers.LSTM(self.units, return_state=True)
        self.output_layer = keras.layers.Dense(output_size, activation='softmax')
        
    def call(self, x, enc_out, prev_state):
        # Concatenate encoder output (or context vector) and the target/predicted embedding
        concatenated_inp = tf.concat([x, enc_out], axis=1)
        concatenated_inp = tf.expand_dims(concatenated_inp, 1)
        # Compute the hidden state h^d
        d, dec_h, dec_c =  self.lstm(concatenated_inp, initial_state=prev_state)
        
        # Decode using vocabulary
        flattened = tf.layers.flatten(d)
        decoded_probs = self.output_layer(flattened)
        
        # Return Decode hidden states and vocabulary logits over the vocabulary
        return d, dec_h, dec_c, decoded_probs

class PointerSwitch(keras.Model):
    def __init__(self, units):
        super().__init__()
        self.W1 = keras.layers.Dense(units,  use_bias=False)
        self.W2 = keras.layers.Dense(units, use_bias=False)
        self.v = keras.layers.Dense(1)
    
    def call(self, enc, c_vec):
        '''
            Compute switch probabilities from the context vector and
            the encoder last output state
        '''
        switch_prob = tf.keras.activations.sigmoid(
            self.v(
                self.W1(enc) + self.W2(c_vec)
            )
        )
        
        return switch_prob

In [110]:
r = Rouge()

def rouge_score(y, y_):
    true_tok = [' '.join(mmap(str,sent)) for sent in y]
    pred_tok = [' '.join(mmap(str,sent)) for sent in y_]
    scores = r.get_scores(true_tok, pred_tok)

    values = []
    for key in scores[0].keys():
        for sub_metric in ['f', 'p', 'r']:
            mean_score = np.mean(mmap(lambda x: x[key][sub_metric], scores))
            values.append(mean_score)
    return values

In [111]:
class PointerNetwork(keras.Model):
    def __init__(self,
                 enc_units,
                 dec_units, 
                 voc_size,
                 att_units, 
                 switch_units,
                 max_len, 
                 start_token,
                 end_token,
                 padding_char):
        super().__init__()
        self.encoder = Encoder(enc_units)
        self.decoder = Decoder(dec_units, voc_size)
        self.attention = Attention(att_units)
        self.pointer_switch = PointerSwitch(switch_units)
        self.embeddings = False
        self.max_len = max_len
        self.start_token = start_token
        self.end_token = end_token
        self.voc_size = voc_size
        self.padding_char = padding_char
        
        self.optimizer = tf.train.AdamOptimizer()
    
    def set_embeddings_layer(self, embeddings_layer):
        self.embeddings = embeddings_layer
    
    def predict_batch(self, X):
        assert self.embeddings, "Call self.set_embeddings_layer first"
        X = tf.convert_to_tensor(X)
        
        embed = self.embeddings(X)
        enc_states, h1, h2 = self.encoder(embed)
        input_tokens = tf.convert_to_tensor([self.start_token] * embed.shape[0])
        # put last encoder state as attention vec at start
        c_vec = h1
        outputs = []
        
        for _ in range(self.max_len):
            dec_input = self.embeddings(input_tokens)
            decoded_state, h1, h2, decoded_probs = self.decoder(dec_input, 
                                                                c_vec, 
                                                                [h1, h2])
            c_vec, pointer_probs = self.attention(enc_states, 
                                                     decoded_state)
            
            # Compute switch probability to decide where to extract the next
            # word token
            switch_probs = self.pointer_switch(h1, c_vec)
            # Decode based on switch probs            
            input_tokens = self.decode_next_word(switch_probs, 
                                                 decoded_probs,
                                                 X,
                                                 pointer_probs)
            outputs.append(input_tokens)
            
        return tf.transpose(tf.convert_to_tensor(outputs))
    
    def decode_next_word(self, switch_probs, decoded_probs, inputs, att_probs):
        sampled_probs = tf.random.uniform(switch_probs.shape, 0 , 1)
        tokens = []
        token = None

        for prob, sampled, decoded, inp, att_p in zip(switch_probs,
                                                      sampled_probs, 
                                                      decoded_probs,
                                                      inputs,
                                                      att_probs):
            if prob.numpy()[0] >= sampled.numpy()[0]:
                token = self.fixed_vocab_decode(decoded)
            else:
                token = self.pointer_greedy_search(att_p, inp)

            tokens.append(token)
            
        return tf.convert_to_tensor(tokens, dtype=tf.float32)
    
    def pointer_greedy_search(self, probs, inputs):
        return inputs[tf.argmax(probs)]
    
    def fixed_vocab_decode(self, decoded_probs):
        return tf.argmax(decoded_probs)
    
    def pointer_batch_loss(self, gen, y, d_prob, p_prob, s_prob):
        # Compute the mask to ignore the padding in the loss
        mask = 1-tf.cast(tf.equal(gen[:,None],
                                  tf.ones(gen[:,None].shape) * self.padding_char
                                 ), tf.float32)
        
        # Compute pointer loss across all values of y for the pointer and generated probs
        pointer_mat = (p_prob + (1 - s_prob)) * mask
        generator_mat = (d_prob + s_prob) * mask
        
        # Add the expected loss in terms of likelihood
        batch_loss = 0
        for i, g in enumerate(gen):
            # Add if the word was taken from the input
            if g == 0:
                batch_loss += pointer_mat[i, y[i]]
            # Add if the word was generated by the network
            else:
                batch_loss += generator_mat[i, y[i]]
            
        # Reduce to scalar, dont forget to include minus sign (its a loss not a likelihood)
        return -batch_loss
    
    def __train_batch(self, X, y, gen):
        assert self.embeddings, "Call self.load_embeddings first"

        X = tf.convert_to_tensor(X)
        y = tf.convert_to_tensor(y, dtype='int32')
        gen = tf.convert_to_tensor(gen, dtype='float32')

        enc_inp = self.embeddings(X)
        enc_states, h1, h2 = self.encoder(enc_inp)
        c_vec = h1
        input_tokens = y[:,0]
        loss = 0
        for t in range(1, y.shape[1]):
            # Get embeddings
            dec_input = self.embeddings(input_tokens)
            
            # Get decoder output
            decoded_state, h1, h2, decoded_probs = self.decoder(dec_input, c_vec, [h1, h2])
            
            # Get context vector for the next step, and pointer probabilities
            c_vec, pointer_probs = self.attention(enc_states, decoded_state)
            
            # Get switch probability (BS*1)
            switch_probs = self.pointer_switch(h1, c_vec)
                        
            # Is target generated or extracted from the input (BS * 1)
            batch_gen = tf.convert_to_tensor(gen[:, t])
            
            # Compute Pointer Network batch loss at timestep t
            loss += self.pointer_batch_loss(batch_gen, y[:, t], decoded_probs,
                                       pointer_probs, switch_probs)

            # Get next decoder input tokens
            input_tokens = y[:, t]
        
        # Dont forget to divide by summary lenght N, since we lose the /N component n by calling
        # N times softmax cross entropy
        loss = loss / int(y.shape[1]-1)
        self._loss = loss
        return loss
    
    def train_batch(self, X, y, gen):
        self.optimizer.minimize(lambda: self.__train_batch(X, y, gen))   
        return [self._loss]
    
    def evaluate(self, X, y, verbose=0):
        y_ = self.predict_batch(X)
        return rouge_score(y, y_)

In [112]:
enc_units = 128
dec_units = 128
voc_size = 300
att_units = 128 
switch_units = 128
max_len = 200
start_index_token = 0
end_index_token = 1
padding_char = -1
ptr = PointerNetwork(enc_units, 
                     dec_units, 
                     voc_size, 
                     att_units, 
                     switch_units, 
                     max_len, 
                     start_index_token, 
                     end_index_token,
                     padding_char)

ptr.set_embeddings_layer(load_pretrained_embeddings(np.ones((300,300))))

In [113]:
ptr.predict_batch(tf.convert_to_tensor(np.ones((2,10))))

<tf.Tensor: id=4864281, shape=(2, 200), dtype=float32, numpy=
array([[  1.,   1., 265.,   1.,   1., 265.,   1., 143.,   1.,   1., 265.,
        143.,   1., 265.,   1.,   1., 143., 265., 143.,   1.,   1., 265.,
        143.,   1.,   1.,   1.,   1., 143., 265.,   1.,   1., 143.,   1.,
        265.,   1.,   1.,   1.,   1., 143.,   1., 143.,   1., 265., 143.,
          1., 265.,   1., 143.,   1.,   1.,   1.,   1.,   1., 143., 265.,
        143.,   1.,   1.,   1., 143.,   1., 265., 143., 143.,   1.,   1.,
        265.,   1.,   1., 143.,   1., 265., 143., 143.,   1.,   1., 265.,
        143., 265.,   1.,   1., 143.,   1., 265.,   1., 143., 265.,   1.,
        143., 265., 202., 265.,   1., 143., 265., 202., 265.,   1.,   1.,
          1.,   1.,   1.,   1., 143., 265., 143., 143.,   1.,   1., 265.,
        143., 143.,   1., 265.,   1., 143.,   1.,   1., 265.,   1., 143.,
          1., 265.,   1.,   1.,   1., 143.,   1., 265., 143., 143., 265.,
        202.,   1., 265., 143.,   1., 265.,   1., 

In [114]:
X = [np.random.randint(0,300, size=20)] * 32
y = [[0, 1, 22, 44, 87,1]] * 32
gen = np.array([np.random.randint(0,2, size=6) for _ in range(32)])

In [115]:
for _ in range(1):
    print(ptr.train_batch(X, y, gen))

[<tf.Tensor: id=4871389, shape=(), dtype=float32, numpy=-16.064342>]


In [116]:
ptr.predict_batch(X)[0]

<tf.Tensor: id=5131056, shape=(200,), dtype=float32, numpy=
array([149., 192., 227., 243., 243., 168., 274., 274., 274., 274., 274.,
       168., 274., 233., 233., 160., 274., 274., 233., 160., 274., 233.,
       160., 233., 274., 274., 233., 274., 274., 168., 233., 233., 160.,
       160., 274., 233., 274., 233., 160., 274., 233., 160., 233., 160.,
       274., 274., 274., 274., 274., 168., 274., 233., 274., 274., 274.,
       233., 233., 160., 233., 274., 274., 274., 168., 233., 274., 233.,
       274., 233., 160., 274., 274., 274., 274., 168., 274., 233., 233.,
       160., 233., 274., 274., 233., 160., 274., 274., 233., 233., 160.,
       160.,  46., 233., 274., 274., 168., 274., 274., 233., 233., 160.,
       274., 274., 233., 274., 274., 274., 168., 274., 274., 274., 274.,
       274., 233., 274., 274., 233., 274., 274., 233.,  80.,  78., 274.,
        78., 274., 168., 233., 233., 160., 274., 233., 274., 233., 160.,
       233., 160., 232., 233., 233., 274., 274., 274., 168., 233

In [117]:
ptr.evaluate(X, y)

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]