In [7]:
import torch
import torch.nn as nn

# Efficient Natural Language Response Suggestion for Smart Reply

Link: [Paper](https://arxiv.org/pdf/1705.00652.pdf)

**Summary:**

This paper describes improvements to the initial smart reply model proposed by Kanaan et al. Improvements are aimed specifically at reducing the computationally complexity of training the smart reply system and reducing latency at inference time. This is achieved by using a feed-forward neural network to learn embeddings and extract a high dot-product between a new message and the set of possible responses ($R$). The authors keep two of the key components from the initial smart reply system, the triggering model and the diversity module based on the EXPANDER algorithm, and instead focus on improving the response selection. 

**Innovations:**

* Using a feedforward network to score responses in place of a generative model to reduce computational cost.
    + N-gram embeddings are used to approximate sequences and captures basic semantic and word ordering information
* Multiple Negatives 
    + Given a batch size of K possible responses each sample in the batch is treated as having K-1 negatives. 
* Hierarchical Quantization
    + Gives further efficiency improvements when searching for the best responses in the candidate space.
    
## Dot Product Model 

THe authors describe a dot product scoring model where $S(x,y)$ is factorized as a dot product between vector $\textbf{h_x}$ that depends only on x and a vector $\textbf{h_y}$ that depends only on 1. This is represented as figure 3 (b) in the paper and is shown below. 

![Dot Product Scoring Model](./www/dot_prod_score.png)

This can be implemented in PyTorch as shown below. The stacks are identical so the model can simply be assigned twice. 

In [8]:
class dotProdModel(nn.Module):
    """Torch Dot Model."""
    
    def __init__(self, hidden_size1, hidden_size2, hidden_size3, 
                 vocab_size, dropout, pretrained=False, weights=None, 
                 emb_dim=None):
        """Initialization."""
        super(dotModel, self).__init__()
        
        if pretrained:
            self.embedding = nn.Embedding(weights.size(0), weights.size(1))
            self.embedding.weight.data.copy_(weights)
            emb_dim = weights.size(1)
        else:
            self.embedding = nn.Embedding(input_dim, emb_dim)
            
        self.linear1 = nn.Linear(emb_dim, hidden_size1)
        self.linear2 = nn.Linear(hidden_size1, hidden_size2)
        self.linear3 = nn.Linear(hidden_size2, hidden_size3)
        
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.Tanh()
        
        
    def forward(self, x):
        """Forward pass."""
        h = torch.sum(self.embedding(x), dim=0)
        h = self.dropout(self.activation(self.linear1(h)))
        h = self.dropout(self.activation(self.linear2(h)))
        h = self.dropout(self.activation(self.linear3(h)))
        
        return h


### Multiple Negatives and the Loss Function

For efficiency, a set of K possible responses is used to approximate $P(y|x)$, one correct response and k-1 random negatives. For simplicity, they use the responses of other examples in a training batch of stochastic gradient descent as negative responses. For a batch size of $K$, there will be $K$ input emails $\textbf{x} = (x_1, ..., x_K)$ and their corresponding responses $\textbf{y} = (y_1, ..., y_K)$. Every reply $y_j$ is effectively treated as a negative candidate for $x_i$ if $i \neq j$. The K-1 negative examples for each $x$ are different at each pass through the data due to shuffling in stochastic gradient decent. The goal of training is to minimize the approximated mean negative log probability of the data. For a single batch this is:

$$
\jmath (\textbf{x}, \textbf{y}, \theta) = -\frac{1}{K}\sum_{i=1}^{K} [S(x_i, y_i) - log \sum_{j=1}^{K} e^{S(x_i, y_j)}]
$$

This is implemented in PyTorch as follows:

In [None]:
class approxMeanNegativeLoss(nn.Module):
    """Loss function."""
    def __init__(self):
        super(approxMeanNegativeLoss, self).__init__()
        
    def forward(self, src_pos, trg_pos, batch_size):
            try:
                assert batch_size == src_pos.size()[0]
            except AssertionError:
                batch_size = src_pos.size()[0]
            S_xi_yi = torch.mm(src_pos, trg_pos.t()).diag()
            log_sum_exp_S = torch.log(torch.sum(torch.exp(torch.mm(src_pos, trg_pos.t())), dim=1))
            return -(((S_xi_yi - log_sum_exp_S).sum()) / batch_size) + 1e-9