In [2]:
import torch
import torch.nn as nn

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# Get to the Point Summarization with Pointer Generators

2017 [paper](https://arxiv.org/pdf/1704.04368.pdf) ([code](https://arxiv.org/pdf/1704.04368.pdf))which augments the attention sequence to sequence model for text summarization in two key ways:

1. Use of a hybrid pointer-generator network that can copy words from the source text via pointing, which aids accurate reproduction of information, while retaining the ability to produce novel words through the generator
2. Use of coverage which keeps track of what has been summarized to discourage repetition

The complete network can be seen below:

![Pointer Generator Network](./img/pointer_generator_network.png)

## Breaking down the model

The Pointer Generator model builds on top of a simple sequence to sequence model with attention. The tokens of the article, $w_i$ are fed one-by-one into the encoder (a single layer bi-directional LSTM), producing a sequence of _encoder hidden states $h_i$._ 

### The Encoder Module

Below is a PyTorch implementation of this encoder module


In [3]:
class EncoderRNN(nn.Module):
    """Encoder RNN Module."""
    def __init__(self, metadata, hidden_dim, num_layers=1, rnn_cell=nn.LSTM, bidirectional=True):
        """

        :param metadata: Dictionary containing metadata about the inputs 
        :param hidden_dim: Number of hiddens dimensions in the RNN
        :param bidirectional: bool - RNN 
        """
        super(EncoderRNN, self).__init__()
        self.metadata = metadata
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.embedding = nn.Embedding(self.metadata['fields']['src'].vocab.vectors.size(0),
                                      self.metadata['fields']['src'].vocab.vectors.size(1))
        self.embedding.weight.data.copy_(self.metadata['fields']['src'].vocab.vectors)
        emb_dim = self.embedding.weight.size(1)
        self.directional_multiplier = int(bidirectional) + 1
        self.rnn = rnn_cell(emb_dim, hidden_dim, num_layers=self.num_layers, batch_first=True, bidirectional=bidirectional)
        self.W_h = nn.Linear(hidden_dim * self.directional_multiplier,
                             hidden_dim * self.directional_multiplier,
                             bias=False)

    def forward(self, inputs, seq_lens):
        """
        Make a forward pass through the encoder.
        

        :param inputs: Input text. Torch tensor of vocab IDs
        :param seq_lens: Lenght of input sequences 
        :return:
            encoder_outputs: The encoder final hidden for all time steps
            encoder_feature: 
        """
        embeddings = self.embedding(inputs)
        packed = pack_padded_sequence(embeddings, seq_lens, batch_first=True)
        output, hidden = self.rnn(packed)
        encoder_outputs, _ = pad_packed_sequence(output, batch_first=True)
        encoder_outputs = encoder_outputs.contiguous()

        encoder_feature = encoder_outputs.view(-1, self.directional_multiplier * self.hidden_dim)
        encoder_feature = self.W_h(encoder_feature)

        return encoder_outputs, encoder_feature, hidden

### Bahdanau Attention Mechanism

Attention is implemented as in [Bahdanau et al](https://arxiv.org/pdf/1409.0473.pdf)

$$
e_i^t = v^T tanh(W_hh_i + W_ss_i + b_{atttn}) \\
\alpha_t = softmax(e^t)
$$

Where,

* $v, W_h, W_s, b_{attn}$ are all learneable parameters
* $h_i$ is the sequence of encoder hidden states
* $s_t$ is the decoder hidden state
* $t$ corresponds to each step in the decoder

The attention distribution $\alpha^t$ is calculated at each decoder step.

In [4]:
class BahdanauAttention(nn.Module):
    """Bahdanau Attention Mechanism."""
    def __init__(self, hidden_dim, is_coverage=False):
        """

        :param hidden_dim:
        """
        super(BahdanauAttention, self).__init__()
        self.is_coverage = is_coverage
        self.hidden_dim = hidden_dim
        
        if self.is_coverage:
            self.W_c = nn.Linear(1, hidden_dim * 2, bias=False)
        
        self.decode_proj = nn.Linear(self.hidden_dim * 2, self.hidden_dim * 2)
        self.v = nn.Linear(hidden_dim * 2, 1, bias=False)

    def forward(self, s_t_hat, encoder_outputs, encoder_feature, enc_padding_mask, coverage):
        """
        Calculation of the Attention probability distribution as outlined by Bahdanau et al.

        :param s_t_hat: Decoder State
        :param encoder_outputs: Encoder hidden states
        :param encoder_feature: Encoder weight matrix
        :param enc_padding_mask: Encoder Mask
        :param coverage:
        :return:
            c_t: context vector (weighted sum of the encoder hidden states
            attn_dist: the attention distribution over the encoder vocabulary
            coverage: Coverage vector which is the sume of the attention distribution over all previous time steps
        """
        b, t_k, n = list(encoder_outputs.size())

        dec_fea_expanded = self.decode_proj(s_t_hat).unsqueeze(1).expand(b, t_k, n).contiguous()
        dec_fea_expanded = dec_fea_expanded.view(-1, n)  # B * t_k x 2*hidden_dim

        att_features = encoder_feature + dec_fea_expanded # B * t_k x 2*hidden_dim
        if self.is_coverage:
            coverage_input = coverage.view(-1, 1)  # B * t_k x 1
            coverage_feature = self.W_c(coverage_input)  # B * t_k x 2*hidden_dim
            att_features = att_features + coverage_feature
            
        scores = self.v(torch.tanh(att_features)).view(-1, t_k)

        attn_dist = F.softmax(scores, dim=1)*enc_padding_mask.to(dtype=torch.float32) # B x t_k
        normalization_factor = attn_dist.sum(1, keepdim=True)
        attn_dist = attn_dist / normalization_factor

        attn_dist = attn_dist.unsqueeze(1)  # B x 1 x t_k
        c_t = torch.bmm(attn_dist, encoder_outputs)  # B x 1 x n
        c_t = c_t.view(-1, self.hidden_dim * 2)  # B x 2*hidden_dim

        attn_dist = attn_dist.view(-1, t_k)  # B x t_k

        if self.is_coverage:
            coverage = coverage.view(-1, t_k)
            coverage = coverage + attn_dist

        return c_t, attn_dist, coverage


The attention distribution is used to produce a weighted sum of the encoder hidden states, known as the context vector $h_t^*$:

$$
h_t^* = \sum_i \alpha_i^th_i
$$

This context vector which is seen as a fixed size representation of what has been read from the source for this step, is concatenated with the decoder state $s_t$ and fed through two linear layers to produce the vocabulary distribution $P_{vocab}$:

$$
P_{vocab} = softmax(V^{\prime}(V[s_t, h_t^*] + b) + b^{\prime})
$$

where,

* $V^{\prime}, V, b^{\prime}, b$ are learneable parameters. 

$P_{vocab}$ is the probability distribution over all words in the vocab. During training the loss for each timestep t is the negative log likelihood of the target word $w_t^*$ for that timestep:

$$
loss_t = -log P(w_t^*)
$$

and the overall loss for the whole sequence is:

$$
loss = \frac{1}{T}\sum_{t=0}^T loss_t
$$

### Pointer Generator Components

Generation probablity $p_{gen} \in [0,1]$ for timestemp $t$ is calculated from the context vector $h_t^*$, the decoder state $s_t$ and the decoder input $x_t$:

$$
p_{gen} = \sigma(w_{h^*}^th_t^* + w_s^Ts_t + w_x^Tx_t + b_{ptr})
$$

where vectors $w_{h^*}, w_s, w_x, b_{ptr}$ are learneable parameters and $\sigma$ is the sigmoid function. $p_{gen}$ is used as a soft switch to decide between generating a word from $P_{vocab}$ or copying a word from the input sequence by sampling from the attention distribution $\alpha^t$. For each document let the _vocabulary_ denote the union of the vocabulary and all the words appearing in the source document. 

The following distribution can be obtained over the extended vocabulary. 

$$
P(w) = p_{gen}P_{vocab}(w) + (1 - P_{gen})\sum_{i:w_i=w}\alpha_i^t
$$

If $w$ is an out-of-vocabulary (OOV) word than $P_{vocab}$ is always zero. Similarly, if $w$ is not in the source document then $\sum_{i:w_i=w}\alpha_i^t$ is zero. The ability to produce OOV words is one of the primary advantages of Pointer-Generator networks. The loss function is as described above but with respect to the newly defined $P(w)$

### Coverage Mechanism

Repetition is common problem in generative seq2seq models. To solve this problem a coverage model is adapted. To solve this problem a coverage vector, $c^t$ is mainitained which is the sum of the attention distributions over all previous decoder steps:

$$
c^t = \sum_{t^{\prime}=0}^{t-1}\alpha^{t^{\prime}}
$$

Intuitively, $c^t$ is a unnormalized distribution over the source document words which represents the degree of coverage which those words have received so far. Note that $c^0$ is a zero vector as non of the words have received coverage so far. The coverage vector is an extra input to the attention mechaism changing the attention equation to:

$$
e_i^t = v^Ttanh(W_hh_i + W_ss_t + w_cc_i^t + b_{attn}) 
$$

where $w_c$ is a learnable parameter vector of the same length as $v$. 

In [None]:
class DecoderRNN(nn.Module):
    """Decoder Module for Pointer-Generator."""
    def __init__(self, metadata, hidden_dim, attention, vocab_size, rnn_cell=nn.LSTM, pointer_gen=True):
        """

        :param metadata:
        :param emb_dim:
        :param hidden_dim:
        :param rnn_cell:
        """
        super(DecoderRNN, self).__init__()
        self.attention_network = attention
        self.metadata = metadata
        self.pointer_gen = pointer_gen
        self.hidden_dim = hidden_dim
        #self.embedding = self.metadata['EmbeddingModuleDict']['trg']
        self.embedding = nn.Embedding(self.metadata['fields']['src'].vocab.vectors.size(0),
                                      self.metadata['fields']['src'].vocab.vectors.size(1))
        self.embedding.weight.data.copy_(self.metadata['fields']['src'].vocab.vectors)
        emb_dim = self.embedding.weight.size(1)

        self.x_context = nn.Linear(hidden_dim * 2 + emb_dim, emb_dim)

        self.rnn = rnn_cell(emb_dim, hidden_dim, num_layers=1, batch_first=True, bidirectional=False)

        if self.pointer_gen:
            self.p_gen_linear = nn.Linear(self.hidden_dim * 4 + emb_dim, 1)

        #p_vocab
        self.out1 = nn.Linear(self.hidden_dim * 3, self.hidden_dim)
        self.out2 = nn.Linear(self.hidden_dim, vocab_size)

    def forward(self, y_t_1, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask,
                c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, step):
        """


        :param y_t_1: Single token input tensor [batch_size x 1]
        :param s_t_1:
        :param encoder_outputs:
        :param encoder_feature:
        :param enc_padding_mask:
        :param c_t_1:
        :param extra_zeros:
        :param enc_batch_extend_vocab:
        :param coverage:
        :param step:
        :return:
        """

        if not self.training and step == 0:
            h_decoder, c_decoder = s_t_1
            s_t_hat = torch.cat((h_decoder.view(-1, self.hidden_dim),
                                 c_decoder.view(-1, self.hidden_dim)), 1)  # B x 2*hidden_dim
            c_t, _, coverage_next = self.attention_network(s_t_hat, encoder_outputs, encoder_feature,
                                                           enc_padding_mask, coverage)
            coverage = coverage_next

        y_t_1_embd = self.embedding(y_t_1)
        x = self.x_context(torch.cat((c_t_1, y_t_1_embd), 1))
        lstm_out, s_t = self.rnn(x.unsqueeze(1), s_t_1)

        h_decoder, c_decoder = s_t
        s_t_hat = torch.cat((h_decoder.view(-1, self.hidden_dim),
                             c_decoder.view(-1, self.hidden_dim)), 1)  # B x 2*hidden_dim
        c_t, attn_dist, coverage_next = self.attention_network(s_t_hat, encoder_outputs, encoder_feature,
                                                               enc_padding_mask, coverage)

        if self.training or step > 0:
            coverage = coverage_next

        p_gen = None
        if self.pointer_gen:
            p_gen_input = torch.cat((c_t, s_t_hat, x), 1)  # B x (2*2*hidden_dim + emb_dim)
            p_gen = self.p_gen_linear(p_gen_input)
            p_gen = torch.sigmoid(p_gen)

        output = torch.cat((lstm_out.view(-1, self.hidden_dim), c_t), 1)  # B x hidden_dim * 3
        output = self.out1(output)  # B x hidden_dim

        # output = F.relu(output)

        output = self.out2(output)  # B x vocab_size
        vocab_dist = F.softmax(output, dim=1)

        if self.pointer_gen:
            vocab_dist_ = p_gen * vocab_dist
            attn_dist_ = (1 - p_gen) * attn_dist

            if extra_zeros is not None:
                vocab_dist_ = torch.cat([vocab_dist_, extra_zeros], 1)

            final_dist = vocab_dist_.scatter_add(1, enc_batch_extend_vocab, attn_dist_)
        else:
            final_dist = vocab_dist

        return final_dist, s_t, c_t, attn_dist, p_gen, coverage
